File indexing completed on 2024-04-06 12:15:47
0001 #ifndef HeterogeneousCore_SonicTriton_TritonMemResource
0002 #define HeterogeneousCore_SonicTriton_TritonMemResource
0003
0004 #include <string>
0005 #include <memory>
0006
0007 #include "grpc_client.h"
0008
0009
0010 template <typename IO>
0011 class TritonData;
0012
0013
0014 template <typename IO>
0015 class TritonMemResource {
0016 public:
0017 TritonMemResource(TritonData<IO>* data, const std::string& name, size_t size);
0018 virtual ~TritonMemResource() {}
0019 uint8_t* addr() { return addr_; }
0020 size_t size() const { return size_; }
0021 virtual void close() {}
0022 void closeSafe();
0023
0024 virtual void copyInput(const void* values, size_t offset, unsigned entry) {}
0025
0026 virtual void copyOutput() {}
0027 virtual void set();
0028
0029 protected:
0030
0031 TritonData<IO>* data_;
0032 std::string name_;
0033 size_t size_;
0034 uint8_t* addr_;
0035 bool closed_;
0036 };
0037
0038 template <typename IO>
0039 class TritonHeapResource : public TritonMemResource<IO> {
0040 public:
0041 TritonHeapResource(TritonData<IO>* data, const std::string& name, size_t size);
0042 ~TritonHeapResource() override {}
0043 void copyInput(const void* values, size_t offset, unsigned entry) override {}
0044 void copyOutput() override {}
0045 void set() override {}
0046 };
0047
0048 template <typename IO>
0049 class TritonCpuShmResource : public TritonMemResource<IO> {
0050 public:
0051 TritonCpuShmResource(TritonData<IO>* data, const std::string& name, size_t size);
0052 ~TritonCpuShmResource() override;
0053 void close() override;
0054 void copyInput(const void* values, size_t offset, unsigned entry) override {}
0055 void copyOutput() override {}
0056
0057 protected:
0058 size_t sizeOrig_;
0059 };
0060
0061 using TritonInputHeapResource = TritonHeapResource<triton::client::InferInput>;
0062 using TritonInputCpuShmResource = TritonCpuShmResource<triton::client::InferInput>;
0063 using TritonOutputHeapResource = TritonHeapResource<triton::client::InferRequestedOutput>;
0064 using TritonOutputCpuShmResource = TritonCpuShmResource<triton::client::InferRequestedOutput>;
0065
0066
0067 template <>
0068 void TritonInputHeapResource::copyInput(const void* values, size_t offset, unsigned entry);
0069 template <>
0070 void TritonInputCpuShmResource::copyInput(const void* values, size_t offset, unsigned entry);
0071 template <>
0072 void TritonOutputHeapResource::copyOutput();
0073 template <>
0074 void TritonOutputCpuShmResource::copyOutput();
0075
0076 #ifdef TRITON_ENABLE_GPU
0077 #include "cuda_runtime_api.h"
0078
0079 template <typename IO>
0080 class TritonGpuShmResource : public TritonMemResource<IO> {
0081 public:
0082 TritonGpuShmResource(TritonData<IO>* data, const std::string& name, size_t size);
0083 ~TritonGpuShmResource() override;
0084 void close() override;
0085 void copyInput(const void* values, size_t offset, unsigned entry) override {}
0086 void copyOutput() override {}
0087
0088 protected:
0089 int deviceId_;
0090 std::shared_ptr<cudaIpcMemHandle_t> handle_;
0091 };
0092
0093 using TritonInputGpuShmResource = TritonGpuShmResource<triton::client::InferInput>;
0094 using TritonOutputGpuShmResource = TritonGpuShmResource<triton::client::InferRequestedOutput>;
0095
0096
0097 template <>
0098 void TritonInputGpuShmResource::copyInput(const void* values, size_t offset, unsigned entry);
0099 template <>
0100 void TritonOutputGpuShmResource::copyOutput();
0101 #endif
0102
0103 #endif