Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-07-20 14:55:05

0001 #ifndef HeterogeneousCore_SonicTriton_TritonMemResource
0002 #define HeterogeneousCore_SonicTriton_TritonMemResource
0003 
0004 #include <string>
0005 #include <memory>
0006 
0007 #include "grpc_client.h"
0008 
0009 //forward declaration
0010 template <typename IO>
0011 class TritonData;
0012 
0013 //base class for memory operations
0014 template <typename IO>
0015 class TritonMemResource {
0016 public:
0017   TritonMemResource(TritonData<IO>* data, const std::string& name, size_t size);
0018   virtual ~TritonMemResource() {}
0019   uint8_t* addr() { return addr_; }
0020   size_t size() const { return size_; }
0021   virtual void close() {}
0022   //used for input
0023   virtual void copyInput(const void* values, size_t offset) {}
0024   //used for output
0025   virtual const uint8_t* copyOutput() { return nullptr; }
0026   virtual void set();
0027 
0028 protected:
0029   //member variables
0030   TritonData<IO>* data_;
0031   std::string name_;
0032   size_t size_;
0033   uint8_t* addr_;
0034   bool closed_;
0035 };
0036 
0037 template <typename IO>
0038 class TritonHeapResource : public TritonMemResource<IO> {
0039 public:
0040   TritonHeapResource(TritonData<IO>* data, const std::string& name, size_t size);
0041   ~TritonHeapResource() override {}
0042   void copyInput(const void* values, size_t offset) override {}
0043   const uint8_t* copyOutput() override { return nullptr; }
0044   void set() override {}
0045 };
0046 
0047 template <typename IO>
0048 class TritonCpuShmResource : public TritonMemResource<IO> {
0049 public:
0050   TritonCpuShmResource(TritonData<IO>* data, const std::string& name, size_t size);
0051   ~TritonCpuShmResource() override;
0052   void close() override;
0053   void copyInput(const void* values, size_t offset) override {}
0054   const uint8_t* copyOutput() override { return nullptr; }
0055 };
0056 
0057 using TritonInputHeapResource = TritonHeapResource<triton::client::InferInput>;
0058 using TritonInputCpuShmResource = TritonCpuShmResource<triton::client::InferInput>;
0059 using TritonOutputHeapResource = TritonHeapResource<triton::client::InferRequestedOutput>;
0060 using TritonOutputCpuShmResource = TritonCpuShmResource<triton::client::InferRequestedOutput>;
0061 
0062 //avoid "explicit specialization after instantiation" error
0063 template <>
0064 void TritonInputHeapResource::copyInput(const void* values, size_t offset);
0065 template <>
0066 void TritonInputCpuShmResource::copyInput(const void* values, size_t offset);
0067 template <>
0068 const uint8_t* TritonOutputHeapResource::copyOutput();
0069 template <>
0070 const uint8_t* TritonOutputCpuShmResource::copyOutput();
0071 
0072 #ifdef TRITON_ENABLE_GPU
0073 #include "cuda_runtime_api.h"
0074 
0075 template <typename IO>
0076 class TritonGpuShmResource : public TritonMemResource<IO> {
0077 public:
0078   TritonGpuShmResource(TritonData<IO>* data, const std::string& name, size_t size);
0079   ~TritonGpuShmResource() override;
0080   void close() override;
0081   void copyInput(const void* values, size_t offset) override {}
0082   const uint8_t* copyOutput() override { return nullptr; }
0083 
0084 protected:
0085   int deviceId_;
0086   std::shared_ptr<cudaIpcMemHandle_t> handle_;
0087 };
0088 
0089 using TritonInputGpuShmResource = TritonGpuShmResource<triton::client::InferInput>;
0090 using TritonOutputGpuShmResource = TritonGpuShmResource<triton::client::InferRequestedOutput>;
0091 
0092 //avoid "explicit specialization after instantiation" error
0093 template <>
0094 void TritonInputGpuShmResource::copyInput(const void* values, size_t offset);
0095 template <>
0096 const uint8_t* TritonOutputGpuShmResource::copyOutput();
0097 #endif
0098 
0099 #endif