File indexing completed on 2024-09-26 05:07:07
0001 #ifndef RecoHGCal_TICL_TracksterInferenceByDNN_H__
0002 #define RecoHGCal_TICL_TracksterInferenceByDNN_H__
0003
0004 #include "RecoHGCal/TICL/interface/TracksterInferenceAlgoBase.h"
0005 #include "RecoLocalCalo/HGCalRecAlgos/interface/RecHitTools.h"
0006
0007 namespace ticl {
0008
0009 class TracksterInferenceByDNN : public TracksterInferenceAlgoBase {
0010 public:
0011 explicit TracksterInferenceByDNN(const edm::ParameterSet& conf);
0012 void inputData(const std::vector<reco::CaloCluster>& layerClusters, std::vector<Trackster>& tracksters) override;
0013 void runInference(std::vector<Trackster>& tracksters) override;
0014
0015 static void fillPSetDescription(edm::ParameterSetDescription& iDesc);
0016
0017 private:
0018 const cms::Ort::ONNXRuntime* onnxPIDSession_;
0019 const cms::Ort::ONNXRuntime* onnxEnergySession_;
0020
0021 const std::string id_modelPath_;
0022 const std::string en_modelPath_;
0023 const std::vector<std::string> inputNames_;
0024 const std::vector<std::string> output_en_;
0025 const std::vector<std::string> output_id_;
0026 const float eidMinClusterEnergy_;
0027 const int eidNLayers_;
0028 const int eidNClusters_;
0029 static constexpr int eidNFeatures_ = 3;
0030 int doPID_;
0031 int doRegression_;
0032
0033 hgcal::RecHitTools rhtools_;
0034 std::vector<std::vector<int64_t>> input_shapes_;
0035 std::vector<int> tracksterIndices_;
0036 std::vector<std::vector<float>> input_Data_;
0037 int batchSize_;
0038 };
0039 }
0040
0041 #endif