Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-09-26 05:07:07

0001 #ifndef RecoHGCal_TICL_TracksterInferenceByCNNv4_H__
0002 #define RecoHGCal_TICL_TracksterInferenceByCNNv4_H__
0003 
0004 #include "RecoHGCal/TICL/interface/TracksterInferenceAlgoBase.h"
0005 #include "RecoLocalCalo/HGCalRecAlgos/interface/RecHitTools.h"
0006 
0007 namespace ticl {
0008 
0009   class TracksterInferenceByCNNv4 : public TracksterInferenceAlgoBase {
0010   public:
0011     explicit TracksterInferenceByCNNv4(const edm::ParameterSet& conf);
0012     void inputData(const std::vector<reco::CaloCluster>& layerClusters, std::vector<Trackster>& tracksters) override;
0013     void runInference(std::vector<Trackster>& tracksters) override;
0014 
0015     static void fillPSetDescription(edm::ParameterSetDescription& iDesc);
0016 
0017   private:
0018     const cms::Ort::ONNXRuntime* onnxSession_;
0019 
0020     const std::string modelPath_;
0021     const std::vector<std::string> inputNames_;
0022     const std::vector<std::string> outputNames_;
0023     const float eidMinClusterEnergy_;
0024     const int eidNLayers_;
0025     const int eidNClusters_;
0026     static constexpr int eidNFeatures_ = 3;
0027     int doPID_;
0028     int doRegression_;
0029 
0030     hgcal::RecHitTools rhtools_;
0031     std::vector<std::vector<int64_t>> input_shapes_;
0032     std::vector<int> tracksterIndices_;
0033     std::vector<std::vector<float>> input_Data_;
0034     int batchSize_;
0035   };
0036 }  // namespace ticl
0037 
0038 #endif  // RecoHGCal_TICL_TracksterInferenceByDNN_H__