Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-06-06 01:33:39

0001 #ifndef RecoTracker_LSTCore_src_alpaka_Triplet_h
0002 #define RecoTracker_LSTCore_src_alpaka_Triplet_h
0003 
0004 #include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
0005 #include "FWCore/Utilities/interface/isFinite.h"
0006 
0007 #include "RecoTracker/LSTCore/interface/alpaka/Common.h"
0008 #include "RecoTracker/LSTCore/interface/ModulesSoA.h"
0009 #include "RecoTracker/LSTCore/interface/ObjectRangesSoA.h"
0010 #include "RecoTracker/LSTCore/interface/TripletsSoA.h"
0011 #include "RecoTracker/LSTCore/interface/Circle.h"
0012 
0013 #include "NeuralNetwork.h"
0014 
0015 namespace ALPAKA_ACCELERATOR_NAMESPACE::lst {
0016 
0017   ALPAKA_FN_ACC ALPAKA_FN_INLINE void addTripletToMemory(ModulesConst modules,
0018                                                          MiniDoubletsConst mds,
0019                                                          SegmentsConst segments,
0020                                                          Triplets& triplets,
0021                                                          unsigned int innerSegmentIndex,
0022                                                          unsigned int outerSegmentIndex,
0023                                                          uint16_t innerInnerLowerModuleIndex,
0024                                                          uint16_t middleLowerModuleIndex,
0025                                                          uint16_t outerOuterLowerModuleIndex,
0026 #ifdef CUT_VALUE_DEBUG
0027                                                          float zOut,
0028                                                          float rtOut,
0029 #endif
0030                                                          float betaIn,
0031                                                          float betaInCut,
0032                                                          float circleRadius,
0033                                                          float circleCenterX,
0034                                                          float circleCenterY,
0035                                                          unsigned int tripletIndex,
0036                                                          float (&t3Scores)[dnn::t3dnn::kOutputFeatures]) {
0037     triplets.segmentIndices()[tripletIndex][0] = innerSegmentIndex;
0038     triplets.segmentIndices()[tripletIndex][1] = outerSegmentIndex;
0039     triplets.lowerModuleIndices()[tripletIndex][0] = innerInnerLowerModuleIndex;
0040     triplets.lowerModuleIndices()[tripletIndex][1] = middleLowerModuleIndex;
0041     triplets.lowerModuleIndices()[tripletIndex][2] = outerOuterLowerModuleIndex;
0042 
0043     triplets.betaIn()[tripletIndex] = __F2H(betaIn);
0044     triplets.radius()[tripletIndex] = circleRadius;
0045     triplets.centerX()[tripletIndex] = circleCenterX;
0046     triplets.centerY()[tripletIndex] = circleCenterY;
0047     triplets.logicalLayers()[tripletIndex][0] =
0048         modules.layers()[innerInnerLowerModuleIndex] + (modules.subdets()[innerInnerLowerModuleIndex] == 4) * 6;
0049     triplets.logicalLayers()[tripletIndex][1] =
0050         modules.layers()[middleLowerModuleIndex] + (modules.subdets()[middleLowerModuleIndex] == 4) * 6;
0051     triplets.logicalLayers()[tripletIndex][2] =
0052         modules.layers()[outerOuterLowerModuleIndex] + (modules.subdets()[outerOuterLowerModuleIndex] == 4) * 6;
0053     //get the hits
0054     unsigned int firstMDIndex = segments.mdIndices()[innerSegmentIndex][0];
0055     unsigned int secondMDIndex = segments.mdIndices()[innerSegmentIndex][1];
0056     unsigned int thirdMDIndex = segments.mdIndices()[outerSegmentIndex][1];
0057 
0058     triplets.hitIndices()[tripletIndex][0] = mds.anchorHitIndices()[firstMDIndex];
0059     triplets.hitIndices()[tripletIndex][1] = mds.outerHitIndices()[firstMDIndex];
0060     triplets.hitIndices()[tripletIndex][2] = mds.anchorHitIndices()[secondMDIndex];
0061     triplets.hitIndices()[tripletIndex][3] = mds.outerHitIndices()[secondMDIndex];
0062     triplets.hitIndices()[tripletIndex][4] = mds.anchorHitIndices()[thirdMDIndex];
0063     triplets.hitIndices()[tripletIndex][5] = mds.outerHitIndices()[thirdMDIndex];
0064 #ifdef CUT_VALUE_DEBUG
0065     triplets.zOut()[tripletIndex] = zOut;
0066     triplets.rtOut()[tripletIndex] = rtOut;
0067     triplets.betaInCut()[tripletIndex] = betaInCut;
0068 #endif
0069 
0070     triplets.fakeScore()[tripletIndex] = t3Scores[0];
0071     triplets.promptScore()[tripletIndex] = t3Scores[1];
0072     triplets.displacedScore()[tripletIndex] = t3Scores[2];
0073   }
0074 
0075   template <typename TAcc>
0076   ALPAKA_FN_ACC ALPAKA_FN_INLINE bool passRZConstraint(TAcc const& acc,
0077                                                        ModulesConst modules,
0078                                                        MiniDoubletsConst mds,
0079                                                        uint16_t innerInnerLowerModuleIndex,
0080                                                        uint16_t middleLowerModuleIndex,
0081                                                        uint16_t outerOuterLowerModuleIndex,
0082                                                        unsigned int firstMDIndex,
0083                                                        unsigned int secondMDIndex,
0084                                                        unsigned int thirdMDIndex,
0085                                                        float circleRadius,
0086                                                        float circleCenterX,
0087                                                        float circleCenterY) {
0088     // Using lst_layer numbering convention defined in ModuleMethods.h
0089     const int layer1 = modules.lstLayers()[innerInnerLowerModuleIndex];
0090     const int layer2 = modules.lstLayers()[middleLowerModuleIndex];
0091     const int layer3 = modules.lstLayers()[outerOuterLowerModuleIndex];
0092 
0093     //all the values are stored in the unit of cm, in the calculation below we need to be cautious if we want to use the meter unit
0094     //get r and z
0095     const float r1 = mds.anchorRt()[firstMDIndex] / 100;
0096     const float r2 = mds.anchorRt()[secondMDIndex] / 100;
0097     const float r3 = mds.anchorRt()[thirdMDIndex] / 100;
0098 
0099     const float z1 = mds.anchorZ()[firstMDIndex] / 100;
0100     const float z2 = mds.anchorZ()[secondMDIndex] / 100;
0101     const float z3 = mds.anchorZ()[thirdMDIndex] / 100;
0102 
0103     //use linear approximation for regions 9 and 20-24 because it works better (see https://github.com/SegmentLinking/cmssw/pull/92)
0104     float residual = alpaka::math::abs(acc, z2 - ((z3 - z1) / (r3 - r1) * (r2 - r1) + z1));
0105 
0106     //region definitions: https://github.com/user-attachments/assets/2b3c1425-66eb-4524-83de-deb6f3b31f71
0107     if (layer1 == 1 && layer2 == 7) {
0108       return residual < 0.01f;  // Region 9
0109     } else if (layer1 == 3 && layer2 == 4) {
0110       if (layer3 == 5) {
0111         return residual < 0.037127972f;  // Region 20
0112       } else if (layer3 == 12) {
0113         return residual < 0.05f;  // Region 21
0114       }
0115     } else if (layer1 == 4) {
0116       if (layer2 == 12) {
0117         return residual < 0.063831687f;  // Region 22
0118       } else if (layer2 == 5) {
0119         if (layer3 == 6) {
0120           return residual < 0.04362525f;  // Region 23
0121         } else if (layer3 == 12) {
0122           return residual < 0.05f;  // Region 24
0123         }
0124       }
0125     }
0126 
0127     //get the type of module: 0 is ps, 1 is 2s
0128     const int moduleType3 = modules.moduleType()[outerOuterLowerModuleIndex];
0129 
0130     //get the x,y position of each MD
0131     const float x1 = mds.anchorX()[firstMDIndex] / 100;
0132     const float x2 = mds.anchorX()[secondMDIndex] / 100;
0133     const float x3 = mds.anchorX()[thirdMDIndex] / 100;
0134 
0135     const float y1 = mds.anchorY()[firstMDIndex] / 100;
0136     const float y2 = mds.anchorY()[secondMDIndex] / 100;
0137     const float y3 = mds.anchorY()[thirdMDIndex] / 100;
0138 
0139     //set initial and target points
0140     float x_init = x2;
0141     float y_init = y2;
0142     float z_init = z2;
0143     float r_init = r2;
0144 
0145     float z_target = z3;
0146     float r_target = r3;
0147 
0148     float x_other = x1;
0149     float y_other = y1;
0150 
0151     float dz = z2 - z1;
0152 
0153     //use MD2 for regions 5 and 19 because it works better (see https://github.com/SegmentLinking/cmssw/pull/92)
0154     if ((layer1 == 8 && layer2 == 14 && layer3 == 15) || (layer1 == 3 && layer2 == 12 && layer3 == 13)) {
0155       x_init = x1;
0156       y_init = y1;
0157       z_init = z1;
0158       r_init = r1;
0159 
0160       z_target = z2;
0161       r_target = r2;
0162 
0163       x_other = x3;
0164       y_other = y3;
0165 
0166       dz = z3 - z1;
0167     }
0168 
0169     //use the 3 MDs to fit a circle. This is the circle parameters, for circle centers and circle radius
0170     float x_center = circleCenterX / 100;
0171     float y_center = circleCenterY / 100;
0172     float pt = 2 * k2Rinv1GeVf * circleRadius;  //k2Rinv1GeVf is already in cm^(-1)
0173 
0174     //determine the charge
0175     int charge = 0;
0176     if ((x2 - x1) * (y3 - y1) - (y2 - y1) * (x3 - x1) > 0)
0177       charge = -1;
0178     else
0179       charge = 1;
0180 
0181     //get the absolute value of px and py at the initial point
0182     float px = 2 * k2Rinv1GeVf * alpaka::math::abs(acc, (y_init - y_center)) * 100;
0183     float py = 2 * k2Rinv1GeVf * alpaka::math::abs(acc, (x_init - x_center)) * 100;
0184 
0185     //Above line only gives you the correct value of px and py, but signs of px and py calculated below.
0186     //We look at if the circle is clockwise or anti-clock wise, to make it simpler, we separate the x-y plane into 4 quarters.
0187     if (x_init > x_center && y_init > y_center)  //1st quad
0188     {
0189       if (charge == 1)
0190         py = -py;
0191       if (charge == -1)
0192         px = -px;
0193     }
0194     if (x_init < x_center && y_init > y_center)  //2nd quad
0195     {
0196       if (charge == -1) {
0197         px = -px;
0198         py = -py;
0199       }
0200     }
0201     if (x_init < x_center && y_init < y_center)  //3rd quad
0202     {
0203       if (charge == 1)
0204         px = -px;
0205       if (charge == -1)
0206         py = -py;
0207     }
0208     if (x_init > x_center && y_init < y_center)  //4th quad
0209     {
0210       if (charge == 1) {
0211         px = -px;
0212         py = -py;
0213       }
0214     }
0215 
0216     //But if the initial T3 curve goes across quarters(i.e. cross axis to separate the quarters), need special redeclaration of px,py signs on these to avoid errors
0217     if (x3 < x2 && x2 < x1)
0218       px = -alpaka::math::abs(acc, px);
0219     else if (x3 > x2 && x2 > x1)
0220       px = alpaka::math::abs(acc, px);
0221     if (y3 < y2 && y2 < y1)
0222       py = -alpaka::math::abs(acc, py);
0223     else if (y3 > y2 && y2 > y1)
0224       py = alpaka::math::abs(acc, py);
0225 
0226     float AO = alpaka::math::sqrt(
0227         acc, (x_other - x_center) * (x_other - x_center) + (y_other - y_center) * (y_other - y_center));
0228     float BO =
0229         alpaka::math::sqrt(acc, (x_init - x_center) * (x_init - x_center) + (y_init - y_center) * (y_init - y_center));
0230     float AB2 = (x_other - x_init) * (x_other - x_init) + (y_other - y_init) * (y_other - y_init);
0231     float dPhi = alpaka::math::acos(acc, (AO * AO + BO * BO - AB2) / (2 * AO * BO));  //Law of Cosines
0232     float ds = circleRadius / 100 * dPhi;
0233     float pz = dz / ds * pt;
0234 
0235     float p = alpaka::math::sqrt(acc, px * px + py * py + pz * pz);
0236     float a = -2.f * k2Rinv1GeVf * 100 * charge;
0237     float rou = a / p;
0238 
0239     float rzChiSquared = 0;
0240     float error = 0;
0241 
0242     //check the tilted module, side: PosZ, NegZ, Center(for not tilted)
0243     float drdz = alpaka::math::abs(acc, modules.drdzs()[outerOuterLowerModuleIndex]);
0244     short side = modules.sides()[outerOuterLowerModuleIndex];
0245     short subdets = modules.subdets()[outerOuterLowerModuleIndex];
0246 
0247     //calculate residual
0248     if (layer3 <= 6 && ((side == lst::Center) or (drdz < 1))) {  // for barrel
0249       float paraA = r_init * r_init + 2 * (px * px + py * py) / (a * a) + 2 * (y_init * px - x_init * py) / a -
0250                     r_target * r_target;
0251       float paraB = 2 * (x_init * px + y_init * py) / a;
0252       float paraC = 2 * (y_init * px - x_init * py) / a + 2 * (px * px + py * py) / (a * a);
0253       float A = paraB * paraB + paraC * paraC;
0254       float B = 2 * paraA * paraB;
0255       float C = paraA * paraA - paraC * paraC;
0256       float sol1 = (-B + alpaka::math::sqrt(acc, B * B - 4 * A * C)) / (2 * A);
0257       float sol2 = (-B - alpaka::math::sqrt(acc, B * B - 4 * A * C)) / (2 * A);
0258       float solz1 = alpaka::math::asin(acc, sol1) / rou * pz / p + z_init;
0259       float solz2 = alpaka::math::asin(acc, sol2) / rou * pz / p + z_init;
0260       float diffz1 = (solz1 - z_target) * 100;
0261       float diffz2 = (solz2 - z_target) * 100;
0262       if (edm::isNotFinite(diffz1))
0263         residual = diffz2;
0264       else if (edm::isNotFinite(diffz2))
0265         residual = diffz1;
0266       else {
0267         residual = (alpaka::math::abs(acc, diffz1) < alpaka::math::abs(acc, diffz2)) ? diffz1 : diffz2;
0268       }
0269     } else {  // for endcap
0270       float s = (z_target - z_init) * p / pz;
0271       float x = x_init + px / a * alpaka::math::sin(acc, rou * s) - py / a * (1 - alpaka::math::cos(acc, rou * s));
0272       float y = y_init + py / a * alpaka::math::sin(acc, rou * s) + px / a * (1 - alpaka::math::cos(acc, rou * s));
0273       residual = (r_target - alpaka::math::sqrt(acc, x * x + y * y)) * 100;
0274     }
0275 
0276     // error
0277     if (moduleType3 == 0) {
0278       error = 0.15f;  //PS
0279     } else {
0280       error = 5.0f;  //2S
0281     }
0282 
0283     float projection_missing2 = 1;
0284     if (drdz < 1)
0285       projection_missing2 = ((subdets == lst::Endcap) or (side == lst::Center))
0286                                 ? 1.f
0287                                 : 1 / (1 + drdz * drdz);  // cos(atan(drdz)), if dr/dz<1
0288     if (drdz > 1)
0289       projection_missing2 = ((subdets == lst::Endcap) or (side == lst::Center))
0290                                 ? 1.f
0291                                 : drdz * drdz / (1 + drdz * drdz);  //sin(atan(drdz)), if dr/dz>1
0292 
0293     rzChiSquared = 12 * (residual * residual) / (error * error * projection_missing2);
0294 
0295     //helix calculation returns NaN, use linear approximation
0296     if (edm::isNotFinite(rzChiSquared) || circleRadius < 0) {
0297       float slope = (z3 - z1) / (r3 - r1);
0298 
0299       residual = (layer3 <= 6) ? ((z3 - z1) - slope * (r3 - r1)) : ((r3 - r1) - (z3 - z1) / slope);
0300       residual = (moduleType3 == 0) ? residual / 0.15f : residual / 5.0f;
0301 
0302       rzChiSquared = 12 * residual * residual;
0303       return rzChiSquared < 2.8e-4;
0304     }
0305 
0306     //cuts for different regions
0307     //region definitions: https://github.com/user-attachments/assets/2b3c1425-66eb-4524-83de-deb6f3b31f71
0308     //for the logic behind the cuts, see https://github.com/SegmentLinking/cmssw/pull/92
0309     if (layer1 == 7) {
0310       if (layer2 == 8) {
0311         if (layer3 == 9) {
0312           return rzChiSquared < 65.47191f;  // Region 0
0313         } else if (layer3 == 14) {
0314           return rzChiSquared < 3.3200853f;  // Region 1
0315         }
0316       } else if (layer2 == 13) {
0317         return rzChiSquared < 17.194584f;  // Region 2
0318       }
0319     } else if (layer1 == 8) {
0320       if (layer2 == 9) {
0321         if (layer3 == 10) {
0322           return rzChiSquared < 114.91959f;  // Region 3
0323         } else if (layer3 == 15) {
0324           return rzChiSquared < 3.4359624f;  // Region 4
0325         }
0326       } else if (layer2 == 14) {
0327         return rzChiSquared < 4.6487956f;  // Region 5
0328       }
0329     } else if (layer1 == 9) {
0330       if (layer2 == 10) {
0331         if (layer3 == 11) {
0332           return rzChiSquared < 97.34339f;  // Region 6
0333         } else if (layer3 == 16) {
0334           return rzChiSquared < 3.095819f;  // Region 7
0335         }
0336       } else if (layer2 == 15) {
0337         return rzChiSquared < 11.477617f;  // Region 8
0338       }
0339     } else if (layer1 == 1) {
0340       if (layer3 == 7) {
0341         return rzChiSquared < 96.949936f;  // Region 10
0342       } else if (layer3 == 3) {
0343         return rzChiSquared < 458.43982f;  // Region 11
0344       }
0345     } else if (layer1 == 2) {
0346       if (layer2 == 7) {
0347         if (layer3 == 8) {
0348           return rzChiSquared < 218.82303f;  // Region 12
0349         } else if (layer3 == 13) {
0350           return rzChiSquared < 3.155554f;  // Region 13
0351         }
0352       } else if (layer2 == 3) {
0353         if (layer3 == 7) {
0354           return rzChiSquared < 235.5005f;  // Region 14
0355         } else if (layer3 == 12) {
0356           return rzChiSquared < 3.8522234f;  // Region 15
0357         } else if (layer3 == 4) {
0358           return rzChiSquared < 3.5852437f;  // Region 16
0359         }
0360       }
0361     } else if (layer1 == 3) {
0362       if (layer2 == 7) {
0363         if (layer3 == 8) {
0364           return rzChiSquared < 42.68f;  // Region 17
0365         } else if (layer3 == 13) {
0366           return rzChiSquared < 3.853796f;  // Region 18
0367         }
0368       } else if (layer2 == 12) {
0369         return rzChiSquared < 6.2774787f;  // Region 19
0370       }
0371     }
0372     return false;
0373   }
0374 
0375   template <typename TAcc>
0376   ALPAKA_FN_ACC ALPAKA_FN_INLINE bool passPointingConstraintBBB(TAcc const& acc,
0377                                                                 ModulesConst modules,
0378                                                                 MiniDoubletsConst mds,
0379                                                                 SegmentsConst segments,
0380                                                                 uint16_t innerInnerLowerModuleIndex,
0381                                                                 uint16_t middleLowerModuleIndex,
0382                                                                 uint16_t outerOuterLowerModuleIndex,
0383                                                                 unsigned int firstMDIndex,
0384                                                                 unsigned int secondMDIndex,
0385                                                                 unsigned int thirdMDIndex,
0386                                                                 float& zOut,
0387                                                                 float& rtOut,
0388                                                                 unsigned int innerSegmentIndex,
0389                                                                 float& betaIn,
0390                                                                 float& betaInCut,
0391                                                                 const float ptCut) {
0392     float rtIn = mds.anchorRt()[firstMDIndex];
0393     float rtMid = mds.anchorRt()[secondMDIndex];
0394     float drt_InSeg = rtMid - rtIn;
0395 
0396     // raw betaIn value without any correction, based on the mini-doublet hit positions
0397     float alpha_InLo = __H2F(segments.dPhiChanges()[innerSegmentIndex]);
0398     float tl_axis_x = mds.anchorX()[thirdMDIndex] - mds.anchorX()[firstMDIndex];
0399     float tl_axis_y = mds.anchorY()[thirdMDIndex] - mds.anchorY()[firstMDIndex];
0400     betaIn = alpha_InLo - cms::alpakatools::reducePhiRange(
0401                               acc, cms::alpakatools::phi(acc, tl_axis_x, tl_axis_y) - mds.anchorPhi()[firstMDIndex]);
0402 
0403     //beta computation
0404     float drt_tl_axis = alpaka::math::sqrt(acc, tl_axis_x * tl_axis_x + tl_axis_y * tl_axis_y);
0405 
0406     //innerOuterAnchor - innerInnerAnchor
0407     const float rt_InSeg = alpaka::math::sqrt(acc,
0408                                               (mds.anchorX()[secondMDIndex] - mds.anchorX()[firstMDIndex]) *
0409                                                       (mds.anchorX()[secondMDIndex] - mds.anchorX()[firstMDIndex]) +
0410                                                   (mds.anchorY()[secondMDIndex] - mds.anchorY()[firstMDIndex]) *
0411                                                       (mds.anchorY()[secondMDIndex] - mds.anchorY()[firstMDIndex]));
0412     betaInCut =
0413         alpaka::math::asin(acc, alpaka::math::min(acc, (-rt_InSeg + drt_tl_axis) * k2Rinv1GeVf / ptCut, kSinAlphaMax)) +
0414         (0.02f / drt_InSeg);
0415 
0416     //Beta cut
0417     return alpaka::math::abs(acc, betaIn) < betaInCut;
0418   }
0419 
0420   template <typename TAcc>
0421   ALPAKA_FN_ACC ALPAKA_FN_INLINE bool passPointingConstraintBBE(TAcc const& acc,
0422                                                                 ModulesConst modules,
0423                                                                 MiniDoubletsConst mds,
0424                                                                 SegmentsConst segments,
0425                                                                 uint16_t innerInnerLowerModuleIndex,
0426                                                                 uint16_t middleLowerModuleIndex,
0427                                                                 uint16_t outerOuterLowerModuleIndex,
0428                                                                 unsigned int firstMDIndex,
0429                                                                 unsigned int secondMDIndex,
0430                                                                 unsigned int thirdMDIndex,
0431                                                                 float& zOut,
0432                                                                 float& rtOut,
0433                                                                 uint16_t innerOuterLowerModuleIndex,
0434                                                                 unsigned int innerSegmentIndex,
0435                                                                 unsigned int outerSegmentIndex,
0436                                                                 float& betaIn,
0437                                                                 float& betaInCut,
0438                                                                 const float ptCut) {
0439     float rt_InLo = mds.anchorRt()[firstMDIndex];
0440     float rt_InOut = mds.anchorRt()[secondMDIndex];
0441 
0442     float sdIn_alpha = __H2F(segments.dPhiChanges()[innerSegmentIndex]);
0443 
0444     float tl_axis_x = mds.anchorX()[thirdMDIndex] - mds.anchorX()[firstMDIndex];
0445     float tl_axis_y = mds.anchorY()[thirdMDIndex] - mds.anchorY()[firstMDIndex];
0446 
0447     betaIn = sdIn_alpha - cms::alpakatools::reducePhiRange(
0448                               acc, cms::alpakatools::phi(acc, tl_axis_x, tl_axis_y) - mds.anchorPhi()[firstMDIndex]);
0449 
0450     float betaInRHmin = betaIn;
0451     float betaInRHmax = betaIn;
0452 
0453     float swapTemp;
0454 
0455     if (alpaka::math::abs(acc, betaInRHmin) > alpaka::math::abs(acc, betaInRHmax)) {
0456       swapTemp = betaInRHmin;
0457       betaInRHmin = betaInRHmax;
0458       betaInRHmax = swapTemp;
0459     }
0460 
0461     float sdIn_dr = alpaka::math::sqrt(acc,
0462                                        (mds.anchorX()[secondMDIndex] - mds.anchorX()[firstMDIndex]) *
0463                                                (mds.anchorX()[secondMDIndex] - mds.anchorX()[firstMDIndex]) +
0464                                            (mds.anchorY()[secondMDIndex] - mds.anchorY()[firstMDIndex]) *
0465                                                (mds.anchorY()[secondMDIndex] - mds.anchorY()[firstMDIndex]));
0466     float sdIn_d = rt_InOut - rt_InLo;
0467 
0468     float dr = alpaka::math::sqrt(acc, tl_axis_x * tl_axis_x + tl_axis_y * tl_axis_y);
0469     betaInCut = alpaka::math::asin(acc, alpaka::math::min(acc, (-sdIn_dr + dr) * k2Rinv1GeVf / ptCut, kSinAlphaMax)) +
0470                 (0.02f / sdIn_d);
0471 
0472     //Beta cut
0473     return alpaka::math::abs(acc, betaInRHmin) < betaInCut;
0474   }
0475 
0476   template <typename TAcc>
0477   ALPAKA_FN_ACC ALPAKA_FN_INLINE bool passPointingConstraintEEE(TAcc const& acc,
0478                                                                 ModulesConst modules,
0479                                                                 MiniDoubletsConst mds,
0480                                                                 SegmentsConst segments,
0481                                                                 uint16_t innerInnerLowerModuleIndex,
0482                                                                 uint16_t middleLowerModuleIndex,
0483                                                                 uint16_t outerOuterLowerModuleIndex,
0484                                                                 unsigned int firstMDIndex,
0485                                                                 unsigned int secondMDIndex,
0486                                                                 unsigned int thirdMDIndex,
0487                                                                 float& zOut,
0488                                                                 float& rtOut,
0489                                                                 unsigned int innerSegmentIndex,
0490                                                                 unsigned int outerSegmentIndex,
0491                                                                 float& betaIn,
0492                                                                 float& betaInCut,
0493                                                                 const float ptCut) {
0494     float rt_InLo = mds.anchorRt()[firstMDIndex];
0495     float rt_InOut = mds.anchorRt()[secondMDIndex];
0496     float sdIn_alpha = __H2F(segments.dPhiChanges()[innerSegmentIndex]);
0497 
0498     float tl_axis_x = mds.anchorX()[thirdMDIndex] - mds.anchorX()[firstMDIndex];
0499     float tl_axis_y = mds.anchorY()[thirdMDIndex] - mds.anchorY()[firstMDIndex];
0500 
0501     betaIn = sdIn_alpha - cms::alpakatools::reducePhiRange(
0502                               acc, cms::alpakatools::phi(acc, tl_axis_x, tl_axis_y) - mds.anchorPhi()[firstMDIndex]);
0503 
0504     float sdIn_alphaRHmin = __H2F(segments.dPhiChangeMins()[innerSegmentIndex]);
0505     float sdIn_alphaRHmax = __H2F(segments.dPhiChangeMaxs()[innerSegmentIndex]);
0506     float betaInRHmin = betaIn + sdIn_alphaRHmin - sdIn_alpha;
0507     float betaInRHmax = betaIn + sdIn_alphaRHmax - sdIn_alpha;
0508 
0509     float swapTemp;
0510 
0511     if (alpaka::math::abs(acc, betaInRHmin) > alpaka::math::abs(acc, betaInRHmax)) {
0512       swapTemp = betaInRHmin;
0513       betaInRHmin = betaInRHmax;
0514       betaInRHmax = swapTemp;
0515     }
0516     float sdIn_dr = alpaka::math::sqrt(acc,
0517                                        (mds.anchorX()[secondMDIndex] - mds.anchorX()[firstMDIndex]) *
0518                                                (mds.anchorX()[secondMDIndex] - mds.anchorX()[firstMDIndex]) +
0519                                            (mds.anchorY()[secondMDIndex] - mds.anchorY()[firstMDIndex]) *
0520                                                (mds.anchorY()[secondMDIndex] - mds.anchorY()[firstMDIndex]));
0521     float sdIn_d = rt_InOut - rt_InLo;
0522 
0523     float dr = alpaka::math::sqrt(acc, tl_axis_x * tl_axis_x + tl_axis_y * tl_axis_y);
0524     betaInCut = alpaka::math::asin(acc, alpaka::math::min(acc, (-sdIn_dr + dr) * k2Rinv1GeVf / ptCut, kSinAlphaMax)) +
0525                 (0.02f / sdIn_d);
0526 
0527     //Beta cut
0528     return alpaka::math::abs(acc, betaInRHmin) < betaInCut;
0529   }
0530 
0531   template <typename TAcc>
0532   ALPAKA_FN_ACC ALPAKA_FN_INLINE bool passPointingConstraint(TAcc const& acc,
0533                                                              ModulesConst modules,
0534                                                              MiniDoubletsConst mds,
0535                                                              SegmentsConst segments,
0536                                                              uint16_t innerInnerLowerModuleIndex,
0537                                                              uint16_t middleLowerModuleIndex,
0538                                                              uint16_t outerOuterLowerModuleIndex,
0539                                                              unsigned int firstMDIndex,
0540                                                              unsigned int secondMDIndex,
0541                                                              unsigned int thirdMDIndex,
0542                                                              float& zOut,
0543                                                              float& rtOut,
0544                                                              uint16_t innerOuterLowerModuleIndex,
0545                                                              unsigned int innerSegmentIndex,
0546                                                              unsigned int outerSegmentIndex,
0547                                                              float& betaIn,
0548                                                              float& betaInCut,
0549                                                              const float ptCut) {
0550     short innerInnerLowerModuleSubdet = modules.subdets()[innerInnerLowerModuleIndex];
0551     short middleLowerModuleSubdet = modules.subdets()[middleLowerModuleIndex];
0552     short outerOuterLowerModuleSubdet = modules.subdets()[outerOuterLowerModuleIndex];
0553 
0554     if (innerInnerLowerModuleSubdet == Barrel and middleLowerModuleSubdet == Barrel and
0555         outerOuterLowerModuleSubdet == Barrel) {
0556       return passPointingConstraintBBB(acc,
0557                                        modules,
0558                                        mds,
0559                                        segments,
0560                                        innerInnerLowerModuleIndex,
0561                                        middleLowerModuleIndex,
0562                                        outerOuterLowerModuleIndex,
0563                                        firstMDIndex,
0564                                        secondMDIndex,
0565                                        thirdMDIndex,
0566                                        zOut,
0567                                        rtOut,
0568                                        innerSegmentIndex,
0569                                        betaIn,
0570                                        betaInCut,
0571                                        ptCut);
0572     } else if (innerInnerLowerModuleSubdet == Barrel and middleLowerModuleSubdet == Barrel and
0573                outerOuterLowerModuleSubdet == Endcap) {
0574       return passPointingConstraintBBE(acc,
0575                                        modules,
0576                                        mds,
0577                                        segments,
0578                                        innerInnerLowerModuleIndex,
0579                                        middleLowerModuleIndex,
0580                                        outerOuterLowerModuleIndex,
0581                                        firstMDIndex,
0582                                        secondMDIndex,
0583                                        thirdMDIndex,
0584                                        zOut,
0585                                        rtOut,
0586                                        innerOuterLowerModuleIndex,
0587                                        innerSegmentIndex,
0588                                        outerSegmentIndex,
0589                                        betaIn,
0590                                        betaInCut,
0591                                        ptCut);
0592     } else if (innerInnerLowerModuleSubdet == Barrel and middleLowerModuleSubdet == Endcap and
0593                outerOuterLowerModuleSubdet == Endcap) {
0594       return passPointingConstraintBBE(acc,
0595                                        modules,
0596                                        mds,
0597                                        segments,
0598                                        innerInnerLowerModuleIndex,
0599                                        middleLowerModuleIndex,
0600                                        outerOuterLowerModuleIndex,
0601                                        firstMDIndex,
0602                                        secondMDIndex,
0603                                        thirdMDIndex,
0604                                        zOut,
0605                                        rtOut,
0606                                        innerOuterLowerModuleIndex,
0607                                        innerSegmentIndex,
0608                                        outerSegmentIndex,
0609                                        betaIn,
0610                                        betaInCut,
0611                                        ptCut);
0612 
0613     }
0614 
0615     else if (innerInnerLowerModuleSubdet == Endcap and middleLowerModuleSubdet == Endcap and
0616              outerOuterLowerModuleSubdet == Endcap) {
0617       return passPointingConstraintEEE(acc,
0618                                        modules,
0619                                        mds,
0620                                        segments,
0621                                        innerInnerLowerModuleIndex,
0622                                        middleLowerModuleIndex,
0623                                        outerOuterLowerModuleIndex,
0624                                        firstMDIndex,
0625                                        secondMDIndex,
0626                                        thirdMDIndex,
0627                                        zOut,
0628                                        rtOut,
0629                                        innerSegmentIndex,
0630                                        outerSegmentIndex,
0631                                        betaIn,
0632                                        betaInCut,
0633                                        ptCut);
0634     }
0635     return false;  // failsafe
0636   }
0637 
0638   template <typename TAcc>
0639   ALPAKA_FN_ACC ALPAKA_FN_INLINE bool runTripletConstraintsAndAlgo(TAcc const& acc,
0640                                                                    ModulesConst modules,
0641                                                                    MiniDoubletsConst mds,
0642                                                                    SegmentsConst segments,
0643                                                                    uint16_t innerInnerLowerModuleIndex,
0644                                                                    uint16_t middleLowerModuleIndex,
0645                                                                    uint16_t outerOuterLowerModuleIndex,
0646                                                                    unsigned int innerSegmentIndex,
0647                                                                    unsigned int outerSegmentIndex,
0648                                                                    float& zOut,
0649                                                                    float& rtOut,
0650                                                                    float& betaIn,
0651                                                                    float& betaInCut,
0652                                                                    float& circleRadius,
0653                                                                    float& circleCenterX,
0654                                                                    float& circleCenterY,
0655                                                                    const float ptCut,
0656                                                                    float (&t3Scores)[dnn::t3dnn::kOutputFeatures]) {
0657     unsigned int firstMDIndex = segments.mdIndices()[innerSegmentIndex][0];
0658     unsigned int secondMDIndex = segments.mdIndices()[outerSegmentIndex][0];
0659     unsigned int thirdMDIndex = segments.mdIndices()[outerSegmentIndex][1];
0660 
0661     float x1 = mds.anchorX()[firstMDIndex];
0662     float x2 = mds.anchorX()[secondMDIndex];
0663     float x3 = mds.anchorX()[thirdMDIndex];
0664     float y1 = mds.anchorY()[firstMDIndex];
0665     float y2 = mds.anchorY()[secondMDIndex];
0666     float y3 = mds.anchorY()[thirdMDIndex];
0667 
0668     std::tie(circleRadius, circleCenterX, circleCenterY) =
0669         computeRadiusFromThreeAnchorHits(acc, x1, y1, x2, y2, x3, y3);
0670 
0671     if (not passRZConstraint(acc,
0672                              modules,
0673                              mds,
0674                              innerInnerLowerModuleIndex,
0675                              middleLowerModuleIndex,
0676                              outerOuterLowerModuleIndex,
0677                              firstMDIndex,
0678                              secondMDIndex,
0679                              thirdMDIndex,
0680                              circleRadius,
0681                              circleCenterX,
0682                              circleCenterY))
0683       return false;
0684 
0685     if (not passPointingConstraint(acc,
0686                                    modules,
0687                                    mds,
0688                                    segments,
0689                                    innerInnerLowerModuleIndex,
0690                                    middleLowerModuleIndex,
0691                                    outerOuterLowerModuleIndex,
0692                                    firstMDIndex,
0693                                    secondMDIndex,
0694                                    thirdMDIndex,
0695                                    zOut,
0696                                    rtOut,
0697                                    middleLowerModuleIndex,
0698                                    innerSegmentIndex,
0699                                    outerSegmentIndex,
0700                                    betaIn,
0701                                    betaInCut,
0702                                    ptCut))
0703       return false;
0704 
0705     bool inference =
0706         lst::t3dnn::runInference(acc, mds, firstMDIndex, secondMDIndex, thirdMDIndex, circleRadius, betaIn, t3Scores);
0707     if (!inference)  // T3-building cut
0708       return false;
0709 
0710     return true;
0711   }
0712 
0713   struct CreateTriplets {
0714     ALPAKA_FN_ACC void operator()(Acc3D const& acc,
0715                                   ModulesConst modules,
0716                                   MiniDoubletsConst mds,
0717                                   SegmentsConst segments,
0718                                   SegmentsOccupancyConst segmentsOccupancy,
0719                                   Triplets triplets,
0720                                   TripletsOccupancy tripletsOccupancy,
0721                                   ObjectRangesConst ranges,
0722                                   uint16_t* index_gpu,
0723                                   uint16_t nonZeroModules,
0724                                   const float ptCut) const {
0725       for (uint16_t innerLowerModuleArrayIdx : cms::alpakatools::uniform_elements_z(acc, nonZeroModules)) {
0726         uint16_t innerInnerLowerModuleIndex = index_gpu[innerLowerModuleArrayIdx];
0727         if (innerInnerLowerModuleIndex >= modules.nLowerModules())
0728           continue;
0729 
0730         uint16_t nConnectedModules = modules.nConnectedModules()[innerInnerLowerModuleIndex];
0731         if (nConnectedModules == 0)
0732           continue;
0733 
0734         unsigned int nInnerSegments = segmentsOccupancy.nSegments()[innerInnerLowerModuleIndex];
0735         for (unsigned int innerSegmentArrayIndex : cms::alpakatools::uniform_elements_y(acc, nInnerSegments)) {
0736           unsigned int innerSegmentIndex =
0737               ranges.segmentRanges()[innerInnerLowerModuleIndex][0] + innerSegmentArrayIndex;
0738 
0739           // middle lower module - outer lower module of inner segment
0740           uint16_t middleLowerModuleIndex = segments.outerLowerModuleIndices()[innerSegmentIndex];
0741 
0742           unsigned int nOuterSegments = segmentsOccupancy.nSegments()[middleLowerModuleIndex];
0743           for (unsigned int outerSegmentArrayIndex : cms::alpakatools::uniform_elements_x(acc, nOuterSegments)) {
0744             unsigned int outerSegmentIndex = ranges.segmentRanges()[middleLowerModuleIndex][0] + outerSegmentArrayIndex;
0745 
0746             //this cut reduces the number of candidates by a factor of 4, i.e., 3 out of 4 warps can end right here!
0747             if (segments.mdIndices()[innerSegmentIndex][1] != segments.mdIndices()[outerSegmentIndex][0])
0748               continue;
0749 
0750             uint16_t outerOuterLowerModuleIndex = segments.outerLowerModuleIndices()[outerSegmentIndex];
0751 
0752             float zOut, rtOut, betaIn, betaInCut, circleRadius, circleCenterX, circleCenterY;
0753 
0754             float t3Scores[dnn::t3dnn::kOutputFeatures] = {0.f};
0755 
0756             bool success = runTripletConstraintsAndAlgo(acc,
0757                                                         modules,
0758                                                         mds,
0759                                                         segments,
0760                                                         innerInnerLowerModuleIndex,
0761                                                         middleLowerModuleIndex,
0762                                                         outerOuterLowerModuleIndex,
0763                                                         innerSegmentIndex,
0764                                                         outerSegmentIndex,
0765                                                         zOut,
0766                                                         rtOut,
0767                                                         betaIn,
0768                                                         betaInCut,
0769                                                         circleRadius,
0770                                                         circleCenterX,
0771                                                         circleCenterY,
0772                                                         ptCut,
0773                                                         t3Scores);
0774 
0775             if (success) {
0776               unsigned int totOccupancyTriplets =
0777                   alpaka::atomicAdd(acc,
0778                                     &tripletsOccupancy.totOccupancyTriplets()[innerInnerLowerModuleIndex],
0779                                     1u,
0780                                     alpaka::hierarchy::Threads{});
0781               if (static_cast<int>(totOccupancyTriplets) >=
0782                   ranges.tripletModuleOccupancy()[innerInnerLowerModuleIndex]) {
0783 #ifdef WARNINGS
0784                 printf("Triplet excess alert! Module index = %d, Occupancy = %d\n",
0785                        innerInnerLowerModuleIndex,
0786                        totOccupancyTriplets);
0787 #endif
0788               } else {
0789                 unsigned int tripletModuleIndex = alpaka::atomicAdd(
0790                     acc, &tripletsOccupancy.nTriplets()[innerInnerLowerModuleIndex], 1u, alpaka::hierarchy::Threads{});
0791                 unsigned int tripletIndex =
0792                     ranges.tripletModuleIndices()[innerInnerLowerModuleIndex] + tripletModuleIndex;
0793                 addTripletToMemory(modules,
0794                                    mds,
0795                                    segments,
0796                                    triplets,
0797                                    innerSegmentIndex,
0798                                    outerSegmentIndex,
0799                                    innerInnerLowerModuleIndex,
0800                                    middleLowerModuleIndex,
0801                                    outerOuterLowerModuleIndex,
0802 #ifdef CUT_VALUE_DEBUG
0803                                    zOut,
0804                                    rtOut,
0805 #endif
0806                                    betaIn,
0807                                    betaInCut,
0808                                    circleRadius,
0809                                    circleCenterX,
0810                                    circleCenterY,
0811                                    tripletIndex,
0812                                    t3Scores);
0813               }
0814             }
0815           }
0816         }
0817       }
0818     }
0819   };
0820 
0821   struct CreateTripletArrayRanges {
0822     ALPAKA_FN_ACC void operator()(Acc1D const& acc,
0823                                   ModulesConst modules,
0824                                   ObjectRanges ranges,
0825                                   SegmentsConst segments,
0826                                   SegmentsOccupancyConst segmentsOccupancy,
0827                                   const float ptCut) const {
0828       // implementation is 1D with a single block
0829       ALPAKA_ASSERT_ACC((alpaka::getWorkDiv<alpaka::Grid, alpaka::Blocks>(acc)[0] == 1));
0830 
0831       // Initialize variables in shared memory and set to 0
0832       int& nTotalTriplets = alpaka::declareSharedVar<int, __COUNTER__>(acc);
0833       if (cms::alpakatools::once_per_block(acc)) {
0834         nTotalTriplets = 0;
0835       }
0836       alpaka::syncBlockThreads(acc);
0837 
0838       // Occupancy matrix for 0.8 GeV pT Cut
0839       constexpr int p08_occupancy_matrix[4][4] = {
0840           {543, 235, 88, 46},  // category 0
0841           {755, 347, 0, 0},    // category 1
0842           {0, 0, 0, 0},        // category 2
0843           {0, 38, 46, 39}      // category 3
0844       };
0845 
0846       // Occupancy matrix for 0.6 GeV pT Cut, 99.9%
0847       constexpr int p06_occupancy_matrix[4][4] = {
0848           {1146, 544, 216, 83},  // category 0
0849           {1032, 275, 0, 0},     // category 1
0850           {0, 0, 0, 0},          // category 2
0851           {0, 115, 110, 76}      // category 3
0852       };
0853 
0854       // Select the appropriate occupancy matrix based on ptCut
0855       const auto& occupancy_matrix = (ptCut < 0.8f) ? p06_occupancy_matrix : p08_occupancy_matrix;
0856 
0857       for (uint16_t i : cms::alpakatools::uniform_elements(acc, modules.nLowerModules())) {
0858         if (segmentsOccupancy.nSegments()[i] == 0) {
0859           ranges.tripletModuleIndices()[i] = nTotalTriplets;
0860           ranges.tripletModuleOccupancy()[i] = 0;
0861           continue;
0862         }
0863 
0864         short module_rings = modules.rings()[i];
0865         short module_layers = modules.layers()[i];
0866         short module_subdets = modules.subdets()[i];
0867         float module_eta = alpaka::math::abs(acc, modules.eta()[i]);
0868 
0869         int category_number = getCategoryNumber(module_layers, module_subdets, module_rings);
0870         int eta_number = getEtaBin(module_eta);
0871 
0872         int dynamic_count = 0;
0873         // How many segments are in module i?
0874         int nSegments_i = segmentsOccupancy.nSegments()[i];
0875         int firstSegmentIdx = ranges.segmentRanges()[i][0];
0876         // Loop over all segments that live in module i
0877         for (int s = 0; s < nSegments_i; ++s) {
0878           int segIndex = firstSegmentIdx + s;
0879           uint16_t midModule = segments.outerLowerModuleIndices()[segIndex];
0880           dynamic_count += segmentsOccupancy.nSegments()[midModule];
0881         }
0882 
0883 #ifdef WARNINGS
0884         if (category_number == -1 || eta_number == -1) {
0885           printf("Unhandled case in createTripletArrayRanges! Module index = %i\n", i);
0886         }
0887 #endif
0888         // Get matrix-based cap
0889         int matrix_cap =
0890             (category_number != -1 && eta_number != -1) ? occupancy_matrix[category_number][eta_number] : 0;
0891 
0892         // Cap occupancy at minimum of dynamic count and matrix value
0893         int occupancy = alpaka::math::min(acc, dynamic_count, matrix_cap);
0894 
0895         ranges.tripletModuleOccupancy()[i] = occupancy;
0896         unsigned int nTotT = alpaka::atomicAdd(acc, &nTotalTriplets, occupancy, alpaka::hierarchy::Threads{});
0897         ranges.tripletModuleIndices()[i] = nTotT;
0898       }
0899 
0900       // Wait for all threads to finish before reporting final values
0901       alpaka::syncBlockThreads(acc);
0902       if (cms::alpakatools::once_per_block(acc)) {
0903         ranges.nTotalTrips() = nTotalTriplets;
0904       }
0905     }
0906   };
0907 
0908   struct AddTripletRangesToEventExplicit {
0909     ALPAKA_FN_ACC void operator()(Acc1D const& acc,
0910                                   ModulesConst modules,
0911                                   TripletsOccupancyConst tripletsOccupancy,
0912                                   ObjectRanges ranges) const {
0913       // implementation is 1D with a single block
0914       ALPAKA_ASSERT_ACC((alpaka::getWorkDiv<alpaka::Grid, alpaka::Blocks>(acc)[0] == 1));
0915 
0916       for (uint16_t i : cms::alpakatools::uniform_elements(acc, modules.nLowerModules())) {
0917         if (tripletsOccupancy.nTriplets()[i] == 0) {
0918           ranges.tripletRanges()[i][0] = -1;
0919           ranges.tripletRanges()[i][1] = -1;
0920         } else {
0921           ranges.tripletRanges()[i][0] = ranges.tripletModuleIndices()[i];
0922           ranges.tripletRanges()[i][1] = ranges.tripletModuleIndices()[i] + tripletsOccupancy.nTriplets()[i] - 1;
0923         }
0924       }
0925     }
0926   };
0927 }  // namespace ALPAKA_ACCELERATOR_NAMESPACE::lst
0928 #endif