Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:26:50

0001 #include "RecoMTD/TimingIDTools/interface/MTDTrackQualityMVA.h"
0002 
0003 MTDTrackQualityMVA::MTDTrackQualityMVA(std::string weights_file) {
0004   std::string options("!Color:Silent");
0005   std::string method("BDT");
0006 
0007   std::string vars_array[] = {MTDTRACKQUALITYMVA_VARS(MTDBDTVAR_STRING)};
0008   int nvars = sizeof(vars_array) / sizeof(vars_array[0]);
0009   vars_.assign(vars_array, vars_array + nvars);
0010 
0011   mva_ = std::make_unique<TMVAEvaluator>();
0012   mva_->initialize(options, method, weights_file, vars_, spec_vars_, true, false);  //use GBR, GradBoost
0013 }
0014 
0015 float MTDTrackQualityMVA::operator()(const reco::TrackRef& trk,
0016                                      const edm::ValueMap<int>& npixBarrels,
0017                                      const edm::ValueMap<int>& npixEndcaps,
0018                                      const edm::ValueMap<float>& btl_chi2s,
0019                                      const edm::ValueMap<float>& btl_time_chi2s,
0020                                      const edm::ValueMap<float>& etl_chi2s,
0021                                      const edm::ValueMap<float>& etl_time_chi2s,
0022                                      const edm::ValueMap<float>& tmtds,
0023                                      const edm::ValueMap<float>& trk_lengths) const {
0024   std::map<std::string, float> vars;
0025 
0026   //---training performed only above 0.5 GeV
0027   constexpr float minPtForMVA = 0.5;
0028   if (trk->pt() < minPtForMVA)
0029     return -1;
0030 
0031   //---training performed only for tracks with MTD hits
0032   if (tmtds[trk] > 0) {
0033     vars.emplace(vars_[int(VarID::pt)], trk->pt());
0034     vars.emplace(vars_[int(VarID::eta)], trk->eta());
0035     vars.emplace(vars_[int(VarID::phi)], trk->phi());
0036     vars.emplace(vars_[int(VarID::chi2)], trk->chi2());
0037     vars.emplace(vars_[int(VarID::ndof)], trk->ndof());
0038     vars.emplace(vars_[int(VarID::numberOfValidHits)], trk->numberOfValidHits());
0039     vars.emplace(vars_[int(VarID::numberOfValidPixelBarrelHits)], npixBarrels[trk]);
0040     vars.emplace(vars_[int(VarID::numberOfValidPixelEndcapHits)], npixEndcaps[trk]);
0041     vars.emplace(vars_[int(VarID::btlMatchChi2)], btl_chi2s.contains(trk.id()) ? btl_chi2s[trk] : -1);
0042     vars.emplace(vars_[int(VarID::btlMatchTimeChi2)], btl_time_chi2s.contains(trk.id()) ? btl_time_chi2s[trk] : -1);
0043     vars.emplace(vars_[int(VarID::etlMatchChi2)], etl_chi2s.contains(trk.id()) ? etl_chi2s[trk] : -1);
0044     vars.emplace(vars_[int(VarID::etlMatchTimeChi2)], etl_time_chi2s.contains(trk.id()) ? etl_time_chi2s[trk] : -1);
0045     vars.emplace(vars_[int(VarID::mtdt)], tmtds[trk]);
0046     vars.emplace(vars_[int(VarID::path_len)], trk_lengths[trk]);
0047     return 1. / (1 + sqrt(2 / (1 + mva_->evaluate(vars, false)) - 1));  //return values between 0-1 (probability)
0048   } else
0049     return -1;
0050 }