Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:40

0001 #include "RecoEcal/EgammaClusterAlgos/interface/SCEnergyCorrectorDRN.h"
0002 
0003 #include "FWCore/Utilities/interface/isFinite.h"
0004 #include "FWCore/Utilities/interface/Transition.h"
0005 #include "DataFormats/EcalDetId/interface/EcalSubdetector.h"
0006 #include "DataFormats/VertexReco/interface/Vertex.h"
0007 #include "DataFormats/Math/interface/deltaPhi.h"
0008 #include "RecoEcal/EgammaCoreTools/interface/EcalTools.h"
0009 #include "RecoEcal/EgammaCoreTools/interface/EcalClusterTools.h"
0010 #include "RecoEgamma/EgammaTools/interface/EgammaHGCALIDParamDefaults.h"
0011 
0012 #include <vdt/vdtMath.h>
0013 
0014 static const float RHO_MAX = 15.0f;
0015 static const float X_MAX = 150.0f;
0016 static const float X_RANGE = 300.0f;
0017 static const float Y_MAX = 150.0f;
0018 static const float Y_RANGE = 300.0f;
0019 static const float Z_MAX = 330.0f;
0020 static const float Z_RANGE = 660.0f;
0021 static const float E_RANGE = 250.0f;
0022 
0023 SCEnergyCorrectorDRN::SCEnergyCorrectorDRN() : caloTopo_(nullptr), caloGeom_(nullptr) {}
0024 
0025 SCEnergyCorrectorDRN::SCEnergyCorrectorDRN(const edm::ParameterSet& iConfig, edm::ConsumesCollector cc)
0026     : SCEnergyCorrectorDRN() {
0027   setTokens(iConfig, cc);
0028 }
0029 
0030 void SCEnergyCorrectorDRN::fillPSetDescription(edm::ParameterSetDescription& desc) {
0031   desc.add<edm::InputTag>("ecalRecHitsEE", edm::InputTag("ecalRecHit", "reducedEcalRecHitsEE"));
0032   desc.add<edm::InputTag>("ecalRecHitsEB", edm::InputTag("ecalRecHit", "reducedEcalRecHitsEB"));
0033   desc.add<edm::InputTag>("rhoFastJet", edm::InputTag("fixedGridRhoAll"));
0034 }
0035 
0036 edm::ParameterSetDescription SCEnergyCorrectorDRN::makePSetDescription() {
0037   edm::ParameterSetDescription desc;
0038   fillPSetDescription(desc);
0039   return desc;
0040 }
0041 
0042 void SCEnergyCorrectorDRN::setEventSetup(const edm::EventSetup& es) {
0043   caloTopo_ = &es.getData(caloTopoToken_);
0044   caloGeom_ = &es.getData(caloGeomToken_);
0045 }
0046 
0047 void SCEnergyCorrectorDRN::setEvent(const edm::Event& event) {
0048   event.getByToken(tokenEBRecHits_, recHitsEB_);
0049   event.getByToken(tokenEERecHits_, recHitsEE_);
0050   event.getByToken(rhoToken_, rhoHandle_);
0051 }
0052 
0053 void SCEnergyCorrectorDRN::makeInput(const edm::Event& iEvent,
0054                                      TritonInputMap& iInput,
0055                                      const reco::SuperClusterCollection& inputSCs) const {
0056   std::vector<unsigned> nHits;
0057   nHits.reserve(inputSCs.size());
0058   unsigned totalHits = 0;
0059   unsigned n;
0060   for (const auto& inputSC : inputSCs) {
0061     n = inputSC.hitsAndFractions().size();
0062     totalHits += n;
0063     nHits.push_back(n);
0064   }
0065 
0066   //set shapes
0067   auto& input1 = iInput.at("x__0");
0068   input1.setShape(0, totalHits);
0069   auto data1 = input1.allocate<float>();
0070   auto& vdata1 = (*data1)[0];
0071 
0072   auto& input2 = iInput.at("batch__1");
0073   input2.setShape(0, totalHits);
0074   auto data2 = input2.allocate<int64_t>();
0075   auto& vdata2 = (*data2)[0];
0076 
0077   auto& input3 = iInput.at("graphx__2");
0078   input3.setShape(0, 2 * nHits.size());
0079   auto data3 = input3.allocate<float>();
0080   auto& vdata3 = (*data3)[0];
0081 
0082   //fill
0083   unsigned batchNum = 0;
0084   float En, frac, x, y, z;
0085   for (const auto& inputSC : inputSCs) {
0086     const auto& hits = inputSC.hitsAndFractions();
0087     const bool isEB = hits[0].first.subdetId() == EcalBarrel;
0088     const auto& recHitsProduct = isEB ? recHitsEB_.product() : recHitsEE_.product();
0089     for (const auto& hit : hits) {
0090       En = EcalClusterTools::recHitEnergy(hit.first, recHitsProduct);
0091       frac = hit.second;
0092       GlobalPoint position = caloGeom_->getGeometry(hit.first)->getPosition();
0093       x = (position.x() + X_MAX) / X_RANGE;
0094       y = (position.y() + Y_MAX) / Y_RANGE;
0095       z = (position.z() + Z_MAX) / Z_RANGE;
0096       vdata1.push_back(x);
0097       vdata1.push_back(y);
0098       vdata1.push_back(z);
0099       vdata1.push_back(En * frac / E_RANGE);
0100       //Triton does not currently support batching for pytorch GNNs
0101       //We pass batch indices explicitely
0102       vdata2.push_back(batchNum);
0103     }
0104     vdata3.push_back(*rhoHandle_ / RHO_MAX);
0105     vdata3.push_back(0.0);
0106     ++batchNum;
0107   }
0108 
0109   // convert to server format
0110   input1.toServer(data1);
0111   input2.toServer(data2);
0112   input3.toServer(data3);
0113 }
0114 
0115 TritonOutput<float> SCEnergyCorrectorDRN::getOutput(const TritonOutputMap& iOutput) {
0116   //check the results
0117   const auto& output1 = iOutput.begin()->second;
0118   // convert from server format
0119   const auto& serverout = output1.fromServer<float>();
0120 
0121   return serverout;
0122 }