Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-03-17 11:13:10

0001 //
0002 //    rfnoc-hls-neuralnet: Vivado HLS code for neural-net building blocks
0003 //
0004 //    Copyright (C) 2017 EJ Kreinar
0005 //
0006 //    This program is free software: you can redistribute it and/or modify
0007 //    it under the terms of the GNU General Public License as published by
0008 //    the Free Software Foundation, either version 3 of the License, or
0009 //    (at your option) any later version.
0010 //
0011 //    This program is distributed in the hope that it will be useful,
0012 //    but WITHOUT ANY WARRANTY; without even the implied warranty of
0013 //    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
0014 //    GNU General Public License for more details.
0015 //
0016 //    You should have received a copy of the GNU General Public License
0017 //    along with this program.  If not, see <http://www.gnu.org/licenses/>.
0018 //
0019 
0020 #ifndef NNET_LAYER_H_
0021 #define NNET_LAYER_H_
0022 
0023 #include "nnet_common.h"
0024 #include <cmath>
0025 
0026 namespace nnet {
0027 
0028   struct layer_config {
0029     // Internal data type definitions
0030     typedef float bias_t;
0031     typedef float weight_t;
0032     typedef float accum_t;
0033 
0034     // Layer Sizes
0035     static const unsigned n_in = 10;
0036     static const unsigned n_out = 10;
0037 
0038     // Resource reuse info
0039     static const unsigned io_type = io_parallel;
0040     static const unsigned reuse_factor = 1;
0041     static const bool store_weights_in_bram = false;
0042     static const unsigned n_zeros = 0;
0043     static const bool use_lowlatency = true;
0044     // partitioning arrays cyclically to go with roll factors?
0045   };
0046 
0047 #define DIV_ROUNDUP(n, d) ((n + d - 1) / d)
0048 
0049   template <class data_T, class res_T, typename CONFIG_T>
0050   void compute_layer(data_T data[CONFIG_T::n_in],
0051                      res_T res[CONFIG_T::n_out],
0052                      typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],
0053                      typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) {
0054     unsigned cycle_factor = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, CONFIG_T::reuse_factor);
0055     typename CONFIG_T::weight_t mult[CONFIG_T::n_in * CONFIG_T::n_out];
0056     /*
0057     if(CONFIG_T::use_lowlatency) { 
0058       int multiplier_limit  = ceil(float(CONFIG_T::n_in*CONFIG_T::n_out) / float(CONFIG_T::reuse_factor)) - floor(float(CONFIG_T::n_zeros) / float(CONFIG_T::reuse_factor));
0059     } 
0060     */
0061     typename CONFIG_T::accum_t acc[CONFIG_T::n_out];
0062     for (unsigned iacc = 0; iacc < CONFIG_T::n_out; iacc++) {
0063       acc[iacc] = (typename CONFIG_T::accum_t)biases[iacc];
0064     }
0065     unsigned rufactor = CONFIG_T::reuse_factor;
0066     if (CONFIG_T::use_lowlatency) {
0067       rufactor = CONFIG_T::n_in;
0068       cycle_factor = CONFIG_T::n_out;
0069     }
0070     data_T cache;
0071     for (unsigned ii = 0; ii < rufactor; ii++) {
0072       if (CONFIG_T::use_lowlatency) {
0073         cache = data[ii];
0074       }
0075       for (unsigned jj = 0; jj < cycle_factor; jj++) {
0076         unsigned windex = ii * cycle_factor + jj;
0077         unsigned index = windex / CONFIG_T::n_out;
0078         if (windex > CONFIG_T::n_in * CONFIG_T::n_out - 1)
0079           continue;
0080         if (CONFIG_T::use_lowlatency) {
0081           mult[windex] = cache * (weights[windex]);
0082         } else {
0083           int aindex = windex / CONFIG_T::n_in;
0084           acc[aindex] += data[index] * weights[windex];
0085         }
0086       }
0087     }
0088     if (CONFIG_T::use_lowlatency) {
0089       // Accumulate multiplication result
0090       for (unsigned ii = 0; ii < CONFIG_T::n_in; ii++) {
0091         for (unsigned jj = 0; jj < CONFIG_T::n_out; jj++) {
0092           int index = ii * CONFIG_T::n_out + jj;
0093           acc[jj] += mult[index];
0094         }
0095       }
0096     }
0097     for (unsigned ires = 0; ires < CONFIG_T::n_out; ires++) {
0098       res[ires] = (res_T)(acc[ires]);
0099     }
0100   }
0101 
0102 }  // namespace nnet
0103 
0104 #endif