Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-03-17 11:22:31

0001 #ifndef RecoTracker_MkFitCore_src_Matriplex_MatriplexCommon_h
0002 #define RecoTracker_MkFitCore_src_Matriplex_MatriplexCommon_h
0003 
0004 #include <cstring>
0005 
0006 // Use intrinsics version of code when available, done via CPP flags.
0007 // #define  MPLEX_USE_INTRINSICS
0008 
0009 //==============================================================================
0010 // Intrinsics -- preamble
0011 //==============================================================================
0012 
0013 #if defined(__x86_64__)
0014 #include "immintrin.h"
0015 #else
0016 #include <cstdlib>
0017 #endif
0018 
0019 #if defined(MPLEX_USE_INTRINSICS)
0020 // This seems unnecessary: __AVX__ is usually defined for all higher ISA extensions
0021 #if defined(__AVX__) || defined(__AVX512F__)
0022 
0023 #define MPLEX_INTRINSICS
0024 
0025 #endif
0026 
0027 #if defined(__AVX512F__)
0028 
0029 typedef __m512 IntrVec_t;
0030 #define MPLEX_INTRINSICS_WIDTH_BYTES 64
0031 #define MPLEX_INTRINSICS_WIDTH_BITS 512
0032 #define AVX512_INTRINSICS
0033 #define GATHER_INTRINSICS
0034 #define GATHER_IDX_LOAD(name, arr) __m512i name = _mm512_load_epi32(arr);
0035 
0036 #define LD(a, i) _mm512_load_ps(&a[i * N + n])
0037 #define ST(a, i, r) _mm512_store_ps(&a[i * N + n], r)
0038 #define ADD(a, b) _mm512_add_ps(a, b)
0039 #define MUL(a, b) _mm512_mul_ps(a, b)
0040 #define FMA(a, b, v) _mm512_fmadd_ps(a, b, v)
0041 
0042 #elif defined(__AVX2__) && defined(__FMA__)
0043 
0044 typedef __m256 IntrVec_t;
0045 #define MPLEX_INTRINSICS_WIDTH_BYTES 32
0046 #define MPLEX_INTRINSICS_WIDTH_BITS 256
0047 #define AVX2_INTRINSICS
0048 #define GATHER_INTRINSICS
0049 // Previously used _mm256_load_epi32(arr) here, but that's part of AVX-512F, not AVX2
0050 #define GATHER_IDX_LOAD(name, arr) __m256i name = _mm256_load_si256(reinterpret_cast<const __m256i *>(arr));
0051 
0052 #define LD(a, i) _mm256_load_ps(&a[i * N + n])
0053 #define ST(a, i, r) _mm256_store_ps(&a[i * N + n], r)
0054 #define ADD(a, b) _mm256_add_ps(a, b)
0055 #define MUL(a, b) _mm256_mul_ps(a, b)
0056 #define FMA(a, b, v) _mm256_fmadd_ps(a, b, v)
0057 
0058 #elif defined(__AVX__)
0059 
0060 typedef __m256 IntrVec_t;
0061 #define MPLEX_INTRINSICS_WIDTH_BYTES 32
0062 #define MPLEX_INTRINSICS_WIDTH_BITS 256
0063 #define AVX_INTRINSICS
0064 
0065 #define LD(a, i) _mm256_load_ps(&a[i * N + n])
0066 #define ST(a, i, r) _mm256_store_ps(&a[i * N + n], r)
0067 #define ADD(a, b) _mm256_add_ps(a, b)
0068 #define MUL(a, b) _mm256_mul_ps(a, b)
0069 // #define FMA(a, b, v)  { __m256 temp = _mm256_mul_ps(a, b); v = _mm256_add_ps(temp, v); }
0070 inline __m256 FMA(const __m256 &a, const __m256 &b, const __m256 &v) {
0071   __m256 temp = _mm256_mul_ps(a, b);
0072   return _mm256_add_ps(temp, v);
0073 }
0074 
0075 #endif
0076 
0077 #endif
0078 
0079 #ifdef __INTEL_COMPILER
0080 #define ASSUME_ALIGNED(a, b) __assume_aligned(a, b)
0081 #else
0082 #define ASSUME_ALIGNED(a, b) a = static_cast<decltype(a)>(__builtin_assume_aligned(a, b))
0083 #endif
0084 
0085 namespace Matriplex {
0086   typedef int idx_t;
0087 
0088   void align_check(const char *pref, void *adr);
0089 }  // namespace Matriplex
0090 
0091 #endif