Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:47

0001 #ifndef HeterogeneousCore_SonicTriton_TritonClient
0002 #define HeterogeneousCore_SonicTriton_TritonClient
0003 
0004 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0005 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0006 #include "HeterogeneousCore/SonicCore/interface/SonicClient.h"
0007 #include "HeterogeneousCore/SonicTriton/interface/TritonData.h"
0008 #include "HeterogeneousCore/SonicTriton/interface/TritonService.h"
0009 
0010 #include <map>
0011 #include <vector>
0012 #include <string>
0013 #include <exception>
0014 #include <unordered_map>
0015 
0016 #include "grpc_client.h"
0017 #include "grpc_service.pb.h"
0018 
0019 enum class TritonBatchMode { Rectangular = 1, Ragged = 2 };
0020 
0021 class TritonClient : public SonicClient<TritonInputMap, TritonOutputMap> {
0022 public:
0023   struct ServerSideStats {
0024     uint64_t inference_count_;
0025     uint64_t execution_count_;
0026     uint64_t success_count_;
0027     uint64_t cumm_time_ns_;
0028     uint64_t queue_time_ns_;
0029     uint64_t compute_input_time_ns_;
0030     uint64_t compute_infer_time_ns_;
0031     uint64_t compute_output_time_ns_;
0032   };
0033 
0034   //constructor

0035   TritonClient(const edm::ParameterSet& params, const std::string& debugName);
0036 
0037   //destructor

0038   ~TritonClient() override;
0039 
0040   //accessors

0041   unsigned batchSize() const;
0042   TritonBatchMode batchMode() const { return batchMode_; }
0043   bool verbose() const { return verbose_; }
0044   bool useSharedMemory() const { return useSharedMemory_; }
0045   void setUseSharedMemory(bool useShm) { useSharedMemory_ = useShm; }
0046   bool setBatchSize(unsigned bsize);
0047   void setBatchMode(TritonBatchMode batchMode);
0048   void resetBatchMode();
0049   void reset() override;
0050   TritonServerType serverType() const { return serverType_; }
0051   bool isLocal() const { return isLocal_; }
0052 
0053   //for fillDescriptions

0054   static void fillPSetDescription(edm::ParameterSetDescription& iDesc);
0055 
0056 protected:
0057   //helpers

0058   bool noOuterDim() const { return noOuterDim_; }
0059   unsigned outerDim() const { return outerDim_; }
0060   unsigned nEntries() const;
0061   void getResults(const std::vector<std::shared_ptr<triton::client::InferResult>>& results);
0062   void evaluate() override;
0063   template <typename F>
0064   bool handle_exception(F&& call);
0065 
0066   void reportServerSideStats(const ServerSideStats& stats) const;
0067   ServerSideStats summarizeServerStats(const inference::ModelStatistics& start_status,
0068                                        const inference::ModelStatistics& end_status) const;
0069 
0070   inference::ModelStatistics getServerSideStatus() const;
0071 
0072   //members

0073   unsigned maxOuterDim_;
0074   unsigned outerDim_;
0075   bool noOuterDim_;
0076   unsigned nEntries_;
0077   TritonBatchMode batchMode_;
0078   bool manualBatchMode_;
0079   bool verbose_;
0080   bool useSharedMemory_;
0081   TritonServerType serverType_;
0082   bool isLocal_;
0083   grpc_compression_algorithm compressionAlgo_;
0084   triton::client::Headers headers_;
0085 
0086   std::unique_ptr<triton::client::InferenceServerGrpcClient> client_;
0087   //stores timeout, model name and version

0088   std::vector<triton::client::InferOptions> options_;
0089 
0090 private:
0091   friend TritonInputData;
0092   friend TritonOutputData;
0093 
0094   //private accessors only used by data

0095   auto client() { return client_.get(); }
0096   void addEntry(unsigned entry);
0097   void resizeEntries(unsigned entry);
0098 };
0099 
0100 #endif