Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:21:54

0001 // // Integer representation of floating point arithmetic suitable for FPGA designs
0002 //
0003 // Author: Yuri Gershtein
0004 // Date:   March 2018
0005 //
0006 // Functionality:
0007 //
0008 //  *note* all integers are assumed to be signed
0009 //
0010 //  all variables have units, stored in a map <string,int>, with string a unit (i.e. "phi") and int the power
0011 //                   "2" is always present in the map, and it's int pair is referred to as 'shift'
0012 //                   units are properly combined / propagated through calculations
0013 //                   adding/subtracting variables with different units throws an exception
0014 //                   adding/subtracting variables with different shifts is allowed and is handled correctly
0015 //
0016 // calculate() method re-calculates the variable double and int values based on its operands
0017 //                   returns false in case of overflows and/or mismatches between double and int calculations.
0018 //
0019 // the maximum and minimum values that the variable assumes are stored and updated each time calculate() is called
0020 // if IMATH_ROOT is defined, all values are also stored in a histogram
0021 //
0022 // VarDef     (string name, string units, double fmax, double K):
0023 //                   define variable with bit value fval = K*ival, and maximum absolute value fmax.
0024 //                   calculates nbins on its own
0025 //                   one can assign value to it using set_ methods
0026 //
0027 // VarParam   (string name, string units, double fval, int nbits):
0028 //                   define a parameter. K is calculated based on the fval and nbits
0029 //
0030 //         or (string name, std::string units, double fval, double K):
0031 //                   define a parameer with bit value fval = K*ival.
0032 //                   calculates nbins on its own
0033 //
0034 // VarAdd     (string name, VarBase *p1, VarBase *p2, double range = -1, int nmax = 18):
0035 // VarSubtract(string name, VarBase *p1, VarBase *p2, double range = -1, int nmax = 18):
0036 //                   add/subtract variables. Bit length increases by 1, but capped at nmax
0037 //                   if range>0 specified, bit length is decreased to drop unnecessary high bits
0038 //
0039 // VarMult    (string name, VarBase *p1, VarBase *p2, double range = -1, int nmax = 18):
0040 //                   multiplication. Bit length is a sum of the lengths of the operads, but capped at nmax
0041 //                   if range>0 specified, bit length is decreased to drop unnecessary high bits or post-shift is reduced
0042 //
0043 // VarTimesC  (string name, VarBase *p1, double cF, int ps = 17):
0044 //                   multiplication by a constant. Bit length stays the same
0045 //                   ps defines number of bits used to represent the constant
0046 //
0047 // VarDSPPostadd (string name, VarBase *p1, VarBase *p2, VarBase *p3, double range = -1, int nmax = 18):
0048 //                   explicit instantiation of the 3-clock DSP postaddition: p1*p2+p3
0049 //                   range and nmax have the same meaning as for the VarMult.
0050 //
0051 // VarShift  (string name, VarBase *p1, int shift):
0052 //                   shifts the variable right by shift (equivalent to multiplication by pow(2, -shift));
0053 //                   Units stay the same, nbits are adjusted.
0054 //
0055 // VarShiftround  (string name, VarBase *p1, int shift):
0056 //                   shifts the variable right by shift, but doing rounding, i.e.
0057 //                   (p>>(shift-1)+1)>>1;
0058 //                   Units stay the same, nbits are adjusted.
0059 //
0060 // VarNeg    (string name, VarBase *p1):
0061 //                   multiplies the variable by -1
0062 //
0063 // VarInv     (string name, VarBase *p1, double offset, int nbits, int n, unsigned int shift, mode m, int nbaddr=-1):
0064 //                   LUT-based inversion, f = 1./(offset + f1) and  i = 2^n / (offsetI + i1)
0065 //                   nbits is the width of the LUT (signed)
0066 //                   m is from enum mode {pos, neg, both} and refers to possible sign values of f
0067 //                            for pos and neg, the most significant bit of p1 (i.e. the sign bit) is ignored
0068 //                   shift is a shift applied in i1<->address conversions (used to reduce size of LUT)
0069 //                   nbaddr: if not specified, it is taken to be equal to p1->nbits()
0070 //
0071 //
0072 // VarNounits (string name, VarBase *p1, int ps = 17):
0073 //                   convert a number with units to a number - needed for trig function expansion (i.e. 1 - 0.5*phi^2)
0074 //                   ps is a number of bits to represent the unit conversion constant
0075 //
0076 // VarAdjustK (string name, VarBase *p1, double Knew, double epsilon = 1e-5, bool do_assert = false, int nbits = -1)
0077 //                   adjust variable shift so the K is as close to Knew as possible (needed for bit length adjustments)
0078 //                   if do_assert is true, throw an exeption if Knew/Kold is not a power of two
0079 //                   epsilon is a comparison precision, nbits forces the bit length (possibly discarding MSBs)
0080 //
0081 // VarAdjustKR (string name, VarBase *p1, double Knew, double epsilon = 1e-5, bool do_assert = false, int nbits = -1)
0082 //                   - same as adjustK(), but with rounding, and therefore latency=1
0083 //
0084 // bool calculate(int debug_level) runs through the entire formula tree recalculating both ineteger and floating point values
0085 //                     returns true if everything is OK, false if obvious problems with the calculation exist, i.e
0086 //                                  -  integer value does not fit into the alotted number of bins
0087 //                                  -  integer value is more then 10% or more then 2 away from fval_/K_
0088 //                     debug_level:  0 - no warnings
0089 //                                   1 - limited warning
0090 //                                   2 - as 1, but also include explicit warnings when LUT was used out of its range
0091 //                                   3 - maximum complaints level
0092 //
0093 // VarFlag (string name, VarBase *cut_var, VarBase *...)
0094 //
0095 //                    flag to apply cuts defined for any variable. When output as Verilog, the flag
0096 //                    is true if and only if the following conditions are all true:
0097 //                       1) the cut defined by each VarCut pointer in the argument list must be passed
0098 //                       by the associated variable
0099 //                       2) each VarBase pointer in the argument list that is not also a VarCut
0100 //                       pointer must pass all of its associated cuts
0101 //                       3) all children of the variables in the argument list must pass all of their
0102 //                       associated cuts
0103 //                    The VarFlag::passes() method replicates the behavior of the output Verilog,
0104 //                    returning true if and only if the above conditions are all true. The
0105 //                    VarBase::local_passes() method can be used to query if a given variable passes
0106 //                    its associated cuts, regardless of whether its children do.
0107 //
0108 #ifndef L1Trigger_TrackFindingTracklet_interface_imath_h
0109 #define L1Trigger_TrackFindingTracklet_interface_imath_h
0110 
0111 //use root if uncommented
0112 //#ifndef CMSSW_GIT_HASH
0113 //#define IMATH_ROOT
0114 //#endif
0115 
0116 #include <limits>
0117 #include <iostream>
0118 #include <fstream>
0119 #include <vector>
0120 #include <map>
0121 #include <cmath>
0122 #include <sstream>
0123 #include <string>
0124 #include <cassert>
0125 #include <set>
0126 
0127 #include "L1Trigger/TrackFindingTracklet/interface/Util.h"
0128 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0129 #include "FWCore/Utilities/interface/Exception.h"
0130 #include "L1Trigger/L1TCommon/interface/BitShift.h"
0131 
0132 #ifdef IMATH_ROOT
0133 #include <TH2F.h>
0134 #include <TFile.h>
0135 #include <TCanvas.h>
0136 #include <TTree.h>
0137 #endif
0138 
0139 //operation latencies for proper HDL pipelining
0140 #define MULT_LATENCY 1
0141 #define LUT_LATENCY 2
0142 #define DSP_LATENCY 3
0143 
0144 // Print out information on the pass/fail status of all variables. Warning:
0145 // this outputs a lot of information for busy events!
0146 
0147 namespace trklet {
0148 
0149   struct imathGlobals {
0150     bool printCutInfo_{false};
0151 #ifdef IMATH_ROOT
0152     TFile *h_file_ = new TFile("imath.root", "RECREATE");
0153     bool use_root;
0154 #endif
0155   };
0156 
0157   class VarCut;
0158   class VarFlag;
0159 
0160   class VarBase {
0161   public:
0162     VarBase(imathGlobals *globals, std::string name, VarBase *p1, VarBase *p2, VarBase *p3, int l) {
0163       globals_ = globals;
0164       p1_ = p1;
0165       p2_ = p2;
0166       p3_ = p3;
0167       name_ = name;
0168       latency_ = l;
0169       int step1 = (p1) ? p1->step() + p1->latency() : 0;
0170       int step2 = (p2) ? p2->step() + p2->latency() : 0;
0171       step_ = std::max(step1, step2);
0172 
0173       cuts_.clear();
0174       cut_var_ = nullptr;
0175 
0176       pipe_counter_ = 0;
0177       pipe_delays_.clear();
0178 
0179       minval_ = std::numeric_limits<double>::max();
0180       maxval_ = -std::numeric_limits<double>::max();
0181       readytoprint_ = true;
0182       readytoanalyze_ = true;
0183       usedasinput_ = false;
0184       Kmap_.clear();
0185       Kmap_["2"] = 0;  // initially, zero shift
0186 #ifdef IMATH_ROOT
0187       h_ = 0;
0188       h_nbins_ = 1024;
0189       h_precision_ = 0.02;
0190 #endif
0191     }
0192     virtual ~VarBase() {
0193 #ifdef IMATH_ROOT
0194       if (globals_->h_file_) {
0195         globals_->h_file_->ls();
0196         globals_->h_file_->Close();
0197         globals_->h_file_ = 0;
0198       }
0199 #endif
0200     }
0201 
0202     static struct Verilog {
0203     } verilog;
0204     static struct HLS {
0205     } hls;
0206 
0207     std::string kstring() const;
0208     std::string name() const { return name_; }
0209     std::string op() const { return op_; }
0210     VarBase *p1() const { return p1_; }
0211     VarBase *p2() const { return p2_; }
0212     VarBase *p3() const { return p3_; }
0213     double fval() const { return fval_; }
0214     long int ival() const { return ival_; }
0215 
0216     bool local_passes() const;
0217     void passes(std::map<const VarBase *, std::vector<bool> > &passes,
0218                 const std::map<const VarBase *, std::vector<bool> > *const previous_passes = nullptr) const;
0219     void print_cuts(std::map<const VarBase *, std::set<std::string> > &cut_strings,
0220                     const int step,
0221                     Verilog,
0222                     const std::map<const VarBase *, std::set<std::string> > *const previous_cut_strings = nullptr) const;
0223     void print_cuts(std::map<const VarBase *, std::set<std::string> > &cut_strings,
0224                     const int step,
0225                     HLS,
0226                     const std::map<const VarBase *, std::set<std::string> > *const previous_cut_strings = nullptr) const;
0227     void add_cut(VarCut *cut, const bool call_set_cut_var = true);
0228     VarBase *cut_var();
0229     // observed range of fval_ (only filled if debug_level > 0)
0230     double minval() const { return minval_; }
0231     double maxval() const { return maxval_; }
0232     void analyze();
0233 #ifdef IMATH_ROOT
0234     TH2F *h() { return h_; }
0235 #endif
0236     void reset() {
0237       minval_ = std::numeric_limits<double>::max();
0238       maxval_ = -std::numeric_limits<double>::max();
0239 #ifdef IMATH_ROOT
0240       h_->Clear();
0241 #endif
0242     }
0243 
0244     int nbits() const { return nbits_; }
0245     std::map<std::string, int> Kmap() const { return Kmap_; }
0246     double range() const { return (1 << (nbits_ - 1)) * K_; }  // everything is signed
0247     double K() const { return K_; };
0248     int shift() const { return Kmap_.at("2"); }
0249 
0250     void makeready();
0251     int step() const { return step_; }
0252     int latency() const { return latency_; }
0253     void add_latency(unsigned int l) { latency_ += l; }  //only call before using the variable in calculation!
0254     bool calculate(int debug_level = 0);
0255     virtual void local_calculate() {}
0256     void calcDebug(int debug_level, long int ival_prev, bool &all_ok);
0257     virtual void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) {
0258       fs << "// VarBase here. Soemthing is wrong!! " << l1 << ", " << l2 << ", " << l3 << "\n";
0259     }
0260     virtual void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) {
0261       fs << "// VarBase here. Soemthing is wrong!! " << l1 << ", " << l2 << ", " << l3 << "\n";
0262     }
0263     void print_step(int step, std::ofstream &fs, Verilog);
0264     void print_step(int step, std::ofstream &fs, HLS);
0265     void print_all(std::ofstream &fs, Verilog);
0266     void print_all(std::ofstream &fs, HLS);
0267     void print_truncation(std::string &t, const std::string &o1, const int ps, Verilog) const;
0268     void print_truncation(std::string &t, const std::string &o1, const int ps, HLS) const;
0269     void inputs(std::vector<VarBase *> *vd);  //collect all inputs
0270 
0271     int pipe_counter() { return pipe_counter_; }
0272     void pipe_increment() { pipe_counter_++; }
0273     void add_delay(int i) { pipe_delays_.push_back(i); }
0274     bool has_delay(int i);  //returns true if already have this variable delayed.
0275     static void verilog_print(const std::vector<VarBase *> &v, std::ofstream &fs) { design_print(v, fs, verilog); }
0276     static void hls_print(const std::vector<VarBase *> &v, std::ofstream &fs) { design_print(v, fs, hls); }
0277     static void design_print(const std::vector<VarBase *> &v, std::ofstream &fs, Verilog);
0278     static void design_print(const std::vector<VarBase *> &v, std::ofstream &fs, HLS);
0279     static std::string pipe_delay(VarBase *v, int nbits, int delay);
0280     std::string pipe_delays(const int step);
0281     static std::string pipe_delay_wire(VarBase *v, std::string name_delayed, int nbits, int delay);
0282 
0283 #ifdef IMATH_ROOT
0284     static TTree *addToTree(imathGlobals *globals, VarBase *v, char *s = 0);
0285     static TTree *addToTree(imathGlobals *globals, int *v, char *s);
0286     static TTree *addToTree(imathGlobals *globals, double *v, char *s);
0287     static void fillTree(imathGlobals *globals);
0288     static void writeTree(imathGlobals *globals);
0289 #endif
0290 
0291     void dump_msg();
0292     std::string dump();
0293     static std::string itos(int i);
0294 
0295   protected:
0296     imathGlobals *globals_;
0297     std::string name_;
0298     VarBase *p1_;
0299     VarBase *p2_;
0300     VarBase *p3_;
0301     std::string op_;  // operation
0302     int latency_;     // number of clock cycles for the operation (for HDL output)
0303     int step_;        // step number in the calculation (for HDL output)
0304 
0305     double fval_;    // exact calculation
0306     long int ival_;  // integer calculation
0307     double val_;     // integer calculation converted to double, ival_*K
0308 
0309     std::vector<VarBase *> cuts_;
0310     VarBase *cut_var_;
0311 
0312     int nbits_;
0313     double K_;
0314     std::map<std::string, int> Kmap_;
0315 
0316     int pipe_counter_;
0317     std::vector<int> pipe_delays_;
0318 
0319     bool readytoanalyze_;
0320     bool readytoprint_;
0321     bool usedasinput_;
0322 
0323     double minval_;
0324     double maxval_;
0325 #ifdef IMATH_ROOT
0326     void set_hist_pars(int n = 256, double p = 0.05) {
0327       h_nbins_ = n;
0328       h_precision_ = p;
0329     }
0330     int h_nbins_;
0331     double h_precision_;
0332     TH2F *h_;
0333 #endif
0334   };
0335 
0336   class VarAdjustK : public VarBase {
0337   public:
0338     VarAdjustK(imathGlobals *globals,
0339                std::string name,
0340                VarBase *p1,
0341                double Knew,
0342                double epsilon = 1e-5,
0343                bool do_assert = false,
0344                int nbits = -1)
0345         : VarBase(globals, name, p1, nullptr, nullptr, 0) {
0346       op_ = "adjustK";
0347       K_ = p1->K();
0348       Kmap_ = p1->Kmap();
0349 
0350       double r = Knew / K_;
0351 
0352       lr_ = (r > 1) ? log2(r) + epsilon : log2(r);
0353       K_ = K_ * pow(2, lr_);
0354       if (do_assert)
0355         assert(std::abs(Knew / K_ - 1) < epsilon);
0356 
0357       if (nbits > 0)
0358         nbits_ = nbits;
0359       else
0360         nbits_ = p1->nbits() - lr_;
0361 
0362       Kmap_["2"] = Kmap_["2"] + lr_;
0363     }
0364 
0365     ~VarAdjustK() override = default;
0366 
0367     void adjust(double Knew, double epsilon = 1e-5, bool do_assert = false, int nbits = -1);
0368 
0369     void local_calculate() override;
0370     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
0371     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
0372 
0373   protected:
0374     int lr_;
0375   };
0376 
0377   class VarAdjustKR : public VarBase {
0378   public:
0379     VarAdjustKR(imathGlobals *globals,
0380                 std::string name,
0381                 VarBase *p1,
0382                 double Knew,
0383                 double epsilon = 1e-5,
0384                 bool do_assert = false,
0385                 int nbits = -1)
0386         : VarBase(globals, name, p1, nullptr, nullptr, 1) {
0387       op_ = "adjustKR";
0388       K_ = p1->K();
0389       Kmap_ = p1->Kmap();
0390 
0391       double r = Knew / K_;
0392 
0393       lr_ = (r > 1) ? log2(r) + epsilon : log2(r);
0394       K_ = K_ * pow(2, lr_);
0395       if (do_assert)
0396         assert(std::abs(Knew / K_ - 1) < epsilon);
0397 
0398       if (nbits > 0)
0399         nbits_ = nbits;
0400       else
0401         nbits_ = p1->nbits() - lr_;
0402 
0403       Kmap_["2"] = Kmap_["2"] + lr_;
0404     }
0405 
0406     ~VarAdjustKR() override = default;
0407 
0408     void local_calculate() override;
0409     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
0410     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
0411 
0412   protected:
0413     int lr_;
0414   };
0415 
0416   class VarParam : public VarBase {
0417   public:
0418     VarParam(imathGlobals *globals, std::string name, double fval, int nbits)
0419         : VarBase(globals, name, nullptr, nullptr, nullptr, 0) {
0420       op_ = "const";
0421       nbits_ = nbits;
0422       int l = log2(std::abs(fval)) + 1.9999999 - nbits;
0423       Kmap_["2"] = l;
0424       K_ = pow(2, l);
0425       fval_ = fval;
0426       ival_ = fval / K_;
0427     }
0428     VarParam(imathGlobals *globals, std::string name, std::string units, double fval, double K)
0429         : VarBase(globals, name, nullptr, nullptr, nullptr, 0) {
0430       op_ = "const";
0431       K_ = K;
0432       nbits_ = log2(fval / K) + 1.999999;  //plus one to round up
0433       if (!units.empty())
0434         Kmap_[units] = 1;
0435       else {
0436         //defining a constant, K should be a power of two
0437         int l = log2(K);
0438         if (std::abs(pow(2, l) / K - 1) > 1e-5) {
0439           char slog[100];
0440           snprintf(slog, 100, "defining unitless constant, yet K is not a power of 2! %g, %g", K, pow(2, l));
0441           edm::LogVerbatim("Tracklet") << slog;
0442         }
0443         Kmap_["2"] = l;
0444       }
0445     }
0446 
0447     ~VarParam() override = default;
0448 
0449     void set_fval(double fval) {
0450       fval_ = fval;
0451       if (fval > 0)
0452         ival_ = fval / K_ + 0.5;
0453       else
0454         ival_ = fval / K_ - 0.5;
0455       val_ = ival_ * K_;
0456     }
0457     void set_ival(int ival) {
0458       ival_ = ival;
0459       fval_ = ival * K_;
0460       val_ = fval_;
0461     }
0462     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
0463     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
0464   };
0465 
0466   class VarDef : public VarBase {
0467   public:
0468     //construct from scratch
0469     VarDef(imathGlobals *globals, std::string name, std::string units, double fmax, double K)
0470         : VarBase(globals, name, nullptr, nullptr, nullptr, 1) {
0471       op_ = "def";
0472       K_ = K;
0473       nbits_ = log2(fmax / K) + 1.999999;  //plus one to round up
0474       if (!units.empty())
0475         Kmap_[units] = 1;
0476       else {
0477         //defining a constant, K should be a power of two
0478         int l = log2(K);
0479         if (std::abs(pow(2, l) / K - 1) > 1e-5) {
0480           char slog[100];
0481           snprintf(slog, 100, "defining unitless constant, yet K is not a power of 2! %g, %g", K, pow(2, l));
0482           edm::LogVerbatim("Tracklet") << slog;
0483         }
0484         Kmap_["2"] = l;
0485       }
0486     }
0487     //construct from abother variable (all provenance info is lost!)
0488     VarDef(imathGlobals *globals, std::string name, VarBase *p) : VarBase(globals, name, nullptr, nullptr, nullptr, 1) {
0489       op_ = "def";
0490       K_ = p->K();
0491       nbits_ = p->nbits();
0492       Kmap_ = p->Kmap();
0493     }
0494     void set_fval(double fval) {
0495       fval_ = fval;
0496       if (fval > 0)
0497         ival_ = fval / K_;
0498       else
0499         ival_ = fval / K_ - 1;
0500       val_ = ival_ * K_;
0501     }
0502     void set_ival(int ival) {
0503       ival_ = ival;
0504       fval_ = ival * K_;
0505       val_ = ival_ * K_;
0506     }
0507     ~VarDef() override = default;
0508     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
0509     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
0510   };
0511 
0512   class VarAdd : public VarBase {
0513   public:
0514     VarAdd(imathGlobals *globals, std::string name, VarBase *p1, VarBase *p2, double range = -1, int nmax = 18)
0515         : VarBase(globals, name, p1, p2, nullptr, 1) {
0516       op_ = "add";
0517 
0518       std::map<std::string, int> map1 = p1->Kmap();
0519       std::map<std::string, int> map2 = p2->Kmap();
0520       int s1 = map1["2"];
0521       int s2 = map2["2"];
0522 
0523       //first check if the constants are all lined up
0524       //go over the two maps subtracting the units
0525       for (const auto &it : map2) {
0526         if (map1.find(it.first) == map1.end())
0527           map1[it.first] = -it.second;
0528         else
0529           map1[it.first] = map1[it.first] - it.second;
0530       }
0531 
0532       char slog[100];
0533 
0534       //assert if different
0535       for (const auto &it : map1) {
0536         if (it.second != 0) {
0537           if (it.first != "2") {
0538             snprintf(
0539                 slog, 100, "VarAdd: bad units! %s^%i for variable %s", (it.first).c_str(), it.second, name_.c_str());
0540             edm::LogVerbatim("Tracklet") << slog;
0541             p1->dump_msg();
0542             p2->dump_msg();
0543             throw cms::Exception("BadConfig") << "imath units are different!";
0544           }
0545         }
0546       }
0547 
0548       double ki1 = p1->K() / pow(2, s1);
0549       double ki2 = p2->K() / pow(2, s2);
0550       //those should be the same
0551       if (std::abs(ki1 / ki2 - 1.) > 1e-6) {
0552         snprintf(slog, 100, "VarAdd: bad constants! %f %f for variable %s", ki1, ki2, name_.c_str());
0553         edm::LogVerbatim("Tracklet") << slog;
0554         p1->dump_msg();
0555         p2->dump_msg();
0556         throw cms::Exception("BadConfig") << "imath constants are different!";
0557       }
0558       //everything checks out!
0559 
0560       Kmap_ = p1->Kmap();
0561 
0562       int s0 = s1 < s2 ? s1 : s2;
0563       shift1 = s1 - s0;
0564       shift2 = s2 - s0;
0565 
0566       int n1 = p1->nbits() + shift1;
0567       int n2 = p2->nbits() + shift2;
0568       int n0 = 1 + (n1 > n2 ? n1 : n2);
0569 
0570       //before shifting, check the range
0571       if (range > 0) {
0572         n0 = log2(range / ki1 / pow(2, s0)) + 1e-9;
0573         n0 = n0 + 2;
0574       }
0575 
0576       if (n0 <= nmax) {  //if it fits, we're done
0577         ps_ = 0;
0578         Kmap_["2"] = s0;
0579         nbits_ = n0;
0580       } else {
0581         ps_ = n0 - nmax;
0582         Kmap_["2"] = s0 + ps_;
0583         nbits_ = nmax;
0584       }
0585 
0586       K_ = ki1 * pow(2, Kmap_["2"]);
0587     }
0588     ~VarAdd() override = default;
0589     void local_calculate() override;
0590     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
0591     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
0592 
0593   protected:
0594     int ps_;
0595     int shift1;
0596     int shift2;
0597   };
0598 
0599   class VarSubtract : public VarBase {
0600   public:
0601     VarSubtract(imathGlobals *globals, std::string name, VarBase *p1, VarBase *p2, double range = -1, int nmax = 18)
0602         : VarBase(globals, name, p1, p2, nullptr, 1) {
0603       op_ = "subtract";
0604 
0605       std::map<std::string, int> map1 = p1->Kmap();
0606       std::map<std::string, int> map2 = p2->Kmap();
0607       int s1 = map1["2"];
0608       int s2 = map2["2"];
0609 
0610       //first check if the constants are all lined up go over the two maps subtracting the units
0611       for (const auto &it : map2) {
0612         if (map1.find(it.first) == map1.end())
0613           map1[it.first] = -it.second;
0614         else
0615           map1[it.first] = map1[it.first] - it.second;
0616       }
0617 
0618       char slog[100];
0619 
0620       //assert if different
0621       for (const auto &it : map1) {
0622         if (it.second != 0) {
0623           if (it.first != "2") {
0624             snprintf(
0625                 slog, 100, "VarAdd: bad units! %s^%i for variable %s", (it.first).c_str(), it.second, name_.c_str());
0626             edm::LogVerbatim("Tracklet") << slog;
0627             p1->dump_msg();
0628             p2->dump_msg();
0629             throw cms::Exception("BadConfig") << "imath units are different!";
0630           }
0631         }
0632       }
0633 
0634       double ki1 = p1->K() / pow(2, s1);
0635       double ki2 = p2->K() / pow(2, s2);
0636       //those should be the same
0637       if (std::abs(ki1 / ki2 - 1.) > 1e-6) {
0638         snprintf(slog, 100, "VarAdd: bad constants! %f %f for variable %s", ki1, ki2, name_.c_str());
0639         edm::LogVerbatim("Tracklet") << slog;
0640         p1->dump_msg();
0641         p2->dump_msg();
0642         throw cms::Exception("BadConfig") << "imath constants are different!";
0643       }
0644       //everything checks out!
0645 
0646       Kmap_ = p1->Kmap();
0647 
0648       int s0 = s1 < s2 ? s1 : s2;
0649       shift1 = s1 - s0;
0650       shift2 = s2 - s0;
0651 
0652       int n1 = p1->nbits() + shift1;
0653       int n2 = p2->nbits() + shift2;
0654       int n0 = 1 + (n1 > n2 ? n1 : n2);
0655 
0656       //before shifting, check the range
0657       if (range > 0) {
0658         n0 = log2(range / ki1 / pow(2, s0)) + 1e-9;
0659         n0 = n0 + 2;
0660       }
0661 
0662       if (n0 <= nmax) {  //if it fits, we're done
0663         ps_ = 0;
0664         Kmap_["2"] = s0;
0665         nbits_ = n0;
0666       } else {
0667         ps_ = n0 - nmax;
0668         Kmap_["2"] = s0 + ps_;
0669         nbits_ = nmax;
0670       }
0671 
0672       K_ = ki1 * pow(2, Kmap_["2"]);
0673     }
0674 
0675     ~VarSubtract() override = default;
0676 
0677     void local_calculate() override;
0678     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
0679     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
0680 
0681   protected:
0682     int ps_;
0683     int shift1;
0684     int shift2;
0685   };
0686 
0687   class VarNounits : public VarBase {
0688   public:
0689     VarNounits(imathGlobals *globals, std::string name, VarBase *p1, int ps = 17)
0690         : VarBase(globals, name, p1, nullptr, nullptr, MULT_LATENCY) {
0691       op_ = "nounits";
0692       ps_ = ps;
0693       nbits_ = p1->nbits();
0694 
0695       int s1 = p1->shift();
0696       double ki = p1->K() / pow(2, s1);
0697       int m = log2(ki);
0698 
0699       K_ = pow(2, s1 + m);
0700       Kmap_["2"] = s1 + m;
0701       double c = ki * pow(2, -m);
0702       cI_ = c * pow(2, ps_);
0703     }
0704     ~VarNounits() override = default;
0705 
0706     void local_calculate() override;
0707     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
0708     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
0709 
0710   protected:
0711     int ps_;
0712     int cI_;
0713   };
0714 
0715   class VarShiftround : public VarBase {
0716   public:
0717     VarShiftround(imathGlobals *globals, std::string name, VarBase *p1, int shift)
0718         : VarBase(globals, name, p1, nullptr, nullptr, 1) {  // latency is one because there is an addition
0719       op_ = "shiftround";
0720       shift_ = shift;
0721 
0722       nbits_ = p1->nbits() - shift;
0723       Kmap_ = p1->Kmap();
0724       K_ = p1->K();
0725     }
0726     ~VarShiftround() override = default;
0727 
0728     void local_calculate() override;
0729     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
0730     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
0731 
0732   protected:
0733     int shift_;
0734   };
0735 
0736   class VarShift : public VarBase {
0737   public:
0738     VarShift(imathGlobals *globals, std::string name, VarBase *p1, int shift)
0739         : VarBase(globals, name, p1, nullptr, nullptr, 0) {
0740       op_ = "shift";
0741       shift_ = shift;
0742 
0743       nbits_ = p1->nbits() - shift;
0744       Kmap_ = p1->Kmap();
0745       K_ = p1->K();
0746     }
0747     ~VarShift() override = default;
0748     void local_calculate() override;
0749     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
0750     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
0751 
0752   protected:
0753     int shift_;
0754   };
0755 
0756   class VarNeg : public VarBase {
0757   public:
0758     VarNeg(imathGlobals *globals, std::string name, VarBase *p1) : VarBase(globals, name, p1, nullptr, nullptr, 1) {
0759       op_ = "neg";
0760       nbits_ = p1->nbits();
0761       Kmap_ = p1->Kmap();
0762       K_ = p1->K();
0763     }
0764     ~VarNeg() override = default;
0765     void local_calculate() override;
0766     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
0767     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
0768   };
0769 
0770   class VarTimesC : public VarBase {
0771   public:
0772     VarTimesC(imathGlobals *globals, std::string name, VarBase *p1, double cF, int ps = 17)
0773         : VarBase(globals, name, p1, nullptr, nullptr, MULT_LATENCY) {
0774       op_ = "timesC";
0775       cF_ = cF;
0776       ps_ = ps;
0777 
0778       nbits_ = p1->nbits();
0779       Kmap_ = p1->Kmap();
0780       K_ = p1->K();
0781 
0782       int s1 = Kmap_["2"];
0783       double l = log2(std::abs(cF));
0784       if (l > 0)
0785         l += 0.999999;
0786       int m = l;
0787 
0788       cI_ = cF * pow(2, ps - m);
0789       K_ = K_ * pow(2, m);
0790       Kmap_["2"] = s1 + m;
0791     }
0792     ~VarTimesC() override = default;
0793     void local_calculate() override;
0794     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
0795     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
0796 
0797   protected:
0798     int ps_;
0799     int cI_;
0800     double cF_;
0801   };
0802 
0803   class VarMult : public VarBase {
0804   public:
0805     VarMult(imathGlobals *globals, std::string name, VarBase *p1, VarBase *p2, double range = -1, int nmax = 18)
0806         : VarBase(globals, name, p1, p2, nullptr, MULT_LATENCY) {
0807       op_ = "mult";
0808 
0809       const std::map<std::string, int> map1 = p1->Kmap();
0810       const std::map<std::string, int> map2 = p2->Kmap();
0811       for (const auto &it : map1) {
0812         if (Kmap_.find(it.first) == Kmap_.end())
0813           Kmap_[it.first] = it.second;
0814         else
0815           Kmap_[it.first] = Kmap_[it.first] + it.second;
0816       }
0817       for (const auto &it : map2) {
0818         if (Kmap_.find(it.first) == Kmap_.end())
0819           Kmap_[it.first] = it.second;
0820         else
0821           Kmap_[it.first] = Kmap_[it.first] + it.second;
0822       }
0823       K_ = p1->K() * p2->K();
0824       int s0 = Kmap_["2"];
0825 
0826       int n0 = p1->nbits() + p2->nbits();
0827       if (range > 0) {
0828         n0 = log2(range / K_) + 1e-9;
0829         n0 = n0 + 2;
0830       }
0831       if (n0 < nmax) {
0832         ps_ = 0;
0833         nbits_ = n0;
0834       } else {
0835         ps_ = n0 - nmax;
0836         nbits_ = nmax;
0837         Kmap_["2"] = s0 + ps_;
0838         K_ = K_ * pow(2, ps_);
0839       }
0840     }
0841     ~VarMult() override = default;
0842     void local_calculate() override;
0843     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
0844     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
0845 
0846   protected:
0847     int ps_;
0848   };
0849 
0850   class VarDSPPostadd : public VarBase {
0851   public:
0852     VarDSPPostadd(
0853         imathGlobals *globals, std::string name, VarBase *p1, VarBase *p2, VarBase *p3, double range = -1, int nmax = 18)
0854         : VarBase(globals, name, p1, p2, p3, DSP_LATENCY) {
0855       op_ = "DSP_postadd";
0856 
0857       //first, get constants for the p1*p2
0858       std::map<std::string, int> map1 = p1->Kmap();
0859       std::map<std::string, int> map2 = p2->Kmap();
0860       for (const auto &it : map2) {
0861         if (map1.find(it.first) == map1.end())
0862           map1[it.first] = it.second;
0863         else
0864           map1[it.first] = map1[it.first] + it.second;
0865       }
0866       double k0 = p1->K() * p2->K();
0867       int s0 = map1["2"];
0868 
0869       //now addition
0870       std::map<std::string, int> map3 = p3->Kmap();
0871       int s3 = map3["2"];
0872 
0873       //first check if the constants are all lined up
0874       //go over the two maps subtracting the units
0875       for (const auto &it : map3) {
0876         if (map1.find(it.first) == map1.end())
0877           map1[it.first] = -it.second;
0878         else
0879           map1[it.first] = map1[it.first] - it.second;
0880       }
0881 
0882       char slog[100];
0883 
0884       //assert if different
0885       for (const auto &it : map1) {
0886         if (it.second != 0) {
0887           if (it.first != "2") {
0888             snprintf(slog,
0889                      100,
0890                      "VarDSPPostadd: bad units! %s^%i for variable %s",
0891                      (it.first).c_str(),
0892                      it.second,
0893                      name_.c_str());
0894             edm::LogVerbatim("Tracklet") << slog;
0895             p1->dump_msg();
0896             p2->dump_msg();
0897             p3->dump_msg();
0898             throw cms::Exception("BadConfig") << "imath units are different!";
0899           }
0900         }
0901       }
0902 
0903       double ki1 = k0 / pow(2, s0);
0904       double ki2 = p3->K() / pow(2, s3);
0905       //those should be the same
0906       if (std::abs(ki1 / ki2 - 1.) > 1e-6) {
0907         snprintf(slog, 100, "VarDSPPostadd: bad constants! %f %f for variable %s", ki1, ki2, name_.c_str());
0908         edm::LogVerbatim("Tracklet") << slog;
0909         p1->dump_msg();
0910         p2->dump_msg();
0911         p3->dump_msg();
0912         throw cms::Exception("BadConfig") << "imath constants are different!";
0913       }
0914       //everything checks out!
0915 
0916       shift3_ = s3 - s0;
0917       if (shift3_ < 0) {
0918         throw cms::Exception("BadConfig") << "imath VarDSPPostadd: loosing precision on C in A*B+C: " << shift3_;
0919       }
0920 
0921       Kmap_ = p3->Kmap();
0922       Kmap_["2"] = Kmap_["2"] - shift3_;
0923 
0924       int n12 = p1->nbits() + p2->nbits();
0925       int n3 = p3->nbits() + shift3_;
0926       int n0 = 1 + (n12 > n3 ? n12 : n3);
0927 
0928       //before shifting, check the range
0929       if (range > 0) {
0930         n0 = log2(range / ki2 / pow(2, s3)) + 1e-9;
0931         n0 = n0 + 2;
0932       }
0933 
0934       if (n0 <= nmax) {  //if it fits, we're done
0935         ps_ = 0;
0936         nbits_ = n0;
0937       } else {
0938         ps_ = n0 - nmax;
0939         Kmap_["2"] = Kmap_["2"] + ps_;
0940         nbits_ = nmax;
0941       }
0942 
0943       K_ = ki2 * pow(2, Kmap_["2"]);
0944     }
0945     ~VarDSPPostadd() override = default;
0946 
0947     void local_calculate() override;
0948     using VarBase::print;
0949     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
0950 
0951   protected:
0952     int ps_;
0953     int shift3_;
0954   };
0955 
0956   class VarInv : public VarBase {
0957   public:
0958     enum mode { pos, neg, both };
0959 
0960     VarInv(imathGlobals *globals,
0961            std::string name,
0962            VarBase *p1,
0963            double offset,
0964            int nbits,
0965            int n,
0966            unsigned int shift,
0967            mode m,
0968            int nbaddr = -1)
0969         : VarBase(globals, name, p1, nullptr, nullptr, LUT_LATENCY) {
0970       op_ = "inv";
0971       offset_ = offset;
0972       nbits_ = nbits;
0973       n_ = n;
0974       shift_ = shift;
0975       m_ = m;
0976       if (nbaddr < 0)
0977         nbaddr = p1->nbits();
0978       nbaddr_ = nbaddr - shift;
0979       if (m_ != mode::both)
0980         nbaddr_--;
0981       Nelements_ = 1 << nbaddr_;
0982       mask_ = Nelements_ - 1;
0983       ashift_ = sizeof(int) * 8 - nbaddr_;
0984 
0985       const std::map<std::string, int> map1 = p1->Kmap();
0986       for (const auto &it : map1)
0987         Kmap_[it.first] = -it.second;
0988       Kmap_["2"] = Kmap_["2"] - n;
0989       K_ = pow(2, -n) / p1->K();
0990 
0991       LUT = new int[Nelements_];
0992       double offsetI = lround(offset_ / p1_->K());
0993       for (int i = 0; i < Nelements_; ++i) {
0994         int i1 = addr_to_ival(i);
0995         LUT[i] = gen_inv(offsetI + i1);
0996       }
0997     }
0998     ~VarInv() override { delete[] LUT; }
0999 
1000     void set_mode(mode m) { m_ = m; }
1001     void initLUT(double offset);
1002     double offset() { return offset_; }
1003     double Ioffset() { return offset_ / p1_->K(); }
1004 
1005     void local_calculate() override;
1006     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
1007     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
1008     void writeLUT(std::ofstream &fs) const { writeLUT(fs, verilog); }
1009     void writeLUT(std::ofstream &fs, Verilog) const;
1010     void writeLUT(std::ofstream &fs, HLS) const;
1011 
1012     int ival_to_addr(int ival) { return ((ival >> shift_) & mask_); }
1013     int addr_to_ival(int addr) {
1014       switch (m_) {
1015         case mode::pos:
1016           return l1t::bitShift(addr, shift_);
1017         case mode::neg:
1018           return l1t::bitShift((addr - Nelements_), shift_);
1019         case mode::both:
1020           return l1t::bitShift(addr, ashift_) >> (ashift_ - shift_);
1021       }
1022       assert(0);
1023     }
1024     int gen_inv(int i) {
1025       unsigned int ms = sizeof(int) * 8 - nbits_;
1026       int lut = 0;
1027       if (i > 0) {
1028         int i1 = i + (1 << shift_) - 1;
1029         int lut1 = (lround((1 << n_) / i) << ms) >> ms;
1030         int lut2 = (lround((1 << n_) / (i1)) << ms) >> ms;
1031         lut = 0.5 * (lut1 + lut2);
1032       } else if (i < -1) {
1033         int i1 = i + (1 << shift_) - 1;
1034         int i2 = i;
1035         int lut1 = (lround((1 << n_) / i1) << ms) >> ms;
1036         int lut2 = (lround((1 << n_) / i2) << ms) >> ms;
1037         lut = 0.5 * (lut1 + lut2);
1038       }
1039       return lut;
1040     }
1041 
1042   protected:
1043     double offset_;
1044     int n_;
1045     mode m_;
1046     unsigned int shift_;
1047     unsigned int mask_;
1048     unsigned int ashift_;
1049     int Nelements_;
1050     int nbaddr_;
1051 
1052     int *LUT;
1053   };
1054 
1055   class VarCut : public VarBase {
1056   public:
1057     VarCut(imathGlobals *globals, double lower_cut, double upper_cut)
1058         : VarBase(globals, "", nullptr, nullptr, nullptr, 0),
1059           lower_cut_(lower_cut),
1060           upper_cut_(upper_cut),
1061           parent_flag_(nullptr) {
1062       op_ = "cut";
1063     }
1064 
1065     VarCut(imathGlobals *globals, VarBase *cut_var, double lower_cut, double upper_cut)
1066         : VarCut(globals, lower_cut, upper_cut) {
1067       set_cut_var(cut_var);
1068     }
1069     ~VarCut() override = default;
1070 
1071     double lower_cut() const { return lower_cut_; }
1072     double upper_cut() const { return upper_cut_; }
1073 
1074     void local_passes(std::map<const VarBase *, std::vector<bool> > &passes,
1075                       const std::map<const VarBase *, std::vector<bool> > *const previous_passes = nullptr) const;
1076     using VarBase::print;
1077     void print(std::map<const VarBase *, std::set<std::string> > &cut_strings,
1078                const int step,
1079                Verilog,
1080                const std::map<const VarBase *, std::set<std::string> > *const previous_cut_strings = nullptr) const;
1081     void print(std::map<const VarBase *, std::set<std::string> > &cut_strings,
1082                const int step,
1083                HLS,
1084                const std::map<const VarBase *, std::set<std::string> > *const previous_cut_strings = nullptr) const;
1085 
1086     void set_parent_flag(VarFlag *parent_flag, const bool call_add_cut);
1087     VarFlag *parent_flag() { return parent_flag_; }
1088     void set_cut_var(VarBase *cut_var, const bool call_add_cut = true);
1089 
1090   protected:
1091     double lower_cut_;
1092     double upper_cut_;
1093     VarFlag *parent_flag_;
1094   };
1095 
1096   class VarFlag : public VarBase {
1097   public:
1098     template <class... Args>
1099     VarFlag(imathGlobals *globals, std::string name, VarBase *cut, Args... args)
1100         : VarBase(globals, name, nullptr, nullptr, nullptr, 0) {
1101       op_ = "flag";
1102       nbits_ = 1;
1103       add_cuts(cut, args...);
1104     }
1105 
1106     template <class... Args>
1107     void add_cuts(VarBase *cut, Args... args) {
1108       add_cut(cut);
1109       add_cuts(args...);
1110     }
1111 
1112     void add_cuts(VarBase *cut) { add_cut(cut); }
1113 
1114     void add_cut(VarBase *cut, const bool call_set_parent_flag = true);
1115 
1116     void calculate_step();
1117     bool passes();
1118     void print(std::ofstream &fs, Verilog, int l1 = 0, int l2 = 0, int l3 = 0) override;
1119     void print(std::ofstream &fs, HLS, int l1 = 0, int l2 = 0, int l3 = 0) override;
1120   };
1121 };  // namespace trklet
1122 #endif