File indexing completed on 2024-04-06 12:24:16
0001 #ifndef PHYSICSTOOLS_TENSORFLOWAOT_BATCHING_H
0002 #define PHYSICSTOOLS_TENSORFLOWAOT_BATCHING_H
0003
0004
0005
0006
0007
0008
0009
0010 #include <cstddef>
0011 #include <vector>
0012 #include <map>
0013 #include <ostream>
0014
0015 namespace tfaot {
0016
0017
0018
0019 class BatchRule {
0020 public:
0021
0022 explicit BatchRule(size_t batchSize, const std::vector<size_t>& sizes, size_t lastPadding = 0);
0023
0024
0025
0026 BatchRule(const std::string& ruleString);
0027
0028
0029 ~BatchRule() = default;
0030
0031
0032 size_t getBatchSize() const { return batchSize_; }
0033
0034
0035 const std::vector<size_t>& getSizes() const { return sizes_; }
0036
0037
0038 size_t getLastPadding() const { return lastPadding_; }
0039
0040
0041 size_t nSizes() const { return sizes_.size(); }
0042
0043
0044 size_t getSize(size_t i) const { return sizes_[i]; }
0045
0046 private:
0047 size_t batchSize_;
0048 std::vector<size_t> sizes_;
0049 size_t lastPadding_;
0050
0051
0052 void validate() const;
0053 };
0054
0055
0056 std::ostream& operator<<(std::ostream& out, const BatchRule& rule);
0057
0058
0059 class BatchStrategy {
0060 public:
0061
0062 explicit BatchStrategy() = default;
0063
0064
0065 ~BatchStrategy() = default;
0066
0067
0068 void setRule(const BatchRule& rule) { rules_.insert_or_assign(rule.getBatchSize(), rule); }
0069
0070
0071 void setRule(const std::string& ruleString) { this->setRule(BatchRule(ruleString)); }
0072
0073
0074 bool hasRule(size_t batchSize) const { return rules_.find(batchSize) != rules_.end(); }
0075
0076
0077 const BatchRule& getRule(size_t batchSize) const;
0078
0079
0080 void setDefaultRule(size_t batchSize, const std::vector<size_t>& availableBatchSizes);
0081
0082 private:
0083 std::map<size_t, BatchRule> rules_;
0084 };
0085
0086 }
0087
0088 #endif