File indexing completed on 2024-04-06 12:15:48
0001 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0002 #include "FWCore/ParameterSet/interface/FileInPath.h"
0003 #include "FWCore/ParameterSet/interface/allowedValues.h"
0004 #include "FWCore/ServiceRegistry/interface/Service.h"
0005 #include "FWCore/Utilities/interface/Exception.h"
0006 #include "HeterogeneousCore/SonicTriton/interface/TritonClient.h"
0007 #include "HeterogeneousCore/SonicTriton/interface/TritonException.h"
0008 #include "HeterogeneousCore/SonicTriton/interface/TritonService.h"
0009 #include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"
0010
0011 #include "grpc_client.h"
0012 #include "grpc_service.pb.h"
0013 #include "model_config.pb.h"
0014
0015 #include "google/protobuf/text_format.h"
0016 #include "google/protobuf/io/zero_copy_stream_impl.h"
0017
0018 #include <algorithm>
0019 #include <cmath>
0020 #include <exception>
0021 #include <experimental/iterator>
0022 #include <fcntl.h>
0023 #include <sstream>
0024 #include <string>
0025 #include <utility>
0026 #include <tuple>
0027
0028 namespace tc = triton::client;
0029
0030 namespace {
0031 grpc_compression_algorithm getCompressionAlgo(const std::string& name) {
0032 if (name.empty() or name.compare("none") == 0)
0033 return grpc_compression_algorithm::GRPC_COMPRESS_NONE;
0034 else if (name.compare("deflate") == 0)
0035 return grpc_compression_algorithm::GRPC_COMPRESS_DEFLATE;
0036 else if (name.compare("gzip") == 0)
0037 return grpc_compression_algorithm::GRPC_COMPRESS_GZIP;
0038 else
0039 throw cms::Exception("GrpcCompression")
0040 << "Unknown compression algorithm requested: " << name << " (choices: none, deflate, gzip)";
0041 }
0042
0043 std::vector<std::shared_ptr<tc::InferResult>> convertToShared(const std::vector<tc::InferResult*>& tmp) {
0044 std::vector<std::shared_ptr<tc::InferResult>> results;
0045 results.reserve(tmp.size());
0046 std::transform(tmp.begin(), tmp.end(), std::back_inserter(results), [](tc::InferResult* ptr) {
0047 return std::shared_ptr<tc::InferResult>(ptr);
0048 });
0049 return results;
0050 }
0051 }
0052
0053
0054
0055
0056 TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& debugName)
0057 : SonicClient(params, debugName, "TritonClient"),
0058 batchMode_(TritonBatchMode::Rectangular),
0059 manualBatchMode_(false),
0060 verbose_(params.getUntrackedParameter<bool>("verbose")),
0061 useSharedMemory_(params.getUntrackedParameter<bool>("useSharedMemory")),
0062 compressionAlgo_(getCompressionAlgo(params.getUntrackedParameter<std::string>("compression"))) {
0063 options_.emplace_back(params.getParameter<std::string>("modelName"));
0064
0065 edm::Service<TritonService> ts;
0066 const auto& server =
0067 ts->serverInfo(options_[0].model_name_, params.getUntrackedParameter<std::string>("preferredServer"));
0068 serverType_ = server.type;
0069 if (verbose_)
0070 edm::LogInfo(fullDebugName_) << "Using server: " << server.url;
0071
0072
0073 if (serverType_ == TritonServerType::LocalCPU)
0074 setMode(SonicMode::Sync);
0075 isLocal_ = serverType_ == TritonServerType::LocalCPU or serverType_ == TritonServerType::LocalGPU;
0076
0077
0078 TRITON_THROW_IF_ERROR(
0079 tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions),
0080 "TritonClient(): unable to create inference context",
0081 isLocal_);
0082
0083
0084 options_[0].model_version_ = params.getParameter<std::string>("modelVersion");
0085 options_[0].client_timeout_ = params.getUntrackedParameter<unsigned>("timeout");
0086
0087 const auto& timeoutUnit = params.getUntrackedParameter<std::string>("timeoutUnit");
0088 unsigned conversion = 1;
0089 if (timeoutUnit == "seconds")
0090 conversion = 1e6;
0091 else if (timeoutUnit == "milliseconds")
0092 conversion = 1e3;
0093 else if (timeoutUnit == "microseconds")
0094 conversion = 1;
0095 else
0096 throw cms::Exception("Configuration") << "Unknown timeout unit: " << timeoutUnit;
0097 options_[0].client_timeout_ *= conversion;
0098
0099
0100 inference::ModelConfig localModelConfig;
0101 {
0102 const std::string& localModelConfigPath(params.getParameter<edm::FileInPath>("modelConfigPath").fullPath());
0103 int fileDescriptor = open(localModelConfigPath.c_str(), O_RDONLY);
0104 if (fileDescriptor < 0)
0105 throw TritonException("LocalFailure")
0106 << "TritonClient(): unable to open local model config: " << localModelConfigPath;
0107 google::protobuf::io::FileInputStream localModelConfigInput(fileDescriptor);
0108 localModelConfigInput.SetCloseOnDelete(true);
0109 if (!google::protobuf::TextFormat::Parse(&localModelConfigInput, &localModelConfig))
0110 throw TritonException("LocalFailure")
0111 << "TritonClient(): unable to parse local model config: " << localModelConfigPath;
0112 }
0113
0114
0115
0116
0117
0118 maxOuterDim_ = localModelConfig.max_batch_size();
0119 noOuterDim_ = maxOuterDim_ == 0;
0120 maxOuterDim_ = std::max(1u, maxOuterDim_);
0121
0122 setBatchSize(1);
0123
0124
0125 inference::ModelConfigResponse modelConfigResponse;
0126 TRITON_THROW_IF_ERROR(client_->ModelConfig(&modelConfigResponse, options_[0].model_name_, options_[0].model_version_),
0127 "TritonClient(): unable to get model config",
0128 isLocal_);
0129 inference::ModelConfig remoteModelConfig(modelConfigResponse.config());
0130
0131 std::map<std::string, std::array<std::string, 2>> checksums;
0132 size_t fileCounter = 0;
0133 for (const auto& modelConfig : {localModelConfig, remoteModelConfig}) {
0134 const auto& agents = modelConfig.model_repository_agents().agents();
0135 auto agent = std::find_if(agents.begin(), agents.end(), [](auto const& a) { return a.name() == "checksum"; });
0136 if (agent != agents.end()) {
0137 const auto& params = agent->parameters();
0138 for (const auto& [key, val] : params) {
0139
0140 if (key.compare(0, options_[0].model_version_.size() + 1, options_[0].model_version_ + "/") == 0)
0141 checksums[key][fileCounter] = val;
0142 }
0143 }
0144 ++fileCounter;
0145 }
0146 std::vector<std::string> incorrect;
0147 for (const auto& [key, val] : checksums) {
0148 if (checksums[key][0] != checksums[key][1])
0149 incorrect.push_back(key);
0150 }
0151 if (!incorrect.empty())
0152 throw TritonException("ModelVersioning") << "The following files have incorrect checksums on the remote server: "
0153 << triton_utils::printColl(incorrect, ", ");
0154
0155
0156 inference::ModelMetadataResponse modelMetadata;
0157 TRITON_THROW_IF_ERROR(client_->ModelMetadata(&modelMetadata, options_[0].model_name_, options_[0].model_version_),
0158 "TritonClient(): unable to get model metadata",
0159 isLocal_);
0160
0161
0162 const auto& nicInputs = modelMetadata.inputs();
0163 const auto& nicOutputs = modelMetadata.outputs();
0164
0165
0166 std::stringstream msg;
0167 std::string msg_str;
0168
0169
0170 if (nicInputs.empty())
0171 msg << "Model on server appears malformed (zero inputs)\n";
0172
0173 if (nicOutputs.empty())
0174 msg << "Model on server appears malformed (zero outputs)\n";
0175
0176
0177 msg_str = msg.str();
0178 if (!msg_str.empty())
0179 throw cms::Exception("ModelErrors") << msg_str;
0180
0181
0182 std::stringstream io_msg;
0183 if (verbose_)
0184 io_msg << "Model inputs: "
0185 << "\n";
0186 for (const auto& nicInput : nicInputs) {
0187 const auto& iname = nicInput.name();
0188 auto [curr_itr, success] = input_.emplace(std::piecewise_construct,
0189 std::forward_as_tuple(iname),
0190 std::forward_as_tuple(iname, nicInput, this, ts->pid()));
0191 auto& curr_input = curr_itr->second;
0192 if (verbose_) {
0193 io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
0194 << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
0195 }
0196 }
0197
0198
0199 const auto& v_outputs = params.getUntrackedParameter<std::vector<std::string>>("outputs");
0200 std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
0201
0202
0203 if (verbose_)
0204 io_msg << "Model outputs: "
0205 << "\n";
0206 for (const auto& nicOutput : nicOutputs) {
0207 const auto& oname = nicOutput.name();
0208 if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
0209 continue;
0210 auto [curr_itr, success] = output_.emplace(std::piecewise_construct,
0211 std::forward_as_tuple(oname),
0212 std::forward_as_tuple(oname, nicOutput, this, ts->pid()));
0213 auto& curr_output = curr_itr->second;
0214 if (verbose_) {
0215 io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
0216 << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
0217 }
0218 if (!s_outputs.empty())
0219 s_outputs.erase(oname);
0220 }
0221
0222
0223 if (!s_outputs.empty())
0224 throw cms::Exception("MissingOutput")
0225 << "Some requested outputs were not available on the server: " << triton_utils::printColl(s_outputs);
0226
0227
0228 std::stringstream model_msg;
0229 if (verbose_) {
0230 model_msg << "Model name: " << options_[0].model_name_ << "\n"
0231 << "Model version: " << options_[0].model_version_ << "\n"
0232 << "Model max outer dim: " << (noOuterDim_ ? 0 : maxOuterDim_) << "\n";
0233 edm::LogInfo(fullDebugName_) << model_msg.str() << io_msg.str();
0234 }
0235 }
0236
0237 TritonClient::~TritonClient() {
0238
0239
0240
0241
0242 input_.clear();
0243 output_.clear();
0244 }
0245
0246 void TritonClient::setBatchMode(TritonBatchMode batchMode) {
0247 unsigned oldBatchSize = batchSize();
0248 batchMode_ = batchMode;
0249 manualBatchMode_ = true;
0250
0251
0252 setBatchSize(oldBatchSize);
0253 }
0254
0255 void TritonClient::resetBatchMode() {
0256 batchMode_ = TritonBatchMode::Rectangular;
0257 manualBatchMode_ = false;
0258 }
0259
0260 unsigned TritonClient::nEntries() const { return !input_.empty() ? input_.begin()->second.entries_.size() : 0; }
0261
0262 unsigned TritonClient::batchSize() const { return batchMode_ == TritonBatchMode::Rectangular ? outerDim_ : nEntries(); }
0263
0264 bool TritonClient::setBatchSize(unsigned bsize) {
0265 if (batchMode_ == TritonBatchMode::Rectangular) {
0266 if (bsize > maxOuterDim_) {
0267 edm::LogWarning(fullDebugName_) << "Requested batch size " << bsize << " exceeds server-specified max batch size "
0268 << maxOuterDim_ << ". Batch size will remain as " << outerDim_;
0269 return false;
0270 } else {
0271 outerDim_ = bsize;
0272
0273 resizeEntries(std::min(outerDim_, 1u));
0274 return true;
0275 }
0276 } else {
0277 resizeEntries(bsize);
0278 outerDim_ = 1;
0279 return true;
0280 }
0281 }
0282
0283 void TritonClient::resizeEntries(unsigned entry) {
0284 if (entry > nEntries())
0285
0286 addEntry(entry - 1);
0287 else if (entry < nEntries()) {
0288 for (auto& element : input_) {
0289 element.second.entries_.resize(entry);
0290 }
0291 for (auto& element : output_) {
0292 element.second.entries_.resize(entry);
0293 }
0294 }
0295 }
0296
0297 void TritonClient::addEntry(unsigned entry) {
0298 for (auto& element : input_) {
0299 element.second.addEntryImpl(entry);
0300 }
0301 for (auto& element : output_) {
0302 element.second.addEntryImpl(entry);
0303 }
0304 if (entry > 0) {
0305 batchMode_ = TritonBatchMode::Ragged;
0306 outerDim_ = 1;
0307 }
0308 }
0309
0310 void TritonClient::reset() {
0311 if (!manualBatchMode_)
0312 batchMode_ = TritonBatchMode::Rectangular;
0313 for (auto& element : input_) {
0314 element.second.reset();
0315 }
0316 for (auto& element : output_) {
0317 element.second.reset();
0318 }
0319 }
0320
0321 template <typename F>
0322 bool TritonClient::handle_exception(F&& call) {
0323
0324 CMS_SA_ALLOW try {
0325 call();
0326 return true;
0327 }
0328
0329 catch (TritonException& e) {
0330 e.convertToWarning();
0331 finish(false);
0332 return false;
0333 }
0334
0335 catch (...) {
0336 finish(false, std::current_exception());
0337 return false;
0338 }
0339 }
0340
0341 void TritonClient::getResults(const std::vector<std::shared_ptr<tc::InferResult>>& results) {
0342 for (unsigned i = 0; i < results.size(); ++i) {
0343 const auto& result = results[i];
0344 for (auto& [oname, output] : output_) {
0345
0346 if (output.variableDims()) {
0347 std::vector<int64_t> tmp_shape;
0348 TRITON_THROW_IF_ERROR(
0349 result->Shape(oname, &tmp_shape), "getResults(): unable to get output shape for " + oname, false);
0350 if (!noOuterDim_)
0351 tmp_shape.erase(tmp_shape.begin());
0352 output.setShape(tmp_shape, i);
0353 }
0354
0355 output.setResult(result, i);
0356
0357 if (i == results.size() - 1)
0358 output.computeSizes();
0359 }
0360 }
0361 }
0362
0363
0364 void TritonClient::evaluate() {
0365
0366 if (tries_ > 0) {
0367 edm::Service<TritonService> ts;
0368 ts->notifyCallStatus(true);
0369 }
0370
0371
0372 if (batchSize() == 0) {
0373
0374 std::vector<std::shared_ptr<tc::InferResult>> empty_results;
0375 getResults(empty_results);
0376 finish(true);
0377 return;
0378 }
0379
0380
0381
0382 unsigned nEntriesVal = nEntries();
0383 std::vector<std::vector<triton::client::InferInput*>> inputsTriton(nEntriesVal);
0384 for (auto& inputTriton : inputsTriton) {
0385 inputTriton.reserve(input_.size());
0386 }
0387 for (auto& [iname, input] : input_) {
0388 for (unsigned i = 0; i < nEntriesVal; ++i) {
0389 inputsTriton[i].push_back(input.data(i));
0390 }
0391 }
0392
0393
0394 std::vector<std::vector<const triton::client::InferRequestedOutput*>> outputsTriton(nEntriesVal);
0395 for (auto& outputTriton : outputsTriton) {
0396 outputTriton.reserve(output_.size());
0397 }
0398 for (auto& [oname, output] : output_) {
0399 for (unsigned i = 0; i < nEntriesVal; ++i) {
0400 outputsTriton[i].push_back(output.data(i));
0401 }
0402 }
0403
0404
0405 auto success = handle_exception([&]() {
0406 for (auto& element : output_) {
0407 element.second.prepare();
0408 }
0409 });
0410 if (!success)
0411 return;
0412
0413
0414 inference::ModelStatistics start_status;
0415 success = handle_exception([&]() {
0416 if (verbose())
0417 start_status = getServerSideStatus();
0418 });
0419 if (!success)
0420 return;
0421
0422 if (mode_ == SonicMode::Async) {
0423
0424 success = handle_exception([&]() {
0425 TRITON_THROW_IF_ERROR(client_->AsyncInferMulti(
0426 [start_status, this](std::vector<tc::InferResult*> resultsTmp) {
0427
0428 const auto& results = convertToShared(resultsTmp);
0429
0430 for (auto ptr : results) {
0431 auto success = handle_exception([&]() {
0432 TRITON_THROW_IF_ERROR(
0433 ptr->RequestStatus(), "evaluate(): unable to get result(s)", isLocal_);
0434 });
0435 if (!success)
0436 return;
0437 }
0438
0439 if (verbose()) {
0440 inference::ModelStatistics end_status;
0441 auto success = handle_exception([&]() { end_status = getServerSideStatus(); });
0442 if (!success)
0443 return;
0444
0445 const auto& stats = summarizeServerStats(start_status, end_status);
0446 reportServerSideStats(stats);
0447 }
0448
0449
0450 auto success = handle_exception([&]() { getResults(results); });
0451 if (!success)
0452 return;
0453
0454
0455 finish(true);
0456 },
0457 options_,
0458 inputsTriton,
0459 outputsTriton,
0460 headers_,
0461 compressionAlgo_),
0462 "evaluate(): unable to launch async run",
0463 isLocal_);
0464 });
0465 if (!success)
0466 return;
0467 } else {
0468
0469 std::vector<tc::InferResult*> resultsTmp;
0470 success = handle_exception([&]() {
0471 TRITON_THROW_IF_ERROR(
0472 client_->InferMulti(&resultsTmp, options_, inputsTriton, outputsTriton, headers_, compressionAlgo_),
0473 "evaluate(): unable to run and/or get result",
0474 isLocal_);
0475 });
0476
0477 const auto& results = convertToShared(resultsTmp);
0478 if (!success)
0479 return;
0480
0481 if (verbose()) {
0482 inference::ModelStatistics end_status;
0483 success = handle_exception([&]() { end_status = getServerSideStatus(); });
0484 if (!success)
0485 return;
0486
0487 const auto& stats = summarizeServerStats(start_status, end_status);
0488 reportServerSideStats(stats);
0489 }
0490
0491 success = handle_exception([&]() { getResults(results); });
0492 if (!success)
0493 return;
0494
0495 finish(true);
0496 }
0497 }
0498
0499 void TritonClient::reportServerSideStats(const TritonClient::ServerSideStats& stats) const {
0500 std::stringstream msg;
0501
0502
0503 const uint64_t count = stats.success_count_;
0504 msg << " Inference count: " << stats.inference_count_ << "\n";
0505 msg << " Execution count: " << stats.execution_count_ << "\n";
0506 msg << " Successful request count: " << count << "\n";
0507
0508 if (count > 0) {
0509 auto get_avg_us = [count](uint64_t tval) {
0510 constexpr uint64_t us_to_ns = 1000;
0511 return tval / us_to_ns / count;
0512 };
0513
0514 const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
0515 const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
0516 const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
0517 const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
0518 const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
0519 const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
0520 const uint64_t overhead =
0521 (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
0522
0523 msg << " Avg request latency: " << cumm_avg_us << " usec"
0524 << "\n"
0525 << " (overhead " << overhead << " usec + "
0526 << "queue " << queue_avg_us << " usec + "
0527 << "compute input " << compute_input_avg_us << " usec + "
0528 << "compute infer " << compute_infer_avg_us << " usec + "
0529 << "compute output " << compute_output_avg_us << " usec)" << std::endl;
0530 }
0531
0532 if (!debugName_.empty())
0533 edm::LogInfo(fullDebugName_) << msg.str();
0534 }
0535
0536 TritonClient::ServerSideStats TritonClient::summarizeServerStats(const inference::ModelStatistics& start_status,
0537 const inference::ModelStatistics& end_status) const {
0538 TritonClient::ServerSideStats server_stats;
0539
0540 server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
0541 server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
0542 server_stats.success_count_ =
0543 end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
0544 server_stats.cumm_time_ns_ =
0545 end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
0546 server_stats.queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
0547 server_stats.compute_input_time_ns_ =
0548 end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
0549 server_stats.compute_infer_time_ns_ =
0550 end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
0551 server_stats.compute_output_time_ns_ =
0552 end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
0553
0554 return server_stats;
0555 }
0556
0557 inference::ModelStatistics TritonClient::getServerSideStatus() const {
0558 if (verbose_) {
0559 inference::ModelStatisticsResponse resp;
0560 TRITON_THROW_IF_ERROR(client_->ModelInferenceStatistics(&resp, options_[0].model_name_, options_[0].model_version_),
0561 "getServerSideStatus(): unable to get model statistics",
0562 isLocal_);
0563 return *(resp.model_stats().begin());
0564 }
0565 return inference::ModelStatistics{};
0566 }
0567
0568
0569 void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) {
0570 edm::ParameterSetDescription descClient;
0571 fillBasePSetDescription(descClient);
0572 descClient.add<std::string>("modelName");
0573 descClient.add<std::string>("modelVersion", "");
0574 descClient.add<edm::FileInPath>("modelConfigPath");
0575
0576 descClient.addUntracked<std::string>("preferredServer", "");
0577 descClient.addUntracked<unsigned>("timeout");
0578 descClient.ifValue(edm::ParameterDescription<std::string>("timeoutUnit", "seconds", false),
0579 edm::allowedValues<std::string>("seconds", "milliseconds", "microseconds"));
0580 descClient.addUntracked<bool>("useSharedMemory", true);
0581 descClient.addUntracked<std::string>("compression", "");
0582 descClient.addUntracked<std::vector<std::string>>("outputs", {});
0583 iDesc.add<edm::ParameterSetDescription>("Client", descClient);
0584 }