File indexing completed on 2024-04-06 12:24:16
0001 #ifndef PHYSICSTOOLS_TENSORFLOWAOT_MODEL_H
0002 #define PHYSICSTOOLS_TENSORFLOWAOT_MODEL_H
0003
0004
0005
0006
0007
0008
0009
0010 #include "FWCore/Utilities/interface/Exception.h"
0011
0012 #include "PhysicsTools/TensorFlowAOT/interface/Util.h"
0013 #include "PhysicsTools/TensorFlowAOT/interface/Batching.h"
0014
0015 namespace tfaot {
0016
0017
0018 template <class W>
0019 class Model {
0020 public:
0021
0022 explicit Model() : wrapper_(std::make_unique<W>()) {}
0023
0024
0025 ~Model() { wrapper_.reset(); };
0026
0027
0028 const std::string& name() const { return wrapper_->name(); }
0029
0030
0031 void setBatchStrategy(const BatchStrategy& strategy) { batchStrategy_ = strategy; }
0032
0033
0034 const BatchStrategy& getBatchStrategy() const { return batchStrategy_; }
0035
0036
0037 void setBatchRule(size_t batchSize, const std::vector<size_t>& sizes, size_t lastPadding = 0) {
0038 batchStrategy_.setRule(BatchRule(batchSize, sizes, lastPadding));
0039 }
0040
0041
0042 void setBatchRule(const std::string& batchRule) { batchStrategy_.setRule(BatchRule(batchRule)); }
0043
0044
0045 template <typename... Outputs, typename... Inputs>
0046 std::tuple<Outputs...> run(size_t batchSize, Inputs&&... inputs);
0047
0048 private:
0049 std::unique_ptr<W> wrapper_;
0050 BatchStrategy batchStrategy_;
0051
0052
0053
0054 const BatchRule& ensureRule(size_t batchSize);
0055
0056
0057 template <typename T>
0058 void reserveOutput(size_t batchSize, size_t resultIndex, std::vector<std::vector<T>>& data) const;
0059
0060
0061 template <typename T>
0062 void injectBatchInput(size_t batchSize, size_t batchIndex, size_t argIndex, const std::vector<T>& batchData);
0063
0064
0065 template <typename T>
0066 void extractBatchOutput(size_t batchSize, size_t batchIndex, size_t resultIndex, std::vector<T>& batchData) const;
0067 };
0068
0069 template <class W>
0070 const BatchRule& Model<W>::ensureRule(size_t batchSize) {
0071
0072 if (!batchStrategy_.hasRule(batchSize)) {
0073 batchStrategy_.setDefaultRule(batchSize, wrapper_->batchSizes());
0074 }
0075 return batchStrategy_.getRule(batchSize);
0076 }
0077
0078 template <class W>
0079 template <typename T>
0080 void Model<W>::reserveOutput(size_t batchSize, size_t resultIndex, std::vector<std::vector<T>>& data) const {
0081 data.resize(batchSize, std::vector<T>(wrapper_->resultCountNoBatch(resultIndex)));
0082 }
0083
0084 template <class W>
0085 template <typename T>
0086 void Model<W>::injectBatchInput(size_t batchSize,
0087 size_t batchIndex,
0088 size_t argIndex,
0089 const std::vector<T>& batchData) {
0090 size_t count = wrapper_->argCountNoBatch(argIndex);
0091 if (batchData.size() != count) {
0092 throw cms::Exception("InputMismatch")
0093 << "model '" << name() << "' received " << batchData.size() << " elements for argument " << argIndex
0094 << ", but " << count << " are expected";
0095 }
0096 T* argPtr = wrapper_->template argData<T>(batchSize, argIndex) + batchIndex * count;
0097 auto beg = batchData.cbegin();
0098 std::copy(beg, beg + count, argPtr);
0099 }
0100
0101 template <class W>
0102 template <typename T>
0103 void Model<W>::extractBatchOutput(size_t batchSize,
0104 size_t batchIndex,
0105 size_t resultIndex,
0106 std::vector<T>& batchData) const {
0107 size_t count = wrapper_->resultCountNoBatch(resultIndex);
0108 const T* resPtr = wrapper_->template resultData<T>(batchSize, resultIndex) + batchIndex * count;
0109 batchData.assign(resPtr, resPtr + count);
0110 }
0111
0112 template <class W>
0113 template <typename... Outputs, typename... Inputs>
0114 std::tuple<Outputs...> Model<W>::run(size_t batchSize, Inputs&&... inputs) {
0115
0116 size_t nInputs = sizeof...(Inputs);
0117 if (nInputs != wrapper_->nArgs()) {
0118 throw cms::Exception("InputMismatch")
0119 << "model '" << name() << "' received " << nInputs << " inputs, but " << wrapper_->nArgs() << " are expected";
0120 }
0121
0122
0123 size_t nOutputs = sizeof...(Outputs);
0124 if (nOutputs != wrapper_->nResults()) {
0125 throw cms::Exception("OutputMismatch") << "requested " << nOutputs << " from model '" << name() << "', but "
0126 << wrapper_->nResults() << " are provided";
0127 }
0128
0129
0130 const BatchRule& rule = ensureRule(batchSize);
0131
0132
0133 auto forEachOutput = createIndexLooper<sizeof...(Outputs)>();
0134
0135
0136 std::tuple<Outputs...> outputs;
0137 forEachOutput([&](auto resultIndex) { reserveOutput(batchSize, resultIndex, std::get<resultIndex>(outputs)); });
0138
0139
0140 size_t batchOffset = 0;
0141 size_t nSizes = rule.nSizes();
0142 for (size_t i = 0; i < nSizes; i++) {
0143
0144 size_t bs = rule.getSize(i);
0145 size_t padding = (i == nSizes - 1) ? rule.getLastPadding() : 0;
0146
0147
0148 for (size_t batchIndex = 0; batchIndex < bs - padding; batchIndex++) {
0149 size_t argIndex = 0;
0150 ([&] { injectBatchInput(bs, batchIndex, argIndex++, inputs[batchOffset + batchIndex]); }(), ...);
0151 }
0152
0153
0154 wrapper_->run(bs);
0155
0156
0157 for (size_t batchIndex = 0; batchIndex < bs - padding; batchIndex++) {
0158 forEachOutput([&](auto resultIndex) {
0159 extractBatchOutput(bs, batchIndex, resultIndex, std::get<resultIndex>(outputs)[batchOffset + batchIndex]);
0160 });
0161 }
0162
0163 batchOffset += bs;
0164 }
0165
0166 return outputs;
0167 }
0168
0169 }
0170
0171 #endif