File indexing completed on 2024-11-22 02:45:11
0001 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0002 #include "FWCore/ParameterSet/interface/FileInPath.h"
0003 #include "FWCore/ParameterSet/interface/allowedValues.h"
0004 #include "FWCore/ServiceRegistry/interface/Service.h"
0005 #include "FWCore/Utilities/interface/Exception.h"
0006 #include "HeterogeneousCore/SonicTriton/interface/TritonClient.h"
0007 #include "HeterogeneousCore/SonicTriton/interface/TritonException.h"
0008 #include "HeterogeneousCore/SonicTriton/interface/TritonService.h"
0009 #include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"
0010
0011 #include "grpc_client.h"
0012 #include "grpc_service.pb.h"
0013 #include "model_config.pb.h"
0014
0015 #include "google/protobuf/text_format.h"
0016 #include "google/protobuf/io/zero_copy_stream_impl.h"
0017
0018 #include <algorithm>
0019 #include <cmath>
0020 #include <exception>
0021 #include <experimental/iterator>
0022 #include <fcntl.h>
0023 #include <sstream>
0024 #include <string>
0025 #include <utility>
0026 #include <tuple>
0027
0028 namespace tc = triton::client;
0029
0030 namespace {
0031 grpc_compression_algorithm getCompressionAlgo(const std::string& name) {
0032 if (name.empty() or name.compare("none") == 0)
0033 return grpc_compression_algorithm::GRPC_COMPRESS_NONE;
0034 else if (name.compare("deflate") == 0)
0035 return grpc_compression_algorithm::GRPC_COMPRESS_DEFLATE;
0036 else if (name.compare("gzip") == 0)
0037 return grpc_compression_algorithm::GRPC_COMPRESS_GZIP;
0038 else
0039 throw cms::Exception("GrpcCompression")
0040 << "Unknown compression algorithm requested: " << name << " (choices: none, deflate, gzip)";
0041 }
0042
0043 std::vector<std::shared_ptr<tc::InferResult>> convertToShared(const std::vector<tc::InferResult*>& tmp) {
0044 std::vector<std::shared_ptr<tc::InferResult>> results;
0045 results.reserve(tmp.size());
0046 std::transform(tmp.begin(), tmp.end(), std::back_inserter(results), [](tc::InferResult* ptr) {
0047 return std::shared_ptr<tc::InferResult>(ptr);
0048 });
0049 return results;
0050 }
0051 }
0052
0053
0054
0055
0056 TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& debugName)
0057 : SonicClient(params, debugName, "TritonClient"),
0058 batchMode_(TritonBatchMode::Rectangular),
0059 manualBatchMode_(false),
0060 verbose_(params.getUntrackedParameter<bool>("verbose")),
0061 useSharedMemory_(params.getUntrackedParameter<bool>("useSharedMemory")),
0062 compressionAlgo_(getCompressionAlgo(params.getUntrackedParameter<std::string>("compression"))) {
0063 options_.emplace_back(params.getParameter<std::string>("modelName"));
0064
0065 edm::Service<TritonService> ts;
0066
0067
0068
0069
0070
0071 token_ = edm::ServiceRegistry::instance().presentToken();
0072
0073 const auto& server =
0074 ts->serverInfo(options_[0].model_name_, params.getUntrackedParameter<std::string>("preferredServer"));
0075 serverType_ = server.type;
0076 edm::LogInfo("TritonDiscovery") << debugName_ << " assigned server: " << server.url;
0077
0078
0079 if (serverType_ == TritonServerType::LocalCPU)
0080 setMode(SonicMode::Sync);
0081 isLocal_ = serverType_ == TritonServerType::LocalCPU or serverType_ == TritonServerType::LocalGPU;
0082
0083
0084 TRITON_THROW_IF_ERROR(
0085 tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions),
0086 "TritonClient(): unable to create inference context",
0087 isLocal_);
0088
0089
0090 options_[0].model_version_ = params.getParameter<std::string>("modelVersion");
0091 options_[0].client_timeout_ = params.getUntrackedParameter<unsigned>("timeout");
0092
0093 const auto& timeoutUnit = params.getUntrackedParameter<std::string>("timeoutUnit");
0094 unsigned conversion = 1;
0095 if (timeoutUnit == "seconds")
0096 conversion = 1e6;
0097 else if (timeoutUnit == "milliseconds")
0098 conversion = 1e3;
0099 else if (timeoutUnit == "microseconds")
0100 conversion = 1;
0101 else
0102 throw cms::Exception("Configuration") << "Unknown timeout unit: " << timeoutUnit;
0103 options_[0].client_timeout_ *= conversion;
0104
0105
0106 inference::ModelConfig localModelConfig;
0107 {
0108 const std::string localModelConfigPath(params.getParameter<edm::FileInPath>("modelConfigPath").fullPath());
0109 int fileDescriptor = open(localModelConfigPath.c_str(), O_RDONLY);
0110 if (fileDescriptor < 0)
0111 throw TritonException("LocalFailure")
0112 << "TritonClient(): unable to open local model config: " << localModelConfigPath;
0113 google::protobuf::io::FileInputStream localModelConfigInput(fileDescriptor);
0114 localModelConfigInput.SetCloseOnDelete(true);
0115 if (!google::protobuf::TextFormat::Parse(&localModelConfigInput, &localModelConfig))
0116 throw TritonException("LocalFailure")
0117 << "TritonClient(): unable to parse local model config: " << localModelConfigPath;
0118 }
0119
0120
0121
0122
0123
0124 maxOuterDim_ = localModelConfig.max_batch_size();
0125 noOuterDim_ = maxOuterDim_ == 0;
0126 maxOuterDim_ = std::max(1u, maxOuterDim_);
0127
0128 setBatchSize(1);
0129
0130
0131 inference::ModelConfigResponse modelConfigResponse;
0132 TRITON_THROW_IF_ERROR(client_->ModelConfig(&modelConfigResponse, options_[0].model_name_, options_[0].model_version_),
0133 "TritonClient(): unable to get model config",
0134 isLocal_);
0135 inference::ModelConfig remoteModelConfig(modelConfigResponse.config());
0136
0137 std::map<std::string, std::array<std::string, 2>> checksums;
0138 size_t fileCounter = 0;
0139 for (const auto& modelConfig : {localModelConfig, remoteModelConfig}) {
0140 const auto& agents = modelConfig.model_repository_agents().agents();
0141 auto agent = std::find_if(agents.begin(), agents.end(), [](auto const& a) { return a.name() == "checksum"; });
0142 if (agent != agents.end()) {
0143 const auto& params = agent->parameters();
0144 for (const auto& [key, val] : params) {
0145
0146 if (key.compare(0, options_[0].model_version_.size() + 1, options_[0].model_version_ + "/") == 0)
0147 checksums[key][fileCounter] = val;
0148 }
0149 }
0150 ++fileCounter;
0151 }
0152 std::vector<std::string> incorrect;
0153 for (const auto& [key, val] : checksums) {
0154 if (checksums[key][0] != checksums[key][1])
0155 incorrect.push_back(key);
0156 }
0157 if (!incorrect.empty())
0158 throw TritonException("ModelVersioning") << "The following files have incorrect checksums on the remote server: "
0159 << triton_utils::printColl(incorrect, ", ");
0160
0161
0162 inference::ModelMetadataResponse modelMetadata;
0163 TRITON_THROW_IF_ERROR(client_->ModelMetadata(&modelMetadata, options_[0].model_name_, options_[0].model_version_),
0164 "TritonClient(): unable to get model metadata",
0165 isLocal_);
0166
0167
0168 const auto& nicInputs = modelMetadata.inputs();
0169 const auto& nicOutputs = modelMetadata.outputs();
0170
0171
0172 std::stringstream msg;
0173 std::string msg_str;
0174
0175
0176 if (nicInputs.empty())
0177 msg << "Model on server appears malformed (zero inputs)\n";
0178
0179 if (nicOutputs.empty())
0180 msg << "Model on server appears malformed (zero outputs)\n";
0181
0182
0183 msg_str = msg.str();
0184 if (!msg_str.empty())
0185 throw cms::Exception("ModelErrors") << msg_str;
0186
0187
0188 std::stringstream io_msg;
0189 if (verbose_)
0190 io_msg << "Model inputs: "
0191 << "\n";
0192 for (const auto& nicInput : nicInputs) {
0193 const auto& iname = nicInput.name();
0194 auto [curr_itr, success] = input_.emplace(std::piecewise_construct,
0195 std::forward_as_tuple(iname),
0196 std::forward_as_tuple(iname, nicInput, this, ts->pid()));
0197 auto& curr_input = curr_itr->second;
0198 if (verbose_) {
0199 io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
0200 << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
0201 }
0202 }
0203
0204
0205 const auto& v_outputs = params.getUntrackedParameter<std::vector<std::string>>("outputs");
0206 std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
0207
0208
0209 if (verbose_)
0210 io_msg << "Model outputs: "
0211 << "\n";
0212 for (const auto& nicOutput : nicOutputs) {
0213 const auto& oname = nicOutput.name();
0214 if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
0215 continue;
0216 auto [curr_itr, success] = output_.emplace(std::piecewise_construct,
0217 std::forward_as_tuple(oname),
0218 std::forward_as_tuple(oname, nicOutput, this, ts->pid()));
0219 auto& curr_output = curr_itr->second;
0220 if (verbose_) {
0221 io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
0222 << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
0223 }
0224 if (!s_outputs.empty())
0225 s_outputs.erase(oname);
0226 }
0227
0228
0229 if (!s_outputs.empty())
0230 throw cms::Exception("MissingOutput")
0231 << "Some requested outputs were not available on the server: " << triton_utils::printColl(s_outputs);
0232
0233
0234 std::stringstream model_msg;
0235 if (verbose_) {
0236 model_msg << "Model name: " << options_[0].model_name_ << "\n"
0237 << "Model version: " << options_[0].model_version_ << "\n"
0238 << "Model max outer dim: " << (noOuterDim_ ? 0 : maxOuterDim_) << "\n";
0239 edm::LogInfo(fullDebugName_) << model_msg.str() << io_msg.str();
0240 }
0241 }
0242
0243 TritonClient::~TritonClient() {
0244
0245
0246
0247
0248 input_.clear();
0249 output_.clear();
0250 }
0251
0252 void TritonClient::setBatchMode(TritonBatchMode batchMode) {
0253 unsigned oldBatchSize = batchSize();
0254 batchMode_ = batchMode;
0255 manualBatchMode_ = true;
0256
0257
0258 setBatchSize(oldBatchSize);
0259 }
0260
0261 void TritonClient::resetBatchMode() {
0262 batchMode_ = TritonBatchMode::Rectangular;
0263 manualBatchMode_ = false;
0264 }
0265
0266 unsigned TritonClient::nEntries() const { return !input_.empty() ? input_.begin()->second.entries_.size() : 0; }
0267
0268 unsigned TritonClient::batchSize() const { return batchMode_ == TritonBatchMode::Rectangular ? outerDim_ : nEntries(); }
0269
0270 bool TritonClient::setBatchSize(unsigned bsize) {
0271 if (batchMode_ == TritonBatchMode::Rectangular) {
0272 if (bsize > maxOuterDim_) {
0273 throw TritonException("LocalFailure")
0274 << "Requested batch size " << bsize << " exceeds server-specified max batch size " << maxOuterDim_ << ".";
0275 return false;
0276 } else {
0277 outerDim_ = bsize;
0278
0279 resizeEntries(std::min(outerDim_, 1u));
0280 return true;
0281 }
0282 } else {
0283 resizeEntries(bsize);
0284 outerDim_ = 1;
0285 return true;
0286 }
0287 }
0288
0289 void TritonClient::resizeEntries(unsigned entry) {
0290 if (entry > nEntries())
0291
0292 addEntry(entry - 1);
0293 else if (entry < nEntries()) {
0294 for (auto& element : input_) {
0295 element.second.entries_.resize(entry);
0296 }
0297 for (auto& element : output_) {
0298 element.second.entries_.resize(entry);
0299 }
0300 }
0301 }
0302
0303 void TritonClient::addEntry(unsigned entry) {
0304 for (auto& element : input_) {
0305 element.second.addEntryImpl(entry);
0306 }
0307 for (auto& element : output_) {
0308 element.second.addEntryImpl(entry);
0309 }
0310 if (entry > 0) {
0311 batchMode_ = TritonBatchMode::Ragged;
0312 outerDim_ = 1;
0313 }
0314 }
0315
0316 void TritonClient::reset() {
0317 if (!manualBatchMode_)
0318 batchMode_ = TritonBatchMode::Rectangular;
0319 for (auto& element : input_) {
0320 element.second.reset();
0321 }
0322 for (auto& element : output_) {
0323 element.second.reset();
0324 }
0325 }
0326
0327 template <typename F>
0328 bool TritonClient::handle_exception(F&& call) {
0329
0330 CMS_SA_ALLOW try {
0331 call();
0332 return true;
0333 }
0334
0335 catch (TritonException& e) {
0336 e.convertToWarning();
0337 finish(false);
0338 return false;
0339 }
0340
0341 catch (...) {
0342 finish(false, std::current_exception());
0343 return false;
0344 }
0345 }
0346
0347 void TritonClient::getResults(const std::vector<std::shared_ptr<tc::InferResult>>& results) {
0348 for (unsigned i = 0; i < results.size(); ++i) {
0349 const auto& result = results[i];
0350 for (auto& [oname, output] : output_) {
0351
0352 if (output.variableDims()) {
0353 std::vector<int64_t> tmp_shape;
0354 TRITON_THROW_IF_ERROR(
0355 result->Shape(oname, &tmp_shape), "getResults(): unable to get output shape for " + oname, false);
0356 if (!noOuterDim_)
0357 tmp_shape.erase(tmp_shape.begin());
0358 output.setShape(tmp_shape, i);
0359 }
0360
0361 output.setResult(result, i);
0362
0363 if (i == results.size() - 1)
0364 output.computeSizes();
0365 }
0366 }
0367 }
0368
0369
0370 void TritonClient::evaluate() {
0371
0372 if (tries_ > 0) {
0373
0374
0375 edm::ServiceRegistry::Operate op(token_);
0376 edm::Service<TritonService> ts;
0377 ts->notifyCallStatus(true);
0378 }
0379
0380
0381 if (batchSize() == 0) {
0382
0383 std::vector<std::shared_ptr<tc::InferResult>> empty_results;
0384 getResults(empty_results);
0385 finish(true);
0386 return;
0387 }
0388
0389
0390
0391 unsigned nEntriesVal = nEntries();
0392 std::vector<std::vector<triton::client::InferInput*>> inputsTriton(nEntriesVal);
0393 for (auto& inputTriton : inputsTriton) {
0394 inputTriton.reserve(input_.size());
0395 }
0396 for (auto& [iname, input] : input_) {
0397 for (unsigned i = 0; i < nEntriesVal; ++i) {
0398 inputsTriton[i].push_back(input.data(i));
0399 }
0400 }
0401
0402
0403 std::vector<std::vector<const triton::client::InferRequestedOutput*>> outputsTriton(nEntriesVal);
0404 for (auto& outputTriton : outputsTriton) {
0405 outputTriton.reserve(output_.size());
0406 }
0407 for (auto& [oname, output] : output_) {
0408 for (unsigned i = 0; i < nEntriesVal; ++i) {
0409 outputsTriton[i].push_back(output.data(i));
0410 }
0411 }
0412
0413
0414 auto success = handle_exception([&]() {
0415 for (auto& element : output_) {
0416 element.second.prepare();
0417 }
0418 });
0419 if (!success)
0420 return;
0421
0422
0423 inference::ModelStatistics start_status;
0424 success = handle_exception([&]() {
0425 if (verbose())
0426 start_status = getServerSideStatus();
0427 });
0428 if (!success)
0429 return;
0430
0431 if (mode_ == SonicMode::Async) {
0432
0433 success = handle_exception([&]() {
0434 TRITON_THROW_IF_ERROR(client_->AsyncInferMulti(
0435 [start_status, this](std::vector<tc::InferResult*> resultsTmp) {
0436
0437 const auto& results = convertToShared(resultsTmp);
0438
0439 for (auto ptr : results) {
0440 auto success = handle_exception([&]() {
0441 TRITON_THROW_IF_ERROR(
0442 ptr->RequestStatus(), "evaluate(): unable to get result(s)", isLocal_);
0443 });
0444 if (!success)
0445 return;
0446 }
0447
0448 if (verbose()) {
0449 inference::ModelStatistics end_status;
0450 auto success = handle_exception([&]() { end_status = getServerSideStatus(); });
0451 if (!success)
0452 return;
0453
0454 const auto& stats = summarizeServerStats(start_status, end_status);
0455 reportServerSideStats(stats);
0456 }
0457
0458
0459 auto success = handle_exception([&]() { getResults(results); });
0460 if (!success)
0461 return;
0462
0463
0464 finish(true);
0465 },
0466 options_,
0467 inputsTriton,
0468 outputsTriton,
0469 headers_,
0470 compressionAlgo_),
0471 "evaluate(): unable to launch async run",
0472 isLocal_);
0473 });
0474 if (!success)
0475 return;
0476 } else {
0477
0478 std::vector<tc::InferResult*> resultsTmp;
0479 success = handle_exception([&]() {
0480 TRITON_THROW_IF_ERROR(
0481 client_->InferMulti(&resultsTmp, options_, inputsTriton, outputsTriton, headers_, compressionAlgo_),
0482 "evaluate(): unable to run and/or get result",
0483 isLocal_);
0484 });
0485
0486 const auto& results = convertToShared(resultsTmp);
0487 if (!success)
0488 return;
0489
0490 if (verbose()) {
0491 inference::ModelStatistics end_status;
0492 success = handle_exception([&]() { end_status = getServerSideStatus(); });
0493 if (!success)
0494 return;
0495
0496 const auto& stats = summarizeServerStats(start_status, end_status);
0497 reportServerSideStats(stats);
0498 }
0499
0500 success = handle_exception([&]() { getResults(results); });
0501 if (!success)
0502 return;
0503
0504 finish(true);
0505 }
0506 }
0507
0508 void TritonClient::reportServerSideStats(const TritonClient::ServerSideStats& stats) const {
0509 std::stringstream msg;
0510
0511
0512 const uint64_t count = stats.success_count_;
0513 msg << " Inference count: " << stats.inference_count_ << "\n";
0514 msg << " Execution count: " << stats.execution_count_ << "\n";
0515 msg << " Successful request count: " << count << "\n";
0516
0517 if (count > 0) {
0518 auto get_avg_us = [count](uint64_t tval) {
0519 constexpr uint64_t us_to_ns = 1000;
0520 return tval / us_to_ns / count;
0521 };
0522
0523 const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
0524 const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
0525 const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
0526 const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
0527 const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
0528 const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
0529 const uint64_t overhead =
0530 (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
0531
0532 msg << " Avg request latency: " << cumm_avg_us << " usec"
0533 << "\n"
0534 << " (overhead " << overhead << " usec + "
0535 << "queue " << queue_avg_us << " usec + "
0536 << "compute input " << compute_input_avg_us << " usec + "
0537 << "compute infer " << compute_infer_avg_us << " usec + "
0538 << "compute output " << compute_output_avg_us << " usec)" << std::endl;
0539 }
0540
0541 if (!debugName_.empty())
0542 edm::LogInfo(fullDebugName_) << msg.str();
0543 }
0544
0545 TritonClient::ServerSideStats TritonClient::summarizeServerStats(const inference::ModelStatistics& start_status,
0546 const inference::ModelStatistics& end_status) const {
0547 TritonClient::ServerSideStats server_stats;
0548
0549 server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
0550 server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
0551 server_stats.success_count_ =
0552 end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
0553 server_stats.cumm_time_ns_ =
0554 end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
0555 server_stats.queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
0556 server_stats.compute_input_time_ns_ =
0557 end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
0558 server_stats.compute_infer_time_ns_ =
0559 end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
0560 server_stats.compute_output_time_ns_ =
0561 end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
0562
0563 return server_stats;
0564 }
0565
0566 inference::ModelStatistics TritonClient::getServerSideStatus() const {
0567 if (verbose_) {
0568 inference::ModelStatisticsResponse resp;
0569 TRITON_THROW_IF_ERROR(client_->ModelInferenceStatistics(&resp, options_[0].model_name_, options_[0].model_version_),
0570 "getServerSideStatus(): unable to get model statistics",
0571 isLocal_);
0572 return *(resp.model_stats().begin());
0573 }
0574 return inference::ModelStatistics{};
0575 }
0576
0577
0578 void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) {
0579 edm::ParameterSetDescription descClient;
0580 fillBasePSetDescription(descClient);
0581 descClient.add<std::string>("modelName");
0582 descClient.add<std::string>("modelVersion", "");
0583 descClient.add<edm::FileInPath>("modelConfigPath");
0584
0585 descClient.addUntracked<std::string>("preferredServer", "");
0586 descClient.addUntracked<unsigned>("timeout");
0587 descClient.ifValue(edm::ParameterDescription<std::string>("timeoutUnit", "seconds", false),
0588 edm::allowedValues<std::string>("seconds", "milliseconds", "microseconds"));
0589 descClient.addUntracked<bool>("useSharedMemory", true);
0590 descClient.addUntracked<std::string>("compression", "");
0591 descClient.addUntracked<std::vector<std::string>>("outputs", {});
0592 iDesc.add<edm::ParameterSetDescription>("Client", descClient);
0593 }