Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-03-17 10:45:25

0001 #ifndef KDTreeLinkerAlgoTemplated_h
0002 #define KDTreeLinkerAlgoTemplated_h
0003 
0004 #include <cassert>
0005 #include <vector>
0006 #include <array>
0007 #include <algorithm>
0008 
0009 // Box structure used to define 2D field.
0010 // It's used in KDTree building step to divide the detector
0011 // space (ECAL, HCAL...) and in searching step to create a bounding
0012 // box around the demanded point (Track collision point, PS projection...).
0013 template <unsigned DIM = 2>
0014 struct KDTreeBox {
0015   std::array<float, DIM> dimmin, dimmax;
0016 
0017   template <typename... Ts>
0018   KDTreeBox(Ts... dimargs) {
0019     static_assert(sizeof...(dimargs) == 2 * DIM, "Constructor requires 2*DIM args");
0020     std::vector<float> dims = {dimargs...};
0021     for (unsigned i = 0; i < DIM; ++i) {
0022       dimmin[i] = dims[2 * i];
0023       dimmax[i] = dims[2 * i + 1];
0024     }
0025   }
0026 
0027   KDTreeBox() {}
0028 };
0029 
0030 // Data stored in each KDTree node.
0031 // The dim1/dim2 fields are usually the duplication of some PFRecHit values
0032 // (eta/phi or x/y). But in some situations, phi field is shifted by +-2.Pi
0033 template <typename DATA, unsigned DIM = 2>
0034 struct KDTreeNodeInfo {
0035   DATA data;
0036   std::array<float, DIM> dims;
0037 
0038 public:
0039   KDTreeNodeInfo() {}
0040 
0041   template <typename... Ts>
0042   KDTreeNodeInfo(const DATA &d, Ts... dimargs) : data(d), dims{{dimargs...}} {}
0043   template <typename... Ts>
0044   bool operator>(const KDTreeNodeInfo &rhs) const {
0045     return (data > rhs.data);
0046   }
0047 };
0048 
0049 template <typename DATA, unsigned DIM = 2>
0050 struct KDTreeNodes {
0051   std::array<std::vector<float>, DIM> dims;
0052   std::vector<int> right;
0053   std::vector<DATA> data;
0054 
0055   int poolSize;
0056   int poolPos;
0057 
0058   constexpr KDTreeNodes() : poolSize(-1), poolPos(-1) {}
0059 
0060   bool empty() const { return poolPos == -1; }
0061   int size() const { return poolPos + 1; }
0062 
0063   void clear() {
0064     for (auto &dim : dims) {
0065       dim.clear();
0066     }
0067     right.clear();
0068     data.clear();
0069     poolSize = -1;
0070     poolPos = -1;
0071   }
0072 
0073   int getNextNode() {
0074     ++poolPos;
0075     return poolPos;
0076   }
0077 
0078   void build(int sizeData) {
0079     poolSize = sizeData * 2 - 1;
0080     for (auto &dim : dims) {
0081       dim.resize(poolSize);
0082     }
0083     right.resize(poolSize);
0084     data.resize(poolSize);
0085   };
0086 
0087   constexpr bool isLeaf(int right) const {
0088     // Valid values of right are always >= 2
0089     // index 0 is the root, and 1 is the first left node
0090     // Exploit index values 0 and 1 to mark which of dim1/dim2 is the
0091     // current one in recSearch() at the depth of the leaf.
0092     return right < 2;
0093   }
0094 
0095   bool isLeafIndex(int index) const { return isLeaf(right[index]); }
0096 };
0097 
0098 // Class that implements the KDTree partition of 2D space and
0099 // a closest point search algorithme.
0100 
0101 template <typename DATA, unsigned int DIM = 2>
0102 class KDTreeLinkerAlgo {
0103 public:
0104   // Dtor calls clear()
0105   ~KDTreeLinkerAlgo() { clear(); }
0106 
0107   // Here we build the KD tree from the "eltList" in the space define by "region".
0108   void build(std::vector<KDTreeNodeInfo<DATA, DIM> > &eltList, const KDTreeBox<DIM> &region);
0109 
0110   // Here we search in the KDTree for all points that would be
0111   // contained in the given searchbox. The founded points are stored in resRecHitList.
0112   void search(const KDTreeBox<DIM> &searchBox, std::vector<DATA> &resRecHitList);
0113 
0114   // This reurns true if the tree is empty
0115   bool empty() { return nodePool_.empty(); }
0116 
0117   // This returns the number of nodes + leaves in the tree
0118   // (nElements should be (size() +1)/2)
0119   int size() { return nodePool_.size(); }
0120 
0121   // This method clears all allocated structures.
0122   void clear() { clearTree(); }
0123 
0124 private:
0125   // The node pool allow us to do just 1 call to new for each tree building.
0126   KDTreeNodes<DATA, DIM> nodePool_;
0127 
0128   std::vector<DATA> *closestNeighbour;
0129   std::vector<KDTreeNodeInfo<DATA, DIM> > *initialEltList;
0130 
0131   //Fast median search with Wirth algorithm in eltList between low and high indexes.
0132   int medianSearch(int low, int high, int treeDepth) const;
0133 
0134   // Recursif kdtree builder. Is called by build()
0135   int recBuild(int low, int hight, int depth);
0136 
0137   // Recursif kdtree search. Is called by search()
0138   void recSearch(int current, const KDTreeBox<DIM> &trackBox, int depth = 0) const;
0139 
0140   // This method frees the KDTree.
0141   void clearTree() { nodePool_.clear(); }
0142 };
0143 
0144 //Implementation
0145 
0146 template <typename DATA, unsigned int DIM>
0147 void KDTreeLinkerAlgo<DATA, DIM>::build(std::vector<KDTreeNodeInfo<DATA, DIM> > &eltList,
0148                                         const KDTreeBox<DIM> &region) {
0149   if (!eltList.empty()) {
0150     initialEltList = &eltList;
0151 
0152     size_t size = initialEltList->size();
0153     nodePool_.build(size);
0154 
0155     // Here we build the KDTree
0156     int root = recBuild(0, size, 0);
0157     assert(root == 0);
0158 
0159     initialEltList = nullptr;
0160   }
0161 }
0162 
0163 //Fast median search with Wirth algorithm in eltList between low and high indexes.
0164 template <typename DATA, unsigned int DIM>
0165 int KDTreeLinkerAlgo<DATA, DIM>::medianSearch(int low, int high, int treeDepth) const {
0166   int nbrElts = high - low;
0167   int median = (nbrElts & 1) ? nbrElts / 2 : nbrElts / 2 - 1;
0168   median += low;
0169 
0170   int l = low;
0171   int m = high - 1;
0172 
0173   while (l < m) {
0174     KDTreeNodeInfo<DATA, DIM> elt = (*initialEltList)[median];
0175     int i = l;
0176     int j = m;
0177 
0178     do {
0179       // The even depth is associated to dim1 dimension
0180       // The odd one to dim2 dimension
0181       const unsigned thedim = treeDepth % DIM;
0182       while ((*initialEltList)[i].dims[thedim] < elt.dims[thedim])
0183         ++i;
0184       while ((*initialEltList)[j].dims[thedim] > elt.dims[thedim])
0185         --j;
0186 
0187       if (i <= j) {
0188         std::swap((*initialEltList)[i], (*initialEltList)[j]);
0189         i++;
0190         j--;
0191       }
0192     } while (i <= j);
0193     if (j < median)
0194       l = i;
0195     if (i > median)
0196       m = j;
0197   }
0198 
0199   return median;
0200 }
0201 
0202 template <typename DATA, unsigned int DIM>
0203 void KDTreeLinkerAlgo<DATA, DIM>::search(const KDTreeBox<DIM> &trackBox, std::vector<DATA> &recHits) {
0204   if (!empty()) {
0205     closestNeighbour = &recHits;
0206     recSearch(0, trackBox, 0);
0207     closestNeighbour = nullptr;
0208   }
0209 }
0210 
0211 template <typename DATA, unsigned int DIM>
0212 void KDTreeLinkerAlgo<DATA, DIM>::recSearch(int current, const KDTreeBox<DIM> &trackBox, int depth) const {
0213   // Iterate until leaf is found, or there are no children in the
0214   // search window. If search has to proceed on both children, proceed
0215   // the search to left child via recursion. Swap search window
0216   // dimension on alternate levels.
0217   while (true) {
0218     const int dimIndex = depth % DIM;
0219     int right = nodePool_.right[current];
0220     if (nodePool_.isLeaf(right)) {
0221       // If point inside the rectangle/area
0222       // Use intentionally bit-wise & instead of logical && for better
0223       // performance. It is faster to always do all comparisons than to
0224       // allow use of branches to not do some if any of the first ones
0225       // is false.
0226       bool isInside = true;
0227       for (unsigned i = 0; i < DIM; ++i) {
0228         float dimCurr = nodePool_.dims[i][current];
0229         isInside &= (dimCurr >= trackBox.dimmin[i]) && (dimCurr <= trackBox.dimmax[i]);
0230       }
0231       if (isInside) {
0232         closestNeighbour->push_back(nodePool_.data[current]);
0233       }
0234       break;
0235     } else {
0236       float median = nodePool_.dims[dimIndex][current];
0237 
0238       bool goLeft = (trackBox.dimmin[dimIndex] <= median);
0239       bool goRight = (trackBox.dimmax[dimIndex] >= median);
0240 
0241       ++depth;
0242       if (goLeft & goRight) {
0243         int left = current + 1;
0244         recSearch(left, trackBox, depth);
0245         // continue with right
0246         current = right;
0247       } else if (goLeft) {
0248         ++current;
0249       } else if (goRight) {
0250         current = right;
0251       } else {
0252         break;
0253       }
0254     }
0255   }
0256 }
0257 
0258 template <typename DATA, unsigned int DIM>
0259 int KDTreeLinkerAlgo<DATA, DIM>::recBuild(int low, int high, int depth) {
0260   int portionSize = high - low;
0261 
0262   if (portionSize == 1) {  // Leaf case
0263     int leaf = nodePool_.getNextNode();
0264     const KDTreeNodeInfo<DATA, DIM> &info = (*initialEltList)[low];
0265     nodePool_.right[leaf] = 0;
0266     for (unsigned i = 0; i < DIM; ++i) {
0267       nodePool_.dims[i][leaf] = info.dims[i];
0268     }
0269     nodePool_.data[leaf] = info.data;
0270     return leaf;
0271 
0272   } else {  // Node case
0273 
0274     // The even depth is associated to dim1 dimension
0275     // The odd one to dim2 dimension
0276     int medianId = medianSearch(low, high, depth);
0277     int dimIndex = depth % DIM;
0278     float medianVal = (*initialEltList)[medianId].dims[dimIndex];
0279 
0280     // We create the node
0281     int nodeInd = nodePool_.getNextNode();
0282 
0283     ++depth;
0284     ++medianId;
0285 
0286     // We recursively build the son nodes
0287     int left = recBuild(low, medianId, depth);
0288     assert(nodeInd + 1 == left);
0289     int right = recBuild(medianId, high, depth);
0290     nodePool_.right[nodeInd] = right;
0291 
0292     nodePool_.dims[dimIndex][nodeInd] = medianVal;
0293 
0294     return nodeInd;
0295   }
0296 }
0297 
0298 #endif