Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-05-09 22:40:18

0001 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/Event.h"
0002 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/EventSetup.h"
0003 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/global/EDProducer.h"
0004 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/EDGetToken.h"
0005 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/EDPutToken.h"
0006 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/ESGetToken.h"
0007 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0008 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0009 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0010 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0011 
0012 #include "DataFormats/TrackerRecHit2D/interface/Phase2TrackerRecHit1D.h"
0013 
0014 #include "FWCore/Utilities/interface/transform.h"
0015 #include "MagneticField/Engine/interface/MagneticField.h"
0016 #include "MagneticField/Records/interface/IdealMagneticFieldRecord.h"
0017 #include "DataFormats/TrackerRecHit2D/interface/SiStripMatchedRecHit2DCollection.h"
0018 #include "DataFormats/TrajectorySeed/interface/TrajectorySeedCollection.h"
0019 #include "DataFormats/TrackReco/interface/trackFromSeedFitFailed.h"
0020 #include "TrackingTools/Records/interface/TransientRecHitRecord.h"
0021 #include "TrackingTools/TrajectoryState/interface/TrajectoryStateTransform.h"
0022 #include "TrackingTools/TransientTrackingRecHit/interface/TransientTrackingRecHitBuilder.h"
0023 
0024 #include "RecoTracker/LSTCore/interface/LSTInputHostCollection.h"
0025 #include "RecoTracker/LSTCore/interface/LSTPrepareInput.h"
0026 
0027 namespace ALPAKA_ACCELERATOR_NAMESPACE {
0028 
0029   class LSTInputProducer : public global::EDProducer<> {
0030   public:
0031     LSTInputProducer(edm::ParameterSet const& iConfig);
0032     ~LSTInputProducer() override = default;
0033 
0034     static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0035 
0036   private:
0037     void produce(edm::StreamID, device::Event& iEvent, const device::EventSetup& iSetup) const override;
0038 
0039     const double ptCut_;
0040 
0041     const edm::EDGetTokenT<Phase2TrackerRecHit1DCollectionNew> phase2OTRecHitToken_;
0042 
0043     const edm::ESGetToken<MagneticField, IdealMagneticFieldRecord> mfToken_;
0044     const edm::EDGetTokenT<reco::BeamSpot> beamSpotToken_;
0045     const std::vector<edm::EDGetTokenT<edm::View<reco::Track>>> seedTokens_;
0046     const edm::EDPutTokenT<TrajectorySeedCollection> lstPixelSeedsPutToken_;
0047 
0048     const edm::EDPutTokenT<lst::LSTInputHostCollection> lstInputPutToken_;
0049   };
0050 
0051   LSTInputProducer::LSTInputProducer(edm::ParameterSet const& iConfig)
0052       : EDProducer<>(iConfig),
0053         ptCut_(iConfig.getParameter<double>("ptCut")),
0054         phase2OTRecHitToken_(consumes(iConfig.getParameter<edm::InputTag>("phase2OTRecHits"))),
0055         mfToken_(esConsumes()),
0056         beamSpotToken_(consumes(iConfig.getParameter<edm::InputTag>("beamSpot"))),
0057         seedTokens_(
0058             edm::vector_transform(iConfig.getParameter<std::vector<edm::InputTag>>("seedTracks"),
0059                                   [&](const edm::InputTag& tag) { return consumes<edm::View<reco::Track>>(tag); })),
0060         lstPixelSeedsPutToken_(produces()),
0061         lstInputPutToken_(produces()) {}
0062 
0063   void LSTInputProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0064     edm::ParameterSetDescription desc;
0065 
0066     desc.add<double>("ptCut", 0.8);
0067 
0068     desc.add<edm::InputTag>("phase2OTRecHits", edm::InputTag("siPhase2RecHits"));
0069 
0070     desc.add<edm::InputTag>("beamSpot", edm::InputTag("offlineBeamSpot"));
0071     desc.add<std::vector<edm::InputTag>>("seedTracks",
0072                                          std::vector<edm::InputTag>{edm::InputTag("lstInitialStepSeedTracks"),
0073                                                                     edm::InputTag("lstHighPtTripletStepSeedTracks")});
0074 
0075     descriptions.addWithDefaultLabel(desc);
0076   }
0077 
0078   void LSTInputProducer::produce(edm::StreamID iID, device::Event& iEvent, const device::EventSetup& iSetup) const {
0079     // Get the phase2OTRecHits
0080     auto const& phase2OTHits = iEvent.get(phase2OTRecHitToken_);
0081 
0082     std::vector<unsigned int> ph2_detId;
0083     ph2_detId.reserve(phase2OTHits.dataSize());
0084     std::vector<float> ph2_x;
0085     ph2_x.reserve(phase2OTHits.dataSize());
0086     std::vector<float> ph2_y;
0087     ph2_y.reserve(phase2OTHits.dataSize());
0088     std::vector<float> ph2_z;
0089     ph2_z.reserve(phase2OTHits.dataSize());
0090     std::vector<TrackingRecHit const*> ph2_hits;
0091     ph2_hits.reserve(phase2OTHits.dataSize());
0092 
0093     for (auto const& it : phase2OTHits) {
0094       const DetId hitId = it.detId();
0095       for (auto const& hit : it) {
0096         ph2_detId.push_back(hitId.rawId());
0097         ph2_x.push_back(hit.globalPosition().x());
0098         ph2_y.push_back(hit.globalPosition().y());
0099         ph2_z.push_back(hit.globalPosition().z());
0100         ph2_hits.push_back(&hit);
0101       }
0102     }
0103 
0104     // Get the pixel seeds
0105     auto const& mf = iSetup.getData(mfToken_);
0106     auto const& bs = iEvent.get(beamSpotToken_);
0107 
0108     // Vector definitions
0109     std::vector<float> see_px;
0110     std::vector<float> see_py;
0111     std::vector<float> see_pz;
0112     std::vector<float> see_dxy;
0113     std::vector<float> see_dz;
0114     std::vector<float> see_ptErr;
0115     std::vector<float> see_etaErr;
0116     std::vector<float> see_stateTrajGlbX;
0117     std::vector<float> see_stateTrajGlbY;
0118     std::vector<float> see_stateTrajGlbZ;
0119     std::vector<float> see_stateTrajGlbPx;
0120     std::vector<float> see_stateTrajGlbPy;
0121     std::vector<float> see_stateTrajGlbPz;
0122     std::vector<int> see_q;
0123     std::vector<std::vector<int>> see_hitIdx;
0124     TrajectorySeedCollection see_seeds;
0125 
0126     for (auto const& seedToken : seedTokens_) {
0127       auto const& seedTracks = iEvent.get(seedToken);
0128 
0129       if (seedTracks.empty())
0130         continue;
0131 
0132       // Get seed track refs
0133       edm::RefToBaseVector<reco::Track> seedTrackRefs;
0134       for (edm::View<reco::Track>::size_type i = 0; i < seedTracks.size(); ++i) {
0135         seedTrackRefs.push_back(seedTracks.refAt(i));
0136       }
0137 
0138       edm::ProductID id = seedTracks[0].seedRef().id();
0139 
0140       for (size_t iSeed = 0; iSeed < seedTrackRefs.size(); ++iSeed) {
0141         auto const& seedTrackRef = seedTrackRefs[iSeed];
0142         auto const& seedTrack = *seedTrackRef;
0143         auto const& seedRef = seedTrack.seedRef();
0144         auto const& seed = *seedRef;
0145 
0146         if (seedRef.id() != id)
0147           throw cms::Exception("LogicError")
0148               << "All tracks in 'TracksFromSeeds' collection should point to seeds in the same collection. Now the "
0149                  "element 0 had ProductID "
0150               << id << " while the element " << seedTrackRef.key() << " had " << seedTrackRef.id() << ".";
0151 
0152         const bool seedFitOk = !trackFromSeedFitFailed(seedTrack);
0153 
0154         const TrackingRecHit* lastRecHit = &*(seed.recHits().end() - 1);
0155         TrajectoryStateOnSurface tsos =
0156             trajectoryStateTransform::transientState(seed.startingState(), lastRecHit->surface(), &mf);
0157         auto const& stateGlobal = tsos.globalParameters();
0158 
0159         std::vector<int> hitIdx;
0160         for (auto const& hit : seed.recHits()) {
0161           int subid = hit.geographicalId().subdetId();
0162           if (subid == (int)PixelSubdetector::PixelBarrel || subid == (int)PixelSubdetector::PixelEndcap) {
0163             const BaseTrackerRecHit* bhit = dynamic_cast<const BaseTrackerRecHit*>(&hit);
0164             const auto& clusterRef = bhit->firstClusterRef();
0165             const auto clusterKey = clusterRef.cluster_pixel().key();
0166             hitIdx.push_back(clusterKey);
0167           } else {
0168             throw cms::Exception("LSTInputProducer") << "Not pixel hits found!";
0169           }
0170         }
0171 
0172         // Fill output
0173         see_px.push_back(seedFitOk ? seedTrack.px() : 0);
0174         see_py.push_back(seedFitOk ? seedTrack.py() : 0);
0175         see_pz.push_back(seedFitOk ? seedTrack.pz() : 0);
0176         see_dxy.push_back(seedFitOk ? seedTrack.dxy(bs.position()) : 0);
0177         see_dz.push_back(seedFitOk ? seedTrack.dz(bs.position()) : 0);
0178         see_ptErr.push_back(seedFitOk ? seedTrack.ptError() : 0);
0179         see_etaErr.push_back(seedFitOk ? seedTrack.etaError() : 0);
0180         see_stateTrajGlbX.push_back(stateGlobal.position().x());
0181         see_stateTrajGlbY.push_back(stateGlobal.position().y());
0182         see_stateTrajGlbZ.push_back(stateGlobal.position().z());
0183         see_stateTrajGlbPx.push_back(stateGlobal.momentum().x());
0184         see_stateTrajGlbPy.push_back(stateGlobal.momentum().y());
0185         see_stateTrajGlbPz.push_back(stateGlobal.momentum().z());
0186         see_q.push_back(seedTrack.charge());
0187         see_hitIdx.emplace_back(std::move(hitIdx));
0188         see_seeds.push_back(seed);
0189       }
0190     }
0191 
0192     auto lstInputHC = lst::prepareInput(see_px,
0193                                         see_py,
0194                                         see_pz,
0195                                         see_dxy,
0196                                         see_dz,
0197                                         see_ptErr,
0198                                         see_etaErr,
0199                                         see_stateTrajGlbX,
0200                                         see_stateTrajGlbY,
0201                                         see_stateTrajGlbZ,
0202                                         see_stateTrajGlbPx,
0203                                         see_stateTrajGlbPy,
0204                                         see_stateTrajGlbPz,
0205                                         see_q,
0206                                         see_hitIdx,
0207                                         {},
0208                                         ph2_detId,
0209                                         ph2_x,
0210                                         ph2_y,
0211                                         ph2_z,
0212                                         ph2_hits,
0213                                         ptCut_,
0214                                         iEvent.queue());
0215 
0216     iEvent.emplace(lstInputPutToken_, std::move(lstInputHC));
0217     iEvent.emplace(lstPixelSeedsPutToken_, std::move(see_seeds));
0218   }
0219 
0220 }  // namespace ALPAKA_ACCELERATOR_NAMESPACE
0221 
0222 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/MakerMacros.h"
0223 DEFINE_FWK_ALPAKA_MODULE(LSTInputProducer);