Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:48

0001 #include "HeterogeneousCore/SonicTriton/interface/TritonData.h"
0002 #include "HeterogeneousCore/SonicTriton/interface/TritonClient.h"
0003 #include "HeterogeneousCore/SonicTriton/interface/TritonMemResource.h"
0004 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0005 
0006 #include "model_config.pb.h"
0007 #include "triton/common/model_config.h"
0008 
0009 #include <sstream>
0010 
0011 namespace tco = triton::common;
0012 namespace tc = triton::client;
0013 
0014 //dims: kept constant, represents config.pbtxt parameters of model (converted from google::protobuf::RepeatedField to vector)
0015 //fullShape: if batching is enabled, first entry is batch size; values can be modified
0016 //shape: view into fullShape, excluding batch size entry
0017 template <typename IO>
0018 TritonData<IO>::TritonData(const std::string& name,
0019                            const TritonData<IO>::TensorMetadata& model_info,
0020                            TritonClient* client,
0021                            const std::string& pid)
0022     : name_(name),
0023       client_(client),
0024       useShm_(client_->useSharedMemory()),
0025       //ensure unique name for shared memory region
0026       shmName_(useShm_ ? pid + "_" + xput() + std::to_string(uid()) : ""),
0027       dims_(model_info.shape().begin(), model_info.shape().end()),
0028       dname_(model_info.datatype()),
0029       dtype_(tco::ProtocolStringToDataType(dname_)),
0030       byteSize_(tco::GetDataTypeByteSize(dtype_)),
0031       totalByteSize_(0) {
0032   //initialize first shape entry
0033   addEntryImpl(0);
0034   //one-time computation of some shape info
0035   variableDims_ = anyNeg(entries_.front().shape_);
0036   productDims_ = variableDims_ ? -1 : dimProduct(entries_.front().shape_);
0037   checkShm();
0038 }
0039 
0040 template <>
0041 void TritonOutputData::checkShm() {
0042   //another specialization for output: can't use shared memory if output size is not known
0043   useShm_ &= !variableDims_;
0044 }
0045 
0046 template <typename IO>
0047 void TritonData<IO>::addEntry(unsigned entry) {
0048   //ensures consistency among all inputs
0049   client_->addEntry(entry);
0050 }
0051 
0052 template <typename IO>
0053 void TritonData<IO>::addEntryImpl(unsigned entry) {
0054   if (entry >= entries_.size()) {
0055     entries_.reserve(entry + 1);
0056     for (unsigned i = entries_.size(); i < entry + 1; ++i) {
0057       entries_.emplace_back(dims_, client_->noOuterDim(), name_, dname_);
0058     }
0059   }
0060 }
0061 
0062 template <>
0063 void TritonInputData::TritonDataEntry::createObject(tc::InferInput** ioptr,
0064                                                     const std::string& name,
0065                                                     const std::string& dname) {
0066   tc::InferInput::Create(ioptr, name, fullShape_, dname);
0067 }
0068 
0069 template <>
0070 void TritonOutputData::TritonDataEntry::createObject(tc::InferRequestedOutput** ioptr,
0071                                                      const std::string& name,
0072                                                      const std::string& dname) {
0073   tc::InferRequestedOutput::Create(ioptr, name);
0074 }
0075 
0076 template <>
0077 std::string TritonInputData::xput() const {
0078   return "input";
0079 }
0080 
0081 template <>
0082 std::string TritonOutputData::xput() const {
0083   return "output";
0084 }
0085 
0086 template <typename IO>
0087 tc::InferenceServerGrpcClient* TritonData<IO>::client() {
0088   return client_->client();
0089 }
0090 
0091 //setters
0092 template <typename IO>
0093 void TritonData<IO>::setShape(const TritonData<IO>::ShapeType& newShape, unsigned entry) {
0094   addEntry(entry);
0095   for (unsigned i = 0; i < newShape.size(); ++i) {
0096     setShape(i, newShape[i], entry);
0097   }
0098 }
0099 
0100 template <typename IO>
0101 void TritonData<IO>::setShape(unsigned loc, int64_t val, unsigned entry) {
0102   addEntry(entry);
0103 
0104   unsigned locFull = fullLoc(loc);
0105 
0106   //check boundary
0107   if (locFull >= entries_[entry].fullShape_.size())
0108     throw cms::Exception("TritonDataError") << name_ << " setShape(): dimension " << locFull << " out of bounds ("
0109                                             << entries_[entry].fullShape_.size() << ")";
0110 
0111   if (val != entries_[entry].fullShape_[locFull]) {
0112     if (dims_[locFull] == -1)
0113       entries_[entry].fullShape_[locFull] = val;
0114     else
0115       throw cms::Exception("TritonDataError")
0116           << name_ << " setShape(): attempt to change value of non-variable shape dimension " << loc;
0117   }
0118 }
0119 
0120 template <typename IO>
0121 void TritonData<IO>::TritonDataEntry::computeSizes(int64_t shapeSize, int64_t byteSize, int64_t batchSize) {
0122   sizeShape_ = shapeSize;
0123   byteSizePerBatch_ = byteSize * sizeShape_;
0124   totalByteSize_ = byteSizePerBatch_ * batchSize;
0125 }
0126 
0127 template <typename IO>
0128 void TritonData<IO>::computeSizes() {
0129   totalByteSize_ = 0;
0130   unsigned outerDim = client_->outerDim();
0131   for (unsigned i = 0; i < entries_.size(); ++i) {
0132     entries_[i].computeSizes(sizeShape(i), byteSize_, outerDim);
0133     entries_[i].offset_ = totalByteSize_;
0134     totalByteSize_ += entries_[i].totalByteSize_;
0135   }
0136 }
0137 
0138 //create a memory resource if none exists;
0139 //otherwise, reuse the memory resource, resizing it if necessary
0140 template <typename IO>
0141 void TritonData<IO>::updateMem(size_t size) {
0142   if (!memResource_ or size > memResource_->size()) {
0143     if (useShm_ and client_->serverType() == TritonServerType::LocalCPU) {
0144       //avoid unnecessarily throwing in destructor
0145       if (memResource_)
0146         memResource_->close();
0147       //need to destroy before constructing new instance because shared memory key will be reused
0148       memResource_.reset();
0149       memResource_ = std::make_shared<TritonCpuShmResource<IO>>(this, shmName_, size);
0150     }
0151 #ifdef TRITON_ENABLE_GPU
0152     else if (useShm_ and client_->serverType() == TritonServerType::LocalGPU) {
0153       //avoid unnecessarily throwing in destructor
0154       if (memResource_)
0155         memResource_->close();
0156       //need to destroy before constructing new instance because shared memory key will be reused
0157       memResource_.reset();
0158       memResource_ = std::make_shared<TritonGpuShmResource<IO>>(this, shmName_, size);
0159     }
0160 #endif
0161     //for remote/heap, size increases don't matter
0162     else if (!memResource_)
0163       memResource_ = std::make_shared<TritonHeapResource<IO>>(this, shmName_, size);
0164   }
0165 }
0166 
0167 //io accessors
0168 template <>
0169 template <typename DT>
0170 TritonInputContainer<DT> TritonInputData::allocate(bool reserve) {
0171   //automatically creates a vector for each item (if batch size known)
0172   auto ptr = std::make_shared<TritonInput<DT>>(client_->batchSize());
0173   if (reserve) {
0174     computeSizes();
0175     for (auto& entry : entries_) {
0176       if (anyNeg(entry.shape_))
0177         continue;
0178       for (auto& vec : *ptr) {
0179         vec.reserve(entry.sizeShape_);
0180       }
0181     }
0182   }
0183   return ptr;
0184 }
0185 
0186 template <>
0187 template <typename DT>
0188 void TritonInputData::toServer(TritonInputContainer<DT> ptr) {
0189   //shouldn't be called twice
0190   if (done_)
0191     throw cms::Exception("TritonDataError") << name_ << " toServer() was already called for this event";
0192 
0193   const auto& data_in = *ptr;
0194 
0195   //check batch size
0196   unsigned batchSize = client_->batchSize();
0197   unsigned outerDim = client_->outerDim();
0198   if (data_in.size() != batchSize) {
0199     throw cms::Exception("TritonDataError") << name_ << " toServer(): input vector has size " << data_in.size()
0200                                             << " but specified batch size is " << batchSize;
0201   }
0202 
0203   //check type
0204   checkType<DT>();
0205 
0206   computeSizes();
0207   updateMem(totalByteSize_);
0208 
0209   unsigned offset = 0;
0210   unsigned counter = 0;
0211   for (unsigned i = 0; i < entries_.size(); ++i) {
0212     auto& entry = entries_[i];
0213 
0214     //shape must be specified for variable dims or if batch size changes
0215     if (!client_->noOuterDim())
0216       entry.fullShape_[0] = outerDim;
0217     entry.data_->SetShape(entry.fullShape_);
0218 
0219     for (unsigned i0 = 0; i0 < outerDim; ++i0) {
0220       //avoid copying empty input
0221       if (entry.byteSizePerBatch_ > 0)
0222         memResource_->copyInput(data_in[counter].data(), offset, i);
0223       offset += entry.byteSizePerBatch_;
0224       ++counter;
0225     }
0226   }
0227   memResource_->set();
0228 
0229   //keep input data in scope
0230   holder_ = ptr;
0231   done_ = true;
0232 }
0233 
0234 //sets up shared memory for outputs, if possible
0235 template <>
0236 void TritonOutputData::prepare() {
0237   computeSizes();
0238   updateMem(totalByteSize_);
0239   memResource_->set();
0240 }
0241 
0242 template <>
0243 template <typename DT>
0244 TritonOutput<DT> TritonOutputData::fromServer() const {
0245   //shouldn't be called twice
0246   if (done_)
0247     throw cms::Exception("TritonDataError") << name_ << " fromServer() was already called for this event";
0248 
0249   //check type
0250   checkType<DT>();
0251 
0252   memResource_->copyOutput();
0253 
0254   unsigned outerDim = client_->outerDim();
0255   TritonOutput<DT> dataOut;
0256   dataOut.reserve(client_->batchSize());
0257   for (unsigned i = 0; i < entries_.size(); ++i) {
0258     const auto& entry = entries_[i];
0259     const DT* r1 = reinterpret_cast<const DT*>(entry.output_);
0260 
0261     if (entry.totalByteSize_ > 0 and !entry.result_) {
0262       throw cms::Exception("TritonDataError") << name_ << " fromServer(): missing result";
0263     }
0264 
0265     for (unsigned i0 = 0; i0 < outerDim; ++i0) {
0266       auto offset = i0 * entry.sizeShape_;
0267       dataOut.emplace_back(r1 + offset, r1 + offset + entry.sizeShape_);
0268     }
0269   }
0270 
0271   done_ = true;
0272   return dataOut;
0273 }
0274 
0275 template <typename IO>
0276 void TritonData<IO>::reset() {
0277   done_ = false;
0278   holder_.reset();
0279   entries_.clear();
0280   totalByteSize_ = 0;
0281   //re-initialize first shape entry
0282   addEntryImpl(0);
0283 }
0284 
0285 template <typename IO>
0286 unsigned TritonData<IO>::fullLoc(unsigned loc) const {
0287   return loc + (client_->noOuterDim() ? 0 : 1);
0288 }
0289 
0290 //explicit template instantiation declarations
0291 template class TritonData<tc::InferInput>;
0292 template class TritonData<tc::InferRequestedOutput>;
0293 
0294 template TritonInputContainer<char> TritonInputData::allocate(bool reserve);
0295 template TritonInputContainer<uint8_t> TritonInputData::allocate(bool reserve);
0296 template TritonInputContainer<uint16_t> TritonInputData::allocate(bool reserve);
0297 template TritonInputContainer<uint32_t> TritonInputData::allocate(bool reserve);
0298 template TritonInputContainer<uint64_t> TritonInputData::allocate(bool reserve);
0299 template TritonInputContainer<int8_t> TritonInputData::allocate(bool reserve);
0300 template TritonInputContainer<int16_t> TritonInputData::allocate(bool reserve);
0301 template TritonInputContainer<int32_t> TritonInputData::allocate(bool reserve);
0302 template TritonInputContainer<int64_t> TritonInputData::allocate(bool reserve);
0303 template TritonInputContainer<float> TritonInputData::allocate(bool reserve);
0304 template TritonInputContainer<double> TritonInputData::allocate(bool reserve);
0305 
0306 template void TritonInputData::toServer(TritonInputContainer<char> data_in);
0307 template void TritonInputData::toServer(TritonInputContainer<uint8_t> data_in);
0308 template void TritonInputData::toServer(TritonInputContainer<uint16_t> data_in);
0309 template void TritonInputData::toServer(TritonInputContainer<uint32_t> data_in);
0310 template void TritonInputData::toServer(TritonInputContainer<uint64_t> data_in);
0311 template void TritonInputData::toServer(TritonInputContainer<int8_t> data_in);
0312 template void TritonInputData::toServer(TritonInputContainer<int16_t> data_in);
0313 template void TritonInputData::toServer(TritonInputContainer<int32_t> data_in);
0314 template void TritonInputData::toServer(TritonInputContainer<int64_t> data_in);
0315 template void TritonInputData::toServer(TritonInputContainer<float> data_in);
0316 template void TritonInputData::toServer(TritonInputContainer<double> data_in);
0317 
0318 template TritonOutput<char> TritonOutputData::fromServer() const;
0319 template TritonOutput<uint8_t> TritonOutputData::fromServer() const;
0320 template TritonOutput<uint16_t> TritonOutputData::fromServer() const;
0321 template TritonOutput<uint32_t> TritonOutputData::fromServer() const;
0322 template TritonOutput<uint64_t> TritonOutputData::fromServer() const;
0323 template TritonOutput<int8_t> TritonOutputData::fromServer() const;
0324 template TritonOutput<int16_t> TritonOutputData::fromServer() const;
0325 template TritonOutput<int32_t> TritonOutputData::fromServer() const;
0326 template TritonOutput<int64_t> TritonOutputData::fromServer() const;
0327 template TritonOutput<float> TritonOutputData::fromServer() const;
0328 template TritonOutput<double> TritonOutputData::fromServer() const;