File indexing completed on 2022-02-25 02:40:54
0001 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0002 #include "FWCore/ParameterSet/interface/FileInPath.h"
0003 #include "FWCore/ServiceRegistry/interface/Service.h"
0004 #include "FWCore/Utilities/interface/Exception.h"
0005 #include "HeterogeneousCore/SonicTriton/interface/TritonClient.h"
0006 #include "HeterogeneousCore/SonicTriton/interface/TritonException.h"
0007 #include "HeterogeneousCore/SonicTriton/interface/TritonService.h"
0008 #include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"
0009
0010 #include "grpc_client.h"
0011 #include "grpc_service.pb.h"
0012
0013 #include <string>
0014 #include <cmath>
0015 #include <exception>
0016 #include <sstream>
0017 #include <utility>
0018 #include <tuple>
0019
0020 namespace tc = triton::client;
0021
0022 namespace {
0023 grpc_compression_algorithm getCompressionAlgo(const std::string& name) {
0024 if (name.empty() or name.compare("none") == 0)
0025 return grpc_compression_algorithm::GRPC_COMPRESS_NONE;
0026 else if (name.compare("deflate") == 0)
0027 return grpc_compression_algorithm::GRPC_COMPRESS_DEFLATE;
0028 else if (name.compare("gzip") == 0)
0029 return grpc_compression_algorithm::GRPC_COMPRESS_GZIP;
0030 else
0031 throw cms::Exception("GrpcCompression")
0032 << "Unknown compression algorithm requested: " << name << " (choices: none, deflate, gzip)";
0033 }
0034 }
0035
0036
0037
0038
0039 TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& debugName)
0040 : SonicClient(params, debugName, "TritonClient"),
0041 verbose_(params.getUntrackedParameter<bool>("verbose")),
0042 useSharedMemory_(params.getUntrackedParameter<bool>("useSharedMemory")),
0043 compressionAlgo_(getCompressionAlgo(params.getUntrackedParameter<std::string>("compression"))),
0044 options_(params.getParameter<std::string>("modelName")) {
0045
0046 edm::Service<TritonService> ts;
0047 const auto& server =
0048 ts->serverInfo(options_.model_name_, params.getUntrackedParameter<std::string>("preferredServer"));
0049 serverType_ = server.type;
0050 if (verbose_)
0051 edm::LogInfo(fullDebugName_) << "Using server: " << server.url;
0052
0053
0054 if (serverType_ == TritonServerType::LocalCPU)
0055 setMode(SonicMode::Sync);
0056
0057
0058 TRITON_THROW_IF_ERROR(
0059 tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions),
0060 "TritonClient(): unable to create inference context");
0061
0062
0063 options_.model_version_ = params.getParameter<std::string>("modelVersion");
0064
0065 options_.client_timeout_ = params.getUntrackedParameter<unsigned>("timeout") * 1e6;
0066
0067
0068 inference::ModelConfigResponse modelConfigResponse;
0069 TRITON_THROW_IF_ERROR(client_->ModelConfig(&modelConfigResponse, options_.model_name_, options_.model_version_),
0070 "TritonClient(): unable to get model config");
0071 inference::ModelConfig modelConfig(modelConfigResponse.config());
0072
0073
0074
0075
0076
0077 maxBatchSize_ = modelConfig.max_batch_size();
0078 noBatch_ = maxBatchSize_ == 0;
0079 maxBatchSize_ = std::max(1u, maxBatchSize_);
0080
0081
0082 inference::ModelMetadataResponse modelMetadata;
0083 TRITON_THROW_IF_ERROR(client_->ModelMetadata(&modelMetadata, options_.model_name_, options_.model_version_),
0084 "TritonClient(): unable to get model metadata");
0085
0086
0087 const auto& nicInputs = modelMetadata.inputs();
0088 const auto& nicOutputs = modelMetadata.outputs();
0089
0090
0091 std::stringstream msg;
0092 std::string msg_str;
0093
0094
0095 if (nicInputs.empty())
0096 msg << "Model on server appears malformed (zero inputs)\n";
0097
0098 if (nicOutputs.empty())
0099 msg << "Model on server appears malformed (zero outputs)\n";
0100
0101
0102 msg_str = msg.str();
0103 if (!msg_str.empty())
0104 throw cms::Exception("ModelErrors") << msg_str;
0105
0106
0107 std::stringstream io_msg;
0108 if (verbose_)
0109 io_msg << "Model inputs: "
0110 << "\n";
0111 inputsTriton_.reserve(nicInputs.size());
0112 for (const auto& nicInput : nicInputs) {
0113 const auto& iname = nicInput.name();
0114 auto [curr_itr, success] = input_.emplace(std::piecewise_construct,
0115 std::forward_as_tuple(iname),
0116 std::forward_as_tuple(iname, nicInput, this, ts->pid()));
0117 auto& curr_input = curr_itr->second;
0118 inputsTriton_.push_back(curr_input.data());
0119 if (verbose_) {
0120 io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
0121 << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
0122 }
0123 }
0124
0125
0126 const auto& v_outputs = params.getUntrackedParameter<std::vector<std::string>>("outputs");
0127 std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
0128
0129
0130 if (verbose_)
0131 io_msg << "Model outputs: "
0132 << "\n";
0133 outputsTriton_.reserve(nicOutputs.size());
0134 for (const auto& nicOutput : nicOutputs) {
0135 const auto& oname = nicOutput.name();
0136 if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
0137 continue;
0138 auto [curr_itr, success] = output_.emplace(std::piecewise_construct,
0139 std::forward_as_tuple(oname),
0140 std::forward_as_tuple(oname, nicOutput, this, ts->pid()));
0141 auto& curr_output = curr_itr->second;
0142 outputsTriton_.push_back(curr_output.data());
0143 if (verbose_) {
0144 io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
0145 << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
0146 }
0147 if (!s_outputs.empty())
0148 s_outputs.erase(oname);
0149 }
0150
0151
0152 if (!s_outputs.empty())
0153 throw cms::Exception("MissingOutput")
0154 << "Some requested outputs were not available on the server: " << triton_utils::printColl(s_outputs);
0155
0156
0157 setBatchSize(1);
0158
0159
0160 std::stringstream model_msg;
0161 if (verbose_) {
0162 model_msg << "Model name: " << options_.model_name_ << "\n"
0163 << "Model version: " << options_.model_version_ << "\n"
0164 << "Model max batch size: " << (noBatch_ ? 0 : maxBatchSize_) << "\n";
0165 edm::LogInfo(fullDebugName_) << model_msg.str() << io_msg.str();
0166 }
0167 }
0168
0169 TritonClient::~TritonClient() {
0170
0171
0172
0173
0174 input_.clear();
0175 output_.clear();
0176 }
0177
0178 bool TritonClient::setBatchSize(unsigned bsize) {
0179 if (bsize > maxBatchSize_) {
0180 edm::LogWarning(fullDebugName_) << "Requested batch size " << bsize << " exceeds server-specified max batch size "
0181 << maxBatchSize_ << ". Batch size will remain as" << batchSize_;
0182 return false;
0183 } else {
0184 batchSize_ = bsize;
0185
0186 for (auto& element : input_) {
0187 element.second.setBatchSize(bsize);
0188 }
0189 for (auto& element : output_) {
0190 element.second.setBatchSize(bsize);
0191 }
0192 return true;
0193 }
0194 }
0195
0196 void TritonClient::reset() {
0197 for (auto& element : input_) {
0198 element.second.reset();
0199 }
0200 for (auto& element : output_) {
0201 element.second.reset();
0202 }
0203 }
0204
0205 template <typename F>
0206 bool TritonClient::handle_exception(F&& call) {
0207
0208 CMS_SA_ALLOW try {
0209 call();
0210 return true;
0211 }
0212
0213 catch (TritonException& e) {
0214 e.convertToWarning();
0215 finish(false);
0216 return false;
0217 }
0218
0219 catch (...) {
0220 finish(false, std::current_exception());
0221 return false;
0222 }
0223 }
0224
0225 void TritonClient::getResults(std::shared_ptr<tc::InferResult> results) {
0226 for (auto& [oname, output] : output_) {
0227
0228 if (output.variableDims()) {
0229 std::vector<int64_t> tmp_shape;
0230 TRITON_THROW_IF_ERROR(results->Shape(oname, &tmp_shape), "getResults(): unable to get output shape for " + oname);
0231 if (!noBatch_)
0232 tmp_shape.erase(tmp_shape.begin());
0233 output.setShape(tmp_shape);
0234 output.computeSizes();
0235 }
0236
0237 output.setResult(results);
0238 }
0239 }
0240
0241
0242 void TritonClient::evaluate() {
0243
0244 if (batchSize_ == 0) {
0245 finish(true);
0246 return;
0247 }
0248
0249
0250 auto success = handle_exception([&]() {
0251 for (auto& element : output_) {
0252 element.second.prepare();
0253 }
0254 });
0255 if (!success)
0256 return;
0257
0258
0259 inference::ModelStatistics start_status;
0260 success = handle_exception([&]() {
0261 if (verbose())
0262 start_status = getServerSideStatus();
0263 });
0264 if (!success)
0265 return;
0266
0267 if (mode_ == SonicMode::Async) {
0268
0269 success = handle_exception([&]() {
0270 TRITON_THROW_IF_ERROR(
0271 client_->AsyncInfer(
0272 [start_status, this](tc::InferResult* results) {
0273
0274 std::shared_ptr<tc::InferResult> results_ptr(results);
0275 auto success = handle_exception(
0276 [&]() { TRITON_THROW_IF_ERROR(results_ptr->RequestStatus(), "evaluate(): unable to get result"); });
0277 if (!success)
0278 return;
0279
0280 if (verbose()) {
0281 inference::ModelStatistics end_status;
0282 success = handle_exception([&]() { end_status = getServerSideStatus(); });
0283 if (!success)
0284 return;
0285
0286 const auto& stats = summarizeServerStats(start_status, end_status);
0287 reportServerSideStats(stats);
0288 }
0289
0290
0291 success = handle_exception([&]() { getResults(results_ptr); });
0292 if (!success)
0293 return;
0294
0295
0296 finish(true);
0297 },
0298 options_,
0299 inputsTriton_,
0300 outputsTriton_,
0301 headers_,
0302 compressionAlgo_),
0303 "evaluate(): unable to launch async run");
0304 });
0305 if (!success)
0306 return;
0307 } else {
0308
0309 tc::InferResult* results;
0310 success = handle_exception([&]() {
0311 TRITON_THROW_IF_ERROR(
0312 client_->Infer(&results, options_, inputsTriton_, outputsTriton_, headers_, compressionAlgo_),
0313 "evaluate(): unable to run and/or get result");
0314 });
0315 if (!success)
0316 return;
0317
0318 if (verbose()) {
0319 inference::ModelStatistics end_status;
0320 success = handle_exception([&]() { end_status = getServerSideStatus(); });
0321 if (!success)
0322 return;
0323
0324 const auto& stats = summarizeServerStats(start_status, end_status);
0325 reportServerSideStats(stats);
0326 }
0327
0328 std::shared_ptr<tc::InferResult> results_ptr(results);
0329 success = handle_exception([&]() { getResults(results_ptr); });
0330 if (!success)
0331 return;
0332
0333 finish(true);
0334 }
0335 }
0336
0337 void TritonClient::reportServerSideStats(const TritonClient::ServerSideStats& stats) const {
0338 std::stringstream msg;
0339
0340
0341 const uint64_t count = stats.success_count_;
0342 msg << " Inference count: " << stats.inference_count_ << "\n";
0343 msg << " Execution count: " << stats.execution_count_ << "\n";
0344 msg << " Successful request count: " << count << "\n";
0345
0346 if (count > 0) {
0347 auto get_avg_us = [count](uint64_t tval) {
0348 constexpr uint64_t us_to_ns = 1000;
0349 return tval / us_to_ns / count;
0350 };
0351
0352 const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
0353 const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
0354 const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
0355 const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
0356 const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
0357 const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
0358 const uint64_t overhead =
0359 (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
0360
0361 msg << " Avg request latency: " << cumm_avg_us << " usec"
0362 << "\n"
0363 << " (overhead " << overhead << " usec + "
0364 << "queue " << queue_avg_us << " usec + "
0365 << "compute input " << compute_input_avg_us << " usec + "
0366 << "compute infer " << compute_infer_avg_us << " usec + "
0367 << "compute output " << compute_output_avg_us << " usec)" << std::endl;
0368 }
0369
0370 if (!debugName_.empty())
0371 edm::LogInfo(fullDebugName_) << msg.str();
0372 }
0373
0374 TritonClient::ServerSideStats TritonClient::summarizeServerStats(const inference::ModelStatistics& start_status,
0375 const inference::ModelStatistics& end_status) const {
0376 TritonClient::ServerSideStats server_stats;
0377
0378 server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
0379 server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
0380 server_stats.success_count_ =
0381 end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
0382 server_stats.cumm_time_ns_ =
0383 end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
0384 server_stats.queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
0385 server_stats.compute_input_time_ns_ =
0386 end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
0387 server_stats.compute_infer_time_ns_ =
0388 end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
0389 server_stats.compute_output_time_ns_ =
0390 end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
0391
0392 return server_stats;
0393 }
0394
0395 inference::ModelStatistics TritonClient::getServerSideStatus() const {
0396 if (verbose_) {
0397 inference::ModelStatisticsResponse resp;
0398 TRITON_THROW_IF_ERROR(client_->ModelInferenceStatistics(&resp, options_.model_name_, options_.model_version_),
0399 "getServerSideStatus(): unable to get model statistics");
0400 return *(resp.model_stats().begin());
0401 }
0402 return inference::ModelStatistics{};
0403 }
0404
0405
0406 void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) {
0407 edm::ParameterSetDescription descClient;
0408 fillBasePSetDescription(descClient);
0409 descClient.add<std::string>("modelName");
0410 descClient.add<std::string>("modelVersion", "");
0411 descClient.add<edm::FileInPath>("modelConfigPath");
0412
0413 descClient.addUntracked<std::string>("preferredServer", "");
0414 descClient.addUntracked<unsigned>("timeout");
0415 descClient.addUntracked<bool>("useSharedMemory", true);
0416 descClient.addUntracked<std::string>("compression", "");
0417 descClient.addUntracked<std::vector<std::string>>("outputs", {});
0418 iDesc.add<edm::ParameterSetDescription>("Client", descClient);
0419 }