Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-09-26 05:07:07

0001 #ifndef RecoHGCal_TICL_TracksterInferenceByDNN_H__
0002 #define RecoHGCal_TICL_TracksterInferenceByDNN_H__
0003 
0004 #include "RecoHGCal/TICL/interface/TracksterInferenceAlgoBase.h"
0005 #include "RecoLocalCalo/HGCalRecAlgos/interface/RecHitTools.h"
0006 
0007 namespace ticl {
0008 
0009   class TracksterInferenceByDNN : public TracksterInferenceAlgoBase {
0010   public:
0011     explicit TracksterInferenceByDNN(const edm::ParameterSet& conf);
0012     void inputData(const std::vector<reco::CaloCluster>& layerClusters, std::vector<Trackster>& tracksters) override;
0013     void runInference(std::vector<Trackster>& tracksters) override;
0014 
0015     static void fillPSetDescription(edm::ParameterSetDescription& iDesc);
0016 
0017   private:
0018     const cms::Ort::ONNXRuntime* onnxPIDSession_;
0019     const cms::Ort::ONNXRuntime* onnxEnergySession_;
0020 
0021     const std::string id_modelPath_;
0022     const std::string en_modelPath_;
0023     const std::vector<std::string> inputNames_;
0024     const std::vector<std::string> output_en_;
0025     const std::vector<std::string> output_id_;
0026     const float eidMinClusterEnergy_;
0027     const int eidNLayers_;
0028     const int eidNClusters_;
0029     static constexpr int eidNFeatures_ = 3;
0030     int doPID_;
0031     int doRegression_;
0032 
0033     hgcal::RecHitTools rhtools_;
0034     std::vector<std::vector<int64_t>> input_shapes_;
0035     std::vector<int> tracksterIndices_;
0036     std::vector<std::vector<float>> input_Data_;
0037     int batchSize_;
0038   };
0039 }  // namespace ticl
0040 
0041 #endif  // RecoHGCal_TICL_TracksterInferenceByDNN_H__