Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-06-20 01:53:32

0001 #ifndef L1TRIGGER_PHASE2L1PARTICLE_FLOW_HGCAL_CONV_NNET_ACTIVATION_H_
0002 #define L1TRIGGER_PHASE2L1PARTICLE_FLOW_HGCAL_CONV_NNET_ACTIVATION_H_
0003 
0004 #include "ap_fixed.h"
0005 #include "nnet_common.h"
0006 #include <cmath>
0007 
0008 namespace nnet {
0009 
0010   inline float exp_fcn_float(float input) { return std::exp(input); }
0011 
0012   template <class data_T, typename CONFIG_T>
0013   inline unsigned softmax_idx_from_real_val(data_T x) {
0014     // Slice the top N bits to get an index into the table
0015     static constexpr int N = ceillog2(CONFIG_T::table_size);  // number of address bits for table
0016     ap_uint<N> y = x(x.width - 1, x.width - N);               // slice the top N bits of input
0017     return (unsigned)y(N - 1, 0);
0018   }
0019 
0020   template <class data_T, typename CONFIG_T>
0021   inline float softmax_real_val_from_idx(unsigned i) {
0022     // Treat the index as the top N bits
0023     static constexpr int N = ceillog2(CONFIG_T::table_size);  // number of address bits for table
0024     data_T x(0);
0025     x(x.width - 1, x.width - N) = i;
0026     return (float)x;
0027   }
0028 
0029   template <class data_T, typename CONFIG_T>
0030   void init_exp_table(typename CONFIG_T::exp_table_t table_out[CONFIG_T::table_size]) {
0031     // The template data_T is the data type used to address the table
0032     for (unsigned i = 0; i < CONFIG_T::table_size; i++) {
0033       // Slicing bits for address is going to round towards 0, so take the central value
0034       float x = softmax_real_val_from_idx<data_T, CONFIG_T>(i);
0035       typename CONFIG_T::exp_table_t exp_x = exp_fcn_float(x);
0036       table_out[i] = exp_x;
0037     }
0038   }
0039 
0040   template <class data_T, typename CONFIG_T>
0041   void init_invert_table(typename CONFIG_T::inv_table_t table_out[CONFIG_T::table_size]) {
0042     // The template data_T is the data type used to address the table
0043     for (unsigned i = 0; i < CONFIG_T::table_size; i++) {
0044       float x = softmax_real_val_from_idx<data_T, CONFIG_T>(i);
0045       typename CONFIG_T::inv_table_t inv_x = 1 / x;
0046       table_out[i] = inv_x;
0047     }
0048   }
0049 
0050   template <class data_T, class res_T, typename CONFIG_T>
0051   void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) {
0052     // Initialize the lookup tables
0053 #ifdef __HLS_SYN__
0054     bool initialized = false;
0055     typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size];
0056     typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size];
0057 #else
0058     static bool initialized = false;
0059     static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size];
0060     static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size];
0061 
0062 #endif
0063     if (!initialized) {
0064       // Note we are exponentiating the inputs, which have type data_T
0065       init_exp_table<data_T, CONFIG_T>(exp_table);
0066       // Note we are inverting the exponentials, which have type exp_table_t
0067       init_invert_table<typename CONFIG_T::exp_table_t, CONFIG_T>(invert_table);
0068       initialized = true;
0069     }
0070 
0071     // Find the max and compute all delta(x_i, x_max)
0072     Op_max<data_T> op_max;
0073     data_T x_max = reduce<data_T, CONFIG_T::n_in, Op_max<data_T>>(data, op_max);
0074 
0075     // For the diffs, use the same type as the input but force rounding and saturation
0076     ap_fixed<data_T::width, data_T::iwidth, AP_RND, AP_SAT> d_xi_xmax[CONFIG_T::n_in];
0077     for (unsigned i = 0; i < CONFIG_T::n_in; i++) {
0078       d_xi_xmax[i] = data[i] - x_max;
0079     }
0080 
0081     // Calculate all the e^x's
0082     typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in];
0083     typename CONFIG_T::exp_table_t exp_sum(0);
0084     for (unsigned i = 0; i < CONFIG_T::n_in; i++) {
0085       unsigned x = softmax_idx_from_real_val<data_T, CONFIG_T>(d_xi_xmax[i]);
0086       exp_res[i] = exp_table[x];
0087     }
0088 
0089     // Explicitly sum the results with an adder tree.
0090     // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing
0091     Op_add<typename CONFIG_T::exp_table_t> op_add;
0092     exp_sum =
0093         reduce<typename CONFIG_T::exp_table_t, CONFIG_T::n_in, Op_add<typename CONFIG_T::exp_table_t>>(exp_res, op_add);
0094 
0095     typename CONFIG_T::inv_table_t inv_exp_sum =
0096         invert_table[softmax_idx_from_real_val<typename CONFIG_T::exp_table_t, CONFIG_T>(exp_sum)];
0097     for (unsigned i = 0; i < CONFIG_T::n_in; i++) {
0098       res[i] = exp_res[i] * inv_exp_sum;
0099     }
0100   }
0101 
0102 }  // namespace nnet
0103 #endif