Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-10-17 22:59:00

0001 #include "FWCore/Framework/interface/global/EDProducer.h"
0002 
0003 #include "FWCore/Framework/interface/Event.h"
0004 #include "FWCore/Framework/interface/MakerMacros.h"
0005 #include "FWCore/Utilities/interface/do_nothing_deleter.h"
0006 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0007 
0008 #include "Geometry/CommonTopologies/interface/GeomDetEnumerators.h"
0009 
0010 #include "DataFormats/SiPixelDetId/interface/PixelSubdetector.h"
0011 #include "DataFormats/SiStripDetId/interface/StripSubdetector.h"
0012 #include "DataFormats/TrajectorySeed/interface/TrajectorySeed.h"
0013 #include "DataFormats/TrackCandidate/interface/TrackCandidateCollection.h"
0014 #include "DataFormats/TrackReco/interface/SeedStopInfo.h"
0015 #include "DataFormats/TrackingRecHit/interface/InvalidTrackingRecHit.h"
0016 #include "DataFormats/TrackerRecHit2D/interface/SiStripRecHit1D.h"
0017 #include "DataFormats/TrackerRecHit2D/interface/Phase2TrackerRecHit1D.h"
0018 #include "DataFormats/TrackerCommon/interface/TrackerTopology.h"
0019 
0020 #include "TrackingTools/Records/interface/TransientRecHitRecord.h"
0021 #include "TrackingTools/TransientTrackingRecHit/interface/TransientTrackingRecHitBuilder.h"
0022 #include "TrackingTools/TrajectoryState/interface/TrajectoryStateTransform.h"
0023 
0024 #include "MagneticField/Engine/interface/MagneticField.h"
0025 #include "MagneticField/Records/interface/IdealMagneticFieldRecord.h"
0026 
0027 #include "TrackingTools/GeomPropagators/interface/Propagator.h"
0028 #include "TrackingTools/Records/interface/TrackingComponentsRecord.h"
0029 #include "TrackingTools/KalmanUpdators/interface/KFUpdator.h"
0030 #include "TrackingTools/KalmanUpdators/interface/Chi2MeasurementEstimator.h"
0031 #include "TrackingTools/TrackFitters/interface/KFTrajectoryFitter.h"
0032 #include "RecoTracker/TransientTrackingRecHit/interface/TkClonerImpl.h"
0033 #include "RecoTracker/TransientTrackingRecHit/interface/TkTransientTrackingRecHitBuilder.h"
0034 #include "TrackingTools/MaterialEffects/interface/PropagatorWithMaterial.h"
0035 
0036 #include "RecoTracker/MkFit/interface/MkFitEventOfHits.h"
0037 #include "RecoTracker/MkFit/interface/MkFitClusterIndexToHit.h"
0038 #include "RecoTracker/MkFit/interface/MkFitSeedWrapper.h"
0039 #include "RecoTracker/MkFit/interface/MkFitOutputWrapper.h"
0040 #include "RecoTracker/MkFit/interface/MkFitGeometry.h"
0041 #include "RecoTracker/Record/interface/TrackerRecoGeometryRecord.h"
0042 
0043 // mkFit indludes
0044 #include "RecoTracker/MkFitCMS/interface/LayerNumberConverter.h"
0045 #include "RecoTracker/MkFitCore/interface/Track.h"
0046 #include "RecoTracker/MkFitCore/interface/HitStructures.h"
0047 
0048 //extra for DNN with cands
0049 #include "PhysicsTools/TensorFlow/interface/TfGraphRecord.h"
0050 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0051 #include "PhysicsTools/TensorFlow/interface/TfGraphDefWrapper.h"
0052 #include "TrackingTools/PatternTools/interface/TSCBLBuilderNoMaterial.h"
0053 #include "DataFormats/BeamSpot/interface/BeamSpot.h"
0054 #include "DataFormats/VertexReco/interface/Vertex.h"
0055 
0056 namespace {
0057   template <typename T>
0058   bool isPhase1Barrel(T subdet) {
0059     return subdet == PixelSubdetector::PixelBarrel || subdet == StripSubdetector::TIB ||
0060            subdet == StripSubdetector::TOB;
0061   }
0062 
0063   template <typename T>
0064   bool isPhase1Endcap(T subdet) {
0065     return subdet == PixelSubdetector::PixelEndcap || subdet == StripSubdetector::TID ||
0066            subdet == StripSubdetector::TEC;
0067   }
0068 }  // namespace
0069 
0070 class MkFitOutputConverter : public edm::global::EDProducer<> {
0071 public:
0072   explicit MkFitOutputConverter(edm::ParameterSet const& iConfig);
0073   ~MkFitOutputConverter() override = default;
0074 
0075   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0076 
0077 private:
0078   void produce(edm::StreamID, edm::Event& iEvent, const edm::EventSetup& iSetup) const override;
0079 
0080   TrackCandidateCollection convertCandidates(const MkFitOutputWrapper& mkFitOutput,
0081                                              const mkfit::EventOfHits& eventOfHits,
0082                                              const MkFitClusterIndexToHit& pixelClusterIndexToHit,
0083                                              const MkFitClusterIndexToHit& stripClusterIndexToHit,
0084                                              const edm::View<TrajectorySeed>& seeds,
0085                                              const MagneticField& mf,
0086                                              const Propagator& propagatorAlong,
0087                                              const Propagator& propagatorOpposite,
0088                                              const MkFitGeometry& mkFitGeom,
0089                                              const TrackerTopology& tTopo,
0090                                              const TkClonerImpl& hitCloner,
0091                                              const std::vector<const DetLayer*>& detLayers,
0092                                              const mkfit::TrackVec& mkFitSeeds,
0093                                              const reco::BeamSpot* bs,
0094                                              const reco::VertexCollection* vertices,
0095                                              const tensorflow::Session* session) const;
0096 
0097   std::pair<TrajectoryStateOnSurface, const GeomDet*> backwardFit(const FreeTrajectoryState& fts,
0098                                                                   const edm::OwnVector<TrackingRecHit>& hits,
0099                                                                   const Propagator& propagatorAlong,
0100                                                                   const Propagator& propagatorOpposite,
0101                                                                   const TkClonerImpl& hitCloner,
0102                                                                   const bool isPhase1,
0103                                                                   bool lastHitWasInvalid,
0104                                                                   bool lastHitWasChanged) const;
0105 
0106   std::pair<TrajectoryStateOnSurface, const GeomDet*> convertInnermostState(const FreeTrajectoryState& fts,
0107                                                                             const edm::OwnVector<TrackingRecHit>& hits,
0108                                                                             const Propagator& propagatorAlong,
0109                                                                             const Propagator& propagatorOpposite) const;
0110 
0111   std::vector<float> computeDNNs(TrackCandidateCollection const& tkCC,
0112                                  const std::vector<TrajectoryStateOnSurface>& states,
0113                                  const reco::BeamSpot* bs,
0114                                  const reco::VertexCollection* vertices,
0115                                  const tensorflow::Session* session,
0116                                  const std::vector<float>& chi2,
0117                                  const bool rescaledError) const;
0118 
0119   const edm::EDGetTokenT<MkFitEventOfHits> eventOfHitsToken_;
0120   const edm::EDGetTokenT<MkFitClusterIndexToHit> pixelClusterIndexToHitToken_;
0121   const edm::EDGetTokenT<MkFitClusterIndexToHit> stripClusterIndexToHitToken_;
0122   const edm::EDGetTokenT<MkFitSeedWrapper> mkfitSeedToken_;
0123   const edm::EDGetTokenT<MkFitOutputWrapper> tracksToken_;
0124   const edm::EDGetTokenT<edm::View<TrajectorySeed>> seedToken_;
0125   const edm::ESGetToken<Propagator, TrackingComponentsRecord> propagatorAlongToken_;
0126   const edm::ESGetToken<Propagator, TrackingComponentsRecord> propagatorOppositeToken_;
0127   const edm::ESGetToken<MagneticField, IdealMagneticFieldRecord> mfToken_;
0128   const edm::ESGetToken<TransientTrackingRecHitBuilder, TransientRecHitRecord> ttrhBuilderToken_;
0129   const edm::ESGetToken<MkFitGeometry, TrackerRecoGeometryRecord> mkFitGeomToken_;
0130   const edm::ESGetToken<TrackerTopology, TrackerTopologyRcd> tTopoToken_;
0131   const edm::EDPutTokenT<TrackCandidateCollection> putTrackCandidateToken_;
0132   const edm::EDPutTokenT<std::vector<SeedStopInfo>> putSeedStopInfoToken_;
0133 
0134   const float qualityMaxInvPt_;
0135   const float qualityMinTheta_;
0136   const float qualityMaxRsq_;
0137   const float qualityMaxZ_;
0138   const float qualityMaxPosErrSq_;
0139   const bool qualitySignPt_;
0140 
0141   const bool doErrorRescale_;
0142 
0143   const int algo_;
0144   const bool algoCandSelection_;
0145   const float algoCandWorkingPoint_;
0146   const int bsize_;
0147   const edm::EDGetTokenT<reco::BeamSpot> bsToken_;
0148   const edm::EDGetTokenT<reco::VertexCollection> verticesToken_;
0149   const std::string tfDnnLabel_;
0150   const edm::ESGetToken<TfGraphDefWrapper, TfGraphRecord> tfDnnToken_;
0151 };
0152 
0153 MkFitOutputConverter::MkFitOutputConverter(edm::ParameterSet const& iConfig)
0154     : eventOfHitsToken_{consumes<MkFitEventOfHits>(iConfig.getParameter<edm::InputTag>("mkFitEventOfHits"))},
0155       pixelClusterIndexToHitToken_{consumes(iConfig.getParameter<edm::InputTag>("mkFitPixelHits"))},
0156       stripClusterIndexToHitToken_{consumes(iConfig.getParameter<edm::InputTag>("mkFitStripHits"))},
0157       mkfitSeedToken_{consumes<MkFitSeedWrapper>(iConfig.getParameter<edm::InputTag>("mkFitSeeds"))},
0158       tracksToken_{consumes<MkFitOutputWrapper>(iConfig.getParameter<edm::InputTag>("tracks"))},
0159       seedToken_{consumes<edm::View<TrajectorySeed>>(iConfig.getParameter<edm::InputTag>("seeds"))},
0160       propagatorAlongToken_{
0161           esConsumes<Propagator, TrackingComponentsRecord>(iConfig.getParameter<edm::ESInputTag>("propagatorAlong"))},
0162       propagatorOppositeToken_{esConsumes<Propagator, TrackingComponentsRecord>(
0163           iConfig.getParameter<edm::ESInputTag>("propagatorOpposite"))},
0164       mfToken_{esConsumes<MagneticField, IdealMagneticFieldRecord>()},
0165       ttrhBuilderToken_{esConsumes<TransientTrackingRecHitBuilder, TransientRecHitRecord>(
0166           iConfig.getParameter<edm::ESInputTag>("ttrhBuilder"))},
0167       mkFitGeomToken_{esConsumes<MkFitGeometry, TrackerRecoGeometryRecord>()},
0168       tTopoToken_{esConsumes<TrackerTopology, TrackerTopologyRcd>()},
0169       putTrackCandidateToken_{produces<TrackCandidateCollection>()},
0170       putSeedStopInfoToken_{produces<std::vector<SeedStopInfo>>()},
0171       qualityMaxInvPt_{float(iConfig.getParameter<double>("qualityMaxInvPt"))},
0172       qualityMinTheta_{float(iConfig.getParameter<double>("qualityMinTheta"))},
0173       qualityMaxRsq_{float(pow(iConfig.getParameter<double>("qualityMaxR"), 2))},
0174       qualityMaxZ_{float(iConfig.getParameter<double>("qualityMaxZ"))},
0175       qualityMaxPosErrSq_{float(pow(iConfig.getParameter<double>("qualityMaxPosErr"), 2))},
0176       qualitySignPt_{iConfig.getParameter<bool>("qualitySignPt")},
0177       doErrorRescale_{iConfig.getParameter<bool>("doErrorRescale")},
0178       algo_{reco::TrackBase::algoByName(
0179           TString(iConfig.getParameter<edm::InputTag>("seeds").label()).ReplaceAll("Seeds", "").Data())},
0180       algoCandSelection_{bool(iConfig.getParameter<bool>("candMVASel"))},
0181       algoCandWorkingPoint_{float(iConfig.getParameter<double>("candWP"))},
0182       bsize_{int(iConfig.getParameter<int>("batchSize"))},
0183       bsToken_(algoCandSelection_ ? consumes<reco::BeamSpot>(edm::InputTag("offlineBeamSpot"))
0184                                   : edm::EDGetTokenT<reco::BeamSpot>()),
0185       verticesToken_(algoCandSelection_ ? consumes<reco::VertexCollection>(edm::InputTag("firstStepPrimaryVertices"))
0186                                         : edm::EDGetTokenT<reco::VertexCollection>()),
0187       tfDnnLabel_(iConfig.getParameter<std::string>("tfDnnLabel")),
0188       tfDnnToken_(esConsumes(edm::ESInputTag("", tfDnnLabel_))) {}
0189 
0190 void MkFitOutputConverter::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0191   edm::ParameterSetDescription desc;
0192 
0193   desc.add("mkFitEventOfHits", edm::InputTag{"mkFitEventOfHits"});
0194   desc.add("mkFitPixelHits", edm::InputTag{"mkFitSiPixelHits"});
0195   desc.add("mkFitStripHits", edm::InputTag{"mkFitSiStripHits"});
0196   desc.add("mkFitSeeds", edm::InputTag{"mkFitSeedConverter"});
0197   desc.add("tracks", edm::InputTag{"mkFitProducer"});
0198   desc.add("seeds", edm::InputTag{"initialStepSeeds"});
0199   desc.add("ttrhBuilder", edm::ESInputTag{"", "WithTrackAngle"});
0200   desc.add("propagatorAlong", edm::ESInputTag{"", "PropagatorWithMaterial"});
0201   desc.add("propagatorOpposite", edm::ESInputTag{"", "PropagatorWithMaterialOpposite"});
0202 
0203   desc.add<double>("qualityMaxInvPt", 100)->setComment("max(1/pt) for converted tracks");
0204   desc.add<double>("qualityMinTheta", 0.01)->setComment("lower bound on theta (or pi-theta) for converted tracks");
0205   desc.add<double>("qualityMaxR", 120)->setComment("max(R) for the state position for converted tracks");
0206   desc.add<double>("qualityMaxZ", 280)->setComment("max(|Z|) for the state position for converted tracks");
0207   desc.add<double>("qualityMaxPosErr", 100)->setComment("max position error for converted tracks");
0208   desc.add<bool>("qualitySignPt", true)->setComment("check sign of 1/pt for converted tracks");
0209 
0210   desc.add<bool>("doErrorRescale", true)->setComment("rescale candidate error before final fit");
0211 
0212   desc.add<std::string>("tfDnnLabel", "trackSelectionTf");
0213 
0214   desc.add<bool>("candMVASel", false)->setComment("flag used to trigger MVA selection at cand level");
0215   desc.add<double>("candWP", 0)->setComment("MVA selection at cand level working point");
0216   desc.add<int>("batchSize", 16)->setComment("batch size for cand DNN evaluation");
0217 
0218   descriptions.addWithDefaultLabel(desc);
0219 }
0220 
0221 void MkFitOutputConverter::produce(edm::StreamID iID, edm::Event& iEvent, const edm::EventSetup& iSetup) const {
0222   const auto& seeds = iEvent.get(seedToken_);
0223   const auto& mkfitSeeds = iEvent.get(mkfitSeedToken_);
0224 
0225   const auto& ttrhBuilder = iSetup.getData(ttrhBuilderToken_);
0226   const auto* tkBuilder = dynamic_cast<TkTransientTrackingRecHitBuilder const*>(&ttrhBuilder);
0227   if (!tkBuilder) {
0228     throw cms::Exception("LogicError") << "TTRHBuilder must be of type TkTransientTrackingRecHitBuilder";
0229   }
0230   const auto& mkFitGeom = iSetup.getData(mkFitGeomToken_);
0231   // primary vertices under the algo_ because in initialStepPreSplitting there are no firstStepPrimaryVertices
0232   // beamspot as well since the producer can be used in hlt
0233   const reco::VertexCollection* vertices = nullptr;
0234   const reco::BeamSpot* beamspot = nullptr;
0235   if (algoCandSelection_) {
0236     vertices = &iEvent.get(verticesToken_);
0237     beamspot = &iEvent.get(bsToken_);
0238   }
0239 
0240   const tensorflow::Session* session = nullptr;
0241   if (algoCandSelection_)
0242     session = iSetup.getData(tfDnnToken_).getSession();
0243 
0244   // Convert mkfit presentation back to CMSSW
0245   iEvent.emplace(putTrackCandidateToken_,
0246                  convertCandidates(iEvent.get(tracksToken_),
0247                                    iEvent.get(eventOfHitsToken_).get(),
0248                                    iEvent.get(pixelClusterIndexToHitToken_),
0249                                    iEvent.get(stripClusterIndexToHitToken_),
0250                                    seeds,
0251                                    iSetup.getData(mfToken_),
0252                                    iSetup.getData(propagatorAlongToken_),
0253                                    iSetup.getData(propagatorOppositeToken_),
0254                                    iSetup.getData(mkFitGeomToken_),
0255                                    iSetup.getData(tTopoToken_),
0256                                    tkBuilder->cloner(),
0257                                    mkFitGeom.detLayers(),
0258                                    mkfitSeeds.seeds(),
0259                                    beamspot,
0260                                    vertices,
0261                                    session));
0262 
0263   // TODO: SeedStopInfo is currently unfilled
0264   iEvent.emplace(putSeedStopInfoToken_, seeds.size());
0265 }
0266 
0267 TrackCandidateCollection MkFitOutputConverter::convertCandidates(const MkFitOutputWrapper& mkFitOutput,
0268                                                                  const mkfit::EventOfHits& eventOfHits,
0269                                                                  const MkFitClusterIndexToHit& pixelClusterIndexToHit,
0270                                                                  const MkFitClusterIndexToHit& stripClusterIndexToHit,
0271                                                                  const edm::View<TrajectorySeed>& seeds,
0272                                                                  const MagneticField& mf,
0273                                                                  const Propagator& propagatorAlong,
0274                                                                  const Propagator& propagatorOpposite,
0275                                                                  const MkFitGeometry& mkFitGeom,
0276                                                                  const TrackerTopology& tTopo,
0277                                                                  const TkClonerImpl& hitCloner,
0278                                                                  const std::vector<const DetLayer*>& detLayers,
0279                                                                  const mkfit::TrackVec& mkFitSeeds,
0280                                                                  const reco::BeamSpot* bs,
0281                                                                  const reco::VertexCollection* vertices,
0282                                                                  const tensorflow::Session* session) const {
0283   TrackCandidateCollection output;
0284   const auto& candidates = mkFitOutput.tracks();
0285   output.reserve(candidates.size());
0286 
0287   LogTrace("MkFitOutputConverter") << "Number of candidates " << candidates.size();
0288 
0289   std::vector<float> chi2;
0290   std::vector<TrajectoryStateOnSurface> states;
0291   chi2.reserve(candidates.size());
0292   states.reserve(candidates.size());
0293 
0294   int candIndex = -1;
0295   for (const auto& cand : candidates) {
0296     ++candIndex;
0297     LogTrace("MkFitOutputConverter") << "Candidate " << candIndex << " pT " << cand.pT() << " eta " << cand.momEta()
0298                                      << " phi " << cand.momPhi() << " chi2 " << cand.chi2();
0299 
0300     // state: check for basic quality first
0301     if (cand.state().invpT() > qualityMaxInvPt_ || (qualitySignPt_ && cand.state().invpT() < 0) ||
0302         cand.state().theta() < qualityMinTheta_ || (M_PI - cand.state().theta()) < qualityMinTheta_ ||
0303         cand.state().posRsq() > qualityMaxRsq_ || std::abs(cand.state().z()) > qualityMaxZ_ ||
0304         (cand.state().errors.At(0, 0) + cand.state().errors.At(1, 1) + cand.state().errors.At(2, 2)) >
0305             qualityMaxPosErrSq_) {
0306       edm::LogInfo("MkFitOutputConverter")
0307           << "Candidate " << candIndex << " failed state quality checks" << cand.state().parameters;
0308       continue;
0309     }
0310 
0311     auto state = cand.state();  // copy because have to modify
0312     state.convertFromCCSToGlbCurvilinear();
0313     const auto& param = state.parameters;
0314     const auto& err = state.errors;
0315     AlgebraicSymMatrix55 cov;
0316     for (int i = 0; i < 5; ++i) {
0317       for (int j = i; j < 5; ++j) {
0318         cov[i][j] = err.At(i, j);
0319       }
0320     }
0321 
0322     auto fts = FreeTrajectoryState(
0323         GlobalTrajectoryParameters(
0324             GlobalPoint(param[0], param[1], param[2]), GlobalVector(param[3], param[4], param[5]), state.charge, &mf),
0325         CurvilinearTrajectoryError(cov));
0326     if (!fts.curvilinearError().posDef()) {
0327       edm::LogInfo("MkFitOutputConverter")
0328           << "Curvilinear error not pos-def\n"
0329           << fts.curvilinearError().matrix() << "\ncandidate " << candIndex << "ignored";
0330       continue;
0331     }
0332 
0333     //Sylvester's criterion, start from the smaller submatrix size
0334     double det = 0;
0335     if ((!fts.curvilinearError().matrix().Sub<AlgebraicSymMatrix22>(0, 0).Det(det)) || det < 0) {
0336       edm::LogInfo("MkFitOutputConverter")
0337           << "Fail pos-def check sub2.det for candidate " << candIndex << " with fts " << fts;
0338       continue;
0339     } else if ((!fts.curvilinearError().matrix().Sub<AlgebraicSymMatrix33>(0, 0).Det(det)) || det < 0) {
0340       edm::LogInfo("MkFitOutputConverter")
0341           << "Fail pos-def check sub3.det for candidate " << candIndex << " with fts " << fts;
0342       continue;
0343     } else if ((!fts.curvilinearError().matrix().Sub<AlgebraicSymMatrix44>(0, 0).Det(det)) || det < 0) {
0344       edm::LogInfo("MkFitOutputConverter")
0345           << "Fail pos-def check sub4.det for candidate " << candIndex << " with fts " << fts;
0346       continue;
0347     } else if ((!fts.curvilinearError().matrix().Det2(det)) || det < 0) {
0348       edm::LogInfo("MkFitOutputConverter")
0349           << "Fail pos-def check det for candidate " << candIndex << " with fts " << fts;
0350       continue;
0351     }
0352 
0353     // hits
0354     edm::OwnVector<TrackingRecHit> recHits;
0355     // nTotalHits() gives sum of valid hits (nFoundHits()) and invalid/missing hits.
0356     const int nhits = cand.nTotalHits();
0357     //std::cout << candIndex << ": " << nhits << " " << cand.nFoundHits() << std::endl;
0358     bool lastHitInvalid = false;
0359     const auto isPhase1 = mkFitGeom.isPhase1();
0360     for (int i = 0; i < nhits; ++i) {
0361       const auto& hitOnTrack = cand.getHitOnTrack(i);
0362       LogTrace("MkFitOutputConverter") << " hit on layer " << hitOnTrack.layer << " index " << hitOnTrack.index;
0363       if (hitOnTrack.index < 0) {
0364         // See index-desc.txt file in mkFit for description of negative values
0365         //
0366         // In order to use the regular InvalidTrackingRecHit I'd need
0367         // a GeomDet (and "unfortunately" that is needed in
0368         // TrackProducer).
0369         //
0370         // I guess we could take the track state and propagate it to
0371         // each layer to find the actual module the track crosses, and
0372         // check whether it is active or not to be able to mark
0373         // inactive hits
0374         const auto* detLayer = detLayers.at(hitOnTrack.layer);
0375         if (detLayer == nullptr) {
0376           throw cms::Exception("LogicError") << "DetLayer for layer index " << hitOnTrack.layer << " is null!";
0377         }
0378         // In principle an InvalidTrackingRecHitNoDet could be
0379         // inserted here, but it seems that it is best to deal with
0380         // them in the TrackProducer.
0381         lastHitInvalid = true;
0382       } else {
0383         auto const isPixel = eventOfHits[hitOnTrack.layer].is_pixel();
0384         auto const& hits = isPixel ? pixelClusterIndexToHit.hits() : stripClusterIndexToHit.hits();
0385 
0386         auto const& thit = static_cast<BaseTrackerRecHit const&>(*hits[hitOnTrack.index]);
0387         if (isPhase1) {
0388           if (thit.firstClusterRef().isPixel() || thit.detUnit()->type().isEndcap()) {
0389             recHits.push_back(hits[hitOnTrack.index]->clone());
0390           } else {
0391             recHits.push_back(std::make_unique<SiStripRecHit1D>(
0392                 thit.localPosition(),
0393                 LocalError(thit.localPositionError().xx(), 0.f, std::numeric_limits<float>::max()),
0394                 *thit.det(),
0395                 thit.firstClusterRef()));
0396           }
0397         } else {
0398           if (thit.firstClusterRef().isPixel()) {
0399             recHits.push_back(hits[hitOnTrack.index]->clone());
0400           } else if (thit.firstClusterRef().isPhase2()) {
0401             recHits.push_back(std::make_unique<Phase2TrackerRecHit1D>(
0402                 thit.localPosition(),
0403                 LocalError(thit.localPositionError().xx(), 0.f, std::numeric_limits<float>::max()),
0404                 *thit.det(),
0405                 thit.firstClusterRef().cluster_phase2OT()));
0406           }
0407         }
0408         LogTrace("MkFitOutputConverter") << "  pos " << recHits.back().globalPosition().x() << " "
0409                                          << recHits.back().globalPosition().y() << " "
0410                                          << recHits.back().globalPosition().z() << " mag2 "
0411                                          << recHits.back().globalPosition().mag2() << " detid "
0412                                          << recHits.back().geographicalId().rawId() << " cluster " << hitOnTrack.index;
0413         lastHitInvalid = false;
0414       }
0415     }
0416 
0417     const auto lastHitId = recHits.back().geographicalId();
0418     // MkFit hits are *not* in the order of propagation, sort by 3D radius for now (as we don't have loopers)
0419     // TODO: Improve the sorting (extract keys? maybe even bubble sort would work well as the hits are almost in the correct order)
0420     recHits.sort([&tTopo, &isPhase1](const auto& a, const auto& b) {
0421       //const GeomDetEnumerators::SubDetector asub = a.det()->subDetector();
0422       //const GeomDetEnumerators::SubDetector bsub = b.det()->subDetector();
0423       //const auto& apos = a.globalPosition();
0424       //const auto& bpos = b.globalPosition();
0425       // For Phase-1, can rely on subdetector index
0426       if (isPhase1) {
0427         const auto asub_ph1 = a.geographicalId().subdetId();
0428         const auto bsub_ph1 = b.geographicalId().subdetId();
0429         const auto& apos_ph1 = a.globalPosition();
0430         const auto& bpos_ph1 = b.globalPosition();
0431         if (asub_ph1 != bsub_ph1) {
0432           // Subdetector order (BPix, FPix, TIB, TID, TOB, TEC) corresponds also the navigation
0433           return asub_ph1 < bsub_ph1;
0434         } else {
0435           //if (GeomDetEnumerators::isBarrel(asub)) {
0436           if (isPhase1Barrel(asub_ph1)) {
0437             return apos_ph1.perp2() < bpos_ph1.perp2();
0438           } else {
0439             return std::abs(apos_ph1.z()) < std::abs(bpos_ph1.z());
0440           }
0441         }
0442       }
0443 
0444       // For Phase-2, can not rely uniquely on subdetector index
0445       const GeomDetEnumerators::SubDetector asub = a.det()->subDetector();
0446       const GeomDetEnumerators::SubDetector bsub = b.det()->subDetector();
0447       const auto& apos = a.globalPosition();
0448       const auto& bpos = b.globalPosition();
0449       const auto aid = a.geographicalId().rawId();
0450       const auto bid = b.geographicalId().rawId();
0451       const auto asubid = a.geographicalId().subdetId();
0452       const auto bsubid = b.geographicalId().subdetId();
0453       if (GeomDetEnumerators::isBarrel(asub) || GeomDetEnumerators::isBarrel(bsub)) {
0454         // For barrel tilted modules, or in case (only) one of the two modules is barrel, use 3D position
0455         if ((asubid == StripSubdetector::TOB && tTopo.tobSide(aid) < 3) ||
0456             (bsubid == StripSubdetector::TOB && tTopo.tobSide(bid) < 3) ||
0457             !(GeomDetEnumerators::isBarrel(asub) && GeomDetEnumerators::isBarrel(bsub))) {
0458           return apos.mag2() < bpos.mag2();
0459         }
0460         // For fully barrel comparisons and no tilt, use 2D position
0461         else {
0462           return apos.perp2() < bpos.perp2();
0463         }
0464       }
0465       // For fully endcap comparisons, use z position
0466       else {
0467         return std::abs(apos.z()) < std::abs(bpos.z());
0468       }
0469     });
0470 
0471     const bool lastHitChanged = (recHits.back().geographicalId() != lastHitId);  // TODO: make use of the bools
0472 
0473     // seed
0474     const auto seedIndex = cand.label();
0475     LogTrace("MkFitOutputConverter") << " from seed " << seedIndex << " seed hits";
0476 
0477     // Rescale candidate error if candidate is already propagated to first layer,
0478     // to be consistent with TransientInitialStateEstimator::innerState used in CkfTrackCandidateMakerBase
0479     // Error is only rescaled for candidates propagated to first layer;
0480     // otherwise, candidates undergo backwardFit where error is already rescaled
0481     if (mkFitOutput.propagatedToFirstLayer() && doErrorRescale_)
0482       fts.rescaleError(100.);
0483     auto tsosDet = mkFitOutput.propagatedToFirstLayer()
0484                        ? convertInnermostState(fts, recHits, propagatorAlong, propagatorOpposite)
0485                        : backwardFit(fts,
0486                                      recHits,
0487                                      propagatorAlong,
0488                                      propagatorOpposite,
0489                                      hitCloner,
0490                                      isPhase1,
0491                                      lastHitInvalid,
0492                                      lastHitChanged);
0493     if (!tsosDet.first.isValid()) {
0494       edm::LogInfo("MkFitOutputConverter")
0495           << "Backward fit of candidate " << candIndex << " failed, ignoring the candidate";
0496       continue;
0497     }
0498 
0499     // convert to persistent, from CkfTrackCandidateMakerBase
0500     auto pstate = trajectoryStateTransform::persistentState(tsosDet.first, tsosDet.second->geographicalId().rawId());
0501 
0502     output.emplace_back(
0503         recHits,
0504         seeds.at(seedIndex),
0505         pstate,
0506         seeds.refAt(seedIndex),
0507         0,                                               // mkFit does not produce loopers, so set nLoops=0
0508         static_cast<uint8_t>(StopReason::UNINITIALIZED)  // TODO: ignore details of stopping reason as well for now
0509     );
0510 
0511     chi2.push_back(cand.chi2());
0512     states.push_back(tsosDet.first);
0513   }
0514 
0515   if (algoCandSelection_) {
0516     const auto& dnnScores = computeDNNs(
0517         output, states, bs, vertices, session, chi2, mkFitOutput.propagatedToFirstLayer() && doErrorRescale_);
0518 
0519     TrackCandidateCollection reducedOutput;
0520     reducedOutput.reserve(output.size());
0521     int scoreIndex = 0;
0522     for (const auto& score : dnnScores) {
0523       if (score > algoCandWorkingPoint_)
0524         reducedOutput.push_back(output[scoreIndex]);
0525       scoreIndex++;
0526     }
0527 
0528     output.swap(reducedOutput);
0529   }
0530 
0531   return output;
0532 }
0533 
0534 std::pair<TrajectoryStateOnSurface, const GeomDet*> MkFitOutputConverter::backwardFit(
0535     const FreeTrajectoryState& fts,
0536     const edm::OwnVector<TrackingRecHit>& hits,
0537     const Propagator& propagatorAlong,
0538     const Propagator& propagatorOpposite,
0539     const TkClonerImpl& hitCloner,
0540     const bool isPhase1,
0541     bool lastHitWasInvalid,
0542     bool lastHitWasChanged) const {
0543   // First filter valid hits as in TransientInitialStateEstimator
0544   TransientTrackingRecHit::ConstRecHitContainer firstHits;
0545 
0546   for (int i = hits.size() - 1; i >= 0; --i) {
0547     if (hits[i].det()) {
0548       // TransientTrackingRecHit::ConstRecHitContainer has shared_ptr,
0549       // and it is passed to backFitter below so it is really needed
0550       // to keep the interface. Since we keep the ownership in hits,
0551       // let's disable the deleter.
0552       firstHits.emplace_back(&(hits[i]), edm::do_nothing_deleter{});
0553     }
0554   }
0555 
0556   // Then propagate along to the surface of the last hit to get a TSOS
0557   const auto& lastHitSurface = firstHits.front()->det()->surface();
0558 
0559   const Propagator* tryFirst = &propagatorAlong;
0560   const Propagator* trySecond = &propagatorOpposite;
0561   if (lastHitWasInvalid || lastHitWasChanged) {
0562     LogTrace("MkFitOutputConverter") << "Propagating first opposite, then along, because lastHitWasInvalid? "
0563                                      << lastHitWasInvalid << " or lastHitWasChanged? " << lastHitWasChanged;
0564     std::swap(tryFirst, trySecond);
0565   } else {
0566     bool doSwitch = false;
0567     const auto& surfacePos = lastHitSurface.position();
0568     const auto& lastHitPos = firstHits.front()->globalPosition();
0569     if (isPhase1) {
0570       const auto lastHitSubdet = firstHits.front()->geographicalId().subdetId();
0571       if (isPhase1Barrel(lastHitSubdet)) {
0572         doSwitch = (surfacePos.perp2() < lastHitPos.perp2());
0573       } else {
0574         doSwitch = (surfacePos.z() < lastHitPos.z());
0575       }
0576     } else {
0577       const auto lastHitSubdet = firstHits.front()->det()->subDetector();
0578       if (GeomDetEnumerators::isBarrel(lastHitSubdet)) {
0579         doSwitch = (surfacePos.perp2() < lastHitPos.perp2());
0580       } else {
0581         doSwitch = (surfacePos.z() < lastHitPos.z());
0582       }
0583     }
0584     if (doSwitch) {
0585       LogTrace("MkFitOutputConverter")
0586           << "Propagating first opposite, then along, because surface is inner than the hit; surface perp2 "
0587           << surfacePos.perp() << " hit " << lastHitPos.perp2() << " surface z " << surfacePos.z() << " hit "
0588           << lastHitPos.z();
0589 
0590       std::swap(tryFirst, trySecond);
0591     }
0592   }
0593 
0594   auto tsosDouble = tryFirst->propagateWithPath(fts, lastHitSurface);
0595   if (!tsosDouble.first.isValid()) {
0596     LogDebug("MkFitOutputConverter") << "Propagating to startingState failed, trying in another direction next";
0597     tsosDouble = trySecond->propagateWithPath(fts, lastHitSurface);
0598   }
0599   auto& startingState = tsosDouble.first;
0600 
0601   if (!startingState.isValid()) {
0602     edm::LogWarning("MkFitOutputConverter")
0603         << "startingState is not valid, FTS was\n"
0604         << fts << " last hit surface surface:"
0605         << "\n position " << lastHitSurface.position() << "\n phiSpan " << lastHitSurface.phiSpan().first << ","
0606         << lastHitSurface.phiSpan().first << "\n rSpan " << lastHitSurface.rSpan().first << ","
0607         << lastHitSurface.rSpan().first << "\n zSpan " << lastHitSurface.zSpan().first << ","
0608         << lastHitSurface.zSpan().first;
0609     return std::pair<TrajectoryStateOnSurface, const GeomDet*>();
0610   }
0611 
0612   // Then return back to the logic from TransientInitialStateEstimator
0613   startingState.rescaleError(100.);
0614 
0615   // avoid cloning
0616   KFUpdator const aKFUpdator;
0617   Chi2MeasurementEstimator const aChi2MeasurementEstimator(100., 3);
0618   KFTrajectoryFitter backFitter(
0619       &propagatorAlong, &aKFUpdator, &aChi2MeasurementEstimator, firstHits.size(), nullptr, &hitCloner);
0620 
0621   // assume for now that the propagation in mkfit always alongMomentum
0622   PropagationDirection backFitDirection = oppositeToMomentum;
0623 
0624   // only direction matters in this context
0625   TrajectorySeed fakeSeed(PTrajectoryStateOnDet(), edm::OwnVector<TrackingRecHit>(), backFitDirection);
0626 
0627   // ignore loopers for now
0628   Trajectory fitres = backFitter.fitOne(fakeSeed, firstHits, startingState, TrajectoryFitter::standard);
0629 
0630   LogDebug("MkFitOutputConverter") << "using a backward fit of :" << firstHits.size() << " hits, starting from:\n"
0631                                    << startingState << " to get the estimate of the initial state of the track.";
0632 
0633   if (!fitres.isValid()) {
0634     edm::LogWarning("MkFitOutputConverter") << "FitTester: first hits fit failed";
0635     return std::pair<TrajectoryStateOnSurface, const GeomDet*>();
0636   }
0637 
0638   TrajectoryMeasurement const& firstMeas = fitres.lastMeasurement();
0639 
0640   // magnetic field can be different!
0641   TrajectoryStateOnSurface firstState(firstMeas.updatedState().localParameters(),
0642                                       firstMeas.updatedState().localError(),
0643                                       firstMeas.updatedState().surface(),
0644                                       propagatorAlong.magneticField());
0645 
0646   firstState.rescaleError(100.);
0647 
0648   LogDebug("MkFitOutputConverter") << "the initial state is found to be:\n:" << firstState
0649                                    << "\n it's field pointer is: " << firstState.magneticField()
0650                                    << "\n the pointer from the state of the back fit was: "
0651                                    << firstMeas.updatedState().magneticField();
0652 
0653   return std::make_pair(firstState, firstMeas.recHit()->det());
0654 }
0655 
0656 std::pair<TrajectoryStateOnSurface, const GeomDet*> MkFitOutputConverter::convertInnermostState(
0657     const FreeTrajectoryState& fts,
0658     const edm::OwnVector<TrackingRecHit>& hits,
0659     const Propagator& propagatorAlong,
0660     const Propagator& propagatorOpposite) const {
0661   auto det = hits[0].det();
0662   if (det == nullptr) {
0663     throw cms::Exception("LogicError") << "Got nullptr from the first hit det()";
0664   }
0665 
0666   const auto& firstHitSurface = det->surface();
0667 
0668   auto tsosDouble = propagatorAlong.propagateWithPath(fts, firstHitSurface);
0669   if (!tsosDouble.first.isValid()) {
0670     LogDebug("MkFitOutputConverter") << "Propagating to startingState along momentum failed, trying opposite next";
0671     tsosDouble = propagatorOpposite.propagateWithPath(fts, firstHitSurface);
0672   }
0673 
0674   return std::make_pair(tsosDouble.first, det);
0675 }
0676 
0677 std::vector<float> MkFitOutputConverter::computeDNNs(TrackCandidateCollection const& tkCC,
0678                                                      const std::vector<TrajectoryStateOnSurface>& states,
0679                                                      const reco::BeamSpot* bs,
0680                                                      const reco::VertexCollection* vertices,
0681                                                      const tensorflow::Session* session,
0682                                                      const std::vector<float>& chi2,
0683                                                      const bool rescaledError) const {
0684   int size_in = (int)tkCC.size();
0685   int nbatches = size_in / bsize_;
0686 
0687   std::vector<float> output(size_in, 0);
0688 
0689   TSCBLBuilderNoMaterial tscblBuilder;
0690 
0691   tensorflow::Tensor input1(tensorflow::DT_FLOAT, {bsize_, 29});
0692   tensorflow::Tensor input2(tensorflow::DT_FLOAT, {bsize_, 1});
0693 
0694   for (auto nb = 0; nb < nbatches + 1; nb++) {
0695     std::vector<bool> invalidProp(bsize_, false);
0696 
0697     for (auto nt = 0; nt < bsize_; nt++) {
0698       int itrack = nt + bsize_ * nb;
0699       if (itrack >= size_in)
0700         continue;
0701 
0702       auto const& tkC = tkCC.at(itrack);
0703 
0704       TrajectoryStateOnSurface state = states.at(itrack);
0705 
0706       if (rescaledError)
0707         state.rescaleError(1 / 100.f);
0708 
0709       TrajectoryStateClosestToBeamLine tsAtClosestApproachTrackCand =
0710           tscblBuilder(*state.freeState(), *bs);  //as in TrackProducerAlgorithm
0711 
0712       if (!(tsAtClosestApproachTrackCand.isValid())) {
0713         edm::LogVerbatim("TrackBuilding") << "TrajectoryStateClosestToBeamLine not valid";
0714         invalidProp[nt] = true;
0715         continue;
0716       }
0717 
0718       auto const& stateAtPCA = tsAtClosestApproachTrackCand.trackStateAtPCA();
0719       auto v0 = stateAtPCA.position();
0720       auto p = stateAtPCA.momentum();
0721       math::XYZPoint pos(v0.x(), v0.y(), v0.z());
0722       math::XYZVector mom(p.x(), p.y(), p.z());
0723 
0724       //pseudo track for access to easy methods
0725       reco::Track trk(0, 0, pos, mom, stateAtPCA.charge(), stateAtPCA.curvilinearError());
0726 
0727       // get best vertex
0728       float dzmin = std::numeric_limits<float>::max();
0729       float dxy_zmin = 0;
0730 
0731       for (auto const& vertex : *vertices) {
0732         if (std::abs(trk.dz(vertex.position())) < dzmin) {
0733           dzmin = trk.dz(vertex.position());
0734           dxy_zmin = trk.dxy(vertex.position());
0735         }
0736       }
0737 
0738       // loop over the RecHits
0739       int ndof = 0;
0740       int pix = 0;
0741       int strip = 0;
0742       for (auto const& recHit : tkC.recHits()) {
0743         ndof += recHit.dimension();
0744         auto const subdet = recHit.geographicalId().subdetId();
0745         if (subdet == PixelSubdetector::PixelBarrel || subdet == PixelSubdetector::PixelEndcap)
0746           pix++;
0747         else
0748           strip++;
0749       }
0750       ndof = ndof - 5;
0751 
0752       input1.matrix<float>()(nt, 0) = trk.pt();  //using inner track only
0753       input1.matrix<float>()(nt, 1) = p.x();
0754       input1.matrix<float>()(nt, 2) = p.y();
0755       input1.matrix<float>()(nt, 3) = p.z();
0756       input1.matrix<float>()(nt, 4) = p.perp();
0757       input1.matrix<float>()(nt, 5) = p.x();
0758       input1.matrix<float>()(nt, 6) = p.y();
0759       input1.matrix<float>()(nt, 7) = p.z();
0760       input1.matrix<float>()(nt, 8) = p.perp();
0761       input1.matrix<float>()(nt, 9) = trk.ptError();
0762       input1.matrix<float>()(nt, 10) = dxy_zmin;
0763       input1.matrix<float>()(nt, 11) = dzmin;
0764       input1.matrix<float>()(nt, 12) = trk.dxy(bs->position());
0765       input1.matrix<float>()(nt, 13) = trk.dz(bs->position());
0766       input1.matrix<float>()(nt, 14) = trk.dxyError();
0767       input1.matrix<float>()(nt, 15) = trk.dzError();
0768       input1.matrix<float>()(nt, 16) = ndof > 0 ? chi2[itrack] / ndof : chi2[itrack] * 1e6;
0769       input1.matrix<float>()(nt, 17) = trk.eta();
0770       input1.matrix<float>()(nt, 18) = trk.phi();
0771       input1.matrix<float>()(nt, 19) = trk.etaError();
0772       input1.matrix<float>()(nt, 20) = trk.phiError();
0773       input1.matrix<float>()(nt, 21) = pix;
0774       input1.matrix<float>()(nt, 22) = strip;
0775       input1.matrix<float>()(nt, 23) = ndof;
0776       input1.matrix<float>()(nt, 24) = 0;
0777       input1.matrix<float>()(nt, 25) = 0;
0778       input1.matrix<float>()(nt, 26) = 0;
0779       input1.matrix<float>()(nt, 27) = 0;
0780       input1.matrix<float>()(nt, 28) = 0;
0781 
0782       input2.matrix<float>()(nt, 0) = algo_;
0783     }
0784 
0785     //inputs finalized
0786     tensorflow::NamedTensorList inputs;
0787     inputs.resize(2);
0788     inputs[0] = tensorflow::NamedTensor("x", input1);
0789     inputs[1] = tensorflow::NamedTensor("y", input2);
0790 
0791     //eval and rescale
0792     std::vector<tensorflow::Tensor> outputs;
0793     tensorflow::run(session, inputs, {"Identity"}, &outputs);
0794 
0795     for (auto nt = 0; nt < bsize_; nt++) {
0796       int itrack = nt + bsize_ * nb;
0797       if (itrack >= size_in)
0798         continue;
0799 
0800       float out0 = 2.0 * outputs[0].matrix<float>()(nt, 0) - 1.0;
0801       if (invalidProp[nt])
0802         out0 = -1;
0803 
0804       output[itrack] = out0;
0805     }
0806   }
0807 
0808   return output;
0809 }
0810 
0811 DEFINE_FWK_MODULE(MkFitOutputConverter);