File indexing completed on 2024-04-06 12:15:48
0001 #include "HeterogeneousCore/SonicTriton/interface/TritonData.h"
0002 #include "HeterogeneousCore/SonicTriton/interface/TritonClient.h"
0003 #include "HeterogeneousCore/SonicTriton/interface/TritonMemResource.h"
0004 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0005
0006 #include "model_config.pb.h"
0007 #include "triton/common/model_config.h"
0008
0009 #include <sstream>
0010
0011 namespace tco = triton::common;
0012 namespace tc = triton::client;
0013
0014
0015
0016
0017 template <typename IO>
0018 TritonData<IO>::TritonData(const std::string& name,
0019 const TritonData<IO>::TensorMetadata& model_info,
0020 TritonClient* client,
0021 const std::string& pid)
0022 : name_(name),
0023 client_(client),
0024 useShm_(client_->useSharedMemory()),
0025
0026 shmName_(useShm_ ? pid + "_" + xput() + std::to_string(uid()) : ""),
0027 dims_(model_info.shape().begin(), model_info.shape().end()),
0028 dname_(model_info.datatype()),
0029 dtype_(tco::ProtocolStringToDataType(dname_)),
0030 byteSize_(tco::GetDataTypeByteSize(dtype_)),
0031 totalByteSize_(0) {
0032
0033 addEntryImpl(0);
0034
0035 variableDims_ = anyNeg(entries_.front().shape_);
0036 productDims_ = variableDims_ ? -1 : dimProduct(entries_.front().shape_);
0037 checkShm();
0038 }
0039
0040 template <>
0041 void TritonOutputData::checkShm() {
0042
0043 useShm_ &= !variableDims_;
0044 }
0045
0046 template <typename IO>
0047 void TritonData<IO>::addEntry(unsigned entry) {
0048
0049 client_->addEntry(entry);
0050 }
0051
0052 template <typename IO>
0053 void TritonData<IO>::addEntryImpl(unsigned entry) {
0054 if (entry >= entries_.size()) {
0055 entries_.reserve(entry + 1);
0056 for (unsigned i = entries_.size(); i < entry + 1; ++i) {
0057 entries_.emplace_back(dims_, client_->noOuterDim(), name_, dname_);
0058 }
0059 }
0060 }
0061
0062 template <>
0063 void TritonInputData::TritonDataEntry::createObject(tc::InferInput** ioptr,
0064 const std::string& name,
0065 const std::string& dname) {
0066 tc::InferInput::Create(ioptr, name, fullShape_, dname);
0067 }
0068
0069 template <>
0070 void TritonOutputData::TritonDataEntry::createObject(tc::InferRequestedOutput** ioptr,
0071 const std::string& name,
0072 const std::string& dname) {
0073 tc::InferRequestedOutput::Create(ioptr, name);
0074 }
0075
0076 template <>
0077 std::string TritonInputData::xput() const {
0078 return "input";
0079 }
0080
0081 template <>
0082 std::string TritonOutputData::xput() const {
0083 return "output";
0084 }
0085
0086 template <typename IO>
0087 tc::InferenceServerGrpcClient* TritonData<IO>::client() {
0088 return client_->client();
0089 }
0090
0091
0092 template <typename IO>
0093 void TritonData<IO>::setShape(const TritonData<IO>::ShapeType& newShape, unsigned entry) {
0094 addEntry(entry);
0095 for (unsigned i = 0; i < newShape.size(); ++i) {
0096 setShape(i, newShape[i], entry);
0097 }
0098 }
0099
0100 template <typename IO>
0101 void TritonData<IO>::setShape(unsigned loc, int64_t val, unsigned entry) {
0102 addEntry(entry);
0103
0104 unsigned locFull = fullLoc(loc);
0105
0106
0107 if (locFull >= entries_[entry].fullShape_.size())
0108 throw cms::Exception("TritonDataError") << name_ << " setShape(): dimension " << locFull << " out of bounds ("
0109 << entries_[entry].fullShape_.size() << ")";
0110
0111 if (val != entries_[entry].fullShape_[locFull]) {
0112 if (dims_[locFull] == -1)
0113 entries_[entry].fullShape_[locFull] = val;
0114 else
0115 throw cms::Exception("TritonDataError")
0116 << name_ << " setShape(): attempt to change value of non-variable shape dimension " << loc;
0117 }
0118 }
0119
0120 template <typename IO>
0121 void TritonData<IO>::TritonDataEntry::computeSizes(int64_t shapeSize, int64_t byteSize, int64_t batchSize) {
0122 sizeShape_ = shapeSize;
0123 byteSizePerBatch_ = byteSize * sizeShape_;
0124 totalByteSize_ = byteSizePerBatch_ * batchSize;
0125 }
0126
0127 template <typename IO>
0128 void TritonData<IO>::computeSizes() {
0129 totalByteSize_ = 0;
0130 unsigned outerDim = client_->outerDim();
0131 for (unsigned i = 0; i < entries_.size(); ++i) {
0132 entries_[i].computeSizes(sizeShape(i), byteSize_, outerDim);
0133 entries_[i].offset_ = totalByteSize_;
0134 totalByteSize_ += entries_[i].totalByteSize_;
0135 }
0136 }
0137
0138
0139
0140 template <typename IO>
0141 void TritonData<IO>::updateMem(size_t size) {
0142 if (!memResource_ or size > memResource_->size()) {
0143 if (useShm_ and client_->serverType() == TritonServerType::LocalCPU) {
0144
0145 if (memResource_)
0146 memResource_->close();
0147
0148 memResource_.reset();
0149 memResource_ = std::make_shared<TritonCpuShmResource<IO>>(this, shmName_, size);
0150 }
0151 #ifdef TRITON_ENABLE_GPU
0152 else if (useShm_ and client_->serverType() == TritonServerType::LocalGPU) {
0153
0154 if (memResource_)
0155 memResource_->close();
0156
0157 memResource_.reset();
0158 memResource_ = std::make_shared<TritonGpuShmResource<IO>>(this, shmName_, size);
0159 }
0160 #endif
0161
0162 else if (!memResource_)
0163 memResource_ = std::make_shared<TritonHeapResource<IO>>(this, shmName_, size);
0164 }
0165 }
0166
0167
0168 template <>
0169 template <typename DT>
0170 TritonInputContainer<DT> TritonInputData::allocate(bool reserve) {
0171
0172 auto ptr = std::make_shared<TritonInput<DT>>(client_->batchSize());
0173 if (reserve) {
0174 computeSizes();
0175 for (auto& entry : entries_) {
0176 if (anyNeg(entry.shape_))
0177 continue;
0178 for (auto& vec : *ptr) {
0179 vec.reserve(entry.sizeShape_);
0180 }
0181 }
0182 }
0183 return ptr;
0184 }
0185
0186 template <>
0187 template <typename DT>
0188 void TritonInputData::toServer(TritonInputContainer<DT> ptr) {
0189
0190 if (done_)
0191 throw cms::Exception("TritonDataError") << name_ << " toServer() was already called for this event";
0192
0193 const auto& data_in = *ptr;
0194
0195
0196 unsigned batchSize = client_->batchSize();
0197 unsigned outerDim = client_->outerDim();
0198 if (data_in.size() != batchSize) {
0199 throw cms::Exception("TritonDataError") << name_ << " toServer(): input vector has size " << data_in.size()
0200 << " but specified batch size is " << batchSize;
0201 }
0202
0203
0204 checkType<DT>();
0205
0206 computeSizes();
0207 updateMem(totalByteSize_);
0208
0209 unsigned offset = 0;
0210 unsigned counter = 0;
0211 for (unsigned i = 0; i < entries_.size(); ++i) {
0212 auto& entry = entries_[i];
0213
0214
0215 if (!client_->noOuterDim())
0216 entry.fullShape_[0] = outerDim;
0217 entry.data_->SetShape(entry.fullShape_);
0218
0219 for (unsigned i0 = 0; i0 < outerDim; ++i0) {
0220
0221 if (entry.byteSizePerBatch_ > 0)
0222 memResource_->copyInput(data_in[counter].data(), offset, i);
0223 offset += entry.byteSizePerBatch_;
0224 ++counter;
0225 }
0226 }
0227 memResource_->set();
0228
0229
0230 holder_ = ptr;
0231 done_ = true;
0232 }
0233
0234
0235 template <>
0236 void TritonOutputData::prepare() {
0237 computeSizes();
0238 updateMem(totalByteSize_);
0239 memResource_->set();
0240 }
0241
0242 template <>
0243 template <typename DT>
0244 TritonOutput<DT> TritonOutputData::fromServer() const {
0245
0246 if (done_)
0247 throw cms::Exception("TritonDataError") << name_ << " fromServer() was already called for this event";
0248
0249
0250 checkType<DT>();
0251
0252 memResource_->copyOutput();
0253
0254 unsigned outerDim = client_->outerDim();
0255 TritonOutput<DT> dataOut;
0256 dataOut.reserve(client_->batchSize());
0257 for (unsigned i = 0; i < entries_.size(); ++i) {
0258 const auto& entry = entries_[i];
0259 const DT* r1 = reinterpret_cast<const DT*>(entry.output_);
0260
0261 if (entry.totalByteSize_ > 0 and !entry.result_) {
0262 throw cms::Exception("TritonDataError") << name_ << " fromServer(): missing result";
0263 }
0264
0265 for (unsigned i0 = 0; i0 < outerDim; ++i0) {
0266 auto offset = i0 * entry.sizeShape_;
0267 dataOut.emplace_back(r1 + offset, r1 + offset + entry.sizeShape_);
0268 }
0269 }
0270
0271 done_ = true;
0272 return dataOut;
0273 }
0274
0275 template <typename IO>
0276 void TritonData<IO>::reset() {
0277 done_ = false;
0278 holder_.reset();
0279 entries_.clear();
0280 totalByteSize_ = 0;
0281
0282 addEntryImpl(0);
0283 }
0284
0285 template <typename IO>
0286 unsigned TritonData<IO>::fullLoc(unsigned loc) const {
0287 return loc + (client_->noOuterDim() ? 0 : 1);
0288 }
0289
0290
0291 template class TritonData<tc::InferInput>;
0292 template class TritonData<tc::InferRequestedOutput>;
0293
0294 template TritonInputContainer<char> TritonInputData::allocate(bool reserve);
0295 template TritonInputContainer<uint8_t> TritonInputData::allocate(bool reserve);
0296 template TritonInputContainer<uint16_t> TritonInputData::allocate(bool reserve);
0297 template TritonInputContainer<uint32_t> TritonInputData::allocate(bool reserve);
0298 template TritonInputContainer<uint64_t> TritonInputData::allocate(bool reserve);
0299 template TritonInputContainer<int8_t> TritonInputData::allocate(bool reserve);
0300 template TritonInputContainer<int16_t> TritonInputData::allocate(bool reserve);
0301 template TritonInputContainer<int32_t> TritonInputData::allocate(bool reserve);
0302 template TritonInputContainer<int64_t> TritonInputData::allocate(bool reserve);
0303 template TritonInputContainer<float> TritonInputData::allocate(bool reserve);
0304 template TritonInputContainer<double> TritonInputData::allocate(bool reserve);
0305
0306 template void TritonInputData::toServer(TritonInputContainer<char> data_in);
0307 template void TritonInputData::toServer(TritonInputContainer<uint8_t> data_in);
0308 template void TritonInputData::toServer(TritonInputContainer<uint16_t> data_in);
0309 template void TritonInputData::toServer(TritonInputContainer<uint32_t> data_in);
0310 template void TritonInputData::toServer(TritonInputContainer<uint64_t> data_in);
0311 template void TritonInputData::toServer(TritonInputContainer<int8_t> data_in);
0312 template void TritonInputData::toServer(TritonInputContainer<int16_t> data_in);
0313 template void TritonInputData::toServer(TritonInputContainer<int32_t> data_in);
0314 template void TritonInputData::toServer(TritonInputContainer<int64_t> data_in);
0315 template void TritonInputData::toServer(TritonInputContainer<float> data_in);
0316 template void TritonInputData::toServer(TritonInputContainer<double> data_in);
0317
0318 template TritonOutput<char> TritonOutputData::fromServer() const;
0319 template TritonOutput<uint8_t> TritonOutputData::fromServer() const;
0320 template TritonOutput<uint16_t> TritonOutputData::fromServer() const;
0321 template TritonOutput<uint32_t> TritonOutputData::fromServer() const;
0322 template TritonOutput<uint64_t> TritonOutputData::fromServer() const;
0323 template TritonOutput<int8_t> TritonOutputData::fromServer() const;
0324 template TritonOutput<int16_t> TritonOutputData::fromServer() const;
0325 template TritonOutput<int32_t> TritonOutputData::fromServer() const;
0326 template TritonOutput<int64_t> TritonOutputData::fromServer() const;
0327 template TritonOutput<float> TritonOutputData::fromServer() const;
0328 template TritonOutput<double> TritonOutputData::fromServer() const;