Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-05-29 23:13:00

0001 /*
0002  * AOT batching rules and strategies.
0003  *
0004  * Author: Marcel Rieger, Bogdan Wiederspan
0005  */
0006 
0007 #include <ostream>
0008 #include <algorithm>
0009 
0010 #include "PhysicsTools/TensorFlowAOT/interface/Batching.h"
0011 
0012 #include "FWCore/Utilities/interface/Exception.h"
0013 
0014 namespace tfaot {
0015 
0016   BatchRule::BatchRule(size_t batchSize, const std::vector<size_t>& sizes, size_t lastPadding)
0017       : batchSize_(batchSize), sizes_(sizes), lastPadding_(lastPadding) {
0018     validate();
0019   }
0020 
0021   BatchRule::BatchRule(const std::string& ruleString) {
0022     // extract the target batch size from the front
0023     std::string rule = ruleString;
0024     auto pos = rule.find(":");
0025     if (pos == std::string::npos) {
0026       throw cms::Exception("InvalidBatchRule") << "invalid batch rule format: " << ruleString;
0027     }
0028     size_t batchSize = std::stoi(rule.substr(0, pos));
0029     rule = rule.substr(pos + 1);
0030 
0031     // loop through remaining comma-separated sizes
0032     std::vector<size_t> sizes;
0033     size_t sumSizes = 0;
0034     while (!rule.empty()) {
0035       pos = rule.find(",");
0036       sizes.push_back(std::stoi(rule.substr(0, pos)));
0037       sumSizes += sizes.back();
0038       rule = pos == std::string::npos ? "" : rule.substr(pos + 1);
0039     }
0040 
0041     // the sum of composite batch sizes should never be smaller than the target batch size
0042     if (sumSizes < batchSize) {
0043       throw cms::Exception("InvalidBatchRule")
0044           << "sum of composite batch sizes is smaller than target batch size: " << ruleString;
0045     }
0046 
0047     // set members and validate
0048     batchSize_ = batchSize;
0049     sizes_ = sizes;
0050     lastPadding_ = sumSizes - batchSize;
0051     validate();
0052   }
0053 
0054   void BatchRule::validate() const {
0055     // sizes must not be empty
0056     if (sizes_.size() == 0) {
0057       throw cms::Exception("EmptySizes") << "no batch sizes provided for stitching";
0058     }
0059 
0060     // the padding must be smaller than the last size
0061     size_t lastSize = sizes_[sizes_.size() - 1];
0062     if (lastPadding_ >= lastSize) {
0063       throw cms::Exception("WrongPadding")
0064           << "padding " << lastPadding_ << " must be smaller than last size " << lastSize;
0065     }
0066 
0067     // compute the covered batch size
0068     size_t sizeSum = 0;
0069     for (const size_t& s : sizes_) {
0070       sizeSum += s;
0071     }
0072     if (lastPadding_ > sizeSum) {
0073       throw cms::Exception("WrongPadding")
0074           << "padding " << lastPadding_ << " must not be larger than sum of sizes " << sizeSum;
0075     }
0076     sizeSum -= lastPadding_;
0077 
0078     // compare to given batch size
0079     if (batchSize_ != sizeSum) {
0080       throw cms::Exception("WrongBatchSize")
0081           << "batch size " << batchSize_ << " does not match sum of sizes - padding " << sizeSum;
0082     }
0083   }
0084 
0085   const BatchRule& BatchStrategy::getRule(size_t batchSize) const {
0086     const auto it = rules_.find(batchSize);
0087     if (it == rules_.end()) {
0088       throw cms::Exception("UnknownBatchSize") << "batchSize " << batchSize << " not known to batching strategy";
0089     }
0090     return it->second;
0091   }
0092 
0093   std::ostream& operator<<(std::ostream& out, const BatchRule& rule) {
0094     out << "BatchRule(batchSize=" << rule.getBatchSize() << ", sizes=";
0095     for (size_t i = 0; i < rule.nSizes(); i++) {
0096       out << (i == 0 ? "" : ",") << rule.getSizes()[i];
0097     }
0098     return out << ", lastPadding=" << rule.getLastPadding() << ")";
0099   }
0100 
0101   void BatchStrategy::setDefaultRule(size_t batchSize, const std::vector<size_t>& availableBatchSizes) {
0102     std::vector<size_t> sizes;
0103     size_t lastPadding = 0;
0104 
0105     // many implementations are possible here, but for simplicity assume the most simple one:
0106     // if there is an exact match, use it, and otherwise repeat the smallest available size
0107     // n times and potentially add padding
0108     if (std::find(availableBatchSizes.begin(), availableBatchSizes.end(), batchSize) != availableBatchSizes.end()) {
0109       sizes.push_back(batchSize);
0110     } else {
0111       size_t smallestBatchSize = *std::min_element(availableBatchSizes.begin(), availableBatchSizes.end());
0112       size_t rest = batchSize % smallestBatchSize;
0113       size_t n = (batchSize / smallestBatchSize) + (rest == 0 ? 0 : 1);
0114       lastPadding = rest == 0 ? 0 : (smallestBatchSize - rest);
0115       sizes.resize(n, smallestBatchSize);
0116     }
0117 
0118     // create and register the rule
0119     setRule(BatchRule(batchSize, sizes, lastPadding));
0120   }
0121 
0122 }  // namespace tfaot