Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:47

0001 #ifndef HeterogeneousCore_SonicTriton_TritonMemResource
0002 #define HeterogeneousCore_SonicTriton_TritonMemResource
0003 
0004 #include <string>
0005 #include <memory>
0006 
0007 #include "grpc_client.h"
0008 
0009 //forward declaration
0010 template <typename IO>
0011 class TritonData;
0012 
0013 //base class for memory operations
0014 template <typename IO>
0015 class TritonMemResource {
0016 public:
0017   TritonMemResource(TritonData<IO>* data, const std::string& name, size_t size);
0018   virtual ~TritonMemResource() {}
0019   uint8_t* addr() { return addr_; }
0020   size_t size() const { return size_; }
0021   virtual void close() {}
0022   void closeSafe();
0023   //used for input
0024   virtual void copyInput(const void* values, size_t offset, unsigned entry) {}
0025   //used for output
0026   virtual void copyOutput() {}
0027   virtual void set();
0028 
0029 protected:
0030   //member variables
0031   TritonData<IO>* data_;
0032   std::string name_;
0033   size_t size_;
0034   uint8_t* addr_;
0035   bool closed_;
0036 };
0037 
0038 template <typename IO>
0039 class TritonHeapResource : public TritonMemResource<IO> {
0040 public:
0041   TritonHeapResource(TritonData<IO>* data, const std::string& name, size_t size);
0042   ~TritonHeapResource() override {}
0043   void copyInput(const void* values, size_t offset, unsigned entry) override {}
0044   void copyOutput() override {}
0045   void set() override {}
0046 };
0047 
0048 template <typename IO>
0049 class TritonCpuShmResource : public TritonMemResource<IO> {
0050 public:
0051   TritonCpuShmResource(TritonData<IO>* data, const std::string& name, size_t size);
0052   ~TritonCpuShmResource() override;
0053   void close() override;
0054   void copyInput(const void* values, size_t offset, unsigned entry) override {}
0055   void copyOutput() override {}
0056 
0057 protected:
0058   size_t sizeOrig_;
0059 };
0060 
0061 using TritonInputHeapResource = TritonHeapResource<triton::client::InferInput>;
0062 using TritonInputCpuShmResource = TritonCpuShmResource<triton::client::InferInput>;
0063 using TritonOutputHeapResource = TritonHeapResource<triton::client::InferRequestedOutput>;
0064 using TritonOutputCpuShmResource = TritonCpuShmResource<triton::client::InferRequestedOutput>;
0065 
0066 //avoid "explicit specialization after instantiation" error
0067 template <>
0068 void TritonInputHeapResource::copyInput(const void* values, size_t offset, unsigned entry);
0069 template <>
0070 void TritonInputCpuShmResource::copyInput(const void* values, size_t offset, unsigned entry);
0071 template <>
0072 void TritonOutputHeapResource::copyOutput();
0073 template <>
0074 void TritonOutputCpuShmResource::copyOutput();
0075 
0076 #ifdef TRITON_ENABLE_GPU
0077 #include "cuda_runtime_api.h"
0078 
0079 template <typename IO>
0080 class TritonGpuShmResource : public TritonMemResource<IO> {
0081 public:
0082   TritonGpuShmResource(TritonData<IO>* data, const std::string& name, size_t size);
0083   ~TritonGpuShmResource() override;
0084   void close() override;
0085   void copyInput(const void* values, size_t offset, unsigned entry) override {}
0086   void copyOutput() override {}
0087 
0088 protected:
0089   int deviceId_;
0090   std::shared_ptr<cudaIpcMemHandle_t> handle_;
0091 };
0092 
0093 using TritonInputGpuShmResource = TritonGpuShmResource<triton::client::InferInput>;
0094 using TritonOutputGpuShmResource = TritonGpuShmResource<triton::client::InferRequestedOutput>;
0095 
0096 //avoid "explicit specialization after instantiation" error
0097 template <>
0098 void TritonInputGpuShmResource::copyInput(const void* values, size_t offset, unsigned entry);
0099 template <>
0100 void TritonOutputGpuShmResource::copyOutput();
0101 #endif
0102 
0103 #endif