File indexing completed on 2024-04-06 12:15:44
0001 #ifndef HeterogeneousCore_CUDAUtilities_interface_eigenSoA_h
0002 #define HeterogeneousCore_CUDAUtilities_interface_eigenSoA_h
0003
0004 #include <algorithm>
0005 #include <cmath>
0006 #include <cstdint>
0007
0008 #include <Eigen/Core>
0009
0010 #include "HeterogeneousCore/CUDAUtilities/interface/cudaCompat.h"
0011
0012 namespace eigenSoA {
0013
0014 constexpr bool isPowerOf2(int32_t v) { return v && !(v & (v - 1)); }
0015
0016 template <typename T, int S>
0017 class alignas(128) ScalarSoA {
0018 public:
0019 using Scalar = T;
0020
0021 __host__ __device__ constexpr Scalar& operator()(int32_t i) { return data_[i]; }
0022 __device__ constexpr const Scalar operator()(int32_t i) const { return __ldg(data_ + i); }
0023 __host__ __device__ constexpr Scalar& operator[](int32_t i) { return data_[i]; }
0024 __device__ constexpr const Scalar operator[](int32_t i) const { return __ldg(data_ + i); }
0025
0026 __host__ __device__ constexpr Scalar* data() { return data_; }
0027 __host__ __device__ constexpr Scalar const* data() const { return data_; }
0028
0029 private:
0030 Scalar data_[S];
0031 static_assert(isPowerOf2(S), "SoA stride not a power of 2");
0032 static_assert(sizeof(data_) % 128 == 0, "SoA size not a multiple of 128");
0033 };
0034
0035 template <typename M, int S>
0036 class alignas(128) MatrixSoA {
0037 public:
0038 using Scalar = typename M::Scalar;
0039 using Map = Eigen::Map<M, 0, Eigen::Stride<M::RowsAtCompileTime * S, S> >;
0040 using CMap = Eigen::Map<const M, 0, Eigen::Stride<M::RowsAtCompileTime * S, S> >;
0041
0042 __host__ __device__ constexpr Map operator()(int32_t i) { return Map(data_ + i); }
0043 __host__ __device__ constexpr CMap operator()(int32_t i) const { return CMap(data_ + i); }
0044 __host__ __device__ constexpr Map operator[](int32_t i) { return Map(data_ + i); }
0045 __host__ __device__ constexpr CMap operator[](int32_t i) const { return CMap(data_ + i); }
0046
0047 private:
0048 Scalar data_[S * M::RowsAtCompileTime * M::ColsAtCompileTime];
0049 static_assert(isPowerOf2(S), "SoA stride not a power of 2");
0050 static_assert(sizeof(data_) % 128 == 0, "SoA size not a multiple of 128");
0051 };
0052
0053 }
0054
0055 #endif