Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-02-25 02:40:54

0001 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0002 #include "FWCore/ParameterSet/interface/FileInPath.h"
0003 #include "FWCore/ServiceRegistry/interface/Service.h"
0004 #include "FWCore/Utilities/interface/Exception.h"
0005 #include "HeterogeneousCore/SonicTriton/interface/TritonClient.h"
0006 #include "HeterogeneousCore/SonicTriton/interface/TritonException.h"
0007 #include "HeterogeneousCore/SonicTriton/interface/TritonService.h"
0008 #include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"
0009 
0010 #include "grpc_client.h"
0011 #include "grpc_service.pb.h"
0012 
0013 #include <string>
0014 #include <cmath>
0015 #include <exception>
0016 #include <sstream>
0017 #include <utility>
0018 #include <tuple>
0019 
0020 namespace tc = triton::client;
0021 
0022 namespace {
0023   grpc_compression_algorithm getCompressionAlgo(const std::string& name) {
0024     if (name.empty() or name.compare("none") == 0)
0025       return grpc_compression_algorithm::GRPC_COMPRESS_NONE;
0026     else if (name.compare("deflate") == 0)
0027       return grpc_compression_algorithm::GRPC_COMPRESS_DEFLATE;
0028     else if (name.compare("gzip") == 0)
0029       return grpc_compression_algorithm::GRPC_COMPRESS_GZIP;
0030     else
0031       throw cms::Exception("GrpcCompression")
0032           << "Unknown compression algorithm requested: " << name << " (choices: none, deflate, gzip)";
0033   }
0034 }  // namespace

0035 
0036 //based on https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/examples/simple_grpc_async_infer_client.cc

0037 //and https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/perf_client.cc

0038 
0039 TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& debugName)
0040     : SonicClient(params, debugName, "TritonClient"),
0041       verbose_(params.getUntrackedParameter<bool>("verbose")),
0042       useSharedMemory_(params.getUntrackedParameter<bool>("useSharedMemory")),
0043       compressionAlgo_(getCompressionAlgo(params.getUntrackedParameter<std::string>("compression"))),
0044       options_(params.getParameter<std::string>("modelName")) {
0045   //get appropriate server for this model

0046   edm::Service<TritonService> ts;
0047   const auto& server =
0048       ts->serverInfo(options_.model_name_, params.getUntrackedParameter<std::string>("preferredServer"));
0049   serverType_ = server.type;
0050   if (verbose_)
0051     edm::LogInfo(fullDebugName_) << "Using server: " << server.url;
0052   //enforce sync mode for fallback CPU server to avoid contention

0053   //todo: could enforce async mode otherwise (unless mode was specified by user?)

0054   if (serverType_ == TritonServerType::LocalCPU)
0055     setMode(SonicMode::Sync);
0056 
0057   //connect to the server

0058   TRITON_THROW_IF_ERROR(
0059       tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions),
0060       "TritonClient(): unable to create inference context");
0061 
0062   //set options

0063   options_.model_version_ = params.getParameter<std::string>("modelVersion");
0064   //convert seconds to microseconds

0065   options_.client_timeout_ = params.getUntrackedParameter<unsigned>("timeout") * 1e6;
0066 
0067   //config needed for batch size

0068   inference::ModelConfigResponse modelConfigResponse;
0069   TRITON_THROW_IF_ERROR(client_->ModelConfig(&modelConfigResponse, options_.model_name_, options_.model_version_),
0070                         "TritonClient(): unable to get model config");
0071   inference::ModelConfig modelConfig(modelConfigResponse.config());
0072 
0073   //check batch size limitations (after i/o setup)

0074   //triton uses max batch size = 0 to denote a model that does not support batching

0075   //but for models that do support batching, a given event may set batch size 0 to indicate no valid input is present

0076   //so set the local max to 1 and keep track of "no batch" case

0077   maxBatchSize_ = modelConfig.max_batch_size();
0078   noBatch_ = maxBatchSize_ == 0;
0079   maxBatchSize_ = std::max(1u, maxBatchSize_);
0080 
0081   //get model info

0082   inference::ModelMetadataResponse modelMetadata;
0083   TRITON_THROW_IF_ERROR(client_->ModelMetadata(&modelMetadata, options_.model_name_, options_.model_version_),
0084                         "TritonClient(): unable to get model metadata");
0085 
0086   //get input and output (which know their sizes)

0087   const auto& nicInputs = modelMetadata.inputs();
0088   const auto& nicOutputs = modelMetadata.outputs();
0089 
0090   //report all model errors at once

