File indexing completed on 2024-04-06 12:24:16
0001
0002
0003
0004
0005
0006
0007 #include <vector>
0008 #include <map>
0009
0010 #include "PhysicsTools/TensorFlowAOT/interface/Wrapper.h"
0011
0012 namespace tfaot {
0013
0014 int Wrapper::argCount(size_t batchSize, size_t argIndex) const {
0015 const auto& counts = argCounts();
0016 const auto it = counts.find(batchSize);
0017 if (it == counts.end()) {
0018 unknownBatchSize(batchSize, "argCount()");
0019 }
0020 if (argIndex >= it->second.size()) {
0021 unknownArgument(argIndex, "argCount()");
0022 }
0023 return it->second.at(argIndex);
0024 }
0025
0026 int Wrapper::argCountNoBatch(size_t argIndex) const {
0027 const auto& counts = argCountsNoBatch();
0028 if (argIndex >= counts.size()) {
0029 unknownArgument(argIndex, "argCountNoBatch()");
0030 }
0031 return counts.at(argIndex);
0032 }
0033
0034 int Wrapper::resultCount(size_t batchSize, size_t resultIndex) const {
0035 const auto& counts = resultCounts();
0036 const auto it = counts.find(batchSize);
0037 if (it == counts.end()) {
0038 unknownBatchSize(batchSize, "resultCount()");
0039 }
0040 if (resultIndex >= it->second.size()) {
0041 unknownResult(resultIndex, "resultCount()");
0042 }
0043 return it->second.at(resultIndex);
0044 }
0045
0046 int Wrapper::resultCountNoBatch(size_t resultIndex) const {
0047 const auto& counts = resultCountsNoBatch();
0048 if (resultIndex >= counts.size()) {
0049 unknownResult(resultIndex, "resultCountNoBatch()");
0050 }
0051 return counts[resultIndex];
0052 }
0053
0054 void Wrapper::run(size_t batchSize) {
0055 if (!runSilent(batchSize)) {
0056 throw cms::Exception("FailedRun") << "evaluation with batch size " << batchSize << " failed for model '" << name_;
0057 }
0058 }
0059
0060 }