Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-05-26 03:56:17

0001 #include "FWCore/Framework/interface/global/EDProducer.h"
0002 
0003 #include "FWCore/Framework/interface/Event.h"
0004 #include "FWCore/Framework/interface/MakerMacros.h"
0005 #include "FWCore/Utilities/interface/do_nothing_deleter.h"
0006 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0007 
0008 #include "DataFormats/SiPixelDetId/interface/PixelSubdetector.h"
0009 #include "DataFormats/SiStripDetId/interface/StripSubdetector.h"
0010 #include "DataFormats/TrajectorySeed/interface/TrajectorySeed.h"
0011 #include "DataFormats/TrackCandidate/interface/TrackCandidateCollection.h"
0012 #include "DataFormats/TrackReco/interface/SeedStopInfo.h"
0013 #include "DataFormats/TrackingRecHit/interface/InvalidTrackingRecHit.h"
0014 #include "DataFormats/TrackerRecHit2D/interface/SiStripRecHit1D.h"
0015 
0016 #include "TrackingTools/Records/interface/TransientRecHitRecord.h"
0017 #include "TrackingTools/TransientTrackingRecHit/interface/TransientTrackingRecHitBuilder.h"
0018 #include "TrackingTools/TrajectoryState/interface/TrajectoryStateTransform.h"
0019 
0020 #include "MagneticField/Engine/interface/MagneticField.h"
0021 #include "MagneticField/Records/interface/IdealMagneticFieldRecord.h"
0022 
0023 #include "TrackingTools/GeomPropagators/interface/Propagator.h"
0024 #include "TrackingTools/Records/interface/TrackingComponentsRecord.h"
0025 #include "TrackingTools/KalmanUpdators/interface/KFUpdator.h"
0026 #include "TrackingTools/KalmanUpdators/interface/Chi2MeasurementEstimator.h"
0027 #include "TrackingTools/TrackFitters/interface/KFTrajectoryFitter.h"
0028 #include "RecoTracker/TransientTrackingRecHit/interface/TkClonerImpl.h"
0029 #include "RecoTracker/TransientTrackingRecHit/interface/TkTransientTrackingRecHitBuilder.h"
0030 #include "TrackingTools/MaterialEffects/interface/PropagatorWithMaterial.h"
0031 
0032 #include "RecoTracker/MkFit/interface/MkFitEventOfHits.h"
0033 #include "RecoTracker/MkFit/interface/MkFitClusterIndexToHit.h"
0034 #include "RecoTracker/MkFit/interface/MkFitSeedWrapper.h"
0035 #include "RecoTracker/MkFit/interface/MkFitOutputWrapper.h"
0036 #include "RecoTracker/MkFit/interface/MkFitGeometry.h"
0037 #include "RecoTracker/Record/interface/TrackerRecoGeometryRecord.h"
0038 
0039 // mkFit indludes
0040 #include "RecoTracker/MkFitCMS/interface/LayerNumberConverter.h"
0041 #include "RecoTracker/MkFitCore/interface/Track.h"
0042 #include "RecoTracker/MkFitCore/interface/HitStructures.h"
0043 
0044 //extra for DNN with cands
0045 #include "PhysicsTools/TensorFlow/interface/TfGraphRecord.h"
0046 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0047 #include "PhysicsTools/TensorFlow/interface/TfGraphDefWrapper.h"
0048 #include "TrackingTools/PatternTools/interface/TSCBLBuilderNoMaterial.h"
0049 #include "DataFormats/BeamSpot/interface/BeamSpot.h"
0050 #include "DataFormats/VertexReco/interface/Vertex.h"
0051 
0052 namespace {
0053   template <typename T>
0054   bool isBarrel(T subdet) {
0055     return subdet == PixelSubdetector::PixelBarrel || subdet == StripSubdetector::TIB ||
0056            subdet == StripSubdetector::TOB;
0057   }
0058 
0059   template <typename T>
0060   bool isEndcap(T subdet) {
0061     return subdet == PixelSubdetector::PixelEndcap || subdet == StripSubdetector::TID ||
0062            subdet == StripSubdetector::TEC;
0063   }
0064 }  // namespace
0065 
0066 class MkFitOutputConverter : public edm::global::EDProducer<> {
0067 public:
0068   explicit MkFitOutputConverter(edm::ParameterSet const& iConfig);
0069   ~MkFitOutputConverter() override = default;
0070 
0071   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0072 
0073 private:
0074   void produce(edm::StreamID, edm::Event& iEvent, const edm::EventSetup& iSetup) const override;
0075 
0076   TrackCandidateCollection convertCandidates(const MkFitOutputWrapper& mkFitOutput,
0077                                              const mkfit::EventOfHits& eventOfHits,
0078                                              const MkFitClusterIndexToHit& pixelClusterIndexToHit,
0079                                              const MkFitClusterIndexToHit& stripClusterIndexToHit,
0080                                              const edm::View<TrajectorySeed>& seeds,
0081                                              const MagneticField& mf,
0082                                              const Propagator& propagatorAlong,
0083                                              const Propagator& propagatorOpposite,
0084                                              const TkClonerImpl& hitCloner,
0085                                              const std::vector<const DetLayer*>& detLayers,
0086                                              const mkfit::TrackVec& mkFitSeeds,
0087                                              const reco::BeamSpot* bs,
0088                                              const reco::VertexCollection* vertices,
0089                                              const tensorflow::Session* session) const;
0090 
0091   std::pair<TrajectoryStateOnSurface, const GeomDet*> backwardFit(const FreeTrajectoryState& fts,
0092                                                                   const edm::OwnVector<TrackingRecHit>& hits,
0093                                                                   const Propagator& propagatorAlong,
0094                                                                   const Propagator& propagatorOpposite,
0095                                                                   const TkClonerImpl& hitCloner,
0096                                                                   bool lastHitWasInvalid,
0097                                                                   bool lastHitWasChanged) const;
0098 
0099   std::pair<TrajectoryStateOnSurface, const GeomDet*> convertInnermostState(const FreeTrajectoryState& fts,
0100                                                                             const edm::OwnVector<TrackingRecHit>& hits,
0101                                                                             const Propagator& propagatorAlong,
0102                                                                             const Propagator& propagatorOpposite) const;
0103 
0104   std::vector<float> computeDNNs(TrackCandidateCollection const& tkCC,
0105                                  const std::vector<TrajectoryStateOnSurface>& states,
0106                                  const reco::BeamSpot* bs,
0107                                  const reco::VertexCollection* vertices,
0108                                  const tensorflow::Session* session,
0109                                  const std::vector<float>& chi2,
0110                                  const bool rescaledError) const;
0111 
0112   const edm::EDGetTokenT<MkFitEventOfHits> eventOfHitsToken_;
0113   const edm::EDGetTokenT<MkFitClusterIndexToHit> pixelClusterIndexToHitToken_;
0114   const edm::EDGetTokenT<MkFitClusterIndexToHit> stripClusterIndexToHitToken_;
0115   const edm::EDGetTokenT<MkFitSeedWrapper> mkfitSeedToken_;
0116   const edm::EDGetTokenT<MkFitOutputWrapper> tracksToken_;
0117   const edm::EDGetTokenT<edm::View<TrajectorySeed>> seedToken_;
0118   const edm::ESGetToken<Propagator, TrackingComponentsRecord> propagatorAlongToken_;
0119   const edm::ESGetToken<Propagator, TrackingComponentsRecord> propagatorOppositeToken_;
0120   const edm::ESGetToken<MagneticField, IdealMagneticFieldRecord> mfToken_;
0121   const edm::ESGetToken<TransientTrackingRecHitBuilder, TransientRecHitRecord> ttrhBuilderToken_;
0122   const edm::ESGetToken<MkFitGeometry, TrackerRecoGeometryRecord> mkFitGeomToken_;
0123   const edm::EDPutTokenT<TrackCandidateCollection> putTrackCandidateToken_;
0124   const edm::EDPutTokenT<std::vector<SeedStopInfo>> putSeedStopInfoToken_;
0125 
0126   const float qualityMaxInvPt_;
0127   const float qualityMinTheta_;
0128   const float qualityMaxRsq_;
0129   const float qualityMaxZ_;
0130   const float qualityMaxPosErrSq_;
0131   const bool qualitySignPt_;
0132 
0133   const bool doErrorRescale_;
0134 
0135   const int algo_;
0136   const bool algoCandSelection_;
0137   const float algoCandWorkingPoint_;
0138   const int bsize_;
0139   const edm::EDGetTokenT<reco::BeamSpot> bsToken_;
0140   const edm::EDGetTokenT<reco::VertexCollection> verticesToken_;
0141   const std::string tfDnnLabel_;
0142   const edm::ESGetToken<TfGraphDefWrapper, TfGraphRecord> tfDnnToken_;
0143 };
0144 
0145 MkFitOutputConverter::MkFitOutputConverter(edm::ParameterSet const& iConfig)
0146     : eventOfHitsToken_{consumes<MkFitEventOfHits>(iConfig.getParameter<edm::InputTag>("mkFitEventOfHits"))},
0147       pixelClusterIndexToHitToken_{consumes(iConfig.getParameter<edm::InputTag>("mkFitPixelHits"))},
0148       stripClusterIndexToHitToken_{consumes(iConfig.getParameter<edm::InputTag>("mkFitStripHits"))},
0149       mkfitSeedToken_{consumes<MkFitSeedWrapper>(iConfig.getParameter<edm::InputTag>("mkFitSeeds"))},
0150       tracksToken_{consumes<MkFitOutputWrapper>(iConfig.getParameter<edm::InputTag>("tracks"))},
0151       seedToken_{consumes<edm::View<TrajectorySeed>>(iConfig.getParameter<edm::InputTag>("seeds"))},
0152       propagatorAlongToken_{
0153           esConsumes<Propagator, TrackingComponentsRecord>(iConfig.getParameter<edm::ESInputTag>("propagatorAlong"))},
0154       propagatorOppositeToken_{esConsumes<Propagator, TrackingComponentsRecord>(
0155           iConfig.getParameter<edm::ESInputTag>("propagatorOpposite"))},
0156       mfToken_{esConsumes<MagneticField, IdealMagneticFieldRecord>()},
0157       ttrhBuilderToken_{esConsumes<TransientTrackingRecHitBuilder, TransientRecHitRecord>(
0158           iConfig.getParameter<edm::ESInputTag>("ttrhBuilder"))},
0159       mkFitGeomToken_{esConsumes<MkFitGeometry, TrackerRecoGeometryRecord>()},
0160       putTrackCandidateToken_{produces<TrackCandidateCollection>()},
0161       putSeedStopInfoToken_{produces<std::vector<SeedStopInfo>>()},
0162       qualityMaxInvPt_{float(iConfig.getParameter<double>("qualityMaxInvPt"))},
0163       qualityMinTheta_{float(iConfig.getParameter<double>("qualityMinTheta"))},
0164       qualityMaxRsq_{float(pow(iConfig.getParameter<double>("qualityMaxR"), 2))},
0165       qualityMaxZ_{float(iConfig.getParameter<double>("qualityMaxZ"))},
0166       qualityMaxPosErrSq_{float(pow(iConfig.getParameter<double>("qualityMaxPosErr"), 2))},
0167       qualitySignPt_{iConfig.getParameter<bool>("qualitySignPt")},
0168       doErrorRescale_{iConfig.getParameter<bool>("doErrorRescale")},
0169       algo_{reco::TrackBase::algoByName(
0170           TString(iConfig.getParameter<edm::InputTag>("seeds").label()).ReplaceAll("Seeds", "").Data())},
0171       algoCandSelection_{bool(iConfig.getParameter<bool>("candMVASel"))},
0172       algoCandWorkingPoint_{float(iConfig.getParameter<double>("candWP"))},
0173       bsize_{int(iConfig.getParameter<int>("batchSize"))},
0174       bsToken_(algoCandSelection_ ? consumes<reco::BeamSpot>(edm::InputTag("offlineBeamSpot"))
0175                                   : edm::EDGetTokenT<reco::BeamSpot>()),
0176       verticesToken_(algoCandSelection_ ? consumes<reco::VertexCollection>(edm::InputTag("firstStepPrimaryVertices"))
0177                                         : edm::EDGetTokenT<reco::VertexCollection>()),
0178       tfDnnLabel_(iConfig.getParameter<std::string>("tfDnnLabel")),
0179       tfDnnToken_(esConsumes(edm::ESInputTag("", tfDnnLabel_))) {}
0180 
0181 void MkFitOutputConverter::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0182   edm::ParameterSetDescription desc;
0183 
0184   desc.add("mkFitEventOfHits", edm::InputTag{"mkFitEventOfHits"});
0185   desc.add("mkFitPixelHits", edm::InputTag{"mkFitSiPixelHits"});
0186   desc.add("mkFitStripHits", edm::InputTag{"mkFitSiStripHits"});
0187   desc.add("mkFitSeeds", edm::InputTag{"mkFitSeedConverter"});
0188   desc.add("tracks", edm::InputTag{"mkFitProducer"});
0189   desc.add("seeds", edm::InputTag{"initialStepSeeds"});
0190   desc.add("ttrhBuilder", edm::ESInputTag{"", "WithTrackAngle"});
0191   desc.add("propagatorAlong", edm::ESInputTag{"", "PropagatorWithMaterial"});
0192   desc.add("propagatorOpposite", edm::ESInputTag{"", "PropagatorWithMaterialOpposite"});
0193 
0194   desc.add<double>("qualityMaxInvPt", 100)->setComment("max(1/pt) for converted tracks");
0195   desc.add<double>("qualityMinTheta", 0.01)->setComment("lower bound on theta (or pi-theta) for converted tracks");
0196   desc.add<double>("qualityMaxR", 120)->setComment("max(R) for the state position for converted tracks");
0197   desc.add<double>("qualityMaxZ", 280)->setComment("max(|Z|) for the state position for converted tracks");
0198   desc.add<double>("qualityMaxPosErr", 100)->setComment("max position error for converted tracks");
0199   desc.add<bool>("qualitySignPt", true)->setComment("check sign of 1/pt for converted tracks");
0200 
0201   desc.add<bool>("doErrorRescale", true)->setComment("rescale candidate error before final fit");
0202 
0203   desc.add<std::string>("tfDnnLabel", "trackSelectionTf");
0204 
0205   desc.add<bool>("candMVASel", false)->setComment("flag used to trigger MVA selection at cand level");
0206   desc.add<double>("candWP", 0)->setComment("MVA selection at cand level working point");
0207   desc.add<int>("batchSize", 16)->setComment("batch size for cand DNN evaluation");
0208 
0209   descriptions.addWithDefaultLabel(desc);
0210 }
0211 
0212 void MkFitOutputConverter::produce(edm::StreamID iID, edm::Event& iEvent, const edm::EventSetup& iSetup) const {
0213   const auto& seeds = iEvent.get(seedToken_);
0214   const auto& mkfitSeeds = iEvent.get(mkfitSeedToken_);
0215 
0216   const auto& ttrhBuilder = iSetup.getData(ttrhBuilderToken_);
0217   const auto* tkBuilder = dynamic_cast<TkTransientTrackingRecHitBuilder const*>(&ttrhBuilder);
0218   if (!tkBuilder) {
0219     throw cms::Exception("LogicError") << "TTRHBuilder must be of type TkTransientTrackingRecHitBuilder";
0220   }
0221   const auto& mkFitGeom = iSetup.getData(mkFitGeomToken_);
0222 
0223   // primary vertices under the algo_ because in initialStepPreSplitting there are no firstStepPrimaryVertices
0224   // beamspot as well since the producer can be used in hlt
0225   const reco::VertexCollection* vertices = nullptr;
0226   const reco::BeamSpot* beamspot = nullptr;
0227   if (algoCandSelection_) {
0228     vertices = &iEvent.get(verticesToken_);
0229     beamspot = &iEvent.get(bsToken_);
0230   }
0231 
0232   const tensorflow::Session* session = nullptr;
0233   if (algoCandSelection_)
0234     session = iSetup.getData(tfDnnToken_).getSession();
0235 
0236   // Convert mkfit presentation back to CMSSW
0237   iEvent.emplace(putTrackCandidateToken_,
0238                  convertCandidates(iEvent.get(tracksToken_),
0239                                    iEvent.get(eventOfHitsToken_).get(),
0240                                    iEvent.get(pixelClusterIndexToHitToken_),
0241                                    iEvent.get(stripClusterIndexToHitToken_),
0242                                    seeds,
0243                                    iSetup.getData(mfToken_),
0244                                    iSetup.getData(propagatorAlongToken_),
0245                                    iSetup.getData(propagatorOppositeToken_),
0246                                    tkBuilder->cloner(),
0247                                    mkFitGeom.detLayers(),
0248                                    mkfitSeeds.seeds(),
0249                                    beamspot,
0250                                    vertices,
0251                                    session));
0252 
0253   // TODO: SeedStopInfo is currently unfilled
0254   iEvent.emplace(putSeedStopInfoToken_, seeds.size());
0255 }
0256 
0257 TrackCandidateCollection MkFitOutputConverter::convertCandidates(const MkFitOutputWrapper& mkFitOutput,
0258                                                                  const mkfit::EventOfHits& eventOfHits,
0259                                                                  const MkFitClusterIndexToHit& pixelClusterIndexToHit,
0260                                                                  const MkFitClusterIndexToHit& stripClusterIndexToHit,
0261                                                                  const edm::View<TrajectorySeed>& seeds,
0262                                                                  const MagneticField& mf,
0263                                                                  const Propagator& propagatorAlong,
0264                                                                  const Propagator& propagatorOpposite,
0265                                                                  const TkClonerImpl& hitCloner,
0266                                                                  const std::vector<const DetLayer*>& detLayers,
0267                                                                  const mkfit::TrackVec& mkFitSeeds,
0268                                                                  const reco::BeamSpot* bs,
0269                                                                  const reco::VertexCollection* vertices,
0270                                                                  const tensorflow::Session* session) const {
0271   TrackCandidateCollection output;
0272   const auto& candidates = mkFitOutput.tracks();
0273   output.reserve(candidates.size());
0274 
0275   LogTrace("MkFitOutputConverter") << "Number of candidates " << candidates.size();
0276 
0277   std::vector<float> chi2;
0278   std::vector<TrajectoryStateOnSurface> states;
0279   chi2.reserve(candidates.size());
0280   states.reserve(candidates.size());
0281 
0282   int candIndex = -1;
0283   for (const auto& cand : candidates) {
0284     ++candIndex;
0285     LogTrace("MkFitOutputConverter") << "Candidate " << candIndex << " pT " << cand.pT() << " eta " << cand.momEta()
0286                                      << " phi " << cand.momPhi() << " chi2 " << cand.chi2();
0287 
0288     // state: check for basic quality first
0289     if (cand.state().invpT() > qualityMaxInvPt_ || (qualitySignPt_ && cand.state().invpT() < 0) ||
0290         cand.state().theta() < qualityMinTheta_ || (M_PI - cand.state().theta()) < qualityMinTheta_ ||
0291         cand.state().posRsq() > qualityMaxRsq_ || std::abs(cand.state().z()) > qualityMaxZ_ ||
0292         (cand.state().errors.At(0, 0) + cand.state().errors.At(1, 1) + cand.state().errors.At(2, 2)) >
0293             qualityMaxPosErrSq_) {
0294       edm::LogInfo("MkFitOutputConverter")
0295           << "Candidate " << candIndex << " failed state quality checks" << cand.state().parameters;
0296       continue;
0297     }
0298 
0299     auto state = cand.state();  // copy because have to modify
0300     state.convertFromCCSToGlbCurvilinear();
0301     const auto& param = state.parameters;
0302     const auto& err = state.errors;
0303     AlgebraicSymMatrix55 cov;
0304     for (int i = 0; i < 5; ++i) {
0305       for (int j = i; j < 5; ++j) {
0306         cov[i][j] = err.At(i, j);
0307       }
0308     }
0309 
0310     auto fts = FreeTrajectoryState(
0311         GlobalTrajectoryParameters(
0312             GlobalPoint(param[0], param[1], param[2]), GlobalVector(param[3], param[4], param[5]), state.charge, &mf),
0313         CurvilinearTrajectoryError(cov));
0314     if (!fts.curvilinearError().posDef()) {
0315       edm::LogInfo("MkFitOutputConverter")
0316           << "Curvilinear error not pos-def\n"
0317           << fts.curvilinearError().matrix() << "\ncandidate " << candIndex << "ignored";
0318       continue;
0319     }
0320 
0321     //Sylvester's criterion, start from the smaller submatrix size
0322     double det = 0;
0323     if ((!fts.curvilinearError().matrix().Sub<AlgebraicSymMatrix22>(0, 0).Det(det)) || det < 0) {
0324       edm::LogInfo("MkFitOutputConverter")
0325           << "Fail pos-def check sub2.det for candidate " << candIndex << " with fts " << fts;
0326       continue;
0327     } else if ((!fts.curvilinearError().matrix().Sub<AlgebraicSymMatrix33>(0, 0).Det(det)) || det < 0) {
0328       edm::LogInfo("MkFitOutputConverter")
0329           << "Fail pos-def check sub3.det for candidate " << candIndex << " with fts " << fts;
0330       continue;
0331     } else if ((!fts.curvilinearError().matrix().Sub<AlgebraicSymMatrix44>(0, 0).Det(det)) || det < 0) {
0332       edm::LogInfo("MkFitOutputConverter")
0333           << "Fail pos-def check sub4.det for candidate " << candIndex << " with fts " << fts;
0334       continue;
0335     } else if ((!fts.curvilinearError().matrix().Det2(det)) || det < 0) {
0336       edm::LogInfo("MkFitOutputConverter")
0337           << "Fail pos-def check det for candidate " << candIndex << " with fts " << fts;
0338       continue;
0339     }
0340 
0341     // hits
0342     edm::OwnVector<TrackingRecHit> recHits;
0343     // nTotalHits() gives sum of valid hits (nFoundHits()) and invalid/missing hits.
0344     const int nhits = cand.nTotalHits();
0345     bool lastHitInvalid = false;
0346     for (int i = 0; i < nhits; ++i) {
0347       const auto& hitOnTrack = cand.getHitOnTrack(i);
0348       LogTrace("MkFitOutputConverter") << " hit on layer " << hitOnTrack.layer << " index " << hitOnTrack.index;
0349       if (hitOnTrack.index < 0) {
0350         // See index-desc.txt file in mkFit for description of negative values
0351         //
0352         // In order to use the regular InvalidTrackingRecHit I'd need
0353         // a GeomDet (and "unfortunately" that is needed in
0354         // TrackProducer).
0355         //
0356         // I guess we could take the track state and propagate it to
0357         // each layer to find the actual module the track crosses, and
0358         // check whether it is active or not to be able to mark
0359         // inactive hits
0360         const auto* detLayer = detLayers.at(hitOnTrack.layer);
0361         if (detLayer == nullptr) {
0362           throw cms::Exception("LogicError") << "DetLayer for layer index " << hitOnTrack.layer << " is null!";
0363         }
0364         // In principle an InvalidTrackingRecHitNoDet could be
0365         // inserted here, but it seems that it is best to deal with
0366         // them in the TrackProducer.
0367         lastHitInvalid = true;
0368       } else {
0369         auto const isPixel = eventOfHits[hitOnTrack.layer].is_pixel();
0370         auto const& hits = isPixel ? pixelClusterIndexToHit.hits() : stripClusterIndexToHit.hits();
0371 
0372         auto const& thit = static_cast<BaseTrackerRecHit const&>(*hits[hitOnTrack.index]);
0373         if (thit.firstClusterRef().isPixel() || thit.detUnit()->type().isEndcap()) {
0374           recHits.push_back(hits[hitOnTrack.index]->clone());
0375         } else {
0376           recHits.push_back(std::make_unique<SiStripRecHit1D>(
0377               thit.localPosition(),
0378               LocalError(thit.localPositionError().xx(), 0.f, std::numeric_limits<float>::max()),
0379               *thit.det(),
0380               thit.firstClusterRef()));
0381         }
0382         LogTrace("MkFitOutputConverter") << "  pos " << recHits.back().globalPosition().x() << " "
0383                                          << recHits.back().globalPosition().y() << " "
0384                                          << recHits.back().globalPosition().z() << " mag2 "
0385                                          << recHits.back().globalPosition().mag2() << " detid "
0386                                          << recHits.back().geographicalId().rawId() << " cluster " << hitOnTrack.index;
0387         lastHitInvalid = false;
0388       }
0389     }
0390 
0391     const auto lastHitId = recHits.back().geographicalId();
0392 
0393     // MkFit hits are *not* in the order of propagation, sort by 3D radius for now (as we don't have loopers)
0394     // TODO: Improve the sorting (extract keys? maybe even bubble sort would work well as the hits are almost in the correct order)
0395     recHits.sort([](const auto& a, const auto& b) {
0396       const auto asub = a.geographicalId().subdetId();
0397       const auto bsub = b.geographicalId().subdetId();
0398       if (asub != bsub) {
0399         // Subdetector order (BPix, FPix, TIB, TID, TOB, TEC) corresponds also the navigation
0400         return asub < bsub;
0401       }
0402 
0403       const auto& apos = a.globalPosition();
0404       const auto& bpos = b.globalPosition();
0405 
0406       if (isBarrel(asub)) {
0407         return apos.perp2() < bpos.perp2();
0408       }
0409       return std::abs(apos.z()) < std::abs(bpos.z());
0410     });
0411 
0412     const bool lastHitChanged = (recHits.back().geographicalId() != lastHitId);  // TODO: make use of the bools
0413 
0414     // seed
0415     const auto seedIndex = cand.label();
0416     LogTrace("MkFitOutputConverter") << " from seed " << seedIndex << " seed hits";
0417 
0418     // Rescale candidate error if candidate is already propagated to first layer,
0419     // to be consistent with TransientInitialStateEstimator::innerState used in CkfTrackCandidateMakerBase
0420     // Error is only rescaled for candidates propagated to first layer;
0421     // otherwise, candidates undergo backwardFit where error is already rescaled
0422     if (mkFitOutput.propagatedToFirstLayer() && doErrorRescale_)
0423       fts.rescaleError(100.);
0424     auto tsosDet =
0425         mkFitOutput.propagatedToFirstLayer()
0426             ? convertInnermostState(fts, recHits, propagatorAlong, propagatorOpposite)
0427             : backwardFit(fts, recHits, propagatorAlong, propagatorOpposite, hitCloner, lastHitInvalid, lastHitChanged);
0428     if (!tsosDet.first.isValid()) {
0429       edm::LogInfo("MkFitOutputConverter")
0430           << "Backward fit of candidate " << candIndex << " failed, ignoring the candidate";
0431       continue;
0432     }
0433 
0434     // convert to persistent, from CkfTrackCandidateMakerBase
0435     auto pstate = trajectoryStateTransform::persistentState(tsosDet.first, tsosDet.second->geographicalId().rawId());
0436 
0437     output.emplace_back(
0438         recHits,
0439         seeds.at(seedIndex),
0440         pstate,
0441         seeds.refAt(seedIndex),
0442         0,                                               // mkFit does not produce loopers, so set nLoops=0
0443         static_cast<uint8_t>(StopReason::UNINITIALIZED)  // TODO: ignore details of stopping reason as well for now
0444     );
0445 
0446     chi2.push_back(cand.chi2());
0447     states.push_back(tsosDet.first);
0448   }
0449 
0450   if (algoCandSelection_) {
0451     const auto& dnnScores = computeDNNs(
0452         output, states, bs, vertices, session, chi2, mkFitOutput.propagatedToFirstLayer() && doErrorRescale_);
0453 
0454     TrackCandidateCollection reducedOutput;
0455     reducedOutput.reserve(output.size());
0456     int scoreIndex = 0;
0457     for (const auto& score : dnnScores) {
0458       if (score > algoCandWorkingPoint_)
0459         reducedOutput.push_back(output[scoreIndex]);
0460       scoreIndex++;
0461     }
0462 
0463     output.swap(reducedOutput);
0464   }
0465 
0466   return output;
0467 }
0468 
0469 std::pair<TrajectoryStateOnSurface, const GeomDet*> MkFitOutputConverter::backwardFit(
0470     const FreeTrajectoryState& fts,
0471     const edm::OwnVector<TrackingRecHit>& hits,
0472     const Propagator& propagatorAlong,
0473     const Propagator& propagatorOpposite,
0474     const TkClonerImpl& hitCloner,
0475     bool lastHitWasInvalid,
0476     bool lastHitWasChanged) const {
0477   // First filter valid hits as in TransientInitialStateEstimator
0478   TransientTrackingRecHit::ConstRecHitContainer firstHits;
0479 
0480   for (int i = hits.size() - 1; i >= 0; --i) {
0481     if (hits[i].det()) {
0482       // TransientTrackingRecHit::ConstRecHitContainer has shared_ptr,
0483       // and it is passed to backFitter below so it is really needed
0484       // to keep the interface. Since we keep the ownership in hits,
0485       // let's disable the deleter.
0486       firstHits.emplace_back(&(hits[i]), edm::do_nothing_deleter{});
0487     }
0488   }
0489 
0490   // Then propagate along to the surface of the last hit to get a TSOS
0491   const auto& lastHitSurface = firstHits.front()->det()->surface();
0492 
0493   const Propagator* tryFirst = &propagatorAlong;
0494   const Propagator* trySecond = &propagatorOpposite;
0495   if (lastHitWasInvalid || lastHitWasChanged) {
0496     LogTrace("MkFitOutputConverter") << "Propagating first opposite, then along, because lastHitWasInvalid? "
0497                                      << lastHitWasInvalid << " or lastHitWasChanged? " << lastHitWasChanged;
0498     std::swap(tryFirst, trySecond);
0499   } else {
0500     const auto lastHitSubdet = firstHits.front()->geographicalId().subdetId();
0501     const auto& surfacePos = lastHitSurface.position();
0502     const auto& lastHitPos = firstHits.front()->globalPosition();
0503     bool doSwitch = false;
0504     if (isBarrel(lastHitSubdet)) {
0505       doSwitch = (surfacePos.perp2() < lastHitPos.perp2());
0506     } else {
0507       doSwitch = (surfacePos.z() < lastHitPos.z());
0508     }
0509     if (doSwitch) {
0510       LogTrace("MkFitOutputConverter")
0511           << "Propagating first opposite, then along, because surface is inner than the hit; surface perp2 "
0512           << surfacePos.perp() << " hit " << lastHitPos.perp2() << " surface z " << surfacePos.z() << " hit "
0513           << lastHitPos.z();
0514 
0515       std::swap(tryFirst, trySecond);
0516     }
0517   }
0518 
0519   auto tsosDouble = tryFirst->propagateWithPath(fts, lastHitSurface);
0520   if (!tsosDouble.first.isValid()) {
0521     LogDebug("MkFitOutputConverter") << "Propagating to startingState failed, trying in another direction next";
0522     tsosDouble = trySecond->propagateWithPath(fts, lastHitSurface);
0523   }
0524   auto& startingState = tsosDouble.first;
0525 
0526   if (!startingState.isValid()) {
0527     edm::LogWarning("MkFitOutputConverter")
0528         << "startingState is not valid, FTS was\n"
0529         << fts << " last hit surface surface:"
0530         << "\n position " << lastHitSurface.position() << "\n phiSpan " << lastHitSurface.phiSpan().first << ","
0531         << lastHitSurface.phiSpan().first << "\n rSpan " << lastHitSurface.rSpan().first << ","
0532         << lastHitSurface.rSpan().first << "\n zSpan " << lastHitSurface.zSpan().first << ","
0533         << lastHitSurface.zSpan().first;
0534     return std::pair<TrajectoryStateOnSurface, const GeomDet*>();
0535   }
0536 
0537   // Then return back to the logic from TransientInitialStateEstimator
0538   startingState.rescaleError(100.);
0539 
0540   // avoid cloning
0541   KFUpdator const aKFUpdator;
0542   Chi2MeasurementEstimator const aChi2MeasurementEstimator(100., 3);
0543   KFTrajectoryFitter backFitter(
0544       &propagatorAlong, &aKFUpdator, &aChi2MeasurementEstimator, firstHits.size(), nullptr, &hitCloner);
0545 
0546   // assume for now that the propagation in mkfit always alongMomentum
0547   PropagationDirection backFitDirection = oppositeToMomentum;
0548 
0549   // only direction matters in this context
0550   TrajectorySeed fakeSeed(PTrajectoryStateOnDet(), edm::OwnVector<TrackingRecHit>(), backFitDirection);
0551 
0552   // ignore loopers for now
0553   Trajectory fitres = backFitter.fitOne(fakeSeed, firstHits, startingState, TrajectoryFitter::standard);
0554 
0555   LogDebug("MkFitOutputConverter") << "using a backward fit of :" << firstHits.size() << " hits, starting from:\n"
0556                                    << startingState << " to get the estimate of the initial state of the track.";
0557 
0558   if (!fitres.isValid()) {
0559     edm::LogWarning("MkFitOutputConverter") << "FitTester: first hits fit failed";
0560     return std::pair<TrajectoryStateOnSurface, const GeomDet*>();
0561   }
0562 
0563   TrajectoryMeasurement const& firstMeas = fitres.lastMeasurement();
0564 
0565   // magnetic field can be different!
0566   TrajectoryStateOnSurface firstState(firstMeas.updatedState().localParameters(),
0567                                       firstMeas.updatedState().localError(),
0568                                       firstMeas.updatedState().surface(),
0569                                       propagatorAlong.magneticField());
0570 
0571   firstState.rescaleError(100.);
0572 
0573   LogDebug("MkFitOutputConverter") << "the initial state is found to be:\n:" << firstState
0574                                    << "\n it's field pointer is: " << firstState.magneticField()
0575                                    << "\n the pointer from the state of the back fit was: "
0576                                    << firstMeas.updatedState().magneticField();
0577 
0578   return std::make_pair(firstState, firstMeas.recHit()->det());
0579 }
0580 
0581 std::pair<TrajectoryStateOnSurface, const GeomDet*> MkFitOutputConverter::convertInnermostState(
0582     const FreeTrajectoryState& fts,
0583     const edm::OwnVector<TrackingRecHit>& hits,
0584     const Propagator& propagatorAlong,
0585     const Propagator& propagatorOpposite) const {
0586   auto det = hits[0].det();
0587   if (det == nullptr) {
0588     throw cms::Exception("LogicError") << "Got nullptr from the first hit det()";
0589   }
0590 
0591   const auto& firstHitSurface = det->surface();
0592 
0593   auto tsosDouble = propagatorAlong.propagateWithPath(fts, firstHitSurface);
0594   if (!tsosDouble.first.isValid()) {
0595     LogDebug("MkFitOutputConverter") << "Propagating to startingState along momentum failed, trying opposite next";
0596     tsosDouble = propagatorOpposite.propagateWithPath(fts, firstHitSurface);
0597   }
0598 
0599   return std::make_pair(tsosDouble.first, det);
0600 }
0601 
0602 std::vector<float> MkFitOutputConverter::computeDNNs(TrackCandidateCollection const& tkCC,
0603                                                      const std::vector<TrajectoryStateOnSurface>& states,
0604                                                      const reco::BeamSpot* bs,
0605                                                      const reco::VertexCollection* vertices,
0606                                                      const tensorflow::Session* session,
0607                                                      const std::vector<float>& chi2,
0608                                                      const bool rescaledError) const {
0609   int size_in = (int)tkCC.size();
0610   int nbatches = size_in / bsize_;
0611 
0612   std::vector<float> output(size_in, 0);
0613 
0614   TSCBLBuilderNoMaterial tscblBuilder;
0615 
0616   tensorflow::Tensor input1(tensorflow::DT_FLOAT, {bsize_, 29});
0617   tensorflow::Tensor input2(tensorflow::DT_FLOAT, {bsize_, 1});
0618 
0619   for (auto nb = 0; nb < nbatches + 1; nb++) {
0620     std::vector<bool> invalidProp(bsize_, false);
0621 
0622     for (auto nt = 0; nt < bsize_; nt++) {
0623       int itrack = nt + bsize_ * nb;
0624       if (itrack >= size_in)
0625         continue;
0626 
0627       auto const& tkC = tkCC.at(itrack);
0628 
0629       TrajectoryStateOnSurface state = states.at(itrack);
0630 
0631       if (rescaledError)
0632         state.rescaleError(1 / 100.f);
0633 
0634       TrajectoryStateClosestToBeamLine tsAtClosestApproachTrackCand =
0635           tscblBuilder(*state.freeState(), *bs);  //as in TrackProducerAlgorithm
0636 
0637       if (!(tsAtClosestApproachTrackCand.isValid())) {
0638         edm::LogVerbatim("TrackBuilding") << "TrajectoryStateClosestToBeamLine not valid";
0639         invalidProp[nt] = true;
0640         continue;
0641       }
0642 
0643       auto const& stateAtPCA = tsAtClosestApproachTrackCand.trackStateAtPCA();
0644       auto v0 = stateAtPCA.position();
0645       auto p = stateAtPCA.momentum();
0646       math::XYZPoint pos(v0.x(), v0.y(), v0.z());
0647       math::XYZVector mom(p.x(), p.y(), p.z());
0648 
0649       //pseudo track for access to easy methods
0650       reco::Track trk(0, 0, pos, mom, stateAtPCA.charge(), stateAtPCA.curvilinearError());
0651 
0652       // get best vertex
0653       float dzmin = std::numeric_limits<float>::max();
0654       float dxy_zmin = 0;
0655 
0656       for (auto const& vertex : *vertices) {
0657         if (std::abs(trk.dz(vertex.position())) < dzmin) {
0658           dzmin = trk.dz(vertex.position());
0659           dxy_zmin = trk.dxy(vertex.position());
0660         }
0661       }
0662 
0663       // loop over the RecHits
0664       int ndof = 0;
0665       int pix = 0;
0666       int strip = 0;
0667       for (auto const& recHit : tkC.recHits()) {
0668         ndof += recHit.dimension();
0669         auto const subdet = recHit.geographicalId().subdetId();
0670         if (subdet == PixelSubdetector::PixelBarrel || subdet == PixelSubdetector::PixelEndcap)
0671           pix++;
0672         else
0673           strip++;
0674       }
0675       ndof = ndof - 5;
0676 
0677       input1.matrix<float>()(nt, 0) = trk.pt();  //using inner track only
0678       input1.matrix<float>()(nt, 1) = p.x();
0679       input1.matrix<float>()(nt, 2) = p.y();
0680       input1.matrix<float>()(nt, 3) = p.z();
0681       input1.matrix<float>()(nt, 4) = p.perp();
0682       input1.matrix<float>()(nt, 5) = p.x();
0683       input1.matrix<float>()(nt, 6) = p.y();
0684       input1.matrix<float>()(nt, 7) = p.z();
0685       input1.matrix<float>()(nt, 8) = p.perp();
0686       input1.matrix<float>()(nt, 9) = trk.ptError();
0687       input1.matrix<float>()(nt, 10) = dxy_zmin;
0688       input1.matrix<float>()(nt, 11) = dzmin;
0689       input1.matrix<float>()(nt, 12) = trk.dxy(bs->position());
0690       input1.matrix<float>()(nt, 13) = trk.dz(bs->position());
0691       input1.matrix<float>()(nt, 14) = trk.dxyError();
0692       input1.matrix<float>()(nt, 15) = trk.dzError();
0693       input1.matrix<float>()(nt, 16) = ndof > 0 ? chi2[itrack] / ndof : chi2[itrack] * 1e6;
0694       input1.matrix<float>()(nt, 17) = trk.eta();
0695       input1.matrix<float>()(nt, 18) = trk.phi();
0696       input1.matrix<float>()(nt, 19) = trk.etaError();
0697       input1.matrix<float>()(nt, 20) = trk.phiError();
0698       input1.matrix<float>()(nt, 21) = pix;
0699       input1.matrix<float>()(nt, 22) = strip;
0700       input1.matrix<float>()(nt, 23) = ndof;
0701       input1.matrix<float>()(nt, 24) = 0;
0702       input1.matrix<float>()(nt, 25) = 0;
0703       input1.matrix<float>()(nt, 26) = 0;
0704       input1.matrix<float>()(nt, 27) = 0;
0705       input1.matrix<float>()(nt, 28) = 0;
0706 
0707       input2.matrix<float>()(nt, 0) = algo_;
0708     }
0709 
0710     //inputs finalized
0711     tensorflow::NamedTensorList inputs;
0712     inputs.resize(2);
0713     inputs[0] = tensorflow::NamedTensor("x", input1);
0714     inputs[1] = tensorflow::NamedTensor("y", input2);
0715 
0716     //eval and rescale
0717     std::vector<tensorflow::Tensor> outputs;
0718     tensorflow::run(const_cast<tensorflow::Session*>(session), inputs, {"Identity"}, &outputs);
0719 
0720     for (auto nt = 0; nt < bsize_; nt++) {
0721       int itrack = nt + bsize_ * nb;
0722       if (itrack >= size_in)
0723         continue;
0724 
0725       float out0 = 2.0 * outputs[0].matrix<float>()(nt, 0) - 1.0;
0726       if (invalidProp[nt])
0727         out0 = -1;
0728 
0729       output[itrack] = out0;
0730     }
0731   }
0732 
0733   return output;
0734 }
0735 
0736 DEFINE_FWK_MODULE(MkFitOutputConverter);