Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:16

0001 #ifndef PHYSICSTOOLS_TENSORFLOWAOT_BATCHING_H
0002 #define PHYSICSTOOLS_TENSORFLOWAOT_BATCHING_H
0003 
0004 /*
0005  * AOT batching rules and strategies.
0006  *
0007  * Author: Marcel Rieger, Bogdan Wiederspan
0008  */
0009 
0010 #include <cstddef>
0011 #include <vector>
0012 #include <map>
0013 #include <ostream>
0014 
0015 namespace tfaot {
0016 
0017   // rule defining how a certain batch size should be composed of various smaller sizes plus an
0018   // optional padding that is applied to the last size
0019   class BatchRule {
0020   public:
0021     // constructor
0022     explicit BatchRule(size_t batchSize, const std::vector<size_t>& sizes, size_t lastPadding = 0);
0023 
0024     // constructor taking a string in the format "batchSize:size1,...,sizeN" with lastPadding being
0025     // inferred from the sum of sizes
0026     BatchRule(const std::string& ruleString);
0027 
0028     // destructor
0029     ~BatchRule() = default;
0030 
0031     // getter for the batch size
0032     size_t getBatchSize() const { return batchSize_; }
0033 
0034     // getter for available sizes
0035     const std::vector<size_t>& getSizes() const { return sizes_; }
0036 
0037     // getter for the last padding value
0038     size_t getLastPadding() const { return lastPadding_; }
0039 
0040     // returns the number of available sizes
0041     size_t nSizes() const { return sizes_.size(); }
0042 
0043     // getter for the registered size at index i
0044     size_t getSize(size_t i) const { return sizes_[i]; }
0045 
0046   private:
0047     size_t batchSize_;
0048     std::vector<size_t> sizes_;
0049     size_t lastPadding_;
0050 
0051     // validation helper
0052     void validate() const;
0053   };
0054 
0055   // stream operator
0056   std::ostream& operator<<(std::ostream& out, const BatchRule& rule);
0057 
0058   // the batch strategy is a collection of batch rules registered to certain batch sizes
0059   class BatchStrategy {
0060   public:
0061     // constructor
0062     explicit BatchStrategy() = default;
0063 
0064     // destructor
0065     ~BatchStrategy() = default;
0066 
0067     // registers a new rule for a batch size
0068     void setRule(const BatchRule& rule) { rules_.insert_or_assign(rule.getBatchSize(), rule); }
0069 
0070     // registers a new rule for a batch size, given a rule string (see BatchRule constructor)
0071     void setRule(const std::string& ruleString) { this->setRule(BatchRule(ruleString)); }
0072 
0073     // returns whether a rule was already registered for a certain batch size
0074     bool hasRule(size_t batchSize) const { return rules_.find(batchSize) != rules_.end(); }
0075 
0076     // returns a rule registered previously for a certain batch size
0077     const BatchRule& getRule(size_t batchSize) const;
0078 
0079     // registers a new rule for a certain batch size according to a certain algorithm
0080     void setDefaultRule(size_t batchSize, const std::vector<size_t>& availableBatchSizes);
0081 
0082   private:
0083     std::map<size_t, BatchRule> rules_;
0084   };
0085 
0086 }  // namespace tfaot
0087 
0088 #endif  // PHYSICSTOOLS_TENSORFLOWAOT_BATCHING_H