Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:47

0001 #ifndef HeterogeneousCore_SonicTriton_TritonData
0002 #define HeterogeneousCore_SonicTriton_TritonData
0003 
0004 #include "FWCore/Utilities/interface/Exception.h"
0005 #include "FWCore/Utilities/interface/Span.h"
0006 #include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"
0007 
0008 #include <vector>
0009 #include <string>
0010 #include <unordered_map>
0011 #include <numeric>
0012 #include <algorithm>
0013 #include <memory>
0014 #include <atomic>
0015 #include <typeinfo>
0016 
0017 #include "grpc_client.h"
0018 #include "grpc_service.pb.h"
0019 
0020 //forward declaration
0021 class TritonClient;
0022 template <typename IO>
0023 class TritonMemResource;
0024 template <typename IO>
0025 class TritonHeapResource;
0026 template <typename IO>
0027 class TritonCpuShmResource;
0028 #ifdef TRITON_ENABLE_GPU
0029 template <typename IO>
0030 class TritonGpuShmResource;
0031 #endif
0032 
0033 //aliases for local input and output types
0034 template <typename DT>
0035 using TritonInput = std::vector<std::vector<DT>>;
0036 template <typename DT>
0037 using TritonOutput = std::vector<edm::Span<const DT*>>;
0038 
0039 //other useful typdefs
0040 template <typename DT>
0041 using TritonInputContainer = std::shared_ptr<TritonInput<DT>>;
0042 
0043 //store all the info needed for triton input and output
0044 //NOTE: this class is not const-thread-safe, and should only be used with stream or one modules
0045 //(generally recommended for SONIC, but especially necessary here)
0046 template <typename IO>
0047 class TritonData {
0048 public:
0049   using Result = triton::client::InferResult;
0050   using TensorMetadata = inference::ModelMetadataResponse_TensorMetadata;
0051   using ShapeType = std::vector<int64_t>;
0052   using ShapeView = edm::Span<ShapeType::const_iterator>;
0053 
0054   //constructor
0055   TritonData(const std::string& name, const TensorMetadata& model_info, TritonClient* client, const std::string& pid);
0056 
0057   //some members can be modified
0058   void setShape(const ShapeType& newShape, unsigned entry = 0);
0059   void setShape(unsigned loc, int64_t val, unsigned entry = 0);
0060 
0061   //io accessors
0062   template <typename DT>
0063   TritonInputContainer<DT> allocate(bool reserve = true);
0064   template <typename DT>
0065   void toServer(TritonInputContainer<DT> ptr);
0066   void prepare();
0067   template <typename DT>
0068   TritonOutput<DT> fromServer() const;
0069 
0070   //const accessors
0071   const ShapeView& shape(unsigned entry = 0) const { return entries_.at(entry).shape_; }
0072   int64_t byteSize() const { return byteSize_; }
0073   const std::string& dname() const { return dname_; }
0074 
0075   //utilities
0076   bool variableDims() const { return variableDims_; }
0077   int64_t sizeDims() const { return productDims_; }
0078   //default to dims if shape isn't filled
0079   int64_t sizeShape(unsigned entry = 0) const {
0080     return variableDims_ ? dimProduct(entries_.at(entry).shape_) : sizeDims();
0081   }
0082 
0083 private:
0084   friend class TritonClient;
0085   friend class TritonMemResource<IO>;
0086   friend class TritonHeapResource<IO>;
0087   friend class TritonCpuShmResource<IO>;
0088 #ifdef TRITON_ENABLE_GPU
0089   friend class TritonGpuShmResource<IO>;
0090 #endif
0091 
0092   //group together all relevant information for a single request
0093   //helpful for organizing multi-request ragged batching case
0094   class TritonDataEntry {
0095   public:
0096     //constructors
0097     TritonDataEntry(const ShapeType& dims, bool noOuterDim, const std::string& name, const std::string& dname)
0098         : fullShape_(dims),
0099           shape_(fullShape_.begin() + (noOuterDim ? 0 : 1), fullShape_.end()),
0100           sizeShape_(0),
0101           byteSizePerBatch_(0),
0102           totalByteSize_(0),
0103           offset_(0),
0104           output_(nullptr) {
0105       //create input or output object
0106       IO* iotmp;
0107       createObject(&iotmp, name, dname);
0108       data_.reset(iotmp);
0109     }
0110     //default needed to be able to use std::vector resize()
0111     TritonDataEntry()
0112         : shape_(fullShape_.begin(), fullShape_.end()),
0113           sizeShape_(0),
0114           byteSizePerBatch_(0),
0115           totalByteSize_(0),
0116           offset_(0),
0117           output_(nullptr) {}
0118 
0119   private:
0120     friend class TritonData<IO>;
0121     friend class TritonClient;
0122     friend class TritonMemResource<IO>;
0123     friend class TritonHeapResource<IO>;
0124     friend class TritonCpuShmResource<IO>;
0125 #ifdef TRITON_ENABLE_GPU
0126     friend class TritonGpuShmResource<IO>;
0127 #endif
0128 
0129     //accessors
0130     void createObject(IO** ioptr, const std::string& name, const std::string& dname);
0131     void computeSizes(int64_t shapeSize, int64_t byteSize, int64_t batchSize);
0132 
0133     //members
0134     ShapeType fullShape_;
0135     ShapeView shape_;
0136     size_t sizeShape_, byteSizePerBatch_, totalByteSize_;
0137     std::shared_ptr<IO> data_;
0138     std::shared_ptr<Result> result_;
0139     unsigned offset_;
0140     const uint8_t* output_;
0141   };
0142 
0143   //private accessors only used internally or by client
0144   void checkShm() {}
0145   unsigned fullLoc(unsigned loc) const;
0146   void reset();
0147   void setResult(std::shared_ptr<Result> result, unsigned entry = 0) { entries_[entry].result_ = result; }
0148   IO* data(unsigned entry = 0) { return entries_[entry].data_.get(); }
0149   void updateMem(size_t size);
0150   void computeSizes();
0151   triton::client::InferenceServerGrpcClient* client();
0152   template <typename DT>
0153   void checkType() const {
0154     if (!triton_utils::checkType<DT>(dtype_))
0155       throw cms::Exception("TritonDataError")
0156           << name_ << ": inconsistent data type " << typeid(DT).name() << " for " << dname_;
0157   }
0158 
0159   //helpers
0160   bool anyNeg(const ShapeView& vec) const {
0161     return std::any_of(vec.begin(), vec.end(), [](int64_t i) { return i < 0; });
0162   }
0163   int64_t dimProduct(const ShapeView& vec) const {
0164     //lambda treats negative dimensions as 0 to avoid overflows
0165     return std::accumulate(
0166         vec.begin(), vec.end(), 1, [](int64_t dim1, int64_t dim2) { return dim1 * std::max(0l, dim2); });
0167   }
0168   //generates a unique id number for each instance of the class
0169   unsigned uid() const {
0170     static std::atomic<unsigned> uid{0};
0171     return ++uid;
0172   }
0173   std::string xput() const;
0174   void addEntry(unsigned entry);
0175   void addEntryImpl(unsigned entry);
0176 
0177   //members
0178   std::string name_;
0179   TritonClient* client_;
0180   bool useShm_;
0181   std::string shmName_;
0182   const ShapeType dims_;
0183   bool variableDims_;
0184   int64_t productDims_;
0185   std::string dname_;
0186   inference::DataType dtype_;
0187   int64_t byteSize_;
0188   std::vector<TritonDataEntry> entries_;
0189   size_t totalByteSize_;
0190   //can be modified in otherwise-const fromServer() method in TritonMemResource::copyOutput():
0191   //TritonMemResource holds a non-const pointer to an instance of this class
0192   //so that TritonOutputGpuShmResource can store data here
0193   std::shared_ptr<void> holder_;
0194   std::shared_ptr<TritonMemResource<IO>> memResource_;
0195   //can be modified in otherwise-const fromServer() method to prevent multiple calls
0196   CMS_SA_ALLOW mutable bool done_{};
0197 };
0198 
0199 using TritonInputData = TritonData<triton::client::InferInput>;
0200 using TritonInputMap = std::unordered_map<std::string, TritonInputData>;
0201 using TritonOutputData = TritonData<triton::client::InferRequestedOutput>;
0202 using TritonOutputMap = std::unordered_map<std::string, TritonOutputData>;
0203 
0204 //avoid "explicit specialization after instantiation" error
0205 template <>
0206 void TritonInputData::TritonDataEntry::createObject(triton::client::InferInput** ioptr,
0207                                                     const std::string& name,
0208                                                     const std::string& dname);
0209 template <>
0210 void TritonOutputData::TritonDataEntry::createObject(triton::client::InferRequestedOutput** ioptr,
0211                                                      const std::string& name,
0212                                                      const std::string& dname);
0213 template <>
0214 void TritonOutputData::checkShm();
0215 template <>
0216 std::string TritonInputData::xput() const;
0217 template <>
0218 std::string TritonOutputData::xput() const;
0219 template <>
0220 template <typename DT>
0221 TritonInputContainer<DT> TritonInputData::allocate(bool reserve);
0222 template <>
0223 template <typename DT>
0224 void TritonInputData::toServer(std::shared_ptr<TritonInput<DT>> ptr);
0225 template <>
0226 void TritonOutputData::prepare();
0227 template <>
0228 template <typename DT>
0229 TritonOutput<DT> TritonOutputData::fromServer() const;
0230 
0231 //explicit template instantiation declarations
0232 extern template class TritonData<triton::client::InferInput>;
0233 extern template class TritonData<triton::client::InferRequestedOutput>;
0234 
0235 #endif