File indexing completed on 2024-04-06 12:25:01
0001 #ifndef RecoEgamma_ElectronTools_EgammaDNNHelper_h
0002 #define RecoEgamma_ElectronTools_EgammaDNNHelper_h
0003
0004 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0005 #include <vector>
0006 #include <memory>
0007 #include <string>
0008 #include <functional>
0009
0010
0011
0012
0013
0014
0015 namespace egammaTools {
0016
0017 struct DNNConfiguration {
0018 std::string inputTensorName;
0019 std::string outputTensorName;
0020 std::vector<std::string> modelsFiles;
0021 std::vector<std::string> scalersFiles;
0022 std::vector<unsigned int> outputDim;
0023 };
0024
0025 struct ScalerConfiguration {
0026
0027
0028
0029
0030
0031 std::string varName;
0032 uint type;
0033 float par1;
0034 float par2;
0035 };
0036
0037
0038
0039 typedef std::function<uint(const std::map<std::string, float>&)> ModelSelector;
0040
0041 class EgammaDNNHelper {
0042 public:
0043 EgammaDNNHelper(const DNNConfiguration&, const ModelSelector& sel, const std::vector<std::string>& availableVars);
0044
0045 std::vector<tensorflow::Session*> getSessions() const;
0046
0047
0048
0049
0050 std::pair<uint, std::vector<float>> getScaledInputs(const std::map<std::string, float>& variables) const;
0051
0052 std::vector<std::pair<uint, std::vector<float>>> evaluate(
0053 const std::vector<std::map<std::string, float>>& candidates,
0054 const std::vector<tensorflow::Session*>& sessions) const;
0055
0056 private:
0057 void initTensorFlowGraphs();
0058 void initScalerFiles(const std::vector<std::string>& availableVars);
0059
0060 const DNNConfiguration cfg_;
0061 const ModelSelector modelSelector_;
0062
0063 uint nModels_;
0064
0065 std::vector<uint> nInputs_;
0066
0067 std::vector<std::unique_ptr<const tensorflow::GraphDef>> graphDefs_;
0068
0069
0070 std::vector<std::vector<ScalerConfiguration>> featuresMap_;
0071 };
0072
0073 };
0074
0075 #endif