Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-02-25 02:40:54

0001 #ifndef HeterogeneousCore_SonicTriton_TritonData
0002 #define HeterogeneousCore_SonicTriton_TritonData
0003 
0004 #include "FWCore/Utilities/interface/Exception.h"
0005 #include "FWCore/Utilities/interface/Span.h"
0006 #include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"
0007 
0008 #include <vector>
0009 #include <string>
0010 #include <unordered_map>
0011 #include <numeric>
0012 #include <algorithm>
0013 #include <memory>
0014 #include <atomic>
0015 #include <typeinfo>
0016 
0017 #include "grpc_client.h"
0018 #include "grpc_service.pb.h"
0019 
0020 //forward declaration
0021 class TritonClient;
0022 template <typename IO>
0023 class TritonMemResource;
0024 template <typename IO>
0025 class TritonHeapResource;
0026 template <typename IO>
0027 class TritonCpuShmResource;
0028 #ifdef TRITON_ENABLE_GPU
0029 template <typename IO>
0030 class TritonGpuShmResource;
0031 #endif
0032 
0033 //aliases for local input and output types
0034 template <typename DT>
0035 using TritonInput = std::vector<std::vector<DT>>;
0036 template <typename DT>
0037 using TritonOutput = std::vector<edm::Span<const DT*>>;
0038 
0039 //other useful typdefs
0040 template <typename DT>
0041 using TritonInputContainer = std::shared_ptr<TritonInput<DT>>;
0042 
0043 //store all the info needed for triton input and output
0044 //NOTE: this class is not const-thread-safe, and should only be used with stream or one modules
0045 //(generally recommended for SONIC, but especially necessary here)
0046 template <typename IO>
0047 class TritonData {
0048 public:
0049   using Result = triton::client::InferResult;
0050   using TensorMetadata = inference::ModelMetadataResponse_TensorMetadata;
0051   using ShapeType = std::vector<int64_t>;
0052   using ShapeView = edm::Span<ShapeType::const_iterator>;
0053 
0054   //constructor
0055   TritonData(const std::string& name, const TensorMetadata& model_info, TritonClient* client, const std::string& pid);
0056 
0057   //some members can be modified
0058   void setShape(const ShapeType& newShape);
0059   void setShape(unsigned loc, int64_t val);
0060 
0061   //io accessors
0062   template <typename DT>
0063   TritonInputContainer<DT> allocate(bool reserve = true);
0064   template <typename DT>
0065   void toServer(TritonInputContainer<DT> ptr);
0066   void prepare();
0067   template <typename DT>
0068   TritonOutput<DT> fromServer() const;
0069 
0070   //const accessors
0071   const ShapeView& shape() const { return shape_; }
0072   int64_t byteSize() const { return byteSize_; }
0073   const std::string& dname() const { return dname_; }
0074   unsigned batchSize() const { return batchSize_; }
0075 
0076   //utilities
0077   bool variableDims() const { return variableDims_; }
0078   int64_t sizeDims() const { return productDims_; }
0079   //default to dims if shape isn't filled
0080   int64_t sizeShape() const { return variableDims_ ? dimProduct(shape_) : sizeDims(); }
0081 
0082 private:
0083   friend class TritonClient;
0084   friend class TritonMemResource<IO>;
0085   friend class TritonHeapResource<IO>;
0086   friend class TritonCpuShmResource<IO>;
0087 #ifdef TRITON_ENABLE_GPU
0088   friend class TritonGpuShmResource<IO>;
0089 #endif
0090 
0091   //private accessors only used internally or by client
0092   unsigned fullLoc(unsigned loc) const { return loc + (noBatch_ ? 0 : 1); }
0093   void setBatchSize(unsigned bsize);
0094   void reset();
0095   void setResult(std::shared_ptr<Result> result) { result_ = result; }
0096   IO* data() { return data_.get(); }
0097   void updateMem(size_t size);
0098   void computeSizes();
0099   void resetSizes();
0100   triton::client::InferenceServerGrpcClient* client();
0101   template <typename DT>
0102   void checkType() const {
0103     if (!triton_utils::checkType<DT>(dtype_))
0104       throw cms::Exception("TritonDataError")
0105           << name_ << ": inconsistent data type " << typeid(DT).name() << " for " << dname_;
0106   }
0107 
0108   //helpers
0109   bool anyNeg(const ShapeView& vec) const {
0110     return std::any_of(vec.begin(), vec.end(), [](int64_t i) { return i < 0; });
0111   }
0112   int64_t dimProduct(const ShapeView& vec) const {
0113     return std::accumulate(vec.begin(), vec.end(), 1, std::multiplies<int64_t>());
0114   }
0115   void createObject(IO** ioptr);
0116   //generates a unique id number for each instance of the class
0117   unsigned uid() const {
0118     static std::atomic<unsigned> uid{0};
0119     return ++uid;
0120   }
0121   std::string xput() const;
0122 
0123   //members
0124   std::string name_;
0125   std::shared_ptr<IO> data_;
0126   TritonClient* client_;
0127   bool useShm_;
0128   std::string shmName_;
0129   const ShapeType dims_;
0130   bool noBatch_;
0131   unsigned batchSize_;
0132   ShapeType fullShape_;
0133   ShapeView shape_;
0134   bool variableDims_;
0135   int64_t productDims_;
0136   std::string dname_;
0137   inference::DataType dtype_;
0138   int64_t byteSize_;
0139   size_t sizeShape_;
0140   size_t byteSizePerBatch_;
0141   size_t totalByteSize_;
0142   //can be modified in otherwise-const fromServer() method in TritonMemResource::copyOutput():
0143   //TritonMemResource holds a non-const pointer to an instance of this class
0144   //so that TritonOutputGpuShmResource can store data here
0145   std::shared_ptr<void> holder_;
0146   std::shared_ptr<TritonMemResource<IO>> memResource_;
0147   std::shared_ptr<Result> result_;
0148   //can be modified in otherwise-const fromServer() method to prevent multiple calls
0149   CMS_SA_ALLOW mutable bool done_{};
0150 };
0151 
0152 using TritonInputData = TritonData<triton::client::InferInput>;
0153 using TritonInputMap = std::unordered_map<std::string, TritonInputData>;
0154 using TritonOutputData = TritonData<triton::client::InferRequestedOutput>;
0155 using TritonOutputMap = std::unordered_map<std::string, TritonOutputData>;
0156 
0157 //avoid "explicit specialization after instantiation" error
0158 template <>
0159 std::string TritonInputData::xput() const;
0160 template <>
0161 std::string TritonOutputData::xput() const;
0162 template <>
0163 template <typename DT>
0164 TritonInputContainer<DT> TritonInputData::allocate(bool reserve);
0165 template <>
0166 template <typename DT>
0167 void TritonInputData::toServer(std::shared_ptr<TritonInput<DT>> ptr);
0168 template <>
0169 void TritonOutputData::prepare();
0170 template <>
0171 template <typename DT>
0172 TritonOutput<DT> TritonOutputData::fromServer() const;
0173 template <>
0174 void TritonInputData::reset();
0175 template <>
0176 void TritonOutputData::reset();
0177 template <>
0178 void TritonInputData::createObject(triton::client::InferInput** ioptr);
0179 template <>
0180 void TritonOutputData::createObject(triton::client::InferRequestedOutput** ioptr);
0181 
0182 //explicit template instantiation declarations
0183 extern template class TritonData<triton::client::InferInput>;
0184 extern template class TritonData<triton::client::InferRequestedOutput>;
0185 
0186 #endif