Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:01:57

0001 #ifndef TensorIndex_h
0002 #define TensorIndex_h
0003 ///
0004 ///Credit:
0005 ///Utility class from
0006 ///
0007 ///http://www.sitmo.com/doc/A_Simple_and_Extremely_Fast_CPP_Template_for_Matrices_and_Tensors
0008 ///
0009 ///Usage:
0010 ///
0011 ///The template below offers a simple and efficient solution for handling matrices and tensors in C++. The idea is to store the matrix (or tensor) in a standard vector by translating the multidimensional index to a one dimensional index.
0012 ///
0013 ///The only thing we need to do is to convert two dimensional indices (r,c) into a one dimensional index. Using template we can do this very efficiently compile time, minimizing the runtime overhead.
0014 ///
0015 template <int d1, int d2 = 1, int d3 = 1, int d4 = 1>
0016 class TensorIndex {
0017 public:
0018   enum { SIZE = d1 * d2 * d3 * d4 };
0019   enum { LEN1 = d1 };
0020   enum { LEN2 = d2 };
0021   enum { LEN3 = d3 };
0022   enum { LEN4 = d4 };
0023 
0024   static int indexOf(const int i) { return i; }
0025   static int indexOf(const int i, const int j) { return j * d1 + i; }
0026   static int indexOf(const int i, const int j, const int k) { return (k * d2 + j) * d1 + i; }
0027   static int indexOf(const int i, const int j, const int k, const int l) { return ((l * d3 + k) * d2 + j) * d1 + i; }
0028 };
0029 
0030 template <int d1, int d2 = 1, int d3 = 1, int d4 = 1>
0031 class TensorIndex_base1 {
0032 public:
0033   enum { SIZE = d1 * d2 * d3 * d4 };
0034   enum { LEN1 = d1 };
0035   enum { LEN2 = d2 };
0036   enum { LEN3 = d3 };
0037   enum { LEN4 = d4 };
0038 
0039   static int indexOf(const int i) { return i - 1; }
0040   static int indexOf(const int i, const int j) { return j * d1 + i - 1 - d1; }
0041   static int indexOf(const int i, const int j, const int k) { return (k * d2 + j) * d1 + i - 1 - d1 - d1 * d2; }
0042   static int indexOf(const int i, const int j, const int k, const int l) {
0043     return ((l * d3 + k) * d2 + j) * d1 + i - 1 - d1 - d1 * d2 - d1 * d2 * d3;
0044   }
0045 };
0046 #endif