File indexing completed on 2024-04-06 12:21:27
0001 #ifndef NNET_MULT_H_
0002 #define NNET_MULT_H_
0003
0004 #include "nnet_common.h"
0005 #include <iostream>
0006 #include <math.h>
0007
0008 namespace nnet {
0009
0010 constexpr int ceillog2(int x) { return (x <= 2) ? 1 : 1 + ceillog2((x + 1) / 2); }
0011
0012 namespace product {
0013
0014
0015
0016
0017
0018
0019 class Product {};
0020
0021 template <class x_T, class w_T>
0022 class both_binary : public Product {
0023 public:
0024 static x_T product(x_T a, w_T w) { return a == w; }
0025 };
0026
0027 template <class x_T, class w_T>
0028 class weight_binary : public Product {
0029 public:
0030 static auto product(x_T a, w_T w) -> decltype(-a) {
0031 if (w == 0)
0032 return -a;
0033 else
0034 return a;
0035 }
0036 };
0037
0038 template <class x_T, class w_T>
0039 class data_binary : public Product {
0040 public:
0041 static auto product(x_T a, w_T w) -> decltype(-w) {
0042 if (a == 0)
0043 return -w;
0044 else
0045 return w;
0046 }
0047 };
0048
0049 template <class x_T, class w_T>
0050 class weight_ternary : public Product {
0051 public:
0052 static auto product(x_T a, w_T w) -> decltype(-a) {
0053 if (w == 0)
0054 return 0;
0055 else if (w == -1)
0056 return -a;
0057 else
0058 return a;
0059 }
0060 };
0061
0062 template <class x_T, class w_T>
0063 class mult : public Product {
0064 public:
0065 static auto product(x_T a, w_T w) -> decltype(a * w) { return a * w; }
0066 };
0067
0068 template <class x_T, class w_T>
0069 class weight_exponential : public Product {
0070 public:
0071 using r_T =
0072 ap_fixed<2 * (decltype(w_T::weight)::width + x_T::width), (decltype(w_T::weight)::width + x_T::width)>;
0073 static r_T product(x_T a, w_T w) {
0074
0075 r_T y = static_cast<r_T>(a) << w.weight;
0076
0077
0078 return w.sign == 1 ? y : static_cast<r_T>(-y);
0079 }
0080 };
0081
0082 }
0083
0084 template <class data_T, class res_T, typename CONFIG_T>
0085 inline typename std::enable_if<std::is_same<data_T, ap_uint<1>>::value &&
0086 std::is_same<typename CONFIG_T::weight_t, ap_uint<1>>::value,
0087 ap_int<nnet::ceillog2(CONFIG_T::n_in) + 2>>::type
0088 cast(typename CONFIG_T::accum_t x) {
0089 return (ap_int<nnet::ceillog2(CONFIG_T::n_in) + 2>)(x - CONFIG_T::n_in / 2) * 2;
0090 }
0091
0092 template <class data_T, class res_T, typename CONFIG_T>
0093 inline typename std::enable_if<std::is_same<data_T, ap_uint<1>>::value &&
0094 !std::is_same<typename CONFIG_T::weight_t, ap_uint<1>>::value,
0095 res_T>::type
0096 cast(typename CONFIG_T::accum_t x) {
0097 return (res_T)x;
0098 }
0099
0100 template <class data_T, class res_T, typename CONFIG_T>
0101 inline typename std::enable_if<(!std::is_same<data_T, ap_uint<1>>::value), res_T>::type cast(
0102 typename CONFIG_T::accum_t x) {
0103 return (res_T)x;
0104 }
0105
0106 }
0107
0108 #endif