File indexing completed on 2023-03-17 11:17:24
0001 #ifndef RecoEcal_EgammaCoreTools_DeepSCGraphEvaluation_h
0002 #define RecoEcal_EgammaCoreTools_DeepSCGraphEvaluation_h
0003
0004 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0005 #include <vector>
0006 #include <array>
0007 #include <memory>
0008 #include <string>
0009 #include <functional>
0010 #include <cmath>
0011
0012
0013
0014
0015
0016
0017 namespace reco {
0018
0019 struct DeepSCConfiguration {
0020 std::string modelFile;
0021 std::string configFileClusterFeatures;
0022 std::string configFileWindowFeatures;
0023 std::string configFileHitsFeatures;
0024 uint nClusterFeatures;
0025 uint nWindowFeatures;
0026 uint nHitsFeatures;
0027 uint maxNClusters;
0028 uint maxNRechits;
0029 uint batchSize;
0030 std::string collectionStrategy;
0031 };
0032
0033
0034
0035
0036
0037
0038 namespace DeepSCInputs {
0039 enum ScalerType {
0040 MeanRms,
0041 MinMax,
0042 None
0043 };
0044 struct InputConfig {
0045
0046 std::string varName;
0047 ScalerType type;
0048 float par1;
0049 float par2;
0050 };
0051 typedef std::vector<InputConfig> InputConfigs;
0052 typedef std::map<std::string, double> FeaturesMap;
0053
0054 struct Inputs {
0055 std::vector<std::vector<std::vector<float>>> clustersX;
0056 std::vector<std::vector<std::vector<std::vector<float>>>> hitsX;
0057 std::vector<std::vector<float>> windowX;
0058 std::vector<std::vector<bool>> isSeed;
0059 };
0060
0061 };
0062
0063 class DeepSCGraphEvaluation {
0064 public:
0065 DeepSCGraphEvaluation(const DeepSCConfiguration&);
0066 ~DeepSCGraphEvaluation();
0067
0068 std::vector<float> getScaledInputs(const DeepSCInputs::FeaturesMap& variables,
0069 const DeepSCInputs::InputConfigs& config) const;
0070
0071 std::vector<std::vector<float>> evaluate(const DeepSCInputs::Inputs& inputs) const;
0072
0073
0074
0075
0076 static const std::vector<std::string> availableClusterInputs;
0077 static const std::vector<std::string> availableWindowInputs;
0078 static const std::vector<std::string> availableHitsInputs;
0079
0080
0081
0082 DeepSCInputs::InputConfigs inputFeaturesClusters;
0083 DeepSCInputs::InputConfigs inputFeaturesWindows;
0084 DeepSCInputs::InputConfigs inputFeaturesHits;
0085
0086 private:
0087 void initTensorFlowGraphAndSession();
0088 DeepSCInputs::InputConfigs readInputFeaturesConfig(std::string file,
0089 const std::vector<std::string>& availableInputs) const;
0090
0091 const DeepSCConfiguration cfg_;
0092 std::unique_ptr<tensorflow::GraphDef> graphDef_;
0093 tensorflow::Session* session_;
0094 };
0095
0096 };
0097
0098 #endif