0091   std::stringstream msg;
0092   std::string msg_str;
0093 
0094   //currently no use case is foreseen for a model with zero inputs or outputs

0095   if (nicInputs.empty())
0096     msg << "Model on server appears malformed (zero inputs)\n";
0097 
0098   if (nicOutputs.empty())
0099     msg << "Model on server appears malformed (zero outputs)\n";
0100 
0101   //stop if errors

0102   msg_str = msg.str();
0103   if (!msg_str.empty())
0104     throw cms::Exception("ModelErrors") << msg_str;
0105 
0106   //setup input map

0107   std::stringstream io_msg;
0108   if (verbose_)
0109     io_msg << "Model inputs: "
0110            << "\n";
0111   inputsTriton_.reserve(nicInputs.size());
0112   for (const auto& nicInput : nicInputs) {
0113     const auto& iname = nicInput.name();
0114     auto [curr_itr, success] = input_.emplace(std::piecewise_construct,
0115                                               std::forward_as_tuple(iname),
0116                                               std::forward_as_tuple(iname, nicInput, this, ts->pid()));
0117     auto& curr_input = curr_itr->second;
0118     inputsTriton_.push_back(curr_input.data());
0119     if (verbose_) {
0120       io_msg << "  " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
0121              << " b) : " << triton_utils::printColl(curr_input.shape()) << "\n";
0122     }
0123   }
0124 
0125   //allow selecting only some outputs from server

0126   const auto& v_outputs = params.getUntrackedParameter<std::vector<std::string>>("outputs");
0127   std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
0128 
0129   //setup output map

0130   if (verbose_)
0131     io_msg << "Model outputs: "
0132            << "\n";
0133   outputsTriton_.reserve(nicOutputs.size());
0134   for (const auto& nicOutput : nicOutputs) {
0135     const auto& oname = nicOutput.name();
0136     if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
0137       continue;
0138     auto [curr_itr, success] = output_.emplace(std::piecewise_construct,
0139                                                std::forward_as_tuple(oname),
0140                                                std::forward_as_tuple(oname, nicOutput, this, ts->pid()));
0141     auto& curr_output = curr_itr->second;
0142     outputsTriton_.push_back(curr_output.data());
0143     if (verbose_) {
0144       io_msg << "  " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
0145              << " b) : " << triton_utils::printColl(curr_output.shape()) << "\n";
0146     }
0147     if (!s_outputs.empty())
0148       s_outputs.erase(oname);
0149   }
0150 
0151   //check if any requested outputs were not available

0152   if (!s_outputs.empty())
0153     throw cms::Exception("MissingOutput")
0154         << "Some requested outputs were not available on the server: " << triton_utils::printColl(s_outputs);
0155 
0156   //propagate batch size to inputs and outputs

0157   setBatchSize(1);
0158 
0159   //print model info

0160   std::stringstream model_msg;
0161   if (verbose_) {
0162     model_msg << "Model name: " << options_.model_name_ << "\n"
0163               << "Model version: " << options_.model_version_ << "\n"
0164               << "Model max batch size: " << (noBatch_ ? 0 : maxBatchSize_) << "\n";
0165     edm::LogInfo(fullDebugName_) << model_msg.str() << io_msg.str();
0166   }
0167 }
0168 
0169 TritonClient::~TritonClient() {
0170   //by default: members of this class destroyed before members of base class

0171   //in shared memory case, TritonMemResource (member of TritonData) unregisters from client_ in its destructor

0172   //but input/output objects are member of base class, so destroyed after client_ (member of this class)

0173   //therefore, clear the maps here

0174   input_.clear();
0175   output_.clear();
0176 }
0177 
0178 bool TritonClient::setBatchSize(unsigned bsize) {
0179   if (bsize > maxBatchSize_) {
0180     edm::LogWarning(fullDebugName_) << "Requested batch size " << bsize << " exceeds server-specified max batch size "
0181                                     << maxBatchSize_ << ". Batch size will remain as" << batchSize_;
0182     return false;
0183   } else {
0184     batchSize_ = bsize;
0185     //set for input and output

0186     for (auto& element : input_) {
0187       element.second.setBatchSize(bsize);
0188     }
0189     for (auto& element : output_) {
0190       element.second.setBatchSize(bsize);
0191     }
0192     return true;
0193   }
0194 }
0195 
0196 void TritonClient::reset() {
0197   for (auto& element : input_) {
0198     element.second.reset();
0199   }
0200   for (auto& element : output_) {
0201     element.second.reset();
0202   }
0203 }
0204 
0205 template <typename F>
0206 bool TritonClient::handle_exception(F&& call) {
0207   //caught exceptions will be propagated to edm::WaitingTaskWithArenaHolder

0208   CMS_SA_ALLOW try {
0209     call();
0210     return true;
0211   }
0212   //TritonExceptions are intended/expected to be recoverable, i.e. retries should be allowed

0213   catch (TritonException& e) {
0214     e.convertToWarning();
0215     finish(false);
0216     return false;
0217   }
0218   //other exceptions are not: execution should stop if they are encountered

0219   catch (...) {
0220     finish(false, std::current_exception());
0221     return false;
0222   }
0223 }
0224 
0225 void TritonClient::getResults(std::shared_ptr<tc::InferResult> results) {
0226   for (auto& [oname, output] : output_) {
0227     //set shape here before output becomes const

0228     if (output.variableDims()) {
0229       std::vector<int64_t> tmp_shape;
0230       TRITON_THROW_IF_ERROR(results->Shape(oname, &tmp_shape), "getResults(): unable to get output shape for " + oname);
0231       if (!noBatch_)
0232         tmp_shape.erase(tmp_shape.begin());
0233       output.setShape(tmp_shape);
0234       output.computeSizes();
0235     }
0236     //extend lifetime

0237     output.setResult(results);
0238   }
0239 }
0240 
0241 //default case for sync and pseudo async

