File indexing completed on 2024-04-06 12:29:14
0001 #include "RecoVertex/MultiVertexFit/interface/MultiVertexFitter.h"
0002
0003 #include <map>
0004 #include <algorithm>
0005 #include <iomanip>
0006
0007
0008 #include "RecoVertex/KalmanVertexFit/interface/KalmanVertexFitter.h"
0009 #include "RecoVertex/VertexTools/interface/LinearizedTrackStateFactory.h"
0010 #include "RecoVertex/VertexTools/interface/VertexTrackFactory.h"
0011 #include "RecoVertex/VertexPrimitives/interface/VertexState.h"
0012 #include "RecoVertex/VertexPrimitives/interface/VertexException.h"
0013 #include "RecoVertex/KalmanVertexFit/interface/KalmanVertexTrackCompatibilityEstimator.h"
0014
0015
0016 #ifdef MVFHarvestingDebug
0017 #include "Vertex/VertexSimpleVis/interface/PrimitivesHarvester.h"
0018 #endif
0019
0020 using namespace std;
0021 using namespace reco;
0022
0023 namespace {
0024 typedef MultiVertexFitter::TrackAndWeight TrackAndWeight;
0025 typedef MultiVertexFitter::TrackAndSeedToWeightMap TrackAndSeedToWeightMap;
0026 typedef MultiVertexFitter::SeedToWeightMap SeedToWeightMap;
0027 typedef CachingVertex<5>::RefCountedVertexTrack RefCountedVertexTrack;
0028
0029 int verbose() {
0030 static const int ret = 0;
0031
0032 return ret;
0033 }
0034
0035 double minWeightFraction() {
0036
0037
0038
0039
0040 static const float ret = 1e-6;
0041
0042 return ret;
0043 }
0044
0045 bool discardLightWeights() {
0046 static const bool ret = true;
0047
0048 return ret;
0049 }
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061 CachingVertex<5> createSeedFromLinPt(const GlobalPoint &gp) {
0062 return CachingVertex<5>(gp, GlobalError(), vector<RefCountedVertexTrack>(), 0.0);
0063 }
0064
0065 double validWeight(double weight) {
0066 if (weight > 1.0) {
0067 cout << "[MultiVertexFitter] weight=" << weight << "??" << endl;
0068 return 1.0;
0069 };
0070
0071 if (weight < 0.0) {
0072 cout << "[MultiVertexFitter] weight=" << weight << "??" << endl;
0073 return 0.0;
0074 };
0075 return weight;
0076 }
0077 }
0078
0079 void MultiVertexFitter::clear() {
0080 theAssComp->resetAnnealing();
0081 theTracks.clear();
0082 thePrimaries.clear();
0083 theVertexStates.clear();
0084 theWeights.clear();
0085 theCache.clear();
0086 }
0087
0088
0089
0090
0091 void MultiVertexFitter::createSeed(const vector<TransientTrack> &tracks) {
0092 if (tracks.size() > 1) {
0093 CachingVertex<5> vtx = createSeedFromLinPt(theSeeder->getLinearizationPoint(tracks));
0094 int snr = seedNr();
0095 theVertexStates.push_back(pair<int, CachingVertex<5> >(snr, vtx));
0096 for (vector<TransientTrack>::const_iterator track = tracks.begin(); track != tracks.end(); ++track) {
0097 theWeights[*track][snr] = 1.;
0098 theTracks.push_back(*track);
0099 };
0100 };
0101 }
0102
0103 void MultiVertexFitter::createPrimaries(const std::vector<reco::TransientTrack> &tracks) {
0104
0105 for (vector<reco::TransientTrack>::const_iterator i = tracks.begin(); i != tracks.end(); ++i) {
0106 thePrimaries.insert(*i);
0107
0108 }
0109
0110 }
0111
0112 int MultiVertexFitter::seedNr() { return theVertexStateNr++; }
0113
0114 void MultiVertexFitter::resetSeedNr() { theVertexStateNr = 0; }
0115
0116 void MultiVertexFitter::createSeed(const vector<TrackAndWeight> &tracks) {
0117
0118 vector<RefCountedVertexTrack> newTracks;
0119
0120 for (vector<TrackAndWeight>::const_iterator track = tracks.begin(); track != tracks.end(); ++track) {
0121 double weight = validWeight(track->second);
0122 const GlobalPoint &pos = track->first.impactPointState().globalPosition();
0123 GlobalError err;
0124 VertexState realseed(pos, err);
0125
0126 RefCountedLinearizedTrackState lTrData = theCache.linTrack(pos, track->first);
0127
0128 VertexTrackFactory<5> vTrackFactory;
0129 RefCountedVertexTrack vTrData = vTrackFactory.vertexTrack(lTrData, realseed, weight);
0130 newTracks.push_back(vTrData);
0131 };
0132
0133 if (newTracks.size() > 1) {
0134 CachingVertex<5> vtx = KalmanVertexFitter().vertex(newTracks);
0135 int snr = seedNr();
0136 theVertexStates.push_back(pair<int, CachingVertex<5> >(snr, vtx));
0137
0138
0139
0140 for (vector<TrackAndWeight>::const_iterator track = tracks.begin(); track != tracks.end(); ++track) {
0141 if (thePrimaries.count(track->first)) {
0142
0143
0144
0145
0146
0147
0148 theWeights[track->first][theVertexStates[0].first] = track->second;
0149 continue;
0150 };
0151 float weight = track->second;
0152 if (weight > 1.0) {
0153 cout << "[MultiVertexFitter] error weight " << weight << " > 1.0 given." << endl;
0154 cout << "[MultiVertexFitter] will revert to 1.0" << endl;
0155 weight = 1.0;
0156 };
0157 if (weight < 0.0) {
0158 cout << "[MultiVertexFitter] error weight " << weight << " < 0.0 given." << endl;
0159 cout << "[MultiVertexFitter] will revert to 0.0" << endl;
0160 weight = 0.0;
0161 };
0162 theWeights[track->first][snr] = weight;
0163 theTracks.push_back(track->first);
0164 };
0165 };
0166
0167
0168
0169
0170 sort(theTracks.begin(), theTracks.end());
0171 for (vector<TransientTrack>::iterator i = theTracks.begin(); i < theTracks.end(); ++i) {
0172 if (i != theTracks.begin()) {
0173 if ((*i) == (*(i - 1))) {
0174 theTracks.erase(i);
0175 };
0176 };
0177 };
0178 }
0179
0180 vector<CachingVertex<5> > MultiVertexFitter::vertices(const vector<TransientVertex> &vtces,
0181 const vector<TransientTrack> &primaries) {
0182
0183 if (vtces.empty()) {
0184 return vector<CachingVertex<5> >();
0185 };
0186 vector<vector<TrackAndWeight> > bundles;
0187 for (vector<TransientVertex>::const_iterator vtx = vtces.begin(); vtx != vtces.end(); ++vtx) {
0188 vector<TransientTrack> trks = vtx->originalTracks();
0189 vector<TrackAndWeight> tnws;
0190 for (vector<TransientTrack>::const_iterator trk = trks.begin(); trk != trks.end(); ++trk) {
0191 float w = vtx->trackWeight(*trk);
0192 if (w > 1e-5) {
0193 TrackAndWeight tmp(*trk, w);
0194 tnws.push_back(tmp);
0195 };
0196 };
0197 bundles.push_back(tnws);
0198 };
0199 return vertices(bundles, primaries);
0200 }
0201
0202 vector<CachingVertex<5> > MultiVertexFitter::vertices(const vector<CachingVertex<5> > &initials,
0203 const vector<TransientTrack> &primaries) {
0204 clear();
0205 createPrimaries(primaries);
0206
0207 if (initials.empty())
0208 return initials;
0209 for (vector<CachingVertex<5> >::const_iterator vtx = initials.begin(); vtx != initials.end(); ++vtx) {
0210 int snr = seedNr();
0211 theVertexStates.push_back(pair<int, CachingVertex<5> >(snr, *vtx));
0212 TransientVertex rvtx = *vtx;
0213 const vector<TransientTrack> &trks = rvtx.originalTracks();
0214 for (vector<TransientTrack>::const_iterator trk = trks.begin(); trk != trks.end(); ++trk) {
0215 if (!(thePrimaries.count(*trk))) {
0216
0217 theTracks.push_back(*trk);
0218 } else {
0219
0220 }
0221 cout << "[MultiVertexFitter] error! track weight currently set to one"
0222 << " FIXME!!!" << endl;
0223 theWeights[*trk][snr] = 1.0;
0224 };
0225 };
0226 #ifdef MVFHarvestingDebug
0227 for (vector<CachingVertex<5> >::const_iterator i = theVertexStates.begin(); i != theVertexStates.end(); ++i)
0228 PrimitivesHarvester::file()->save(*i);
0229 #endif
0230 return fit();
0231 }
0232
0233 vector<CachingVertex<5> > MultiVertexFitter::vertices(const vector<vector<TransientTrack> > &tracks,
0234 const vector<TransientTrack> &primaries) {
0235 clear();
0236 createPrimaries(primaries);
0237
0238 for (vector<vector<TransientTrack> >::const_iterator cluster = tracks.begin(); cluster != tracks.end(); ++cluster) {
0239 createSeed(*cluster);
0240 };
0241 if (verbose()) {
0242 printSeeds();
0243 };
0244 #ifdef MVFHarvestingDebug
0245 for (vector<CachingVertex<5> >::const_iterator i = theVertexStates.begin(); i != theVertexStates.end(); ++i)
0246 PrimitivesHarvester::file()->save(*i);
0247 #endif
0248 return fit();
0249 }
0250
0251 vector<CachingVertex<5> > MultiVertexFitter::vertices(const vector<vector<TrackAndWeight> > &tracks,
0252 const vector<TransientTrack> &primaries) {
0253 clear();
0254 createPrimaries(primaries);
0255
0256 for (vector<vector<TrackAndWeight> >::const_iterator cluster = tracks.begin(); cluster != tracks.end(); ++cluster) {
0257 createSeed(*cluster);
0258 };
0259 if (verbose()) {
0260 printSeeds();
0261 };
0262
0263 return fit();
0264 }
0265
0266 MultiVertexFitter::MultiVertexFitter(const AnnealingSchedule &ann,
0267 const LinearizationPointFinder &seeder,
0268 float revive_below)
0269 : theVertexStateNr(0), theReviveBelow(revive_below), theAssComp(ann.clone()), theSeeder(seeder.clone()) {}
0270
0271 MultiVertexFitter::MultiVertexFitter(const MultiVertexFitter &o)
0272 : theVertexStateNr(o.theVertexStateNr),
0273 theReviveBelow(o.theReviveBelow),
0274 theAssComp(o.theAssComp->clone()),
0275 theSeeder(o.theSeeder->clone()) {}
0276
0277 MultiVertexFitter::~MultiVertexFitter() {
0278 delete theAssComp;
0279 delete theSeeder;
0280 }
0281
0282 void MultiVertexFitter::updateWeights() {
0283 theWeights.clear();
0284 if (verbose() & 4) {
0285 cout << "[MultiVertexFitter] Start weight update." << endl;
0286 };
0287
0288 KalmanVertexTrackCompatibilityEstimator<5> theComp;
0289
0290
0291
0292
0293 for (set<TransientTrack>::const_iterator trk = thePrimaries.begin(); trk != thePrimaries.end(); ++trk) {
0294 int seednr = theVertexStates[0].first;
0295 CachingVertex<5> seed = theVertexStates[0].second;
0296 pair<bool, double> result = theComp.estimate(seed, theCache.linTrack(seed.position(), *trk));
0297 double weight = 0.;
0298 if (result.first)
0299 weight = theAssComp->phi(result.second);
0300 theWeights[*trk][seednr] = weight;
0301 }
0302
0303
0304
0305
0306 for (vector<TransientTrack>::const_iterator trk = theTracks.begin(); trk != theTracks.end(); ++trk) {
0307 double tot_weight = theAssComp->phi(theAssComp->cutoff() * theAssComp->cutoff());
0308
0309 for (vector<pair<int, CachingVertex<5> > >::const_iterator seed = theVertexStates.begin();
0310 seed != theVertexStates.end();
0311 ++seed) {
0312 pair<bool, double> result = theComp.estimate(seed->second, theCache.linTrack(seed->second.position(), *trk));
0313 double weight = 0.;
0314 if (result.first)
0315 weight = theAssComp->phi(result.second);
0316 tot_weight += weight;
0317 theWeights[*trk][seed->first] = weight;
0318
0319
0320 };
0321
0322
0323
0324 if (tot_weight > 0.0) {
0325 for (vector<pair<int, CachingVertex<5> > >::const_iterator seed = theVertexStates.begin();
0326 seed != theVertexStates.end();
0327 ++seed) {
0328 double normedweight = theWeights[*trk][seed->first] / tot_weight;
0329 if (normedweight > 1.0) {
0330 cout << "[MultiVertexFitter] he? w["
0331 << "," << seed->second.position() << "] = " << normedweight << " totw=" << tot_weight << endl;
0332 normedweight = 1.0;
0333 };
0334 if (normedweight < 0.0) {
0335 cout << "[MultiVertexFitter] he? weight=" << normedweight << " totw=" << tot_weight << endl;
0336 normedweight = 0.0;
0337 };
0338 theWeights[*trk][seed->first] = normedweight;
0339 };
0340 } else {
0341
0342 cout << "[MultiVertexFitter] track found with no assignment - ";
0343 cout << "will assign uniformly." << endl;
0344 float w = .5 / (float)theVertexStates.size();
0345 for (vector<pair<int, CachingVertex<5> > >::const_iterator seed = theVertexStates.begin();
0346 seed != theVertexStates.end();
0347 ++seed) {
0348 theWeights[*trk][seed->first] = w;
0349 };
0350 };
0351 };
0352 if (verbose() & 2)
0353 printWeights();
0354 }
0355
0356 bool MultiVertexFitter::updateSeeds() {
0357 double max_disp = 0.;
0358
0359
0360
0361
0362 vector<pair<int, CachingVertex<5> > > newSeeds;
0363
0364 for (vector<pair<int, CachingVertex<5> > >::const_iterator seed = theVertexStates.begin();
0365 seed != theVertexStates.end();
0366 ++seed) {
0367
0368
0369
0370 int snr = seed->first;
0371 VertexState realseed(seed->second.position(), seed->second.error());
0372
0373 double totweight = 0.;
0374 for (vector<TransientTrack>::const_iterator track = theTracks.begin(); track != theTracks.end(); ++track) {
0375 totweight += theWeights[*track][snr];
0376 };
0377
0378 int nr_good_trks = 0;
0379
0380
0381
0382
0383 if (discardLightWeights()) {
0384 for (vector<TransientTrack>::const_iterator track = theTracks.begin(); track != theTracks.end(); ++track) {
0385 if (theWeights[*track][snr] > totweight * minWeightFraction()) {
0386 nr_good_trks++;
0387 };
0388 };
0389 };
0390
0391 vector<RefCountedVertexTrack> newTracks;
0392 for (vector<TransientTrack>::const_iterator track = theTracks.begin(); track != theTracks.end(); ++track) {
0393 double weight = validWeight(theWeights[*track][snr]);
0394
0395
0396
0397
0398
0399
0400 if (!discardLightWeights() || weight > minWeightFraction() * totweight || nr_good_trks < 2) {
0401
0402
0403
0404
0405 RefCountedLinearizedTrackState lTrData = theCache.linTrack(seed->second.position(), *track);
0406
0407 VertexTrackFactory<5> vTrackFactory;
0408 RefCountedVertexTrack vTrData = vTrackFactory.vertexTrack(lTrData, realseed, weight);
0409 newTracks.push_back(vTrData);
0410 };
0411 };
0412
0413 for (set<TransientTrack>::const_iterator track = thePrimaries.begin(); track != thePrimaries.end(); ++track) {
0414 double weight = validWeight(theWeights[*track][snr]);
0415
0416 RefCountedLinearizedTrackState lTrData = theCache.linTrack(seed->second.position(), *track);
0417
0418 VertexTrackFactory<5> vTrackFactory;
0419 RefCountedVertexTrack vTrData = vTrackFactory.vertexTrack(lTrData, realseed, weight);
0420 newTracks.push_back(vTrData);
0421 };
0422
0423 try {
0424 if (newTracks.size() < 2) {
0425 throw VertexException("less than two tracks in vector");
0426 };
0427
0428 if (verbose()) {
0429 cout << "[MultiVertexFitter] now fitting with Kalman: ";
0430 for (vector<RefCountedVertexTrack>::const_iterator i = newTracks.begin(); i != newTracks.end(); ++i) {
0431 cout << (**i).weight() << " ";
0432 };
0433 cout << endl;
0434 };
0435
0436 if (newTracks.size() > 1) {
0437 KalmanVertexFitter fitter;
0438
0439 CachingVertex<5> newVertex = fitter.vertex(newTracks);
0440 int snr = seedNr();
0441 double disp = (newVertex.position() - seed->second.position()).mag();
0442 if (disp > max_disp)
0443 max_disp = disp;
0444 newSeeds.push_back(pair<int, CachingVertex<5> >(snr, newVertex));
0445 };
0446 } catch (exception &e) {
0447 cout << "[MultiVertexFitter] exception: " << e.what() << endl;
0448 }
0449 };
0450
0451
0452 theVertexStates.clear();
0453 theWeights.clear();
0454 theVertexStates = newSeeds;
0455 #ifdef MVFHarvestingDebug
0456 for (vector<CachingVertex<5> >::const_iterator i = theVertexStates.begin(); i != theVertexStates.end(); ++i)
0457 PrimitivesHarvester::file()->save(*i);
0458 #endif
0459 updateWeights();
0460
0461 static const double disp_limit = 1e-4;
0462
0463
0464 if (verbose() & 2) {
0465 printSeeds();
0466 cout << "[MultiVertexFitter] max displacement in this iteration: " << max_disp << endl;
0467 };
0468 if (max_disp < disp_limit)
0469 return false;
0470 return true;
0471 }
0472
0473
0474 vector<CachingVertex<5> > MultiVertexFitter::fit() {
0475 if (verbose() & 2)
0476 printWeights();
0477 int ctr = 1;
0478 static const int ctr_max = 50;
0479
0480 while (updateSeeds() || !(theAssComp->isAnnealed())) {
0481 if (++ctr >= ctr_max)
0482 break;
0483 theAssComp->anneal();
0484
0485 resetSeedNr();
0486 };
0487
0488 if (verbose()) {
0489 cout << "[MultiVertexFitter] number of iterations: " << ctr << endl;
0490 cout << "[MultiVertexFitter] remaining seeds: " << theVertexStates.size() << endl;
0491 printWeights();
0492 };
0493
0494 vector<CachingVertex<5> > ret;
0495 for (vector<pair<int, CachingVertex<5> > >::const_iterator i = theVertexStates.begin(); i != theVertexStates.end();
0496 ++i) {
0497 ret.push_back(i->second);
0498 };
0499
0500 return ret;
0501 }
0502
0503 void MultiVertexFitter::printWeights(const reco::TransientTrack &t) const {
0504
0505 for (vector<pair<int, CachingVertex<5> > >::const_iterator seed = theVertexStates.begin();
0506 seed != theVertexStates.end();
0507 ++seed) {
0508 double val = 0;
0509 auto a = theWeights.find(t);
0510 if (a != theWeights.end()) {
0511 auto b = a->second.find(seed->first);
0512 if (b != a->second.end())
0513 val = b->second;
0514 }
0515 cout << " -- Vertex[" << seed->first << "] with " << setw(12) << setprecision(3) << val;
0516 };
0517 cout << endl;
0518 }
0519
0520 void MultiVertexFitter::printWeights() const {
0521 cout << endl << "Weight table: " << endl << "=================" << endl;
0522 for (set<TransientTrack>::const_iterator trk = thePrimaries.begin(); trk != thePrimaries.end(); ++trk) {
0523 printWeights(*trk);
0524 };
0525 for (vector<TransientTrack>::const_iterator trk = theTracks.begin(); trk != theTracks.end(); ++trk) {
0526 printWeights(*trk);
0527 };
0528 }
0529
0530 void MultiVertexFitter::printSeeds() const {
0531 cout << endl << "Seed table: " << endl << "=====================" << endl;
0532
0533
0534
0535
0536
0537
0538
0539 }
0540
0541 void MultiVertexFitter::lostVertexClaimer() {
0542 if (!(theReviveBelow < 0.))
0543 return;
0544
0545
0546
0547 bool has_revived = false;
0548
0549 for (vector<pair<int, CachingVertex<5> > >::const_iterator i = theVertexStates.begin(); i != theVertexStates.end();
0550 ++i) {
0551 double totweight = 0.;
0552 for (vector<TransientTrack>::const_iterator trk = theTracks.begin(); trk != theTracks.end(); ++trk) {
0553 totweight += theWeights[*trk][i->first];
0554 };
0555
0556
0557
0558
0559
0560 if (totweight < theReviveBelow && totweight > 0.0) {
0561 cout << "[MultiVertexFitter] now trying to revive vertex"
0562 << " revive_below=" << theReviveBelow << endl;
0563 has_revived = true;
0564 for (vector<TransientTrack>::const_iterator trk = theTracks.begin(); trk != theTracks.end(); ++trk) {
0565 theWeights[*trk][i->first] /= totweight;
0566 };
0567 };
0568 };
0569 if (has_revived && verbose())
0570 printWeights();
0571 }