File indexing completed on 2023-10-25 09:50:29
0001 #ifndef HeterogeneousCore_SonicTriton_TritonMemResource
0002 #define HeterogeneousCore_SonicTriton_TritonMemResource
0003
0004 #include <string>
0005 #include <memory>
0006
0007 #include "grpc_client.h"
0008
0009
0010 template <typename IO>
0011 class TritonData;
0012
0013
0014 template <typename IO>
0015 class TritonMemResource {
0016 public:
0017 TritonMemResource(TritonData<IO>* data, const std::string& name, size_t size);
0018 virtual ~TritonMemResource() {}
0019 uint8_t* addr() { return addr_; }
0020 size_t size() const { return size_; }
0021 virtual void close() {}
0022
0023 virtual void copyInput(const void* values, size_t offset, unsigned entry) {}
0024
0025 virtual void copyOutput() {}
0026 virtual void set();
0027
0028 protected:
0029
0030 TritonData<IO>* data_;
0031 std::string name_;
0032 size_t size_;
0033 uint8_t* addr_;
0034 bool closed_;
0035 };
0036
0037 template <typename IO>
0038 class TritonHeapResource : public TritonMemResource<IO> {
0039 public:
0040 TritonHeapResource(TritonData<IO>* data, const std::string& name, size_t size);
0041 ~TritonHeapResource() override {}
0042 void copyInput(const void* values, size_t offset, unsigned entry) override {}
0043 void copyOutput() override {}
0044 void set() override {}
0045 };
0046
0047 template <typename IO>
0048 class TritonCpuShmResource : public TritonMemResource<IO> {
0049 public:
0050 TritonCpuShmResource(TritonData<IO>* data, const std::string& name, size_t size);
0051 ~TritonCpuShmResource() override;
0052 void close() override;
0053 void copyInput(const void* values, size_t offset, unsigned entry) override {}
0054 void copyOutput() override {}
0055
0056 protected:
0057 size_t sizeOrig_;
0058 };
0059
0060 using TritonInputHeapResource = TritonHeapResource<triton::client::InferInput>;
0061 using TritonInputCpuShmResource = TritonCpuShmResource<triton::client::InferInput>;
0062 using TritonOutputHeapResource = TritonHeapResource<triton::client::InferRequestedOutput>;
0063 using TritonOutputCpuShmResource = TritonCpuShmResource<triton::client::InferRequestedOutput>;
0064
0065
0066 template <>
0067 void TritonInputHeapResource::copyInput(const void* values, size_t offset, unsigned entry);
0068 template <>
0069 void TritonInputCpuShmResource::copyInput(const void* values, size_t offset, unsigned entry);
0070 template <>
0071 void TritonOutputHeapResource::copyOutput();
0072 template <>
0073 void TritonOutputCpuShmResource::copyOutput();
0074
0075 #ifdef TRITON_ENABLE_GPU
0076 #include "cuda_runtime_api.h"
0077
0078 template <typename IO>
0079 class TritonGpuShmResource : public TritonMemResource<IO> {
0080 public:
0081 TritonGpuShmResource(TritonData<IO>* data, const std::string& name, size_t size);
0082 ~TritonGpuShmResource() override;
0083 void close() override;
0084 void copyInput(const void* values, size_t offset, unsigned entry) override {}
0085 void copyOutput() override {}
0086
0087 protected:
0088 int deviceId_;
0089 std::shared_ptr<cudaIpcMemHandle_t> handle_;
0090 };
0091
0092 using TritonInputGpuShmResource = TritonGpuShmResource<triton::client::InferInput>;
0093 using TritonOutputGpuShmResource = TritonGpuShmResource<triton::client::InferRequestedOutput>;
0094
0095
0096 template <>
0097 void TritonInputGpuShmResource::copyInput(const void* values, size_t offset, unsigned entry);
0098 template <>
0099 void TritonOutputGpuShmResource::copyOutput();
0100 #endif
0101
0102 #endif