File indexing completed on 2024-09-26 05:07:08
0001 #include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
0002 #include "RecoHGCal/TICL/interface/TracksterInferenceByDNN.h"
0003 #include "RecoHGCal/TICL/interface/TracksterInferenceAlgoFactory.h"
0004 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0005 #include "FWCore/Framework/interface/MakerMacros.h"
0006 #include "RecoHGCal/TICL/interface/PatternRecognitionAlgoBase.h"
0007 #include "RecoLocalCalo/HGCalRecAlgos/interface/RecHitTools.h"
0008 #include "TrackstersPCA.h"
0009
0010 namespace ticl {
0011 using namespace cms::Ort;
0012
0013
0014 TracksterInferenceByDNN::TracksterInferenceByDNN(const edm::ParameterSet& conf)
0015 : TracksterInferenceAlgoBase(conf),
0016 id_modelPath_(
0017 conf.getParameter<edm::FileInPath>("onnxPIDModelPath").fullPath()),
0018 en_modelPath_(
0019 conf.getParameter<edm::FileInPath>("onnxEnergyModelPath").fullPath()),
0020 inputNames_(conf.getParameter<std::vector<std::string>>("inputNames")),
0021 output_en_(conf.getParameter<std::vector<std::string>>("output_en")),
0022 output_id_(conf.getParameter<std::vector<std::string>>("output_id")),
0023 eidMinClusterEnergy_(conf.getParameter<double>("eid_min_cluster_energy")),
0024 eidNLayers_(conf.getParameter<int>("eid_n_layers")),
0025 eidNClusters_(conf.getParameter<int>("eid_n_clusters")),
0026 doPID_(conf.getParameter<int>("doPID")),
0027 doRegression_(conf.getParameter<int>("doRegression"))
0028 {
0029
0030 static std::unique_ptr<cms::Ort::ONNXRuntime> onnxPIDRuntimeInstance =
0031 std::make_unique<cms::Ort::ONNXRuntime>(id_modelPath_.c_str());
0032 onnxPIDSession_ = onnxPIDRuntimeInstance.get();
0033 static std::unique_ptr<cms::Ort::ONNXRuntime> onnxEnergyRuntimeInstance =
0034 std::make_unique<cms::Ort::ONNXRuntime>(en_modelPath_.c_str());
0035 onnxEnergySession_ = onnxEnergyRuntimeInstance.get();
0036 }
0037
0038
0039 void TracksterInferenceByDNN::inputData(const std::vector<reco::CaloCluster>& layerClusters,
0040 std::vector<Trackster>& tracksters) {
0041 tracksterIndices_.clear();
0042 for (int i = 0; i < static_cast<int>(tracksters.size()); i++) {
0043 float sumClusterEnergy = 0.;
0044 for (const unsigned int& vertex : tracksters[i].vertices()) {
0045 sumClusterEnergy += static_cast<float>(layerClusters[vertex].energy());
0046 if (sumClusterEnergy >= eidMinClusterEnergy_) {
0047 tracksters[i].setRegressedEnergy(0.f);
0048 tracksters[i].zeroProbabilities();
0049 tracksterIndices_.push_back(i);
0050 break;
0051 }
0052 }
0053 }
0054
0055
0056 batchSize_ = static_cast<int>(tracksterIndices_.size());
0057 if (batchSize_ == 0)
0058 return;
0059
0060 std::vector<int64_t> inputShape = {batchSize_, eidNLayers_, eidNClusters_, eidNFeatures_};
0061 input_shapes_ = {inputShape};
0062
0063 input_Data_.clear();
0064 input_Data_.emplace_back(batchSize_ * eidNLayers_ * eidNClusters_ * eidNFeatures_, 0);
0065
0066 for (int i = 0; i < batchSize_; i++) {
0067 const Trackster& trackster = tracksters[tracksterIndices_[i]];
0068
0069
0070 std::vector<int> clusterIndices(trackster.vertices().size());
0071 for (int k = 0; k < static_cast<int>(trackster.vertices().size()); k++) {
0072 clusterIndices[k] = k;
0073 }
0074
0075 std::sort(clusterIndices.begin(), clusterIndices.end(), [&layerClusters, &trackster](const int& a, const int& b) {
0076 return layerClusters[trackster.vertices(a)].energy() > layerClusters[trackster.vertices(b)].energy();
0077 });
0078
0079 std::vector<int> seenClusters(eidNLayers_, 0);
0080
0081
0082 for (const int& k : clusterIndices) {
0083 const reco::CaloCluster& cluster = layerClusters[trackster.vertices(k)];
0084 int j = rhtools_.getLayerWithOffset(cluster.hitsAndFractions()[0].first) - 1;
0085 if (j < eidNLayers_ && seenClusters[j] < eidNClusters_) {
0086 auto index = (i * eidNLayers_ + j) * eidNFeatures_ * eidNClusters_ + seenClusters[j] * eidNFeatures_;
0087 input_Data_[0][index] =
0088 static_cast<float>(cluster.energy() / static_cast<float>(trackster.vertex_multiplicity(k)));
0089 input_Data_[0][index + 1] = static_cast<float>(std::abs(cluster.eta()));
0090 input_Data_[0][index + 2] = static_cast<float>(cluster.phi());
0091 seenClusters[j]++;
0092 }
0093 }
0094 }
0095 }
0096
0097
0098 void TracksterInferenceByDNN::runInference(std::vector<Trackster>& tracksters) {
0099 if (batchSize_ == 0)
0100 return;
0101
0102 if (doPID_ and doRegression_) {
0103
0104 auto result = onnxEnergySession_->run(inputNames_, input_Data_, input_shapes_, output_en_, batchSize_);
0105 auto& energyOutputTensor = result[0];
0106 if (!output_en_.empty()) {
0107 for (int i = 0; i < static_cast<int>(batchSize_); i++) {
0108 const float energy = energyOutputTensor[i];
0109 tracksters[tracksterIndices_[i]].setRegressedEnergy(energy);
0110 }
0111 }
0112 }
0113
0114 if (doPID_) {
0115
0116 auto pidOutput = onnxPIDSession_->run(inputNames_, input_Data_, input_shapes_, output_id_, batchSize_);
0117 auto pidOutputTensor = pidOutput[0];
0118 float* probs = pidOutputTensor.data();
0119 if (!output_id_.empty()) {
0120 for (int i = 0; i < batchSize_; i++) {
0121 tracksters[tracksterIndices_[i]].setProbabilities(probs);
0122 probs += tracksters[tracksterIndices_[i]].id_probabilities().size();
0123 }
0124 }
0125 }
0126 }
0127
0128 void TracksterInferenceByDNN::fillPSetDescription(edm::ParameterSetDescription& iDesc) {
0129 iDesc.add<int>("algo_verbosity", 0);
0130 iDesc
0131 .add<edm::FileInPath>("onnxPIDModelPath",
0132 edm::FileInPath("RecoHGCal/TICL/data/ticlv5/onnx_models/patternrecognition/id_v0.onnx"))
0133 ->setComment("Path to ONNX PID model CLU3D");
0134 iDesc
0135 .add<edm::FileInPath>(
0136 "onnxEnergyModelPath",
0137 edm::FileInPath("RecoHGCal/TICL/data/ticlv5/onnx_models/patternrecognition/energy_v0.onnx"))
0138 ->setComment("Path to ONNX Energy model CLU3D");
0139 iDesc.add<std::vector<std::string>>("inputNames", {"input"});
0140 iDesc.add<std::vector<std::string>>("output_en", {"enreg_output"});
0141 iDesc.add<std::vector<std::string>>("output_id", {"pid_output"});
0142 iDesc.add<double>("eid_min_cluster_energy", 1.0);
0143 iDesc.add<int>("eid_n_layers", 50);
0144 iDesc.add<int>("eid_n_clusters", 10);
0145 iDesc.add<int>("doPID", 1);
0146 iDesc.add<int>("doRegression", 1);
0147 }
0148 }