Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-06-06 01:33:37

0001 #ifndef RecoTracker_LSTCore_src_alpaka_NeuralNetwork_h
0002 #define RecoTracker_LSTCore_src_alpaka_NeuralNetwork_h
0003 
0004 #include "FWCore/Utilities/interface/CMSUnrollLoop.h"
0005 #include "HeterogeneousCore/AlpakaMath/interface/deltaPhi.h"
0006 
0007 #include "RecoTracker/LSTCore/interface/alpaka/Common.h"
0008 #include "RecoTracker/LSTCore/interface/MiniDoubletsSoA.h"
0009 
0010 #include "T5NeuralNetworkWeights.h"
0011 #include "T3NeuralNetworkWeights.h"
0012 #include "pT3NeuralNetworkWeights.h"
0013 #include "T5EmbedNetworkWeights.h"
0014 #include "pLSEmbedNetworkWeights.h"
0015 
0016 namespace ALPAKA_ACCELERATOR_NAMESPACE::lst {
0017 
0018   template <int FEATURES, typename TAcc>
0019   ALPAKA_FN_ACC ALPAKA_FN_INLINE void softmax_activation(TAcc const& acc, float (&input)[FEATURES]) {
0020     float sum = 0.f;
0021     // Compute exp and sum
0022     CMS_UNROLL_LOOP
0023     for (unsigned int i = 0; i < FEATURES; ++i) {
0024       input[i] = alpaka::math::exp(acc, input[i]);
0025       sum += input[i];
0026     }
0027 
0028     // Normalize
0029     CMS_UNROLL_LOOP
0030     for (unsigned int i = 0; i < FEATURES; ++i) {
0031       input[i] /= sum;
0032     }
0033   }
0034 
0035   template <int FEATURES>
0036   ALPAKA_FN_ACC ALPAKA_FN_INLINE void relu_activation(float (&input)[FEATURES]) {
0037     CMS_UNROLL_LOOP
0038     for (unsigned int col = 0; col < FEATURES; ++col) {
0039       input[col] = (input[col] > 0.f) ? input[col] : 0.f;
0040     }
0041   }
0042 
0043   template <typename TAcc>
0044   ALPAKA_FN_ACC ALPAKA_FN_INLINE float sigmoid_activation(TAcc const& acc, const float x) {
0045     return alpaka::math::exp(acc, x) / (alpaka::math::exp(acc, x) + 1.f);
0046   }
0047 
0048   template <int IN_FEATURES, int OUT_FEATURES>
0049   ALPAKA_FN_ACC ALPAKA_FN_INLINE void linear_layer(const float (&input)[IN_FEATURES],
0050                                                    float (&output)[OUT_FEATURES],
0051                                                    const float (&weights)[IN_FEATURES][OUT_FEATURES],
0052                                                    const float (&biases)[OUT_FEATURES]) {
0053     CMS_UNROLL_LOOP
0054     for (unsigned int i = 0; i < OUT_FEATURES; ++i) {
0055       output[i] = biases[i];
0056       CMS_UNROLL_LOOP
0057       for (int j = 0; j < IN_FEATURES; ++j) {
0058         output[i] += input[j] * weights[j][i];
0059       }
0060     }
0061   }
0062 
0063   namespace t3dnn {
0064     template <typename TAcc>
0065     ALPAKA_FN_ACC ALPAKA_FN_INLINE bool runInference(TAcc const& acc,
0066                                                      MiniDoubletsConst mds,
0067                                                      const unsigned int mdIndex1,
0068                                                      const unsigned int mdIndex2,
0069                                                      const unsigned int mdIndex3,
0070                                                      const float radius,
0071                                                      const float betaIn,
0072                                                      float (&output)[dnn::t3dnn::kOutputFeatures]) {
0073       // Constants for T3 DNN
0074       constexpr unsigned int kInputFeatures = 14;
0075       constexpr unsigned int kHiddenFeatures = 32;
0076 
0077       // Extract hit information
0078       float eta1 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex1]);  // inner T3 anchor hit 1 eta
0079       float eta2 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex2]);  // inner T3 anchor hit 2 eta
0080       float eta3 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex3]);  // inner T3 anchor hit 3 eta
0081 
0082       float phi1 = mds.anchorPhi()[mdIndex1];  // inner T3 anchor hit 1 phi
0083       float phi2 = mds.anchorPhi()[mdIndex2];  // inner T3 anchor hit 2 phi
0084       float phi3 = mds.anchorPhi()[mdIndex3];  // inner T3 anchor hit 3 phi
0085 
0086       float z1 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex1]);  // inner T3 anchor hit 1 z
0087       float z2 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex2]);  // inner T3 anchor hit 2 z
0088       float z3 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex3]);  // inner T3 anchor hit 3 z
0089 
0090       float r1 = mds.anchorRt()[mdIndex1];  // inner T3 anchor hit 1 r
0091       float r2 = mds.anchorRt()[mdIndex2];  // inner T3 anchor hit 2 r
0092       float r3 = mds.anchorRt()[mdIndex3];  // inner T3 anchor hit 3 r
0093 
0094       // Build input feature vector matching training order
0095       float x[kInputFeatures] = {
0096           eta1 / dnn::t3dnn::kEta_norm,                   // First hit eta normalized
0097           alpaka::math::abs(acc, phi1) / dnn::kPhi_norm,  // First hit phi normalized
0098           z1 / dnn::t3dnn::kZ_max,                        // First hit z normalized
0099           r1 / dnn::t3dnn::kR_max,                        // First hit r normalized
0100 
0101           eta2 - eta1,                                                   // Difference in eta between hit 2 and 1
0102           cms::alpakatools::deltaPhi(acc, phi2, phi1) / dnn::kPhi_norm,  // Difference in phi between hit 2 and 1
0103           (z2 - z1) / dnn::t3dnn::kZ_max,  // Difference in z between hit 2 and 1 normalized
0104           (r2 - r1) / dnn::t3dnn::kR_max,  // Difference in r between hit 2 and 1 normalized
0105 
0106           eta3 - eta2,                                                   // Difference in eta between hit 3 and 2
0107           cms::alpakatools::deltaPhi(acc, phi3, phi2) / dnn::kPhi_norm,  // Difference in phi between hit 3 and 2
0108           (z3 - z2) / dnn::t3dnn::kZ_max,  // Difference in z between hit 3 and 2 normalized
0109           (r3 - r2) / dnn::t3dnn::kR_max,  // Difference in r between hit 3 and 2 normalized
0110 
0111           alpaka::math::log10(acc, radius),  // T3's circle radius
0112           betaIn                             // Beta angle of inner segment
0113       };
0114 
0115       float x_1[kHiddenFeatures];  // Layer 1 output
0116       float x_2[kHiddenFeatures];  // Layer 2 output
0117 
0118       // Layer 1: Linear + Relu
0119       linear_layer<kInputFeatures, kHiddenFeatures>(x, x_1, dnn::t3dnn::wgtT_layer1, dnn::t3dnn::bias_layer1);
0120       relu_activation<kHiddenFeatures>(x_1);
0121 
0122       // Layer 2: Linear + Relu
0123       linear_layer<kHiddenFeatures, kHiddenFeatures>(x_1, x_2, dnn::t3dnn::wgtT_layer2, dnn::t3dnn::bias_layer2);
0124       relu_activation<kHiddenFeatures>(x_2);
0125 
0126       // Layer 3: Linear + Softmax
0127       linear_layer<kHiddenFeatures, dnn::t3dnn::kOutputFeatures>(
0128           x_2, output, dnn::t3dnn::wgtT_output_layer, dnn::t3dnn::bias_output_layer);
0129       softmax_activation<dnn::t3dnn::kOutputFeatures>(acc, output);
0130 
0131       // Get pt and eta bin indices
0132       float t3_pt = radius * lst::k2Rinv1GeVf * 2;
0133       uint8_t pt_index = (t3_pt > 5);
0134       uint8_t bin_index = (eta1 > 2.5f) ? (dnn::kEtaBins - 1) : static_cast<unsigned int>(eta1 / dnn::kEtaSize);
0135 
0136       return output[1] > dnn::t3dnn::kWp_prompt[pt_index][bin_index] ||
0137              output[2] > dnn::t3dnn::kWp_displaced[pt_index][bin_index];
0138     }
0139   }  // namespace t3dnn
0140 
0141   namespace pt3dnn {
0142 
0143     template <typename TAcc>
0144     ALPAKA_FN_ACC ALPAKA_FN_INLINE bool runInference(TAcc const& acc,
0145                                                      const float rPhiChiSquared,
0146                                                      const float tripletRadius,
0147                                                      const float pixelRadius,
0148                                                      const float pixRadiusError,
0149                                                      const float rzChiSquared,
0150                                                      const float pixelEta,
0151                                                      const float pixelPt) {
0152       constexpr unsigned int kInputFeatures = 6;
0153       constexpr unsigned int kHiddenFeatures = 32;
0154       constexpr unsigned int kOutputFeatures = 1;
0155 
0156       float x[kInputFeatures] = {alpaka::math::log10(acc, rPhiChiSquared),
0157                                  alpaka::math::log10(acc, tripletRadius),
0158                                  alpaka::math::log10(acc, pixelRadius),
0159                                  alpaka::math::log10(acc, pixRadiusError),
0160                                  alpaka::math::log10(acc, rzChiSquared),
0161                                  alpaka::math::abs(acc, pixelEta) / dnn::pt3dnn::kEta_norm};
0162 
0163       float x1[kHiddenFeatures];
0164       float x2[kHiddenFeatures];
0165       float x3[kOutputFeatures];
0166 
0167       linear_layer<kInputFeatures, kHiddenFeatures>(x, x1, dnn::pt3dnn::wgtT_layer1, dnn::pt3dnn::bias_layer1);
0168       relu_activation<kHiddenFeatures>(x1);
0169 
0170       linear_layer<kHiddenFeatures, kHiddenFeatures>(x1, x2, dnn::pt3dnn::wgtT_layer2, dnn::pt3dnn::bias_layer2);
0171       relu_activation<kHiddenFeatures>(x2);
0172 
0173       linear_layer<kHiddenFeatures, kOutputFeatures>(
0174           x2, x3, dnn::pt3dnn::wgtT_output_layer, dnn::pt3dnn::bias_output_layer);
0175       float output = sigmoid_activation(acc, x3[0]);
0176 
0177       uint8_t bin_index = (alpaka::math::abs(acc, pixelEta) > 2.5f)
0178                               ? (dnn::kEtaBins - 1)
0179                               : static_cast<unsigned int>(alpaka::math::abs(acc, pixelEta) / dnn::kEtaSize);
0180 
0181       if (pixelPt > 5.0f)
0182         return output > dnn::pt3dnn::kWpHigh;
0183 
0184       return output > dnn::pt3dnn::kWp[bin_index];
0185     }
0186 
0187   }  // namespace pt3dnn
0188 
0189   namespace t5dnn {
0190     template <typename TAcc>
0191     ALPAKA_FN_ACC ALPAKA_FN_INLINE bool runInference(TAcc const& acc,
0192                                                      MiniDoubletsConst mds,
0193                                                      const unsigned int mdIndex1,
0194                                                      const unsigned int mdIndex2,
0195                                                      const unsigned int mdIndex3,
0196                                                      const unsigned int mdIndex4,
0197                                                      const unsigned int mdIndex5,
0198                                                      const float innerRadius,
0199                                                      const float outerRadius,
0200                                                      const float bridgeRadius) {
0201       // Constants
0202       constexpr unsigned int kInputFeatures = 23;
0203       constexpr unsigned int kHiddenFeatures = 32;
0204 
0205       float eta1 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex1]);  // inner T3 anchor hit 1 eta
0206       float eta2 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex2]);  // inner T3 anchor hit 2 eta
0207       float eta3 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex3]);  // inner T3 anchor hit 3 eta
0208       float eta4 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex4]);  // outer T3 anchor hit 4 eta
0209       float eta5 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex5]);  // outer T3 anchor hit 5 eta
0210 
0211       float phi1 = mds.anchorPhi()[mdIndex1];  // inner T3 anchor hit 1 phi
0212       float phi2 = mds.anchorPhi()[mdIndex2];  // inner T3 anchor hit 2 phi
0213       float phi3 = mds.anchorPhi()[mdIndex3];  // inner T3 anchor hit 3 phi
0214       float phi4 = mds.anchorPhi()[mdIndex4];  // outer T3 anchor hit 4 phi
0215       float phi5 = mds.anchorPhi()[mdIndex5];  // outer T3 anchor hit 5 phi
0216 
0217       float z1 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex1]);  // inner T3 anchor hit 1 z
0218       float z2 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex2]);  // inner T3 anchor hit 2 z
0219       float z3 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex3]);  // inner T3 anchor hit 3 z
0220       float z4 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex4]);  // outer T3 anchor hit 4 z
0221       float z5 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex5]);  // outer T3 anchor hit 5 z
0222 
0223       float r1 = mds.anchorRt()[mdIndex1];  // inner T3 anchor hit 1 r
0224       float r2 = mds.anchorRt()[mdIndex2];  // inner T3 anchor hit 2 r
0225       float r3 = mds.anchorRt()[mdIndex3];  // inner T3 anchor hit 3 r
0226       float r4 = mds.anchorRt()[mdIndex4];  // outer T3 anchor hit 4 r
0227       float r5 = mds.anchorRt()[mdIndex5];  // outer T3 anchor hit 5 r
0228 
0229       // Build the input feature vector using pairwise differences after the first hit
0230       float x[kInputFeatures] = {
0231           eta1 / dnn::t5dnn::kEta_norm,                   // inner T3: First hit eta normalized
0232           alpaka::math::abs(acc, phi1) / dnn::kPhi_norm,  // inner T3: First hit phi normalized
0233           z1 / dnn::t5dnn::kZ_max,                        // inner T3: First hit z normalized
0234           r1 / dnn::t5dnn::kR_max,                        // inner T3: First hit r normalized
0235 
0236           eta2 - eta1,  // inner T3: Difference in eta between hit 2 and 1
0237           cms::alpakatools::deltaPhi(acc, phi2, phi1) /
0238               dnn::kPhi_norm,              // inner T3: Difference in phi between hit 2 and 1
0239           (z2 - z1) / dnn::t5dnn::kZ_max,  // inner T3: Difference in z between hit 2 and 1 normalized
0240           (r2 - r1) / dnn::t5dnn::kR_max,  // inner T3: Difference in r between hit 2 and 1 normalized
0241 
0242           eta3 - eta2,  // inner T3: Difference in eta between hit 3 and 2
0243           cms::alpakatools::deltaPhi(acc, phi3, phi2) /
0244               dnn::kPhi_norm,              // inner T3: Difference in phi between hit 3 and 2
0245           (z3 - z2) / dnn::t5dnn::kZ_max,  // inner T3: Difference in z between hit 3 and 2 normalized
0246           (r3 - r2) / dnn::t5dnn::kR_max,  // inner T3: Difference in r between hit 3 and 2 normalized
0247 
0248           eta4 - eta3,  // outer T3: Difference in eta between hit 4 and 3
0249           cms::alpakatools::deltaPhi(acc, phi4, phi3) /
0250               dnn::kPhi_norm,              // outer T3: Difference in phi between hit 4 and 3
0251           (z4 - z3) / dnn::t5dnn::kZ_max,  // outer T3: Difference in z between hit 4 and 3 normalized
0252           (r4 - r3) / dnn::t5dnn::kR_max,  // outer T3: Difference in r between hit 4 and 3 normalized
0253 
0254           eta5 - eta4,  // outer T3: Difference in eta between hit 5 and 4
0255           cms::alpakatools::deltaPhi(acc, phi5, phi4) /
0256               dnn::kPhi_norm,              // outer T3: Difference in phi between hit 5 and 4
0257           (z5 - z4) / dnn::t5dnn::kZ_max,  // outer T3: Difference in z between hit 5 and 4 normalized
0258           (r5 - r4) / dnn::t5dnn::kR_max,  // outer T3: Difference in r between hit 5 and 4 normalized
0259 
0260           alpaka::math::log10(acc, innerRadius),   // T5 inner radius
0261           alpaka::math::log10(acc, bridgeRadius),  // T5 bridge radius
0262           alpaka::math::log10(acc, outerRadius)    // T5 outer radius
0263       };
0264 
0265       float x_1[kHiddenFeatures];  // Layer 1 output
0266       float x_2[kHiddenFeatures];  // Layer 2 output
0267       float x_3[1];                // Layer 3 linear output
0268 
0269       // Layer 1: Linear + Relu
0270       linear_layer<kInputFeatures, kHiddenFeatures>(x, x_1, dnn::t5dnn::wgtT_layer1, dnn::t5dnn::bias_layer1);
0271       relu_activation<kHiddenFeatures>(x_1);
0272 
0273       // Layer 2: Linear + Relu
0274       linear_layer<kHiddenFeatures, kHiddenFeatures>(x_1, x_2, dnn::t5dnn::wgtT_layer2, dnn::t5dnn::bias_layer2);
0275       relu_activation<kHiddenFeatures>(x_2);
0276 
0277       // Layer 3: Linear + Sigmoid
0278       linear_layer<kHiddenFeatures, 1>(x_2, x_3, dnn::t5dnn::wgtT_output_layer, dnn::t5dnn::bias_output_layer);
0279       float x_5 = sigmoid_activation(acc, x_3[0]);
0280 
0281       // Get the bin index based on abs(eta) of first hit and t5_pt
0282       float t5_pt = innerRadius * lst::k2Rinv1GeVf * 2;
0283 
0284       uint8_t pt_index = (t5_pt > 5.0f);
0285       uint8_t bin_index = (eta1 > 2.5f) ? (dnn::kEtaBins - 1) : static_cast<unsigned int>(eta1 / dnn::kEtaSize);
0286 
0287       // Compare x_5 to the cut value for the relevant bin
0288       return x_5 > dnn::t5dnn::kWp[pt_index][bin_index];
0289     }
0290   }  // namespace t5dnn
0291 
0292   namespace t5embdnn {
0293     template <typename TAcc>
0294     ALPAKA_FN_ACC ALPAKA_FN_INLINE void runEmbed(TAcc const& acc,
0295                                                  MiniDoubletsConst mds,
0296                                                  unsigned int mdIndex1,
0297                                                  unsigned int mdIndex2,
0298                                                  unsigned int mdIndex3,
0299                                                  unsigned int mdIndex4,
0300                                                  unsigned int mdIndex5,
0301                                                  float innerRadius,
0302                                                  float outerRadius,
0303                                                  float bridgeRadius,
0304                                                  float fakeScore1,
0305                                                  float promptScore1,
0306                                                  float dispScore1,
0307                                                  float fakeScore2,
0308                                                  float promptScore2,
0309                                                  float dispScore2,
0310                                                  float (&embedding)[Params_T5::kEmbed]) {
0311       constexpr unsigned int kInputFeatures = 30;
0312       constexpr unsigned int kHiddenFeatures = 32;
0313 
0314       float eta1 = mds.anchorEta()[mdIndex1];
0315       float eta2 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex2]);
0316       float eta3 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex3]);
0317       float eta4 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex4]);
0318       float eta5 = alpaka::math::abs(acc, mds.anchorEta()[mdIndex5]);
0319 
0320       float phi1 = mds.anchorPhi()[mdIndex1];
0321       float phi2 = mds.anchorPhi()[mdIndex2];
0322       float phi3 = mds.anchorPhi()[mdIndex3];
0323       float phi4 = mds.anchorPhi()[mdIndex4];
0324       float phi5 = mds.anchorPhi()[mdIndex5];
0325 
0326       float z1 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex1]);
0327       float z2 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex2]);
0328       float z3 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex3]);
0329       float z4 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex4]);
0330       float z5 = alpaka::math::abs(acc, mds.anchorZ()[mdIndex5]);
0331 
0332       float r1 = mds.anchorRt()[mdIndex1];
0333       float r2 = mds.anchorRt()[mdIndex2];
0334       float r3 = mds.anchorRt()[mdIndex3];
0335       float r4 = mds.anchorRt()[mdIndex4];
0336       float r5 = mds.anchorRt()[mdIndex5];
0337 
0338       float x[kInputFeatures] = {eta1 / dnn::t5dnn::kEta_norm,
0339                                  alpaka::math::cos(acc, phi1),
0340                                  alpaka::math::sin(acc, phi1),
0341                                  z1 / dnn::t5dnn::kZ_max,
0342                                  r1 / dnn::t5dnn::kR_max,
0343 
0344                                  eta2 - alpaka::math::abs(acc, eta1),
0345                                  cms::alpakatools::deltaPhi(acc, phi2, phi1),
0346                                  (z2 - z1) / dnn::t5dnn::kZ_max,
0347                                  (r2 - r1) / dnn::t5dnn::kR_max,
0348 
0349                                  eta3 - eta2,
0350                                  cms::alpakatools::deltaPhi(acc, phi3, phi2),
0351                                  (z3 - z2) / dnn::t5dnn::kZ_max,
0352                                  (r3 - r2) / dnn::t5dnn::kR_max,
0353 
0354                                  eta4 - eta3,
0355                                  cms::alpakatools::deltaPhi(acc, phi4, phi3),
0356                                  (z4 - z3) / dnn::t5dnn::kZ_max,
0357                                  (r4 - r3) / dnn::t5dnn::kR_max,
0358 
0359                                  eta5 - eta4,
0360                                  cms::alpakatools::deltaPhi(acc, phi5, phi4),
0361                                  (z5 - z4) / dnn::t5dnn::kZ_max,
0362                                  (r5 - r4) / dnn::t5dnn::kR_max,
0363 
0364                                  1.0f / innerRadius,
0365                                  1.0f / bridgeRadius,
0366                                  1.0f / outerRadius,
0367 
0368                                  fakeScore1,
0369                                  promptScore1,
0370                                  dispScore1,
0371                                  (fakeScore2 - fakeScore1),
0372                                  (promptScore2 - promptScore1),
0373                                  (dispScore2 - dispScore1)};
0374 
0375       float h1[kHiddenFeatures];
0376       float h2[kHiddenFeatures];
0377 
0378       linear_layer<kInputFeatures, kHiddenFeatures>(x, h1, dnn::t5embdnn::wgtT_fc1, dnn::t5embdnn::bias_fc1);
0379       relu_activation<kHiddenFeatures>(h1);
0380 
0381       linear_layer<kHiddenFeatures, kHiddenFeatures>(h1, h2, dnn::t5embdnn::wgtT_fc2, dnn::t5embdnn::bias_fc2);
0382       relu_activation<kHiddenFeatures>(h2);
0383 
0384       linear_layer<kHiddenFeatures, Params_T5::kEmbed>(h2, embedding, dnn::t5embdnn::wgtT_fc3, dnn::t5embdnn::bias_fc3);
0385     }
0386 
0387   }  // namespace t5embdnn
0388 
0389   namespace plsembdnn {
0390     template <typename TAcc>
0391     ALPAKA_FN_ACC ALPAKA_FN_INLINE void runEmbed(TAcc const& acc,
0392                                                  const float eta,
0393                                                  const float etaErr,
0394                                                  const float phi,
0395                                                  const float circleCenterX,
0396                                                  const float circleCenterY,
0397                                                  const float circleRadius,
0398                                                  const float ptIn,
0399                                                  const float ptErr,
0400                                                  const bool isQuad,
0401                                                  float (&embedding)[Params_pLS::kEmbed]) {
0402       constexpr unsigned int kInputFeatures = 10;
0403       constexpr unsigned int kHiddenFeatures = 32;
0404 
0405       float x[kInputFeatures] = {eta / dnn::plsembdnn::kEta_norm,
0406                                  etaErr / dnn::plsembdnn::kEtaErr_norm,
0407                                  alpaka::math::cos(acc, phi),
0408                                  alpaka::math::sin(acc, phi),
0409                                  1.0f / ptIn,
0410                                  alpaka::math::log10(acc, ptErr),
0411                                  isQuad ? 1.0f : 0.0f,
0412                                  alpaka::math::log10(acc, alpaka::math::abs(acc, circleCenterX)),
0413                                  alpaka::math::log10(acc, alpaka::math::abs(acc, circleCenterY)),
0414                                  alpaka::math::log10(acc, circleRadius)};
0415 
0416       float h1[kHiddenFeatures];
0417       float h2[kHiddenFeatures];
0418 
0419       linear_layer<kInputFeatures, kHiddenFeatures>(x, h1, dnn::plsembdnn::wgtT_fc1, dnn::plsembdnn::bias_fc1);
0420       relu_activation<kHiddenFeatures>(h1);
0421 
0422       linear_layer<kHiddenFeatures, kHiddenFeatures>(h1, h2, dnn::plsembdnn::wgtT_fc2, dnn::plsembdnn::bias_fc2);
0423       relu_activation<kHiddenFeatures>(h2);
0424 
0425       linear_layer<kHiddenFeatures, Params_pLS::kEmbed>(
0426           h2, embedding, dnn::plsembdnn::wgtT_fc3, dnn::plsembdnn::bias_fc3);
0427     }
0428 
0429   }  // namespace plsembdnn
0430 
0431 }  // namespace ALPAKA_ACCELERATOR_NAMESPACE::lst
0432 
0433 #endif