Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-03-17 11:17:24

0001 #ifndef RecoEcal_EgammaCoreTools_DeepSCGraphEvaluation_h
0002 #define RecoEcal_EgammaCoreTools_DeepSCGraphEvaluation_h
0003 
0004 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0005 #include <vector>
0006 #include <array>
0007 #include <memory>
0008 #include <string>
0009 #include <functional>
0010 #include <cmath>
0011 
0012 //author: Davide Valsecchi
0013 //description:
0014 // Handles Tensorflow DNN graphs and variables scaler configuration.
0015 // To be used for DeepSC.
0016 
0017 namespace reco {
0018 
0019   struct DeepSCConfiguration {
0020     std::string modelFile;
0021     std::string configFileClusterFeatures;
0022     std::string configFileWindowFeatures;
0023     std::string configFileHitsFeatures;
0024     uint nClusterFeatures;
0025     uint nWindowFeatures;
0026     uint nHitsFeatures;
0027     uint maxNClusters;
0028     uint maxNRechits;
0029     uint batchSize;
0030     std::string collectionStrategy;
0031   };
0032 
0033   /*
0034    * Structure representing the detector windows of a single events, to be evaluated with the DeepSC model.
0035    * The index structure is described in the following
0036    */
0037 
0038   namespace DeepSCInputs {
0039     enum ScalerType {
0040       MeanRms,  // scale as (var - mean)/rms
0041       MinMax,   // scale as (var - min) (max-min)
0042       None      // do nothing
0043     };
0044     struct InputConfig {
0045       // Each input variable is represented by the tuple <varname, standardization_type, par1, par2>
0046       std::string varName;
0047       ScalerType type;
0048       float par1;
0049       float par2;
0050     };
0051     typedef std::vector<InputConfig> InputConfigs;
0052     typedef std::map<std::string, double> FeaturesMap;
0053 
0054     struct Inputs {
0055       std::vector<std::vector<std::vector<float>>> clustersX;
0056       std::vector<std::vector<std::vector<std::vector<float>>>> hitsX;
0057       std::vector<std::vector<float>> windowX;
0058       std::vector<std::vector<bool>> isSeed;
0059     };
0060 
0061   };  // namespace DeepSCInputs
0062 
0063   class DeepSCGraphEvaluation {
0064   public:
0065     DeepSCGraphEvaluation(const DeepSCConfiguration&);
0066     ~DeepSCGraphEvaluation();
0067 
0068     std::vector<float> getScaledInputs(const DeepSCInputs::FeaturesMap& variables,
0069                                        const DeepSCInputs::InputConfigs& config) const;
0070 
0071     std::vector<std::vector<float>> evaluate(const DeepSCInputs::Inputs& inputs) const;
0072 
0073     // List of input variables names used to check the variables request as
0074     // inputs in a dynamic way from configuration file.
0075     // If an input variables is not found at construction time an expection is thrown.
0076     static const std::vector<std::string> availableClusterInputs;
0077     static const std::vector<std::string> availableWindowInputs;
0078     static const std::vector<std::string> availableHitsInputs;
0079 
0080     // Configuration of the input variables including the scaling parameters.
0081     // The list is used to define the vector of input features passed to the tensorflow model.
0082     DeepSCInputs::InputConfigs inputFeaturesClusters;
0083     DeepSCInputs::InputConfigs inputFeaturesWindows;
0084     DeepSCInputs::InputConfigs inputFeaturesHits;
0085 
0086   private:
0087     void initTensorFlowGraphAndSession();
0088     DeepSCInputs::InputConfigs readInputFeaturesConfig(std::string file,
0089                                                        const std::vector<std::string>& availableInputs) const;
0090 
0091     const DeepSCConfiguration cfg_;
0092     std::unique_ptr<tensorflow::GraphDef> graphDef_;
0093     tensorflow::Session* session_;
0094   };
0095 
0096 };  // namespace reco
0097 
0098 #endif