Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-25 02:14:11

0001 #include "FWCore/Framework/interface/global/EDProducer.h"
0002 
0003 #include "FWCore/Framework/interface/Event.h"
0004 #include "FWCore/Framework/interface/MakerMacros.h"
0005 #include "FWCore/Utilities/interface/do_nothing_deleter.h"
0006 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0007 
0008 #include "DataFormats/SiPixelDetId/interface/PixelSubdetector.h"
0009 #include "DataFormats/SiStripDetId/interface/StripSubdetector.h"
0010 #include "DataFormats/TrajectorySeed/interface/TrajectorySeed.h"
0011 #include "DataFormats/TrackCandidate/interface/TrackCandidateCollection.h"
0012 #include "DataFormats/TrackReco/interface/SeedStopInfo.h"
0013 #include "DataFormats/TrackingRecHit/interface/InvalidTrackingRecHit.h"
0014 #include "DataFormats/TrackerRecHit2D/interface/SiStripRecHit1D.h"
0015 #include "DataFormats/TrackerRecHit2D/interface/Phase2TrackerRecHit1D.h"
0016 
0017 #include "TrackingTools/Records/interface/TransientRecHitRecord.h"
0018 #include "TrackingTools/TransientTrackingRecHit/interface/TransientTrackingRecHitBuilder.h"
0019 #include "TrackingTools/TrajectoryState/interface/TrajectoryStateTransform.h"
0020 
0021 #include "MagneticField/Engine/interface/MagneticField.h"
0022 #include "MagneticField/Records/interface/IdealMagneticFieldRecord.h"
0023 
0024 #include "TrackingTools/GeomPropagators/interface/Propagator.h"
0025 #include "TrackingTools/Records/interface/TrackingComponentsRecord.h"
0026 #include "TrackingTools/KalmanUpdators/interface/KFUpdator.h"
0027 #include "TrackingTools/KalmanUpdators/interface/Chi2MeasurementEstimator.h"
0028 #include "TrackingTools/TrackFitters/interface/KFTrajectoryFitter.h"
0029 #include "RecoTracker/TransientTrackingRecHit/interface/TkClonerImpl.h"
0030 #include "RecoTracker/TransientTrackingRecHit/interface/TkTransientTrackingRecHitBuilder.h"
0031 #include "TrackingTools/MaterialEffects/interface/PropagatorWithMaterial.h"
0032 
0033 #include "RecoTracker/MkFit/interface/MkFitEventOfHits.h"
0034 #include "RecoTracker/MkFit/interface/MkFitClusterIndexToHit.h"
0035 #include "RecoTracker/MkFit/interface/MkFitSeedWrapper.h"
0036 #include "RecoTracker/MkFit/interface/MkFitOutputWrapper.h"
0037 #include "RecoTracker/MkFit/interface/MkFitGeometry.h"
0038 #include "RecoTracker/Record/interface/TrackerRecoGeometryRecord.h"
0039 
0040 // mkFit indludes
0041 #include "RecoTracker/MkFitCMS/interface/LayerNumberConverter.h"
0042 #include "RecoTracker/MkFitCore/interface/Track.h"
0043 #include "RecoTracker/MkFitCore/interface/HitStructures.h"
0044 
0045 //extra for DNN with cands
0046 #include "PhysicsTools/TensorFlow/interface/TfGraphRecord.h"
0047 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0048 #include "PhysicsTools/TensorFlow/interface/TfGraphDefWrapper.h"
0049 #include "TrackingTools/PatternTools/interface/TSCBLBuilderNoMaterial.h"
0050 #include "DataFormats/BeamSpot/interface/BeamSpot.h"
0051 #include "DataFormats/VertexReco/interface/Vertex.h"
0052 
0053 namespace {
0054   template <typename T>
0055   bool isBarrel(T subdet) {
0056     return subdet == PixelSubdetector::PixelBarrel || subdet == StripSubdetector::TIB ||
0057            subdet == StripSubdetector::TOB;
0058   }
0059 
0060   template <typename T>
0061   bool isEndcap(T subdet) {
0062     return subdet == PixelSubdetector::PixelEndcap || subdet == StripSubdetector::TID ||
0063            subdet == StripSubdetector::TEC;
0064   }
0065 }  // namespace
0066 
0067 class MkFitOutputConverter : public edm::global::EDProducer<> {
0068 public:
0069   explicit MkFitOutputConverter(edm::ParameterSet const& iConfig);
0070   ~MkFitOutputConverter() override = default;
0071 
0072   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0073 
0074 private:
0075   void produce(edm::StreamID, edm::Event& iEvent, const edm::EventSetup& iSetup) const override;
0076 
0077   TrackCandidateCollection convertCandidates(const MkFitOutputWrapper& mkFitOutput,
0078                                              const mkfit::EventOfHits& eventOfHits,
0079                                              const MkFitClusterIndexToHit& pixelClusterIndexToHit,
0080                                              const MkFitClusterIndexToHit& stripClusterIndexToHit,
0081                                              const edm::View<TrajectorySeed>& seeds,
0082                                              const MagneticField& mf,
0083                                              const Propagator& propagatorAlong,
0084                                              const Propagator& propagatorOpposite,
0085                                              const MkFitGeometry& mkFitGeom,
0086                                              const TkClonerImpl& hitCloner,
0087                                              const std::vector<const DetLayer*>& detLayers,
0088                                              const mkfit::TrackVec& mkFitSeeds,
0089                                              const reco::BeamSpot* bs,
0090                                              const reco::VertexCollection* vertices,
0091                                              const tensorflow::Session* session) const;
0092 
0093   std::pair<TrajectoryStateOnSurface, const GeomDet*> backwardFit(const FreeTrajectoryState& fts,
0094                                                                   const edm::OwnVector<TrackingRecHit>& hits,
0095                                                                   const Propagator& propagatorAlong,
0096                                                                   const Propagator& propagatorOpposite,
0097                                                                   const TkClonerImpl& hitCloner,
0098                                                                   bool lastHitWasInvalid,
0099                                                                   bool lastHitWasChanged) const;
0100 
0101   std::pair<TrajectoryStateOnSurface, const GeomDet*> convertInnermostState(const FreeTrajectoryState& fts,
0102                                                                             const edm::OwnVector<TrackingRecHit>& hits,
0103                                                                             const Propagator& propagatorAlong,
0104                                                                             const Propagator& propagatorOpposite) const;
0105 
0106   std::vector<float> computeDNNs(TrackCandidateCollection const& tkCC,
0107                                  const std::vector<TrajectoryStateOnSurface>& states,
0108                                  const reco::BeamSpot* bs,
0109                                  const reco::VertexCollection* vertices,
0110                                  const tensorflow::Session* session,
0111                                  const std::vector<float>& chi2,
0112                                  const bool rescaledError) const;
0113 
0114   const edm::EDGetTokenT<MkFitEventOfHits> eventOfHitsToken_;
0115   const edm::EDGetTokenT<MkFitClusterIndexToHit> pixelClusterIndexToHitToken_;
0116   const edm::EDGetTokenT<MkFitClusterIndexToHit> stripClusterIndexToHitToken_;
0117   const edm::EDGetTokenT<MkFitSeedWrapper> mkfitSeedToken_;
0118   const edm::EDGetTokenT<MkFitOutputWrapper> tracksToken_;
0119   const edm::EDGetTokenT<edm::View<TrajectorySeed>> seedToken_;
0120   const edm::ESGetToken<Propagator, TrackingComponentsRecord> propagatorAlongToken_;
0121   const edm::ESGetToken<Propagator, TrackingComponentsRecord> propagatorOppositeToken_;
0122   const edm::ESGetToken<MagneticField, IdealMagneticFieldRecord> mfToken_;
0123   const edm::ESGetToken<TransientTrackingRecHitBuilder, TransientRecHitRecord> ttrhBuilderToken_;
0124   const edm::ESGetToken<MkFitGeometry, TrackerRecoGeometryRecord> mkFitGeomToken_;
0125   const edm::EDPutTokenT<TrackCandidateCollection> putTrackCandidateToken_;
0126   const edm::EDPutTokenT<std::vector<SeedStopInfo>> putSeedStopInfoToken_;
0127 
0128   const float qualityMaxInvPt_;
0129   const float qualityMinTheta_;
0130   const float qualityMaxRsq_;
0131   const float qualityMaxZ_;
0132   const float qualityMaxPosErrSq_;
0133   const bool qualitySignPt_;
0134 
0135   const bool doErrorRescale_;
0136 
0137   const int algo_;
0138   const bool algoCandSelection_;
0139   const float algoCandWorkingPoint_;
0140   const int bsize_;
0141   const edm::EDGetTokenT<reco::BeamSpot> bsToken_;
0142   const edm::EDGetTokenT<reco::VertexCollection> verticesToken_;
0143   const std::string tfDnnLabel_;
0144   const edm::ESGetToken<TfGraphDefWrapper, TfGraphRecord> tfDnnToken_;
0145 };
0146 
0147 MkFitOutputConverter::MkFitOutputConverter(edm::ParameterSet const& iConfig)
0148     : eventOfHitsToken_{consumes<MkFitEventOfHits>(iConfig.getParameter<edm::InputTag>("mkFitEventOfHits"))},
0149       pixelClusterIndexToHitToken_{consumes(iConfig.getParameter<edm::InputTag>("mkFitPixelHits"))},
0150       stripClusterIndexToHitToken_{consumes(iConfig.getParameter<edm::InputTag>("mkFitStripHits"))},
0151       mkfitSeedToken_{consumes<MkFitSeedWrapper>(iConfig.getParameter<edm::InputTag>("mkFitSeeds"))},
0152       tracksToken_{consumes<MkFitOutputWrapper>(iConfig.getParameter<edm::InputTag>("tracks"))},
0153       seedToken_{consumes<edm::View<TrajectorySeed>>(iConfig.getParameter<edm::InputTag>("seeds"))},
0154       propagatorAlongToken_{
0155           esConsumes<Propagator, TrackingComponentsRecord>(iConfig.getParameter<edm::ESInputTag>("propagatorAlong"))},
0156       propagatorOppositeToken_{esConsumes<Propagator, TrackingComponentsRecord>(
0157           iConfig.getParameter<edm::ESInputTag>("propagatorOpposite"))},
0158       mfToken_{esConsumes<MagneticField, IdealMagneticFieldRecord>()},
0159       ttrhBuilderToken_{esConsumes<TransientTrackingRecHitBuilder, TransientRecHitRecord>(
0160           iConfig.getParameter<edm::ESInputTag>("ttrhBuilder"))},
0161       mkFitGeomToken_{esConsumes<MkFitGeometry, TrackerRecoGeometryRecord>()},
0162       putTrackCandidateToken_{produces<TrackCandidateCollection>()},
0163       putSeedStopInfoToken_{produces<std::vector<SeedStopInfo>>()},
0164       qualityMaxInvPt_{float(iConfig.getParameter<double>("qualityMaxInvPt"))},
0165       qualityMinTheta_{float(iConfig.getParameter<double>("qualityMinTheta"))},
0166       qualityMaxRsq_{float(pow(iConfig.getParameter<double>("qualityMaxR"), 2))},
0167       qualityMaxZ_{float(iConfig.getParameter<double>("qualityMaxZ"))},
0168       qualityMaxPosErrSq_{float(pow(iConfig.getParameter<double>("qualityMaxPosErr"), 2))},
0169       qualitySignPt_{iConfig.getParameter<bool>("qualitySignPt")},
0170       doErrorRescale_{iConfig.getParameter<bool>("doErrorRescale")},
0171       algo_{reco::TrackBase::algoByName(
0172           TString(iConfig.getParameter<edm::InputTag>("seeds").label()).ReplaceAll("Seeds", "").Data())},
0173       algoCandSelection_{bool(iConfig.getParameter<bool>("candMVASel"))},
0174       algoCandWorkingPoint_{float(iConfig.getParameter<double>("candWP"))},
0175       bsize_{int(iConfig.getParameter<int>("batchSize"))},
0176       bsToken_(algoCandSelection_ ? consumes<reco::BeamSpot>(edm::InputTag("offlineBeamSpot"))
0177                                   : edm::EDGetTokenT<reco::BeamSpot>()),
0178       verticesToken_(algoCandSelection_ ? consumes<reco::VertexCollection>(edm::InputTag("firstStepPrimaryVertices"))
0179                                         : edm::EDGetTokenT<reco::VertexCollection>()),
0180       tfDnnLabel_(iConfig.getParameter<std::string>("tfDnnLabel")),
0181       tfDnnToken_(esConsumes(edm::ESInputTag("", tfDnnLabel_))) {}
0182 
0183 void MkFitOutputConverter::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0184   edm::ParameterSetDescription desc;
0185 
0186   desc.add("mkFitEventOfHits", edm::InputTag{"mkFitEventOfHits"});
0187   desc.add("mkFitPixelHits", edm::InputTag{"mkFitSiPixelHits"});
0188   desc.add("mkFitStripHits", edm::InputTag{"mkFitSiStripHits"});
0189   desc.add("mkFitSeeds", edm::InputTag{"mkFitSeedConverter"});
0190   desc.add("tracks", edm::InputTag{"mkFitProducer"});
0191   desc.add("seeds", edm::InputTag{"initialStepSeeds"});
0192   desc.add("ttrhBuilder", edm::ESInputTag{"", "WithTrackAngle"});
0193   desc.add("propagatorAlong", edm::ESInputTag{"", "PropagatorWithMaterial"});
0194   desc.add("propagatorOpposite", edm::ESInputTag{"", "PropagatorWithMaterialOpposite"});
0195 
0196   desc.add<double>("qualityMaxInvPt", 100)->setComment("max(1/pt) for converted tracks");
0197   desc.add<double>("qualityMinTheta", 0.01)->setComment("lower bound on theta (or pi-theta) for converted tracks");
0198   desc.add<double>("qualityMaxR", 120)->setComment("max(R) for the state position for converted tracks");
0199   desc.add<double>("qualityMaxZ", 280)->setComment("max(|Z|) for the state position for converted tracks");
0200   desc.add<double>("qualityMaxPosErr", 100)->setComment("max position error for converted tracks");
0201   desc.add<bool>("qualitySignPt", true)->setComment("check sign of 1/pt for converted tracks");
0202 
0203   desc.add<bool>("doErrorRescale", true)->setComment("rescale candidate error before final fit");
0204 
0205   desc.add<std::string>("tfDnnLabel", "trackSelectionTf");
0206 
0207   desc.add<bool>("candMVASel", false)->setComment("flag used to trigger MVA selection at cand level");
0208   desc.add<double>("candWP", 0)->setComment("MVA selection at cand level working point");
0209   desc.add<int>("batchSize", 16)->setComment("batch size for cand DNN evaluation");
0210 
0211   descriptions.addWithDefaultLabel(desc);
0212 }
0213 
0214 void MkFitOutputConverter::produce(edm::StreamID iID, edm::Event& iEvent, const edm::EventSetup& iSetup) const {
0215   const auto& seeds = iEvent.get(seedToken_);
0216   const auto& mkfitSeeds = iEvent.get(mkfitSeedToken_);
0217 
0218   const auto& ttrhBuilder = iSetup.getData(ttrhBuilderToken_);
0219   const auto* tkBuilder = dynamic_cast<TkTransientTrackingRecHitBuilder const*>(&ttrhBuilder);
0220   if (!tkBuilder) {
0221     throw cms::Exception("LogicError") << "TTRHBuilder must be of type TkTransientTrackingRecHitBuilder";
0222   }
0223   const auto& mkFitGeom = iSetup.getData(mkFitGeomToken_);
0224 
0225   // primary vertices under the algo_ because in initialStepPreSplitting there are no firstStepPrimaryVertices
0226   // beamspot as well since the producer can be used in hlt
0227   const reco::VertexCollection* vertices = nullptr;
0228   const reco::BeamSpot* beamspot = nullptr;
0229   if (algoCandSelection_) {
0230     vertices = &iEvent.get(verticesToken_);
0231     beamspot = &iEvent.get(bsToken_);
0232   }
0233 
0234   const tensorflow::Session* session = nullptr;
0235   if (algoCandSelection_)
0236     session = iSetup.getData(tfDnnToken_).getSession();
0237 
0238   // Convert mkfit presentation back to CMSSW
0239   iEvent.emplace(putTrackCandidateToken_,
0240                  convertCandidates(iEvent.get(tracksToken_),
0241                                    iEvent.get(eventOfHitsToken_).get(),
0242                                    iEvent.get(pixelClusterIndexToHitToken_),
0243                                    iEvent.get(stripClusterIndexToHitToken_),
0244                                    seeds,
0245                                    iSetup.getData(mfToken_),
0246                                    iSetup.getData(propagatorAlongToken_),
0247                                    iSetup.getData(propagatorOppositeToken_),
0248                                    iSetup.getData(mkFitGeomToken_),
0249                                    tkBuilder->cloner(),
0250                                    mkFitGeom.detLayers(),
0251                                    mkfitSeeds.seeds(),
0252                                    beamspot,
0253                                    vertices,
0254                                    session));
0255 
0256   // TODO: SeedStopInfo is currently unfilled
0257   iEvent.emplace(putSeedStopInfoToken_, seeds.size());
0258 }
0259 
0260 TrackCandidateCollection MkFitOutputConverter::convertCandidates(const MkFitOutputWrapper& mkFitOutput,
0261                                                                  const mkfit::EventOfHits& eventOfHits,
0262                                                                  const MkFitClusterIndexToHit& pixelClusterIndexToHit,
0263                                                                  const MkFitClusterIndexToHit& stripClusterIndexToHit,
0264                                                                  const edm::View<TrajectorySeed>& seeds,
0265                                                                  const MagneticField& mf,
0266                                                                  const Propagator& propagatorAlong,
0267                                                                  const Propagator& propagatorOpposite,
0268                                                                  const MkFitGeometry& mkFitGeom,
0269                                                                  const TkClonerImpl& hitCloner,
0270                                                                  const std::vector<const DetLayer*>& detLayers,
0271                                                                  const mkfit::TrackVec& mkFitSeeds,
0272                                                                  const reco::BeamSpot* bs,
0273                                                                  const reco::VertexCollection* vertices,
0274                                                                  const tensorflow::Session* session) const {
0275   TrackCandidateCollection output;
0276   const auto& candidates = mkFitOutput.tracks();
0277   output.reserve(candidates.size());
0278 
0279   LogTrace("MkFitOutputConverter") << "Number of candidates " << candidates.size();
0280 
0281   std::vector<float> chi2;
0282   std::vector<TrajectoryStateOnSurface> states;
0283   chi2.reserve(candidates.size());
0284   states.reserve(candidates.size());
0285 
0286   int candIndex = -1;
0287   for (const auto& cand : candidates) {
0288     ++candIndex;
0289     LogTrace("MkFitOutputConverter") << "Candidate " << candIndex << " pT " << cand.pT() << " eta " << cand.momEta()
0290                                      << " phi " << cand.momPhi() << " chi2 " << cand.chi2();
0291 
0292     // state: check for basic quality first
0293     if (cand.state().invpT() > qualityMaxInvPt_ || (qualitySignPt_ && cand.state().invpT() < 0) ||
0294         cand.state().theta() < qualityMinTheta_ || (M_PI - cand.state().theta()) < qualityMinTheta_ ||
0295         cand.state().posRsq() > qualityMaxRsq_ || std::abs(cand.state().z()) > qualityMaxZ_ ||
0296         (cand.state().errors.At(0, 0) + cand.state().errors.At(1, 1) + cand.state().errors.At(2, 2)) >
0297             qualityMaxPosErrSq_) {
0298       edm::LogInfo("MkFitOutputConverter")
0299           << "Candidate " << candIndex << " failed state quality checks" << cand.state().parameters;
0300       continue;
0301     }
0302 
0303     auto state = cand.state();  // copy because have to modify
0304     state.convertFromCCSToGlbCurvilinear();
0305     const auto& param = state.parameters;
0306     const auto& err = state.errors;
0307     AlgebraicSymMatrix55 cov;
0308     for (int i = 0; i < 5; ++i) {
0309       for (int j = i; j < 5; ++j) {
0310         cov[i][j] = err.At(i, j);
0311       }
0312     }
0313 
0314     auto fts = FreeTrajectoryState(
0315         GlobalTrajectoryParameters(
0316             GlobalPoint(param[0], param[1], param[2]), GlobalVector(param[3], param[4], param[5]), state.charge, &mf),
0317         CurvilinearTrajectoryError(cov));
0318     if (!fts.curvilinearError().posDef()) {
0319       edm::LogInfo("MkFitOutputConverter")
0320           << "Curvilinear error not pos-def\n"
0321           << fts.curvilinearError().matrix() << "\ncandidate " << candIndex << "ignored";
0322       continue;
0323     }
0324 
0325     //Sylvester's criterion, start from the smaller submatrix size
0326     double det = 0;
0327     if ((!fts.curvilinearError().matrix().Sub<AlgebraicSymMatrix22>(0, 0).Det(det)) || det < 0) {
0328       edm::LogInfo("MkFitOutputConverter")
0329           << "Fail pos-def check sub2.det for candidate " << candIndex << " with fts " << fts;
0330       continue;
0331     } else if ((!fts.curvilinearError().matrix().Sub<AlgebraicSymMatrix33>(0, 0).Det(det)) || det < 0) {
0332       edm::LogInfo("MkFitOutputConverter")
0333           << "Fail pos-def check sub3.det for candidate " << candIndex << " with fts " << fts;
0334       continue;
0335     } else if ((!fts.curvilinearError().matrix().Sub<AlgebraicSymMatrix44>(0, 0).Det(det)) || det < 0) {
0336       edm::LogInfo("MkFitOutputConverter")
0337           << "Fail pos-def check sub4.det for candidate " << candIndex << " with fts " << fts;
0338       continue;
0339     } else if ((!fts.curvilinearError().matrix().Det2(det)) || det < 0) {
0340       edm::LogInfo("MkFitOutputConverter")
0341           << "Fail pos-def check det for candidate " << candIndex << " with fts " << fts;
0342       continue;
0343     }
0344 
0345     // hits
0346     edm::OwnVector<TrackingRecHit> recHits;
0347     // nTotalHits() gives sum of valid hits (nFoundHits()) and invalid/missing hits.
0348     const int nhits = cand.nTotalHits();
0349     bool lastHitInvalid = false;
0350     for (int i = 0; i < nhits; ++i) {
0351       const auto& hitOnTrack = cand.getHitOnTrack(i);
0352       LogTrace("MkFitOutputConverter") << " hit on layer " << hitOnTrack.layer << " index " << hitOnTrack.index;
0353       if (hitOnTrack.index < 0) {
0354         // See index-desc.txt file in mkFit for description of negative values
0355         //
0356         // In order to use the regular InvalidTrackingRecHit I'd need
0357         // a GeomDet (and "unfortunately" that is needed in
0358         // TrackProducer).
0359         //
0360         // I guess we could take the track state and propagate it to
0361         // each layer to find the actual module the track crosses, and
0362         // check whether it is active or not to be able to mark
0363         // inactive hits
0364         const auto* detLayer = detLayers.at(hitOnTrack.layer);
0365         if (detLayer == nullptr) {
0366           throw cms::Exception("LogicError") << "DetLayer for layer index " << hitOnTrack.layer << " is null!";
0367         }
0368         // In principle an InvalidTrackingRecHitNoDet could be
0369         // inserted here, but it seems that it is best to deal with
0370         // them in the TrackProducer.
0371         lastHitInvalid = true;
0372       } else {
0373         auto const isPixel = eventOfHits[hitOnTrack.layer].is_pixel();
0374         auto const& hits = isPixel ? pixelClusterIndexToHit.hits() : stripClusterIndexToHit.hits();
0375 
0376         auto const& thit = static_cast<BaseTrackerRecHit const&>(*hits[hitOnTrack.index]);
0377         if (mkFitGeom.isPhase1()) {
0378           if (thit.firstClusterRef().isPixel() || thit.detUnit()->type().isEndcap()) {
0379             recHits.push_back(hits[hitOnTrack.index]->clone());
0380           } else {
0381             recHits.push_back(std::make_unique<SiStripRecHit1D>(
0382                 thit.localPosition(),
0383                 LocalError(thit.localPositionError().xx(), 0.f, std::numeric_limits<float>::max()),
0384                 *thit.det(),
0385                 thit.firstClusterRef()));
0386           }
0387         } else {
0388           recHits.push_back(hits[hitOnTrack.index]->clone());
0389         }
0390         LogTrace("MkFitOutputConverter") << "  pos " << recHits.back().globalPosition().x() << " "
0391                                          << recHits.back().globalPosition().y() << " "
0392                                          << recHits.back().globalPosition().z() << " mag2 "
0393                                          << recHits.back().globalPosition().mag2() << " detid "
0394                                          << recHits.back().geographicalId().rawId() << " cluster " << hitOnTrack.index;
0395         lastHitInvalid = false;
0396       }
0397     }
0398 
0399     const auto lastHitId = recHits.back().geographicalId();
0400 
0401     // MkFit hits are *not* in the order of propagation, sort by 3D radius for now (as we don't have loopers)
0402     // TODO: Improve the sorting (extract keys? maybe even bubble sort would work well as the hits are almost in the correct order)
0403     recHits.sort([](const auto& a, const auto& b) {
0404       const auto asub = a.geographicalId().subdetId();
0405       const auto bsub = b.geographicalId().subdetId();
0406       if (asub != bsub) {
0407         // Subdetector order (BPix, FPix, TIB, TID, TOB, TEC) corresponds also the navigation
0408         return asub < bsub;
0409       }
0410 
0411       const auto& apos = a.globalPosition();
0412       const auto& bpos = b.globalPosition();
0413 
0414       if (isBarrel(asub)) {
0415         return apos.perp2() < bpos.perp2();
0416       }
0417       return std::abs(apos.z()) < std::abs(bpos.z());
0418     });
0419 
0420     const bool lastHitChanged = (recHits.back().geographicalId() != lastHitId);  // TODO: make use of the bools
0421 
0422     // seed
0423     const auto seedIndex = cand.label();
0424     LogTrace("MkFitOutputConverter") << " from seed " << seedIndex << " seed hits";
0425 
0426     // Rescale candidate error if candidate is already propagated to first layer,
0427     // to be consistent with TransientInitialStateEstimator::innerState used in CkfTrackCandidateMakerBase
0428     // Error is only rescaled for candidates propagated to first layer;
0429     // otherwise, candidates undergo backwardFit where error is already rescaled
0430     if (mkFitOutput.propagatedToFirstLayer() && doErrorRescale_)
0431       fts.rescaleError(100.);
0432     auto tsosDet =
0433         mkFitOutput.propagatedToFirstLayer()
0434             ? convertInnermostState(fts, recHits, propagatorAlong, propagatorOpposite)
0435             : backwardFit(fts, recHits, propagatorAlong, propagatorOpposite, hitCloner, lastHitInvalid, lastHitChanged);
0436     if (!tsosDet.first.isValid()) {
0437       edm::LogInfo("MkFitOutputConverter")
0438           << "Backward fit of candidate " << candIndex << " failed, ignoring the candidate";
0439       continue;
0440     }
0441 
0442     // convert to persistent, from CkfTrackCandidateMakerBase
0443     auto pstate = trajectoryStateTransform::persistentState(tsosDet.first, tsosDet.second->geographicalId().rawId());
0444 
0445     output.emplace_back(
0446         recHits,
0447         seeds.at(seedIndex),
0448         pstate,
0449         seeds.refAt(seedIndex),
0450         0,                                               // mkFit does not produce loopers, so set nLoops=0
0451         static_cast<uint8_t>(StopReason::UNINITIALIZED)  // TODO: ignore details of stopping reason as well for now
0452     );
0453 
0454     chi2.push_back(cand.chi2());
0455     states.push_back(tsosDet.first);
0456   }
0457 
0458   if (algoCandSelection_) {
0459     const auto& dnnScores = computeDNNs(
0460         output, states, bs, vertices, session, chi2, mkFitOutput.propagatedToFirstLayer() && doErrorRescale_);
0461 
0462     TrackCandidateCollection reducedOutput;
0463     reducedOutput.reserve(output.size());
0464     int scoreIndex = 0;
0465     for (const auto& score : dnnScores) {
0466       if (score > algoCandWorkingPoint_)
0467         reducedOutput.push_back(output[scoreIndex]);
0468       scoreIndex++;
0469     }
0470 
0471     output.swap(reducedOutput);
0472   }
0473 
0474   return output;
0475 }
0476 
0477 std::pair<TrajectoryStateOnSurface, const GeomDet*> MkFitOutputConverter::backwardFit(
0478     const FreeTrajectoryState& fts,
0479     const edm::OwnVector<TrackingRecHit>& hits,
0480     const Propagator& propagatorAlong,
0481     const Propagator& propagatorOpposite,
0482     const TkClonerImpl& hitCloner,
0483     bool lastHitWasInvalid,
0484     bool lastHitWasChanged) const {
0485   // First filter valid hits as in TransientInitialStateEstimator
0486   TransientTrackingRecHit::ConstRecHitContainer firstHits;
0487 
0488   for (int i = hits.size() - 1; i >= 0; --i) {
0489     if (hits[i].det()) {
0490       // TransientTrackingRecHit::ConstRecHitContainer has shared_ptr,
0491       // and it is passed to backFitter below so it is really needed
0492       // to keep the interface. Since we keep the ownership in hits,
0493       // let's disable the deleter.
0494       firstHits.emplace_back(&(hits[i]), edm::do_nothing_deleter{});
0495     }
0496   }
0497 
0498   // Then propagate along to the surface of the last hit to get a TSOS
0499   const auto& lastHitSurface = firstHits.front()->det()->surface();
0500 
0501   const Propagator* tryFirst = &propagatorAlong;
0502   const Propagator* trySecond = &propagatorOpposite;
0503   if (lastHitWasInvalid || lastHitWasChanged) {
0504     LogTrace("MkFitOutputConverter") << "Propagating first opposite, then along, because lastHitWasInvalid? "
0505                                      << lastHitWasInvalid << " or lastHitWasChanged? " << lastHitWasChanged;
0506     std::swap(tryFirst, trySecond);
0507   } else {
0508     const auto lastHitSubdet = firstHits.front()->geographicalId().subdetId();
0509     const auto& surfacePos = lastHitSurface.position();
0510     const auto& lastHitPos = firstHits.front()->globalPosition();
0511     bool doSwitch = false;
0512     if (isBarrel(lastHitSubdet)) {
0513       doSwitch = (surfacePos.perp2() < lastHitPos.perp2());
0514     } else {
0515       doSwitch = (surfacePos.z() < lastHitPos.z());
0516     }
0517     if (doSwitch) {
0518       LogTrace("MkFitOutputConverter")
0519           << "Propagating first opposite, then along, because surface is inner than the hit; surface perp2 "
0520           << surfacePos.perp() << " hit " << lastHitPos.perp2() << " surface z " << surfacePos.z() << " hit "
0521           << lastHitPos.z();
0522 
0523       std::swap(tryFirst, trySecond);
0524     }
0525   }
0526 
0527   auto tsosDouble = tryFirst->propagateWithPath(fts, lastHitSurface);
0528   if (!tsosDouble.first.isValid()) {
0529     LogDebug("MkFitOutputConverter") << "Propagating to startingState failed, trying in another direction next";
0530     tsosDouble = trySecond->propagateWithPath(fts, lastHitSurface);
0531   }
0532   auto& startingState = tsosDouble.first;
0533 
0534   if (!startingState.isValid()) {
0535     edm::LogWarning("MkFitOutputConverter")
0536         << "startingState is not valid, FTS was\n"
0537         << fts << " last hit surface surface:"
0538         << "\n position " << lastHitSurface.position() << "\n phiSpan " << lastHitSurface.phiSpan().first << ","
0539         << lastHitSurface.phiSpan().first << "\n rSpan " << lastHitSurface.rSpan().first << ","
0540         << lastHitSurface.rSpan().first << "\n zSpan " << lastHitSurface.zSpan().first << ","
0541         << lastHitSurface.zSpan().first;
0542     return std::pair<TrajectoryStateOnSurface, const GeomDet*>();
0543   }
0544 
0545   // Then return back to the logic from TransientInitialStateEstimator
0546   startingState.rescaleError(100.);
0547 
0548   // avoid cloning
0549   KFUpdator const aKFUpdator;
0550   Chi2MeasurementEstimator const aChi2MeasurementEstimator(100., 3);
0551   KFTrajectoryFitter backFitter(
0552       &propagatorAlong, &aKFUpdator, &aChi2MeasurementEstimator, firstHits.size(), nullptr, &hitCloner);
0553 
0554   // assume for now that the propagation in mkfit always alongMomentum
0555   PropagationDirection backFitDirection = oppositeToMomentum;
0556 
0557   // only direction matters in this context
0558   TrajectorySeed fakeSeed(PTrajectoryStateOnDet(), edm::OwnVector<TrackingRecHit>(), backFitDirection);
0559 
0560   // ignore loopers for now
0561   Trajectory fitres = backFitter.fitOne(fakeSeed, firstHits, startingState, TrajectoryFitter::standard);
0562 
0563   LogDebug("MkFitOutputConverter") << "using a backward fit of :" << firstHits.size() << " hits, starting from:\n"
0564                                    << startingState << " to get the estimate of the initial state of the track.";
0565 
0566   if (!fitres.isValid()) {
0567     edm::LogWarning("MkFitOutputConverter") << "FitTester: first hits fit failed";
0568     return std::pair<TrajectoryStateOnSurface, const GeomDet*>();
0569   }
0570 
0571   TrajectoryMeasurement const& firstMeas = fitres.lastMeasurement();
0572 
0573   // magnetic field can be different!
0574   TrajectoryStateOnSurface firstState(firstMeas.updatedState().localParameters(),
0575                                       firstMeas.updatedState().localError(),
0576                                       firstMeas.updatedState().surface(),
0577                                       propagatorAlong.magneticField());
0578 
0579   firstState.rescaleError(100.);
0580 
0581   LogDebug("MkFitOutputConverter") << "the initial state is found to be:\n:" << firstState
0582                                    << "\n it's field pointer is: " << firstState.magneticField()
0583                                    << "\n the pointer from the state of the back fit was: "
0584                                    << firstMeas.updatedState().magneticField();
0585 
0586   return std::make_pair(firstState, firstMeas.recHit()->det());
0587 }
0588 
0589 std::pair<TrajectoryStateOnSurface, const GeomDet*> MkFitOutputConverter::convertInnermostState(
0590     const FreeTrajectoryState& fts,
0591     const edm::OwnVector<TrackingRecHit>& hits,
0592     const Propagator& propagatorAlong,
0593     const Propagator& propagatorOpposite) const {
0594   auto det = hits[0].det();
0595   if (det == nullptr) {
0596     throw cms::Exception("LogicError") << "Got nullptr from the first hit det()";
0597   }
0598 
0599   const auto& firstHitSurface = det->surface();
0600 
0601   auto tsosDouble = propagatorAlong.propagateWithPath(fts, firstHitSurface);
0602   if (!tsosDouble.first.isValid()) {
0603     LogDebug("MkFitOutputConverter") << "Propagating to startingState along momentum failed, trying opposite next";
0604     tsosDouble = propagatorOpposite.propagateWithPath(fts, firstHitSurface);
0605   }
0606 
0607   return std::make_pair(tsosDouble.first, det);
0608 }
0609 
0610 std::vector<float> MkFitOutputConverter::computeDNNs(TrackCandidateCollection const& tkCC,
0611                                                      const std::vector<TrajectoryStateOnSurface>& states,
0612                                                      const reco::BeamSpot* bs,
0613                                                      const reco::VertexCollection* vertices,
0614                                                      const tensorflow::Session* session,
0615                                                      const std::vector<float>& chi2,
0616                                                      const bool rescaledError) const {
0617   int size_in = (int)tkCC.size();
0618   int nbatches = size_in / bsize_;
0619 
0620   std::vector<float> output(size_in, 0);
0621 
0622   TSCBLBuilderNoMaterial tscblBuilder;
0623 
0624   tensorflow::Tensor input1(tensorflow::DT_FLOAT, {bsize_, 29});
0625   tensorflow::Tensor input2(tensorflow::DT_FLOAT, {bsize_, 1});
0626 
0627   for (auto nb = 0; nb < nbatches + 1; nb++) {
0628     std::vector<bool> invalidProp(bsize_, false);
0629 
0630     for (auto nt = 0; nt < bsize_; nt++) {
0631       int itrack = nt + bsize_ * nb;
0632       if (itrack >= size_in)
0633         continue;
0634 
0635       auto const& tkC = tkCC.at(itrack);
0636 
0637       TrajectoryStateOnSurface state = states.at(itrack);
0638 
0639       if (rescaledError)
0640         state.rescaleError(1 / 100.f);
0641 
0642       TrajectoryStateClosestToBeamLine tsAtClosestApproachTrackCand =
0643           tscblBuilder(*state.freeState(), *bs);  //as in TrackProducerAlgorithm
0644 
0645       if (!(tsAtClosestApproachTrackCand.isValid())) {
0646         edm::LogVerbatim("TrackBuilding") << "TrajectoryStateClosestToBeamLine not valid";
0647         invalidProp[nt] = true;
0648         continue;
0649       }
0650 
0651       auto const& stateAtPCA = tsAtClosestApproachTrackCand.trackStateAtPCA();
0652       auto v0 = stateAtPCA.position();
0653       auto p = stateAtPCA.momentum();
0654       math::XYZPoint pos(v0.x(), v0.y(), v0.z());
0655       math::XYZVector mom(p.x(), p.y(), p.z());
0656 
0657       //pseudo track for access to easy methods
0658       reco::Track trk(0, 0, pos, mom, stateAtPCA.charge(), stateAtPCA.curvilinearError());
0659 
0660       // get best vertex
0661       float dzmin = std::numeric_limits<float>::max();
0662       float dxy_zmin = 0;
0663 
0664       for (auto const& vertex : *vertices) {
0665         if (std::abs(trk.dz(vertex.position())) < dzmin) {
0666           dzmin = trk.dz(vertex.position());
0667           dxy_zmin = trk.dxy(vertex.position());
0668         }
0669       }
0670 
0671       // loop over the RecHits
0672       int ndof = 0;
0673       int pix = 0;
0674       int strip = 0;
0675       for (auto const& recHit : tkC.recHits()) {
0676         ndof += recHit.dimension();
0677         auto const subdet = recHit.geographicalId().subdetId();
0678         if (subdet == PixelSubdetector::PixelBarrel || subdet == PixelSubdetector::PixelEndcap)
0679           pix++;
0680         else
0681           strip++;
0682       }
0683       ndof = ndof - 5;
0684 
0685       input1.matrix<float>()(nt, 0) = trk.pt();  //using inner track only
0686       input1.matrix<float>()(nt, 1) = p.x();
0687       input1.matrix<float>()(nt, 2) = p.y();
0688       input1.matrix<float>()(nt, 3) = p.z();
0689       input1.matrix<float>()(nt, 4) = p.perp();
0690       input1.matrix<float>()(nt, 5) = p.x();
0691       input1.matrix<float>()(nt, 6) = p.y();
0692       input1.matrix<float>()(nt, 7) = p.z();
0693       input1.matrix<float>()(nt, 8) = p.perp();
0694       input1.matrix<float>()(nt, 9) = trk.ptError();
0695       input1.matrix<float>()(nt, 10) = dxy_zmin;
0696       input1.matrix<float>()(nt, 11) = dzmin;
0697       input1.matrix<float>()(nt, 12) = trk.dxy(bs->position());
0698       input1.matrix<float>()(nt, 13) = trk.dz(bs->position());
0699       input1.matrix<float>()(nt, 14) = trk.dxyError();
0700       input1.matrix<float>()(nt, 15) = trk.dzError();
0701       input1.matrix<float>()(nt, 16) = ndof > 0 ? chi2[itrack] / ndof : chi2[itrack] * 1e6;
0702       input1.matrix<float>()(nt, 17) = trk.eta();
0703       input1.matrix<float>()(nt, 18) = trk.phi();
0704       input1.matrix<float>()(nt, 19) = trk.etaError();
0705       input1.matrix<float>()(nt, 20) = trk.phiError();
0706       input1.matrix<float>()(nt, 21) = pix;
0707       input1.matrix<float>()(nt, 22) = strip;
0708       input1.matrix<float>()(nt, 23) = ndof;
0709       input1.matrix<float>()(nt, 24) = 0;
0710       input1.matrix<float>()(nt, 25) = 0;
0711       input1.matrix<float>()(nt, 26) = 0;
0712       input1.matrix<float>()(nt, 27) = 0;
0713       input1.matrix<float>()(nt, 28) = 0;
0714 
0715       input2.matrix<float>()(nt, 0) = algo_;
0716     }
0717 
0718     //inputs finalized
0719     tensorflow::NamedTensorList inputs;
0720     inputs.resize(2);
0721     inputs[0] = tensorflow::NamedTensor("x", input1);
0722     inputs[1] = tensorflow::NamedTensor("y", input2);
0723 
0724     //eval and rescale
0725     std::vector<tensorflow::Tensor> outputs;
0726     tensorflow::run(session, inputs, {"Identity"}, &outputs);
0727 
0728     for (auto nt = 0; nt < bsize_; nt++) {
0729       int itrack = nt + bsize_ * nb;
0730       if (itrack >= size_in)
0731         continue;
0732 
0733       float out0 = 2.0 * outputs[0].matrix<float>()(nt, 0) - 1.0;
0734       if (invalidProp[nt])
0735         out0 = -1;
0736 
0737       output[itrack] = out0;
0738     }
0739   }
0740 
0741   return output;
0742 }
0743 
0744 DEFINE_FWK_MODULE(MkFitOutputConverter);