Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-05-04 22:50:31

0001 #define CATCH_CONFIG_MAIN
0002 #include <catch.hpp>
0003 
0004 #include "DataFormats/SoATemplate/interface/SoALayout.h"
0005 
0006 GENERATE_SOA_LAYOUT(SoATemplate,
0007                     SOA_COLUMN(float, x),
0008                     SOA_COLUMN(float, y),
0009                     SOA_COLUMN(float, z),
0010                     SOA_COLUMN(double, v_x),
0011                     SOA_COLUMN(double, v_y),
0012                     SOA_COLUMN(double, v_z),
0013 
0014                     SOA_ELEMENT_METHODS(
0015 
0016                         void normalise() {
0017                           float norm_position = square_norm_position();
0018                           if (norm_position > 0.0f) {
0019                             x() /= norm_position;
0020                             y() /= norm_position;
0021                             z() /= norm_position;
0022                           };
0023                           double norm_velocity = square_norm_velocity();
0024                           if (norm_velocity > 0.0f) {
0025                             v_x() /= norm_velocity;
0026                             v_y() /= norm_velocity;
0027                             v_z() /= norm_velocity;
0028                           };
0029                         }),
0030 
0031                     SOA_CONST_ELEMENT_METHODS(
0032                         float square_norm_position() const { return sqrt(x() * x() + y() * y() + z() * z()); };
0033 
0034                         double square_norm_velocity()
0035                             const { return sqrt(v_x() * v_x() + v_y() * v_y() + v_z() * v_z()); };
0036 
0037                         template <typename T1, typename T2>
0038                         static auto time(T1 pos, T2 vel) {
0039                           if (not(vel == 0))
0040                             return pos / vel;
0041                           return 0.;
0042                         }),
0043 
0044                     SOA_SCALAR(int, detectorType))
0045 
0046 using SoA = SoATemplate<>;
0047 using SoAView = SoA::View;
0048 using SoAConstView = SoA::ConstView;
0049 
0050 TEST_CASE("SoACustomizedMethods") {
0051   // common number of elements for the SoAs
0052   const std::size_t elems = 10;
0053 
0054   // buffer size
0055   const std::size_t bufferSize = SoA::computeDataSize(elems);
0056 
0057   // memory buffer for the SoA
0058   std::unique_ptr<std::byte, decltype(std::free) *> buffer{
0059       reinterpret_cast<std::byte *>(aligned_alloc(SoA::alignment, bufferSize)), std::free};
0060 
0061   // SoA objects
0062   SoA soa{buffer.get(), elems};
0063   SoAView view{soa};
0064   SoAConstView const_view{soa};
0065 
0066   // fill up
0067   for (size_t i = 0; i < elems; i++) {
0068     view[i].x() = static_cast<float>(i);
0069     view[i].y() = static_cast<float>(i) * 2.0f;
0070     view[i].z() = static_cast<float>(i) * 3.0f;
0071     view[i].v_x() = static_cast<double>(i);
0072     view[i].v_y() = static_cast<double>(i) * 20;
0073     view[i].v_z() = static_cast<double>(i) * 30;
0074   }
0075   view.detectorType() = 42;
0076 
0077   SECTION("ConstView methods") {
0078     // arrays of norms
0079     std::array<float, elems> position_norms;
0080     std::array<double, elems> velocity_norms;
0081 
0082     // Check for the correctness of the square_norm() functions
0083     for (size_t i = 0; i < elems; i++) {
0084       position_norms[i] = sqrt(const_view[i].x() * const_view[i].x() + const_view[i].y() * const_view[i].y() +
0085                                const_view[i].z() * const_view[i].z());
0086       velocity_norms[i] = sqrt(const_view[i].v_x() * const_view[i].v_x() + const_view[i].v_y() * const_view[i].v_y() +
0087                                const_view[i].v_z() * const_view[i].v_z());
0088       REQUIRE(position_norms[i] == const_view[i].square_norm_position());
0089       REQUIRE(velocity_norms[i] == const_view[i].square_norm_velocity());
0090     }
0091   }
0092 
0093   SECTION("View methods") {
0094     // array of times
0095     std::array<double, elems> times;
0096 
0097     // Check for the correctness of the time() function
0098     times[0] = 0.;
0099     for (size_t i = 0; i < elems; i++) {
0100       if (not(i == 0))
0101         times[i] = view[i].x() / view[i].v_x();
0102       REQUIRE(times[i] == SoAView::const_element::time(view[i].x(), view[i].v_x()));
0103     }
0104 
0105     // normalise the particles data
0106     for (size_t i = 0; i < elems; i++) {
0107       view[i].normalise();
0108     }
0109 
0110     // Check for the norm equal to 1 except for the first element
0111     REQUIRE(view[0].square_norm_position() == 0.f);
0112     REQUIRE(view[0].square_norm_velocity() == 0.);
0113     for (size_t i = 1; i < elems; i++) {
0114       REQUIRE_THAT(view[i].square_norm_position(), Catch::Matchers::WithinAbs(1.f, 1.e-6));
0115       REQUIRE_THAT(view[i].square_norm_velocity(), Catch::Matchers::WithinAbs(1., 1.e-9));
0116     }
0117   }
0118 }