Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-10-25 09:50:29

0001 #ifndef HeterogeneousCore_SonicTriton_TritonMemResource
0002 #define HeterogeneousCore_SonicTriton_TritonMemResource
0003 
0004 #include <string>
0005 #include <memory>
0006 
0007 #include "grpc_client.h"
0008 
0009 //forward declaration
0010 template <typename IO>
0011 class TritonData;
0012 
0013 //base class for memory operations
0014 template <typename IO>
0015 class TritonMemResource {
0016 public:
0017   TritonMemResource(TritonData<IO>* data, const std::string& name, size_t size);
0018   virtual ~TritonMemResource() {}
0019   uint8_t* addr() { return addr_; }
0020   size_t size() const { return size_; }
0021   virtual void close() {}
0022   //used for input
0023   virtual void copyInput(const void* values, size_t offset, unsigned entry) {}
0024   //used for output
0025   virtual void copyOutput() {}
0026   virtual void set();
0027 
0028 protected:
0029   //member variables
0030   TritonData<IO>* data_;
0031   std::string name_;
0032   size_t size_;
0033   uint8_t* addr_;
0034   bool closed_;
0035 };
0036 
0037 template <typename IO>
0038 class TritonHeapResource : public TritonMemResource<IO> {
0039 public:
0040   TritonHeapResource(TritonData<IO>* data, const std::string& name, size_t size);
0041   ~TritonHeapResource() override {}
0042   void copyInput(const void* values, size_t offset, unsigned entry) override {}
0043   void copyOutput() override {}
0044   void set() override {}
0045 };
0046 
0047 template <typename IO>
0048 class TritonCpuShmResource : public TritonMemResource<IO> {
0049 public:
0050   TritonCpuShmResource(TritonData<IO>* data, const std::string& name, size_t size);
0051   ~TritonCpuShmResource() override;
0052   void close() override;
0053   void copyInput(const void* values, size_t offset, unsigned entry) override {}
0054   void copyOutput() override {}
0055 
0056 protected:
0057   size_t sizeOrig_;
0058 };
0059 
0060 using TritonInputHeapResource = TritonHeapResource<triton::client::InferInput>;
0061 using TritonInputCpuShmResource = TritonCpuShmResource<triton::client::InferInput>;
0062 using TritonOutputHeapResource = TritonHeapResource<triton::client::InferRequestedOutput>;
0063 using TritonOutputCpuShmResource = TritonCpuShmResource<triton::client::InferRequestedOutput>;
0064 
0065 //avoid "explicit specialization after instantiation" error
0066 template <>
0067 void TritonInputHeapResource::copyInput(const void* values, size_t offset, unsigned entry);
0068 template <>
0069 void TritonInputCpuShmResource::copyInput(const void* values, size_t offset, unsigned entry);
0070 template <>
0071 void TritonOutputHeapResource::copyOutput();
0072 template <>
0073 void TritonOutputCpuShmResource::copyOutput();
0074 
0075 #ifdef TRITON_ENABLE_GPU
0076 #include "cuda_runtime_api.h"
0077 
0078 template <typename IO>
0079 class TritonGpuShmResource : public TritonMemResource<IO> {
0080 public:
0081   TritonGpuShmResource(TritonData<IO>* data, const std::string& name, size_t size);
0082   ~TritonGpuShmResource() override;
0083   void close() override;
0084   void copyInput(const void* values, size_t offset, unsigned entry) override {}
0085   void copyOutput() override {}
0086 
0087 protected:
0088   int deviceId_;
0089   std::shared_ptr<cudaIpcMemHandle_t> handle_;
0090 };
0091 
0092 using TritonInputGpuShmResource = TritonGpuShmResource<triton::client::InferInput>;
0093 using TritonOutputGpuShmResource = TritonGpuShmResource<triton::client::InferRequestedOutput>;
0094 
0095 //avoid "explicit specialization after instantiation" error
0096 template <>
0097 void TritonInputGpuShmResource::copyInput(const void* values, size_t offset, unsigned entry);
0098 template <>
0099 void TritonOutputGpuShmResource::copyOutput();
0100 #endif
0101 
0102 #endif