File indexing completed on 2024-04-06 12:26:50
0001 #include "RecoMTD/TimingIDTools/interface/MTDTrackQualityMVA.h"
0002
0003 MTDTrackQualityMVA::MTDTrackQualityMVA(std::string weights_file) {
0004 std::string options("!Color:Silent");
0005 std::string method("BDT");
0006
0007 std::string vars_array[] = {MTDTRACKQUALITYMVA_VARS(MTDBDTVAR_STRING)};
0008 int nvars = sizeof(vars_array) / sizeof(vars_array[0]);
0009 vars_.assign(vars_array, vars_array + nvars);
0010
0011 mva_ = std::make_unique<TMVAEvaluator>();
0012 mva_->initialize(options, method, weights_file, vars_, spec_vars_, true, false);
0013 }
0014
0015 float MTDTrackQualityMVA::operator()(const reco::TrackRef& trk,
0016 const edm::ValueMap<int>& npixBarrels,
0017 const edm::ValueMap<int>& npixEndcaps,
0018 const edm::ValueMap<float>& btl_chi2s,
0019 const edm::ValueMap<float>& btl_time_chi2s,
0020 const edm::ValueMap<float>& etl_chi2s,
0021 const edm::ValueMap<float>& etl_time_chi2s,
0022 const edm::ValueMap<float>& tmtds,
0023 const edm::ValueMap<float>& trk_lengths) const {
0024 std::map<std::string, float> vars;
0025
0026
0027 constexpr float minPtForMVA = 0.5;
0028 if (trk->pt() < minPtForMVA)
0029 return -1;
0030
0031
0032 if (tmtds[trk] > 0) {
0033 vars.emplace(vars_[int(VarID::pt)], trk->pt());
0034 vars.emplace(vars_[int(VarID::eta)], trk->eta());
0035 vars.emplace(vars_[int(VarID::phi)], trk->phi());
0036 vars.emplace(vars_[int(VarID::chi2)], trk->chi2());
0037 vars.emplace(vars_[int(VarID::ndof)], trk->ndof());
0038 vars.emplace(vars_[int(VarID::numberOfValidHits)], trk->numberOfValidHits());
0039 vars.emplace(vars_[int(VarID::numberOfValidPixelBarrelHits)], npixBarrels[trk]);
0040 vars.emplace(vars_[int(VarID::numberOfValidPixelEndcapHits)], npixEndcaps[trk]);
0041 vars.emplace(vars_[int(VarID::btlMatchChi2)], btl_chi2s.contains(trk.id()) ? btl_chi2s[trk] : -1);
0042 vars.emplace(vars_[int(VarID::btlMatchTimeChi2)], btl_time_chi2s.contains(trk.id()) ? btl_time_chi2s[trk] : -1);
0043 vars.emplace(vars_[int(VarID::etlMatchChi2)], etl_chi2s.contains(trk.id()) ? etl_chi2s[trk] : -1);
0044 vars.emplace(vars_[int(VarID::etlMatchTimeChi2)], etl_time_chi2s.contains(trk.id()) ? etl_time_chi2s[trk] : -1);
0045 vars.emplace(vars_[int(VarID::mtdt)], tmtds[trk]);
0046 vars.emplace(vars_[int(VarID::path_len)], trk_lengths[trk]);
0047 return 1. / (1 + sqrt(2 / (1 + mva_->evaluate(vars, false)) - 1));
0048 } else
0049 return -1;
0050 }