Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:25:01

0001 #ifndef RecoEgamma_ElectronTools_EgammaDNNHelper_h
0002 #define RecoEgamma_ElectronTools_EgammaDNNHelper_h
0003 
0004 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0005 #include <vector>
0006 #include <memory>
0007 #include <string>
0008 #include <functional>
0009 
0010 //author: Davide Valsecchi
0011 //description:
0012 // Handles Tensorflow DNN graphs and variables scaler configuration.
0013 // To be used for PFID egamma DNNs
0014 
0015 namespace egammaTools {
0016 
0017   struct DNNConfiguration {
0018     std::string inputTensorName;
0019     std::string outputTensorName;
0020     std::vector<std::string> modelsFiles;
0021     std::vector<std::string> scalersFiles;
0022     std::vector<unsigned int> outputDim;
0023   };
0024 
0025   struct ScalerConfiguration {
0026     /* Each input variable is represented by the tuple <varname, standardization_type, par1, par2>
0027     * The standardization_type can be:
0028     * 0 = Do not scale the variable
0029     * 1 = standard norm. par1=mean, par2=std
0030     * 2 = MinMax. par1=min, par2=max */
0031     std::string varName;
0032     uint type;
0033     float par1;
0034     float par2;
0035   };
0036 
0037   // Model for function to be used on the specific candidate to get the model
0038   // index to be used for the evaluation.
0039   typedef std::function<uint(const std::map<std::string, float>&)> ModelSelector;
0040 
0041   class EgammaDNNHelper {
0042   public:
0043     EgammaDNNHelper(const DNNConfiguration&, const ModelSelector& sel, const std::vector<std::string>& availableVars);
0044 
0045     std::vector<tensorflow::Session*> getSessions() const;
0046     // Function getting the input vector for a specific electron, already scaled
0047     // together with the model index it has to be used.
0048     // The model index is determined by the ModelSelector functor passed in the constructor
0049     // which has access to all the variables.
0050     std::pair<uint, std::vector<float>> getScaledInputs(const std::map<std::string, float>& variables) const;
0051 
0052     std::vector<std::pair<uint, std::vector<float>>> evaluate(
0053         const std::vector<std::map<std::string, float>>& candidates,
0054         const std::vector<tensorflow::Session*>& sessions) const;
0055 
0056   private:
0057     void initTensorFlowGraphs();
0058     void initScalerFiles(const std::vector<std::string>& availableVars);
0059 
0060     const DNNConfiguration cfg_;
0061     const ModelSelector modelSelector_;
0062     // Number of models handled by the object
0063     uint nModels_;
0064     // Number of inputs for each loaded model
0065     std::vector<uint> nInputs_;
0066 
0067     std::vector<std::unique_ptr<const tensorflow::GraphDef>> graphDefs_;
0068 
0069     // List of input variables for each of the model;
0070     std::vector<std::vector<ScalerConfiguration>> featuresMap_;
0071   };
0072 
0073 };  // namespace egammaTools
0074 
0075 #endif