Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:16

0001 /*
0002  * AOT wrapper interface for interacting with models compiled for different batch sizes.
0003  *
0004  * Author: Marcel Rieger, Bogdan Wiederspan
0005  */
0006 
0007 #include <vector>
0008 #include <map>
0009 
0010 #include "PhysicsTools/TensorFlowAOT/interface/Wrapper.h"
0011 
0012 namespace tfaot {
0013 
0014   int Wrapper::argCount(size_t batchSize, size_t argIndex) const {
0015     const auto& counts = argCounts();
0016     const auto it = counts.find(batchSize);
0017     if (it == counts.end()) {
0018       unknownBatchSize(batchSize, "argCount()");
0019     }
0020     if (argIndex >= it->second.size()) {
0021       unknownArgument(argIndex, "argCount()");
0022     }
0023     return it->second.at(argIndex);
0024   }
0025 
0026   int Wrapper::argCountNoBatch(size_t argIndex) const {
0027     const auto& counts = argCountsNoBatch();
0028     if (argIndex >= counts.size()) {
0029       unknownArgument(argIndex, "argCountNoBatch()");
0030     }
0031     return counts.at(argIndex);
0032   }
0033 
0034   int Wrapper::resultCount(size_t batchSize, size_t resultIndex) const {
0035     const auto& counts = resultCounts();
0036     const auto it = counts.find(batchSize);
0037     if (it == counts.end()) {
0038       unknownBatchSize(batchSize, "resultCount()");
0039     }
0040     if (resultIndex >= it->second.size()) {
0041       unknownResult(resultIndex, "resultCount()");
0042     }
0043     return it->second.at(resultIndex);
0044   }
0045 
0046   int Wrapper::resultCountNoBatch(size_t resultIndex) const {
0047     const auto& counts = resultCountsNoBatch();
0048     if (resultIndex >= counts.size()) {
0049       unknownResult(resultIndex, "resultCountNoBatch()");
0050     }
0051     return counts[resultIndex];
0052   }
0053 
0054   void Wrapper::run(size_t batchSize) {
0055     if (!runSilent(batchSize)) {
0056       throw cms::Exception("FailedRun") << "evaluation with batch size " << batchSize << " failed for model '" << name_;
0057     }
0058   }
0059 
0060 }  // namespace tfaot