File indexing completed on 2025-05-04 22:50:31
0001 #define CATCH_CONFIG_MAIN
0002 #include <catch.hpp>
0003
0004 #include "DataFormats/SoATemplate/interface/SoALayout.h"
0005
0006 GENERATE_SOA_LAYOUT(SoATemplate,
0007 SOA_COLUMN(float, x),
0008 SOA_COLUMN(float, y),
0009 SOA_COLUMN(float, z),
0010 SOA_COLUMN(double, v_x),
0011 SOA_COLUMN(double, v_y),
0012 SOA_COLUMN(double, v_z),
0013
0014 SOA_ELEMENT_METHODS(
0015
0016 void normalise() {
0017 float norm_position = square_norm_position();
0018 if (norm_position > 0.0f) {
0019 x() /= norm_position;
0020 y() /= norm_position;
0021 z() /= norm_position;
0022 };
0023 double norm_velocity = square_norm_velocity();
0024 if (norm_velocity > 0.0f) {
0025 v_x() /= norm_velocity;
0026 v_y() /= norm_velocity;
0027 v_z() /= norm_velocity;
0028 };
0029 }),
0030
0031 SOA_CONST_ELEMENT_METHODS(
0032 float square_norm_position() const { return sqrt(x() * x() + y() * y() + z() * z()); };
0033
0034 double square_norm_velocity()
0035 const { return sqrt(v_x() * v_x() + v_y() * v_y() + v_z() * v_z()); };
0036
0037 template <typename T1, typename T2>
0038 static auto time(T1 pos, T2 vel) {
0039 if (not(vel == 0))
0040 return pos / vel;
0041 return 0.;
0042 }),
0043
0044 SOA_SCALAR(int, detectorType))
0045
0046 using SoA = SoATemplate<>;
0047 using SoAView = SoA::View;
0048 using SoAConstView = SoA::ConstView;
0049
0050 TEST_CASE("SoACustomizedMethods") {
0051
0052 const std::size_t elems = 10;
0053
0054
0055 const std::size_t bufferSize = SoA::computeDataSize(elems);
0056
0057
0058 std::unique_ptr<std::byte, decltype(std::free) *> buffer{
0059 reinterpret_cast<std::byte *>(aligned_alloc(SoA::alignment, bufferSize)), std::free};
0060
0061
0062 SoA soa{buffer.get(), elems};
0063 SoAView view{soa};
0064 SoAConstView const_view{soa};
0065
0066
0067 for (size_t i = 0; i < elems; i++) {
0068 view[i].x() = static_cast<float>(i);
0069 view[i].y() = static_cast<float>(i) * 2.0f;
0070 view[i].z() = static_cast<float>(i) * 3.0f;
0071 view[i].v_x() = static_cast<double>(i);
0072 view[i].v_y() = static_cast<double>(i) * 20;
0073 view[i].v_z() = static_cast<double>(i) * 30;
0074 }
0075 view.detectorType() = 42;
0076
0077 SECTION("ConstView methods") {
0078
0079 std::array<float, elems> position_norms;
0080 std::array<double, elems> velocity_norms;
0081
0082
0083 for (size_t i = 0; i < elems; i++) {
0084 position_norms[i] = sqrt(const_view[i].x() * const_view[i].x() + const_view[i].y() * const_view[i].y() +
0085 const_view[i].z() * const_view[i].z());
0086 velocity_norms[i] = sqrt(const_view[i].v_x() * const_view[i].v_x() + const_view[i].v_y() * const_view[i].v_y() +
0087 const_view[i].v_z() * const_view[i].v_z());
0088 REQUIRE(position_norms[i] == const_view[i].square_norm_position());
0089 REQUIRE(velocity_norms[i] == const_view[i].square_norm_velocity());
0090 }
0091 }
0092
0093 SECTION("View methods") {
0094
0095 std::array<double, elems> times;
0096
0097
0098 times[0] = 0.;
0099 for (size_t i = 0; i < elems; i++) {
0100 if (not(i == 0))
0101 times[i] = view[i].x() / view[i].v_x();
0102 REQUIRE(times[i] == SoAView::const_element::time(view[i].x(), view[i].v_x()));
0103 }
0104
0105
0106 for (size_t i = 0; i < elems; i++) {
0107 view[i].normalise();
0108 }
0109
0110
0111 REQUIRE(view[0].square_norm_position() == 0.f);
0112 REQUIRE(view[0].square_norm_velocity() == 0.);
0113 for (size_t i = 1; i < elems; i++) {
0114 REQUIRE_THAT(view[i].square_norm_position(), Catch::Matchers::WithinAbs(1.f, 1.e-6));
0115 REQUIRE_THAT(view[i].square_norm_velocity(), Catch::Matchers::WithinAbs(1., 1.e-9));
0116 }
0117 }
0118 }