File indexing completed on 2024-09-26 05:07:07
0001 #ifndef RecoHGCal_TICL_TracksterInferenceByCNNv4_H__
0002 #define RecoHGCal_TICL_TracksterInferenceByCNNv4_H__
0003
0004 #include "RecoHGCal/TICL/interface/TracksterInferenceAlgoBase.h"
0005 #include "RecoLocalCalo/HGCalRecAlgos/interface/RecHitTools.h"
0006
0007 namespace ticl {
0008
0009 class TracksterInferenceByCNNv4 : public TracksterInferenceAlgoBase {
0010 public:
0011 explicit TracksterInferenceByCNNv4(const edm::ParameterSet& conf);
0012 void inputData(const std::vector<reco::CaloCluster>& layerClusters, std::vector<Trackster>& tracksters) override;
0013 void runInference(std::vector<Trackster>& tracksters) override;
0014
0015 static void fillPSetDescription(edm::ParameterSetDescription& iDesc);
0016
0017 private:
0018 const cms::Ort::ONNXRuntime* onnxSession_;
0019
0020 const std::string modelPath_;
0021 const std::vector<std::string> inputNames_;
0022 const std::vector<std::string> outputNames_;
0023 const float eidMinClusterEnergy_;
0024 const int eidNLayers_;
0025 const int eidNClusters_;
0026 static constexpr int eidNFeatures_ = 3;
0027 int doPID_;
0028 int doRegression_;
0029
0030 hgcal::RecHitTools rhtools_;
0031 std::vector<std::vector<int64_t>> input_shapes_;
0032 std::vector<int> tracksterIndices_;
0033 std::vector<std::vector<float>> input_Data_;
0034 int batchSize_;
0035 };
0036 }
0037
0038 #endif