Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:31:29

0001 #include "TrackingTools/GsfTools/interface/GsfMatrixTools.h"
0002 
0003 namespace KullbackLeiblerDistanceDetails {
0004 
0005   template <unsigned int N>
0006   double compute(SingleGaussianState<N> const& sgs1, SingleGaussianState<N> const& sgs2) {
0007     using Vector = ROOT::Math::SVector<double, N>;
0008     using Matrix = ROOT::Math::SMatrix<double, N, N, ROOT::Math::MatRepSym<double, N>>;
0009 
0010     const Vector& mu1 = sgs1.mean();
0011     const Matrix& V1 = sgs1.covariance();
0012     const Vector& mu2 = sgs2.mean();
0013     const Matrix& V2 = sgs2.covariance();
0014 
0015     const Matrix& G1 = sgs1.weightMatrix();
0016     const Matrix& G2 = sgs2.weightMatrix();
0017     Vector mudiff = mu1 - mu2;
0018     Matrix Vdiff = V1 - V2;
0019     Matrix Gdiff = G2 - G1;
0020     Matrix Gsum = G1 + G2;
0021 
0022     /*  for sse above is faster (less instructions, less CPI) for avx2 equivelent maybe below faster
0023     Vector mudiff = sgs1.mean();
0024     Matrix Vdiff = sgs1.covariance();
0025     const Vector& mu2 = sgs2.mean();
0026     const Matrix& V2 = sgs2.covariance();
0027     
0028     Matrix Gsum = sgs1.weightMatrix();
0029     const Matrix& G2 = sgs2.weightMatrix();
0030     mudiff -= mu2;
0031     Vdiff -= V2;
0032     Matrix Gdiff = G2 - Gsum;
0033     Gsum += G2;
0034 */
0035 
0036     //   double dist = (Vdiff * Gdiff).trace() + Gsum.similarity(mudiff);
0037     double dist = GsfMatrixTools::trace(Vdiff, Gdiff) + ROOT::Math::Similarity(mudiff, Gsum);
0038 
0039     return dist;
0040   }
0041 }  // namespace KullbackLeiblerDistanceDetails
0042 
0043 template <unsigned int N>
0044 double KullbackLeiblerDistance<N>::operator()(const SingleGaussianState<N>& sgs1,
0045                                               const SingleGaussianState<N>& sgs2) const {
0046   // compute inverse here (if not yet done)
0047   sgs1.weightMatrix();
0048   sgs2.weightMatrix();
0049 
0050   return KullbackLeiblerDistanceDetails::compute<N>(sgs1, sgs2);
0051 }