Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:21:27

0001 #ifndef NNET_MULT_H_
0002 #define NNET_MULT_H_
0003 
0004 #include "nnet_common.h"
0005 #include <iostream>
0006 #include <math.h>
0007 
0008 namespace nnet {
0009 
0010   constexpr int ceillog2(int x) { return (x <= 2) ? 1 : 1 + ceillog2((x + 1) / 2); }
0011 
0012   namespace product {
0013 
0014     /* ---
0015  * different methods to perform the product of input and weight, depending on the
0016  * types of each.
0017  * --- */
0018 
0019     class Product {};
0020 
0021     template <class x_T, class w_T>
0022     class both_binary : public Product {
0023     public:
0024       static x_T product(x_T a, w_T w) { return a == w; }
0025     };
0026 
0027     template <class x_T, class w_T>
0028     class weight_binary : public Product {
0029     public:
0030       static auto product(x_T a, w_T w) -> decltype(-a) {
0031         if (w == 0)
0032           return -a;
0033         else
0034           return a;
0035       }
0036     };
0037 
0038     template <class x_T, class w_T>
0039     class data_binary : public Product {
0040     public:
0041       static auto product(x_T a, w_T w) -> decltype(-w) {
0042         if (a == 0)
0043           return -w;
0044         else
0045           return w;
0046       }
0047     };
0048 
0049     template <class x_T, class w_T>
0050     class weight_ternary : public Product {
0051     public:
0052       static auto product(x_T a, w_T w) -> decltype(-a) {
0053         if (w == 0)
0054           return 0;
0055         else if (w == -1)
0056           return -a;
0057         else
0058           return a;  // if(w == 1)
0059       }
0060     };
0061 
0062     template <class x_T, class w_T>
0063     class mult : public Product {
0064     public:
0065       static auto product(x_T a, w_T w) -> decltype(a * w) { return a * w; }
0066     };
0067 
0068     template <class x_T, class w_T>
0069     class weight_exponential : public Product {
0070     public:
0071       using r_T =
0072           ap_fixed<2 * (decltype(w_T::weight)::width + x_T::width), (decltype(w_T::weight)::width + x_T::width)>;
0073       static r_T product(x_T a, w_T w) {
0074         // Shift by the exponent. Negative weights shift right
0075         r_T y = static_cast<r_T>(a) << w.weight;
0076 
0077         // Negate or not depending on weight sign
0078         return w.sign == 1 ? y : static_cast<r_T>(-y);
0079       }
0080     };
0081 
0082   }  // namespace product
0083 
0084   template <class data_T, class res_T, typename CONFIG_T>
0085   inline typename std::enable_if<std::is_same<data_T, ap_uint<1>>::value &&
0086                                      std::is_same<typename CONFIG_T::weight_t, ap_uint<1>>::value,
0087                                  ap_int<nnet::ceillog2(CONFIG_T::n_in) + 2>>::type
0088   cast(typename CONFIG_T::accum_t x) {
0089     return (ap_int<nnet::ceillog2(CONFIG_T::n_in) + 2>)(x - CONFIG_T::n_in / 2) * 2;
0090   }
0091 
0092   template <class data_T, class res_T, typename CONFIG_T>
0093   inline typename std::enable_if<std::is_same<data_T, ap_uint<1>>::value &&
0094                                      !std::is_same<typename CONFIG_T::weight_t, ap_uint<1>>::value,
0095                                  res_T>::type
0096   cast(typename CONFIG_T::accum_t x) {
0097     return (res_T)x;
0098   }
0099 
0100   template <class data_T, class res_T, typename CONFIG_T>
0101   inline typename std::enable_if<(!std::is_same<data_T, ap_uint<1>>::value), res_T>::type cast(
0102       typename CONFIG_T::accum_t x) {
0103     return (res_T)x;
0104   }
0105 
0106 }  // namespace nnet
0107 
0108 #endif