Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-04-13 22:49:35

0001 #ifndef KDTreeLinkerAlgoTemplated_h
0002 #define KDTreeLinkerAlgoTemplated_h
0003 
0004 #include "DataFormats/Math/interface/logic.h"
0005 
0006 #include <cassert>
0007 #include <vector>
0008 #include <array>
0009 #include <algorithm>
0010 
0011 // Box structure used to define 2D field.
0012 // It's used in KDTree building step to divide the detector
0013 // space (ECAL, HCAL...) and in searching step to create a bounding
0014 // box around the demanded point (Track collision point, PS projection...).
0015 template <unsigned DIM = 2>
0016 struct KDTreeBox {
0017   std::array<float, DIM> dimmin, dimmax;
0018 
0019   template <typename... Ts>
0020   KDTreeBox(Ts... dimargs) {
0021     static_assert(sizeof...(dimargs) == 2 * DIM, "Constructor requires 2*DIM args");
0022     std::vector<float> dims = {dimargs...};
0023     for (unsigned i = 0; i < DIM; ++i) {
0024       dimmin[i] = dims[2 * i];
0025       dimmax[i] = dims[2 * i + 1];
0026     }
0027   }
0028 
0029   KDTreeBox() {}
0030 };
0031 
0032 // Data stored in each KDTree node.
0033 // The dim1/dim2 fields are usually the duplication of some PFRecHit values
0034 // (eta/phi or x/y). But in some situations, phi field is shifted by +-2.Pi
0035 template <typename DATA, unsigned DIM = 2>
0036 struct KDTreeNodeInfo {
0037   DATA data;
0038   std::array<float, DIM> dims;
0039 
0040 public:
0041   KDTreeNodeInfo() {}
0042 
0043   template <typename... Ts>
0044   KDTreeNodeInfo(const DATA &d, Ts... dimargs) : data(d), dims{{dimargs...}} {}
0045   template <typename... Ts>
0046   bool operator>(const KDTreeNodeInfo &rhs) const {
0047     return (data > rhs.data);
0048   }
0049 };
0050 
0051 template <typename DATA, unsigned DIM = 2>
0052 struct KDTreeNodes {
0053   std::array<std::vector<float>, DIM> dims;
0054   std::vector<int> right;
0055   std::vector<DATA> data;
0056 
0057   int poolSize;
0058   int poolPos;
0059 
0060   constexpr KDTreeNodes() : poolSize(-1), poolPos(-1) {}
0061 
0062   bool empty() const { return poolPos == -1; }
0063   int size() const { return poolPos + 1; }
0064 
0065   void clear() {
0066     for (auto &dim : dims) {
0067       dim.clear();
0068       dim.shrink_to_fit();
0069     }
0070     right.clear();
0071     right.shrink_to_fit();
0072     data.clear();
0073     data.shrink_to_fit();
0074     poolSize = -1;
0075     poolPos = -1;
0076   }
0077 
0078   int getNextNode() {
0079     ++poolPos;
0080     return poolPos;
0081   }
0082 
0083   void build(int sizeData) {
0084     poolSize = sizeData * 2 - 1;
0085     for (auto &dim : dims) {
0086       dim.resize(poolSize);
0087     }
0088     right.resize(poolSize);
0089     data.resize(poolSize);
0090   };
0091 
0092   constexpr bool isLeaf(int right) const {
0093     // Valid values of right are always >= 2
0094     // index 0 is the root, and 1 is the first left node
0095     // Exploit index values 0 and 1 to mark which of dim1/dim2 is the
0096     // current one in recSearch() at the depth of the leaf.
0097     return right < 2;
0098   }
0099 
0100   bool isLeafIndex(int index) const { return isLeaf(right[index]); }
0101 };
0102 
0103 // Class that implements the KDTree partition of 2D space and
0104 // a closest point search algorithme.
0105 
0106 template <typename DATA, unsigned int DIM = 2>
0107 class KDTreeLinkerAlgo {
0108 public:
0109   // Dtor calls clear()
0110   ~KDTreeLinkerAlgo() { clear(); }
0111 
0112   // Here we build the KD tree from the "eltList" in the space define by "region".
0113   void build(std::vector<KDTreeNodeInfo<DATA, DIM> > &eltList, const KDTreeBox<DIM> &region);
0114 
0115   // Here we search in the KDTree for all points that would be
0116   // contained in the given searchbox. The founded points are stored in resRecHitList.
0117   void search(const KDTreeBox<DIM> &searchBox, std::vector<DATA> &resRecHitList);
0118 
0119   // This reurns true if the tree is empty
0120   bool empty() { return nodePool_.empty(); }
0121 
0122   // This returns the number of nodes + leaves in the tree
0123   // (nElements should be (size() +1)/2)
0124   int size() { return nodePool_.size(); }
0125 
0126   // This method clears all allocated structures.
0127   void clear() { clearTree(); }
0128 
0129 private:
0130   // The node pool allow us to do just 1 call to new for each tree building.
0131   KDTreeNodes<DATA, DIM> nodePool_;
0132 
0133   std::vector<DATA> *closestNeighbour;
0134   std::vector<KDTreeNodeInfo<DATA, DIM> > *initialEltList;
0135 
0136   //Fast median search with Wirth algorithm in eltList between low and high indexes.
0137   int medianSearch(int low, int high, int treeDepth) const;
0138 
0139   // Recursif kdtree builder. Is called by build()
0140   int recBuild(int low, int hight, int depth);
0141 
0142   // Recursif kdtree search. Is called by search()
0143   void recSearch(int current, const KDTreeBox<DIM> &trackBox, int depth = 0) const;
0144 
0145   // This method frees the KDTree.
0146   void clearTree() { nodePool_.clear(); }
0147 };
0148 
0149 //Implementation
0150 
0151 template <typename DATA, unsigned int DIM>
0152 void KDTreeLinkerAlgo<DATA, DIM>::build(std::vector<KDTreeNodeInfo<DATA, DIM> > &eltList,
0153                                         const KDTreeBox<DIM> &region) {
0154   if (!eltList.empty()) {
0155     initialEltList = &eltList;
0156 
0157     size_t size = initialEltList->size();
0158     nodePool_.build(size);
0159 
0160     // Here we build the KDTree
0161     int root = recBuild(0, size, 0);
0162     assert(root == 0);
0163 
0164     initialEltList = nullptr;
0165   }
0166 }
0167 
0168 //Fast median search with Wirth algorithm in eltList between low and high indexes.
0169 template <typename DATA, unsigned int DIM>
0170 int KDTreeLinkerAlgo<DATA, DIM>::medianSearch(int low, int high, int treeDepth) const {
0171   int nbrElts = high - low;
0172   int median = (nbrElts & 1) ? nbrElts / 2 : nbrElts / 2 - 1;
0173   median += low;
0174 
0175   int l = low;
0176   int m = high - 1;
0177 
0178   while (l < m) {
0179     KDTreeNodeInfo<DATA, DIM> elt = (*initialEltList)[median];
0180     int i = l;
0181     int j = m;
0182 
0183     do {
0184       // The even depth is associated to dim1 dimension
0185       // The odd one to dim2 dimension
0186       const unsigned thedim = treeDepth % DIM;
0187       while ((*initialEltList)[i].dims[thedim] < elt.dims[thedim])
0188         ++i;
0189       while ((*initialEltList)[j].dims[thedim] > elt.dims[thedim])
0190         --j;
0191 
0192       if (i <= j) {
0193         std::swap((*initialEltList)[i], (*initialEltList)[j]);
0194         i++;
0195         j--;
0196       }
0197     } while (i <= j);
0198     if (j < median)
0199       l = i;
0200     if (i > median)
0201       m = j;
0202   }
0203 
0204   return median;
0205 }
0206 
0207 template <typename DATA, unsigned int DIM>
0208 void KDTreeLinkerAlgo<DATA, DIM>::search(const KDTreeBox<DIM> &trackBox, std::vector<DATA> &recHits) {
0209   if (!empty()) {
0210     closestNeighbour = &recHits;
0211     recSearch(0, trackBox, 0);
0212     closestNeighbour = nullptr;
0213   }
0214 }
0215 
0216 template <typename DATA, unsigned int DIM>
0217 void KDTreeLinkerAlgo<DATA, DIM>::recSearch(int current, const KDTreeBox<DIM> &trackBox, int depth) const {
0218   // Iterate until leaf is found, or there are no children in the
0219   // search window. If search has to proceed on both children, proceed
0220   // the search to left child via recursion. Swap search window
0221   // dimension on alternate levels.
0222   while (true) {
0223     const int dimIndex = depth % DIM;
0224     int right = nodePool_.right[current];
0225     if (nodePool_.isLeaf(right)) {
0226       // If point inside the rectangle/area
0227       // Use intentionally bit-wise & instead of logical && for better
0228       // performance. It is faster to always do all comparisons than to
0229       // allow use of branches to not do some if any of the first ones
0230       // is false.
0231       bool isInside = true;
0232       for (unsigned i = 0; i < DIM; ++i) {
0233         float dimCurr = nodePool_.dims[i][current];
0234         isInside &= reco::branchless_and(dimCurr >= trackBox.dimmin[i], dimCurr <= trackBox.dimmax[i]);
0235       }
0236       if (isInside) {
0237         closestNeighbour->push_back(nodePool_.data[current]);
0238       }
0239       break;
0240     } else {
0241       float median = nodePool_.dims[dimIndex][current];
0242 
0243       bool goLeft = (trackBox.dimmin[dimIndex] <= median);
0244       bool goRight = (trackBox.dimmax[dimIndex] >= median);
0245 
0246       ++depth;
0247       if (goLeft & goRight) {
0248         int left = current + 1;
0249         recSearch(left, trackBox, depth);
0250         // continue with right
0251         current = right;
0252       } else if (goLeft) {
0253         ++current;
0254       } else if (goRight) {
0255         current = right;
0256       } else {
0257         break;
0258       }
0259     }
0260   }
0261 }
0262 
0263 template <typename DATA, unsigned int DIM>
0264 int KDTreeLinkerAlgo<DATA, DIM>::recBuild(int low, int high, int depth) {
0265   int portionSize = high - low;
0266 
0267   if (portionSize == 1) {  // Leaf case
0268     int leaf = nodePool_.getNextNode();
0269     const KDTreeNodeInfo<DATA, DIM> &info = (*initialEltList)[low];
0270     nodePool_.right[leaf] = 0;
0271     for (unsigned i = 0; i < DIM; ++i) {
0272       nodePool_.dims[i][leaf] = info.dims[i];
0273     }
0274     nodePool_.data[leaf] = info.data;
0275     return leaf;
0276 
0277   } else {  // Node case
0278 
0279     // The even depth is associated to dim1 dimension
0280     // The odd one to dim2 dimension
0281     int medianId = medianSearch(low, high, depth);
0282     int dimIndex = depth % DIM;
0283     float medianVal = (*initialEltList)[medianId].dims[dimIndex];
0284 
0285     // We create the node
0286     int nodeInd = nodePool_.getNextNode();
0287 
0288     ++depth;
0289     ++medianId;
0290 
0291     // We recursively build the son nodes
0292     int left = recBuild(low, medianId, depth);
0293     assert(nodeInd + 1 == left);
0294     int right = recBuild(medianId, high, depth);
0295     nodePool_.right[nodeInd] = right;
0296 
0297     nodePool_.dims[dimIndex][nodeInd] = medianVal;
0298 
0299     return nodeInd;
0300   }
0301 }
0302 
0303 #endif