Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-03-23 16:00:23

0001 #ifndef RecoHGCal_TICL_TracksterInferenceByPFN_H__
0002 #define RecoHGCal_TICL_TracksterInferenceByPFN_H__
0003 
0004 #include "RecoHGCal/TICL/interface/TracksterInferenceAlgoBase.h"
0005 #include "RecoLocalCalo/HGCalRecAlgos/interface/RecHitTools.h"
0006 
0007 namespace ticl {
0008 
0009   class TracksterInferenceByPFN : public TracksterInferenceAlgoBase {
0010   public:
0011     explicit TracksterInferenceByPFN(const edm::ParameterSet& conf);
0012     void inputData(const std::vector<reco::CaloCluster>& layerClusters, std::vector<Trackster>& tracksters) override;
0013     void runInference(std::vector<Trackster>& tracksters) override;
0014 
0015     static void fillPSetDescription(edm::ParameterSetDescription& iDesc);
0016 
0017   private:
0018     const std::unique_ptr<cms::Ort::ONNXRuntime> onnxPIDRuntimeInstance_;
0019     const std::unique_ptr<cms::Ort::ONNXRuntime> onnxEnergyRuntimeInstance_;
0020     const cms::Ort::ONNXRuntime* onnxPIDSession_;
0021     const cms::Ort::ONNXRuntime* onnxEnergySession_;
0022 
0023     const std::vector<std::string> inputNames_;
0024     const std::vector<std::string> output_en_;
0025     const std::vector<std::string> output_id_;
0026     const float eidMinClusterEnergy_;
0027     const int eidNLayers_;
0028     const int eidNClusters_;
0029     static constexpr int eidNFeatures_ = 7;
0030     int doPID_;
0031     int doRegression_;
0032 
0033     hgcal::RecHitTools rhtools_;
0034     std::vector<std::vector<int64_t>> input_shapes_;
0035     std::vector<int> tracksterIndices_;
0036     std::vector<std::vector<float>> input_Data_;
0037     int batchSize_;
0038   };
0039 }  // namespace ticl
0040 
0041 #endif  // RecoHGCal_TICL_TracksterInferenceByPFN_H__