File indexing completed on 2025-05-09 22:40:18
0001 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/Event.h"
0002 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/EventSetup.h"
0003 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/global/EDProducer.h"
0004 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/EDGetToken.h"
0005 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/EDPutToken.h"
0006 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/ESGetToken.h"
0007 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0008 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0009 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0010 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0011
0012 #include "DataFormats/TrackerRecHit2D/interface/Phase2TrackerRecHit1D.h"
0013
0014 #include "FWCore/Utilities/interface/transform.h"
0015 #include "MagneticField/Engine/interface/MagneticField.h"
0016 #include "MagneticField/Records/interface/IdealMagneticFieldRecord.h"
0017 #include "DataFormats/TrackerRecHit2D/interface/SiStripMatchedRecHit2DCollection.h"
0018 #include "DataFormats/TrajectorySeed/interface/TrajectorySeedCollection.h"
0019 #include "DataFormats/TrackReco/interface/trackFromSeedFitFailed.h"
0020 #include "TrackingTools/Records/interface/TransientRecHitRecord.h"
0021 #include "TrackingTools/TrajectoryState/interface/TrajectoryStateTransform.h"
0022 #include "TrackingTools/TransientTrackingRecHit/interface/TransientTrackingRecHitBuilder.h"
0023
0024 #include "RecoTracker/LSTCore/interface/LSTInputHostCollection.h"
0025 #include "RecoTracker/LSTCore/interface/LSTPrepareInput.h"
0026
0027 namespace ALPAKA_ACCELERATOR_NAMESPACE {
0028
0029 class LSTInputProducer : public global::EDProducer<> {
0030 public:
0031 LSTInputProducer(edm::ParameterSet const& iConfig);
0032 ~LSTInputProducer() override = default;
0033
0034 static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0035
0036 private:
0037 void produce(edm::StreamID, device::Event& iEvent, const device::EventSetup& iSetup) const override;
0038
0039 const double ptCut_;
0040
0041 const edm::EDGetTokenT<Phase2TrackerRecHit1DCollectionNew> phase2OTRecHitToken_;
0042
0043 const edm::ESGetToken<MagneticField, IdealMagneticFieldRecord> mfToken_;
0044 const edm::EDGetTokenT<reco::BeamSpot> beamSpotToken_;
0045 const std::vector<edm::EDGetTokenT<edm::View<reco::Track>>> seedTokens_;
0046 const edm::EDPutTokenT<TrajectorySeedCollection> lstPixelSeedsPutToken_;
0047
0048 const edm::EDPutTokenT<lst::LSTInputHostCollection> lstInputPutToken_;
0049 };
0050
0051 LSTInputProducer::LSTInputProducer(edm::ParameterSet const& iConfig)
0052 : EDProducer<>(iConfig),
0053 ptCut_(iConfig.getParameter<double>("ptCut")),
0054 phase2OTRecHitToken_(consumes(iConfig.getParameter<edm::InputTag>("phase2OTRecHits"))),
0055 mfToken_(esConsumes()),
0056 beamSpotToken_(consumes(iConfig.getParameter<edm::InputTag>("beamSpot"))),
0057 seedTokens_(
0058 edm::vector_transform(iConfig.getParameter<std::vector<edm::InputTag>>("seedTracks"),
0059 [&](const edm::InputTag& tag) { return consumes<edm::View<reco::Track>>(tag); })),
0060 lstPixelSeedsPutToken_(produces()),
0061 lstInputPutToken_(produces()) {}
0062
0063 void LSTInputProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0064 edm::ParameterSetDescription desc;
0065
0066 desc.add<double>("ptCut", 0.8);
0067
0068 desc.add<edm::InputTag>("phase2OTRecHits", edm::InputTag("siPhase2RecHits"));
0069
0070 desc.add<edm::InputTag>("beamSpot", edm::InputTag("offlineBeamSpot"));
0071 desc.add<std::vector<edm::InputTag>>("seedTracks",
0072 std::vector<edm::InputTag>{edm::InputTag("lstInitialStepSeedTracks"),
0073 edm::InputTag("lstHighPtTripletStepSeedTracks")});
0074
0075 descriptions.addWithDefaultLabel(desc);
0076 }
0077
0078 void LSTInputProducer::produce(edm::StreamID iID, device::Event& iEvent, const device::EventSetup& iSetup) const {
0079
0080 auto const& phase2OTHits = iEvent.get(phase2OTRecHitToken_);
0081
0082 std::vector<unsigned int> ph2_detId;
0083 ph2_detId.reserve(phase2OTHits.dataSize());
0084 std::vector<float> ph2_x;
0085 ph2_x.reserve(phase2OTHits.dataSize());
0086 std::vector<float> ph2_y;
0087 ph2_y.reserve(phase2OTHits.dataSize());
0088 std::vector<float> ph2_z;
0089 ph2_z.reserve(phase2OTHits.dataSize());
0090 std::vector<TrackingRecHit const*> ph2_hits;
0091 ph2_hits.reserve(phase2OTHits.dataSize());
0092
0093 for (auto const& it : phase2OTHits) {
0094 const DetId hitId = it.detId();
0095 for (auto const& hit : it) {
0096 ph2_detId.push_back(hitId.rawId());
0097 ph2_x.push_back(hit.globalPosition().x());
0098 ph2_y.push_back(hit.globalPosition().y());
0099 ph2_z.push_back(hit.globalPosition().z());
0100 ph2_hits.push_back(&hit);
0101 }
0102 }
0103
0104
0105 auto const& mf = iSetup.getData(mfToken_);
0106 auto const& bs = iEvent.get(beamSpotToken_);
0107
0108
0109 std::vector<float> see_px;
0110 std::vector<float> see_py;
0111 std::vector<float> see_pz;
0112 std::vector<float> see_dxy;
0113 std::vector<float> see_dz;
0114 std::vector<float> see_ptErr;
0115 std::vector<float> see_etaErr;
0116 std::vector<float> see_stateTrajGlbX;
0117 std::vector<float> see_stateTrajGlbY;
0118 std::vector<float> see_stateTrajGlbZ;
0119 std::vector<float> see_stateTrajGlbPx;
0120 std::vector<float> see_stateTrajGlbPy;
0121 std::vector<float> see_stateTrajGlbPz;
0122 std::vector<int> see_q;
0123 std::vector<std::vector<int>> see_hitIdx;
0124 TrajectorySeedCollection see_seeds;
0125
0126 for (auto const& seedToken : seedTokens_) {
0127 auto const& seedTracks = iEvent.get(seedToken);
0128
0129 if (seedTracks.empty())
0130 continue;
0131
0132
0133 edm::RefToBaseVector<reco::Track> seedTrackRefs;
0134 for (edm::View<reco::Track>::size_type i = 0; i < seedTracks.size(); ++i) {
0135 seedTrackRefs.push_back(seedTracks.refAt(i));
0136 }
0137
0138 edm::ProductID id = seedTracks[0].seedRef().id();
0139
0140 for (size_t iSeed = 0; iSeed < seedTrackRefs.size(); ++iSeed) {
0141 auto const& seedTrackRef = seedTrackRefs[iSeed];
0142 auto const& seedTrack = *seedTrackRef;
0143 auto const& seedRef = seedTrack.seedRef();
0144 auto const& seed = *seedRef;
0145
0146 if (seedRef.id() != id)
0147 throw cms::Exception("LogicError")
0148 << "All tracks in 'TracksFromSeeds' collection should point to seeds in the same collection. Now the "
0149 "element 0 had ProductID "
0150 << id << " while the element " << seedTrackRef.key() << " had " << seedTrackRef.id() << ".";
0151
0152 const bool seedFitOk = !trackFromSeedFitFailed(seedTrack);
0153
0154 const TrackingRecHit* lastRecHit = &*(seed.recHits().end() - 1);
0155 TrajectoryStateOnSurface tsos =
0156 trajectoryStateTransform::transientState(seed.startingState(), lastRecHit->surface(), &mf);
0157 auto const& stateGlobal = tsos.globalParameters();
0158
0159 std::vector<int> hitIdx;
0160 for (auto const& hit : seed.recHits()) {
0161 int subid = hit.geographicalId().subdetId();
0162 if (subid == (int)PixelSubdetector::PixelBarrel || subid == (int)PixelSubdetector::PixelEndcap) {
0163 const BaseTrackerRecHit* bhit = dynamic_cast<const BaseTrackerRecHit*>(&hit);
0164 const auto& clusterRef = bhit->firstClusterRef();
0165 const auto clusterKey = clusterRef.cluster_pixel().key();
0166 hitIdx.push_back(clusterKey);
0167 } else {
0168 throw cms::Exception("LSTInputProducer") << "Not pixel hits found!";
0169 }
0170 }
0171
0172
0173 see_px.push_back(seedFitOk ? seedTrack.px() : 0);
0174 see_py.push_back(seedFitOk ? seedTrack.py() : 0);
0175 see_pz.push_back(seedFitOk ? seedTrack.pz() : 0);
0176 see_dxy.push_back(seedFitOk ? seedTrack.dxy(bs.position()) : 0);
0177 see_dz.push_back(seedFitOk ? seedTrack.dz(bs.position()) : 0);
0178 see_ptErr.push_back(seedFitOk ? seedTrack.ptError() : 0);
0179 see_etaErr.push_back(seedFitOk ? seedTrack.etaError() : 0);
0180 see_stateTrajGlbX.push_back(stateGlobal.position().x());
0181 see_stateTrajGlbY.push_back(stateGlobal.position().y());
0182 see_stateTrajGlbZ.push_back(stateGlobal.position().z());
0183 see_stateTrajGlbPx.push_back(stateGlobal.momentum().x());
0184 see_stateTrajGlbPy.push_back(stateGlobal.momentum().y());
0185 see_stateTrajGlbPz.push_back(stateGlobal.momentum().z());
0186 see_q.push_back(seedTrack.charge());
0187 see_hitIdx.emplace_back(std::move(hitIdx));
0188 see_seeds.push_back(seed);
0189 }
0190 }
0191
0192 auto lstInputHC = lst::prepareInput(see_px,
0193 see_py,
0194 see_pz,
0195 see_dxy,
0196 see_dz,
0197 see_ptErr,
0198 see_etaErr,
0199 see_stateTrajGlbX,
0200 see_stateTrajGlbY,
0201 see_stateTrajGlbZ,
0202 see_stateTrajGlbPx,
0203 see_stateTrajGlbPy,
0204 see_stateTrajGlbPz,
0205 see_q,
0206 see_hitIdx,
0207 {},
0208 ph2_detId,
0209 ph2_x,
0210 ph2_y,
0211 ph2_z,
0212 ph2_hits,
0213 ptCut_,
0214 iEvent.queue());
0215
0216 iEvent.emplace(lstInputPutToken_, std::move(lstInputHC));
0217 iEvent.emplace(lstPixelSeedsPutToken_, std::move(see_seeds));
0218 }
0219
0220 }
0221
0222 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/MakerMacros.h"
0223 DEFINE_FWK_ALPAKA_MODULE(LSTInputProducer);