Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-03-17 11:05:50

0001 #ifndef HeterogeneousCore_SonicTriton_TritonClient
0002 #define HeterogeneousCore_SonicTriton_TritonClient
0003 
0004 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0005 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0006 #include "HeterogeneousCore/SonicCore/interface/SonicClient.h"
0007 #include "HeterogeneousCore/SonicTriton/interface/TritonData.h"
0008 #include "HeterogeneousCore/SonicTriton/interface/TritonService.h"
0009 
0010 #include <map>
0011 #include <vector>
0012 #include <string>
0013 #include <exception>
0014 #include <unordered_map>
0015 
0016 #include "grpc_client.h"
0017 #include "grpc_service.pb.h"
0018 
0019 enum class TritonBatchMode { Rectangular = 1, Ragged = 2 };
0020 
0021 class TritonClient : public SonicClient<TritonInputMap, TritonOutputMap> {
0022 public:
0023   struct ServerSideStats {
0024     uint64_t inference_count_;
0025     uint64_t execution_count_;
0026     uint64_t success_count_;
0027     uint64_t cumm_time_ns_;
0028     uint64_t queue_time_ns_;
0029     uint64_t compute_input_time_ns_;
0030     uint64_t compute_infer_time_ns_;
0031     uint64_t compute_output_time_ns_;
0032   };
0033 
0034   //constructor

0035   TritonClient(const edm::ParameterSet& params, const std::string& debugName);
0036 
0037   //destructor

0038   ~TritonClient() override;
0039 
0040   //accessors

0041   unsigned batchSize() const;
0042   TritonBatchMode batchMode() const { return batchMode_; }
0043   bool verbose() const { return verbose_; }
0044   bool useSharedMemory() const { return useSharedMemory_; }
0045   void setUseSharedMemory(bool useShm) { useSharedMemory_ = useShm; }
0046   bool setBatchSize(unsigned bsize);
0047   void setBatchMode(TritonBatchMode batchMode);
0048   void resetBatchMode();
0049   void reset() override;
0050   TritonServerType serverType() const { return serverType_; }
0051 
0052   //for fillDescriptions

0053   static void fillPSetDescription(edm::ParameterSetDescription& iDesc);
0054 
0055 protected:
0056   //helpers

0057   bool noOuterDim() const { return noOuterDim_; }
0058   unsigned outerDim() const { return outerDim_; }
0059   unsigned nEntries() const;
0060   void getResults(const std::vector<std::shared_ptr<triton::client::InferResult>>& results);
0061   void evaluate() override;
0062   template <typename F>
0063   bool handle_exception(F&& call);
0064 
0065   void reportServerSideStats(const ServerSideStats& stats) const;
0066   ServerSideStats summarizeServerStats(const inference::ModelStatistics& start_status,
0067                                        const inference::ModelStatistics& end_status) const;
0068 
0069   inference::ModelStatistics getServerSideStatus() const;
0070 
0071   //members

0072   unsigned maxOuterDim_;
0073   unsigned outerDim_;
0074   bool noOuterDim_;
0075   unsigned nEntries_;
0076   TritonBatchMode batchMode_;
0077   bool manualBatchMode_;
0078   bool verbose_;
0079   bool useSharedMemory_;
0080   TritonServerType serverType_;
0081   grpc_compression_algorithm compressionAlgo_;
0082   triton::client::Headers headers_;
0083 
0084   std::unique_ptr<triton::client::InferenceServerGrpcClient> client_;
0085   //stores timeout, model name and version

0086   std::vector<triton::client::InferOptions> options_;
0087 
0088 private:
0089   friend TritonInputData;
0090   friend TritonOutputData;
0091 
0092   //private accessors only used by data

0093   auto client() { return client_.get(); }
0094   void addEntry(unsigned entry);
0095   void resizeEntries(unsigned entry);
0096 };
0097 
0098 #endif