Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-11-22 02:45:11

0001 #ifndef HeterogeneousCore_SonicTriton_TritonClient
0002 #define HeterogeneousCore_SonicTriton_TritonClient
0003 
0004 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0005 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0006 #include "FWCore/ServiceRegistry/interface/ServiceToken.h"
0007 #include "HeterogeneousCore/SonicCore/interface/SonicClient.h"
0008 #include "HeterogeneousCore/SonicTriton/interface/TritonData.h"
0009 #include "HeterogeneousCore/SonicTriton/interface/TritonService.h"
0010 
0011 #include <map>
0012 #include <vector>
0013 #include <string>
0014 #include <exception>
0015 #include <unordered_map>
0016 
0017 #include "grpc_client.h"
0018 #include "grpc_service.pb.h"
0019 
0020 enum class TritonBatchMode { Rectangular = 1, Ragged = 2 };
0021 
0022 class TritonClient : public SonicClient<TritonInputMap, TritonOutputMap> {
0023 public:
0024   struct ServerSideStats {
0025     uint64_t inference_count_;
0026     uint64_t execution_count_;
0027     uint64_t success_count_;
0028     uint64_t cumm_time_ns_;
0029     uint64_t queue_time_ns_;
0030     uint64_t compute_input_time_ns_;
0031     uint64_t compute_infer_time_ns_;
0032     uint64_t compute_output_time_ns_;
0033   };
0034 
0035   //constructor

0036   TritonClient(const edm::ParameterSet& params, const std::string& debugName);
0037 
0038   //destructor

0039   ~TritonClient() override;
0040 
0041   //accessors

0042   unsigned batchSize() const;
0043   TritonBatchMode batchMode() const { return batchMode_; }
0044   bool verbose() const { return verbose_; }
0045   bool useSharedMemory() const { return useSharedMemory_; }
0046   void setUseSharedMemory(bool useShm) { useSharedMemory_ = useShm; }
0047   bool setBatchSize(unsigned bsize);
0048   void setBatchMode(TritonBatchMode batchMode);
0049   void resetBatchMode();
0050   void reset() override;
0051   TritonServerType serverType() const { return serverType_; }
0052   bool isLocal() const { return isLocal_; }
0053 
0054   //for fillDescriptions

0055   static void fillPSetDescription(edm::ParameterSetDescription& iDesc);
0056 
0057 protected:
0058   //helpers

0059   bool noOuterDim() const { return noOuterDim_; }
0060   unsigned outerDim() const { return outerDim_; }
0061   unsigned nEntries() const;
0062   void getResults(const std::vector<std::shared_ptr<triton::client::InferResult>>& results);
0063   void evaluate() override;
0064   template <typename F>
0065   bool handle_exception(F&& call);
0066 
0067   void reportServerSideStats(const ServerSideStats& stats) const;
0068   ServerSideStats summarizeServerStats(const inference::ModelStatistics& start_status,
0069                                        const inference::ModelStatistics& end_status) const;
0070 
0071   inference::ModelStatistics getServerSideStatus() const;
0072 
0073   //members

0074   unsigned maxOuterDim_;
0075   unsigned outerDim_;
0076   bool noOuterDim_;
0077   unsigned nEntries_;
0078   TritonBatchMode batchMode_;
0079   bool manualBatchMode_;
0080   bool verbose_;
0081   bool useSharedMemory_;
0082   TritonServerType serverType_;
0083   bool isLocal_;
0084   grpc_compression_algorithm compressionAlgo_;
0085   triton::client::Headers headers_;
0086 
0087   std::unique_ptr<triton::client::InferenceServerGrpcClient> client_;
0088   //stores timeout, model name and version

0089   std::vector<triton::client::InferOptions> options_;
0090   edm::ServiceToken token_;
0091 
0092 private:
0093   friend TritonInputData;
0094   friend TritonOutputData;
0095 
0096   //private accessors only used by data

0097   auto client() { return client_.get(); }
0098   void addEntry(unsigned entry);
0099   void resizeEntries(unsigned entry);
0100 };
0101 
0102 #endif