Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:16

0001 #ifndef PHYSICSTOOLS_TENSORFLOWAOT_MODEL_H
0002 #define PHYSICSTOOLS_TENSORFLOWAOT_MODEL_H
0003 
0004 /*
0005  * AOT model interface.
0006  *
0007  * Author: Marcel Rieger, Bogdan Wiederspan
0008  */
0009 
0010 #include "FWCore/Utilities/interface/Exception.h"
0011 
0012 #include "PhysicsTools/TensorFlowAOT/interface/Util.h"
0013 #include "PhysicsTools/TensorFlowAOT/interface/Batching.h"
0014 
0015 namespace tfaot {
0016 
0017   // model interface receiving the AOT wrapper type as a template argument
0018   template <class W>
0019   class Model {
0020   public:
0021     // constructor
0022     explicit Model() : wrapper_(std::make_unique<W>()) {}
0023 
0024     // destructor
0025     ~Model() { wrapper_.reset(); };
0026 
0027     // getter for the name
0028     const std::string& name() const { return wrapper_->name(); }
0029 
0030     // setter for the batch strategy
0031     void setBatchStrategy(const BatchStrategy& strategy) { batchStrategy_ = strategy; }
0032 
0033     // getter for the batch strategy
0034     const BatchStrategy& getBatchStrategy() const { return batchStrategy_; }
0035 
0036     // adds a new batch rule to the strategy
0037     void setBatchRule(size_t batchSize, const std::vector<size_t>& sizes, size_t lastPadding = 0) {
0038       batchStrategy_.setRule(BatchRule(batchSize, sizes, lastPadding));
0039     }
0040 
0041     // adds a new batch rule to the strategy, given a rule string (see BatchRule constructor)
0042     void setBatchRule(const std::string& batchRule) { batchStrategy_.setRule(BatchRule(batchRule)); }
0043 
0044     // evaluates the model for multiple inputs and outputs of different types
0045     template <typename... Outputs, typename... Inputs>
0046     std::tuple<Outputs...> run(size_t batchSize, Inputs&&... inputs);
0047 
0048   private:
0049     std::unique_ptr<W> wrapper_;
0050     BatchStrategy batchStrategy_;
0051 
0052     // ensures that a batch rule exists for a certain batch size, and if not, registers a new one
0053     // based on the default algorithm
0054     const BatchRule& ensureRule(size_t batchSize);
0055 
0056     // reserves memory in a nested (batched) vector to accomodate the result output at an index
0057     template <typename T>
0058     void reserveOutput(size_t batchSize, size_t resultIndex, std::vector<std::vector<T>>& data) const;
0059 
0060     // injects data of a specific batch element into the argument data at an index
0061     template <typename T>
0062     void injectBatchInput(size_t batchSize, size_t batchIndex, size_t argIndex, const std::vector<T>& batchData);
0063 
0064     // extracts result data at an index into a specific batch
0065     template <typename T>
0066     void extractBatchOutput(size_t batchSize, size_t batchIndex, size_t resultIndex, std::vector<T>& batchData) const;
0067   };
0068 
0069   template <class W>
0070   const BatchRule& Model<W>::ensureRule(size_t batchSize) {
0071     // register a default rule if there is none yet for that batch size
0072     if (!batchStrategy_.hasRule(batchSize)) {
0073       batchStrategy_.setDefaultRule(batchSize, wrapper_->batchSizes());
0074     }
0075     return batchStrategy_.getRule(batchSize);
0076   }
0077 
0078   template <class W>
0079   template <typename T>
0080   void Model<W>::reserveOutput(size_t batchSize, size_t resultIndex, std::vector<std::vector<T>>& data) const {
0081     data.resize(batchSize, std::vector<T>(wrapper_->resultCountNoBatch(resultIndex)));
0082   }
0083 
0084   template <class W>
0085   template <typename T>
0086   void Model<W>::injectBatchInput(size_t batchSize,
0087                                   size_t batchIndex,
0088                                   size_t argIndex,
0089                                   const std::vector<T>& batchData) {
0090     size_t count = wrapper_->argCountNoBatch(argIndex);
0091     if (batchData.size() != count) {
0092       throw cms::Exception("InputMismatch")
0093           << "model '" << name() << "' received " << batchData.size() << " elements for argument " << argIndex
0094           << ", but " << count << " are expected";
0095     }
0096     T* argPtr = wrapper_->template argData<T>(batchSize, argIndex) + batchIndex * count;
0097     auto beg = batchData.cbegin();
0098     std::copy(beg, beg + count, argPtr);
0099   }
0100 
0101   template <class W>
0102   template <typename T>
0103   void Model<W>::extractBatchOutput(size_t batchSize,
0104                                     size_t batchIndex,
0105                                     size_t resultIndex,
0106                                     std::vector<T>& batchData) const {
0107     size_t count = wrapper_->resultCountNoBatch(resultIndex);
0108     const T* resPtr = wrapper_->template resultData<T>(batchSize, resultIndex) + batchIndex * count;
0109     batchData.assign(resPtr, resPtr + count);
0110   }
0111 
0112   template <class W>
0113   template <typename... Outputs, typename... Inputs>
0114   std::tuple<Outputs...> Model<W>::run(size_t batchSize, Inputs&&... inputs) {
0115     // check number of inputs
0116     size_t nInputs = sizeof...(Inputs);
0117     if (nInputs != wrapper_->nArgs()) {
0118       throw cms::Exception("InputMismatch")
0119           << "model '" << name() << "' received " << nInputs << " inputs, but " << wrapper_->nArgs() << " are expected";
0120     }
0121 
0122     // check number of outputs
0123     size_t nOutputs = sizeof...(Outputs);
0124     if (nOutputs != wrapper_->nResults()) {
0125       throw cms::Exception("OutputMismatch") << "requested " << nOutputs << " from model '" << name() << "', but "
0126                                              << wrapper_->nResults() << " are provided";
0127     }
0128 
0129     // get the corresponding batch rule
0130     const BatchRule& rule = ensureRule(batchSize);
0131 
0132     // create a callback that invokes lambdas over all outputs with normal indices
0133     auto forEachOutput = createIndexLooper<sizeof...(Outputs)>();
0134 
0135     // reserve output arrays
0136     std::tuple<Outputs...> outputs;
0137     forEachOutput([&](auto resultIndex) { reserveOutput(batchSize, resultIndex, std::get<resultIndex>(outputs)); });
0138 
0139     // loop over particular batch sizes, copy input, evaluate and compose the output
0140     size_t batchOffset = 0;
0141     size_t nSizes = rule.nSizes();
0142     for (size_t i = 0; i < nSizes; i++) {
0143       // get actual model batch size and optional padding
0144       size_t bs = rule.getSize(i);
0145       size_t padding = (i == nSizes - 1) ? rule.getLastPadding() : 0;
0146 
0147       // fill inputs separately per batch element
0148       for (size_t batchIndex = 0; batchIndex < bs - padding; batchIndex++) {
0149         size_t argIndex = 0;
0150         ([&] { injectBatchInput(bs, batchIndex, argIndex++, inputs[batchOffset + batchIndex]); }(), ...);
0151       }
0152 
0153       // model evaluation
0154       wrapper_->run(bs);
0155 
0156       // fill outputs separately per batch element
0157       for (size_t batchIndex = 0; batchIndex < bs - padding; batchIndex++) {
0158         forEachOutput([&](auto resultIndex) {
0159           extractBatchOutput(bs, batchIndex, resultIndex, std::get<resultIndex>(outputs)[batchOffset + batchIndex]);
0160         });
0161       }
0162 
0163       batchOffset += bs;
0164     }
0165 
0166     return outputs;
0167   }
0168 
0169 }  // namespace tfaot
0170 
0171 #endif  // PHYSICSTOOLS_TENSORFLOWAOT_MODEL_H