0242 void TritonClient::evaluate() {
0243   //in case there is nothing to process

0244   if (batchSize_ == 0) {
0245     finish(true);
0246     return;
0247   }
0248 
0249   //set up shared memory for output

0250   auto success = handle_exception([&]() {
0251     for (auto& element : output_) {
0252       element.second.prepare();
0253     }
0254   });
0255   if (!success)
0256     return;
0257 
0258   // Get the status of the server prior to the request being made.

0259   inference::ModelStatistics start_status;
0260   success = handle_exception([&]() {
0261     if (verbose())
0262       start_status = getServerSideStatus();
0263   });
0264   if (!success)
0265     return;
0266 
0267   if (mode_ == SonicMode::Async) {
0268     //non-blocking call

0269     success = handle_exception([&]() {
0270       TRITON_THROW_IF_ERROR(
0271           client_->AsyncInfer(
0272               [start_status, this](tc::InferResult* results) {
0273                 //get results

0274                 std::shared_ptr<tc::InferResult> results_ptr(results);
0275                 auto success = handle_exception(
0276                     [&]() { TRITON_THROW_IF_ERROR(results_ptr->RequestStatus(), "evaluate(): unable to get result"); });
0277                 if (!success)
0278                   return;
0279 
0280                 if (verbose()) {
0281                   inference::ModelStatistics end_status;
0282                   success = handle_exception([&]() { end_status = getServerSideStatus(); });
0283                   if (!success)
0284                     return;
0285 
0286                   const auto& stats = summarizeServerStats(start_status, end_status);
0287                   reportServerSideStats(stats);
0288                 }
0289 
0290                 //check result

0291                 success = handle_exception([&]() { getResults(results_ptr); });
0292                 if (!success)
0293                   return;
0294 
0295                 //finish

0296                 finish(true);
0297               },
0298               options_,
0299               inputsTriton_,
0300               outputsTriton_,
0301               headers_,
0302               compressionAlgo_),
0303           "evaluate(): unable to launch async run");
0304     });
0305     if (!success)
0306       return;
0307   } else {
0308     //blocking call

0309     tc::InferResult* results;
0310     success = handle_exception([&]() {
0311       TRITON_THROW_IF_ERROR(
0312           client_->Infer(&results, options_, inputsTriton_, outputsTriton_, headers_, compressionAlgo_),
0313           "evaluate(): unable to run and/or get result");
0314     });
0315     if (!success)
0316       return;
0317 
0318     if (verbose()) {
0319       inference::ModelStatistics end_status;
0320       success = handle_exception([&]() { end_status = getServerSideStatus(); });
0321       if (!success)
0322         return;
0323 
0324       const auto& stats = summarizeServerStats(start_status, end_status);
0325       reportServerSideStats(stats);
0326     }
0327 
0328     std::shared_ptr<tc::InferResult> results_ptr(results);
0329     success = handle_exception([&]() { getResults(results_ptr); });
0330     if (!success)
0331       return;
0332 
0333     finish(true);
0334   }
0335 }
0336 
0337 void TritonClient::reportServerSideStats(const TritonClient::ServerSideStats& stats) const {
0338   std::stringstream msg;
0339 
0340   // https://github.com/triton-inference-server/server/blob/v2.3.0/src/clients/c++/perf_client/inference_profiler.cc

0341   const uint64_t count = stats.success_count_;
0342   msg << "  Inference count: " << stats.inference_count_ << "\n";
0343   msg << "  Execution count: " << stats.execution_count_ << "\n";
0344   msg << "  Successful request count: " << count << "\n";
0345 
0346   if (count > 0) {
0347     auto get_avg_us = [count](uint64_t tval) {
0348       constexpr uint64_t us_to_ns = 1000;
0349       return tval / us_to_ns / count;
0350     };
0351 
0352     const uint64_t cumm_avg_us = get_avg_us(stats.cumm_time_ns_);
0353     const uint64_t queue_avg_us = get_avg_us(stats.queue_time_ns_);
0354     const uint64_t compute_input_avg_us = get_avg_us(stats.compute_input_time_ns_);
0355     const uint64_t compute_infer_avg_us = get_avg_us(stats.compute_infer_time_ns_);
0356     const uint64_t compute_output_avg_us = get_avg_us(stats.compute_output_time_ns_);
0357     const uint64_t compute_avg_us = compute_input_avg_us + compute_infer_avg_us + compute_output_avg_us;
0358     const uint64_t overhead =
0359         (cumm_avg_us > queue_avg_us + compute_avg_us) ? (cumm_avg_us - queue_avg_us - compute_avg_us) : 0;
0360 
0361     msg << "  Avg request latency: " << cumm_avg_us << " usec"
0362         << "\n"
0363         << "  (overhead " << overhead << " usec + "
0364         << "queue " << queue_avg_us << " usec + "
0365         << "compute input " << compute_input_avg_us << " usec + "
0366         << "compute infer " << compute_infer_avg_us << " usec + "
0367         << "compute output " << compute_output_avg_us << " usec)" << std::endl;
0368   }
0369 
0370   if (!debugName_.empty())
0371     edm::LogInfo(fullDebugName_) << msg.str();
0372 }
0373 
0374 TritonClient::ServerSideStats TritonClient::summarizeServerStats(const inference::ModelStatistics& start_status,
0375                                                                  const inference::ModelStatistics& end_status) const {
0376   TritonClient::ServerSideStats server_stats;
0377 
0378   server_stats.inference_count_ = end_status.inference_count() - start_status.inference_count();
0379   server_stats.execution_count_ = end_status.execution_count() - start_status.execution_count();
0380   server_stats.success_count_ =
0381       end_status.inference_stats().success().count() - start_status.inference_stats().success().count();
0382   server_stats.cumm_time_ns_ =
0383       end_status.inference_stats().success().ns() - start_status.inference_stats().success().ns();
0384   server_stats.queue_time_ns_ = end_status.inference_stats().queue().ns() - start_status.inference_stats().queue().ns();
0385   server_stats.compute_input_time_ns_ =
0386       end_status.inference_stats().compute_input().ns() - start_status.inference_stats().compute_input().ns();
0387   server_stats.compute_infer_time_ns_ =
0388       end_status.inference_stats().compute_infer().ns() - start_status.inference_stats().compute_infer().ns();
0389   server_stats.compute_output_time_ns_ =
0390       end_status.inference_stats().compute_output().ns() - start_status.inference_stats().compute_output().ns();
0391 
0392   return server_stats;
0393 }
0394 
0395 inference::ModelStatistics TritonClient::getServerSideStatus() const {
0396   if (verbose_) {
0397     inference::ModelStatisticsResponse resp;
0398     TRITON_THROW_IF_ERROR(client_->ModelInferenceStatistics(&resp, options_.model_name_, options_.model_version_),
0399                           "getServerSideStatus(): unable to get model statistics");
0400     return *(resp.model_stats().begin());
0401   }
0402   return inference::ModelStatistics{};
0403 }
0404 
0405 //for fillDescriptions

0406 void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) {
0407   edm::ParameterSetDescription descClient;
0408   fillBasePSetDescription(descClient);
0409   descClient.add<std::string>("modelName");
0410   descClient.add<std::string>("modelVersion", "");
0411   descClient.add<edm::FileInPath>("modelConfigPath");
0412   //server parameters should not affect the physics results

0413   descClient.addUntracked<std::string>("preferredServer", "");
0414   descClient.addUntracked<unsigned>("timeout");
0415   descClient.addUntracked<bool>("useSharedMemory", true);
0416   descClient.addUntracked<std::string>("compression", "");
0417   descClient.addUntracked<std::vector<std::string>>("outputs", {});
0418   iDesc.add<edm::ParameterSetDescription>("Client", descClient);
0419 }