File indexing completed on 2024-12-20 03:14:01
0001
0002
0003
0004
0005
0006
0007 #include <ostream>
0008 #include <algorithm>
0009
0010 #include "PhysicsTools/TensorFlowAOT/interface/Batching.h"
0011
0012 #include "FWCore/Utilities/interface/Exception.h"
0013
0014 namespace tfaot {
0015
0016 BatchRule::BatchRule(size_t batchSize, const std::vector<size_t>& sizes, size_t lastPadding)
0017 : batchSize_(batchSize), sizes_(sizes), lastPadding_(lastPadding) {
0018 validate();
0019 }
0020
0021 BatchRule::BatchRule(const std::string& ruleString) {
0022
0023 std::string rule = ruleString;
0024 auto pos = rule.find(':');
0025 if (pos == std::string::npos) {
0026 throw cms::Exception("InvalidBatchRule") << "invalid batch rule format: " << ruleString;
0027 }
0028 size_t batchSize = std::stoi(rule.substr(0, pos));
0029 rule = rule.substr(pos + 1);
0030
0031
0032 std::vector<size_t> sizes;
0033 size_t sumSizes = 0;
0034 while (!rule.empty()) {
0035 pos = rule.find(',');
0036 sizes.push_back(std::stoi(rule.substr(0, pos)));
0037 sumSizes += sizes.back();
0038 rule = pos == std::string::npos ? "" : rule.substr(pos + 1);
0039 }
0040
0041
0042 if (sumSizes < batchSize) {
0043 throw cms::Exception("InvalidBatchRule")
0044 << "sum of composite batch sizes is smaller than target batch size: " << ruleString;
0045 }
0046
0047
0048 batchSize_ = batchSize;
0049 sizes_ = sizes;
0050 lastPadding_ = sumSizes - batchSize;
0051 validate();
0052 }
0053
0054 void BatchRule::validate() const {
0055
0056 if (sizes_.empty()) {
0057 throw cms::Exception("EmptySizes") << "no batch sizes provided for stitching";
0058 }
0059
0060
0061 size_t lastSize = sizes_[sizes_.size() - 1];
0062 if (lastPadding_ >= lastSize) {
0063 throw cms::Exception("WrongPadding")
0064 << "padding " << lastPadding_ << " must be smaller than last size " << lastSize;
0065 }
0066
0067
0068 size_t sizeSum = 0;
0069 for (const size_t& s : sizes_) {
0070 sizeSum += s;
0071 }
0072 if (lastPadding_ > sizeSum) {
0073 throw cms::Exception("WrongPadding")
0074 << "padding " << lastPadding_ << " must not be larger than sum of sizes " << sizeSum;
0075 }
0076 sizeSum -= lastPadding_;
0077
0078
0079 if (batchSize_ != sizeSum) {
0080 throw cms::Exception("WrongBatchSize")
0081 << "batch size " << batchSize_ << " does not match sum of sizes - padding " << sizeSum;
0082 }
0083 }
0084
0085 const BatchRule& BatchStrategy::getRule(size_t batchSize) const {
0086 const auto it = rules_.find(batchSize);
0087 if (it == rules_.end()) {
0088 throw cms::Exception("UnknownBatchSize") << "batchSize " << batchSize << " not known to batching strategy";
0089 }
0090 return it->second;
0091 }
0092
0093 std::ostream& operator<<(std::ostream& out, const BatchRule& rule) {
0094 out << "BatchRule(batchSize=" << rule.getBatchSize() << ", sizes=";
0095 for (size_t i = 0; i < rule.nSizes(); i++) {
0096 out << (i == 0 ? "" : ",") << rule.getSizes()[i];
0097 }
0098 return out << ", lastPadding=" << rule.getLastPadding() << ")";
0099 }
0100
0101 void BatchStrategy::setDefaultRule(size_t batchSize, const std::vector<size_t>& availableBatchSizes) {
0102 std::vector<size_t> sizes;
0103 size_t lastPadding = 0;
0104
0105
0106
0107
0108 if (std::find(availableBatchSizes.begin(), availableBatchSizes.end(), batchSize) != availableBatchSizes.end()) {
0109 sizes.push_back(batchSize);
0110 } else {
0111 size_t smallestBatchSize = *std::min_element(availableBatchSizes.begin(), availableBatchSizes.end());
0112 size_t rest = batchSize % smallestBatchSize;
0113 size_t n = (batchSize / smallestBatchSize) + (rest == 0 ? 0 : 1);
0114 lastPadding = rest == 0 ? 0 : (smallestBatchSize - rest);
0115 sizes.resize(n, smallestBatchSize);
0116 }
0117
0118
0119 setRule(BatchRule(batchSize, sizes, lastPadding));
0120 }
0121
0122 }