Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-04-01 23:54:08

0001 #!/usr/bin/env python
0002 # encoding: utf-8
0003 
0004 # File        : makeTrackConversionLUTs.py
0005 # Author      : Zhenbin Wu
0006 # Contact     : zhenbin.wu@gmail.com
0007 # Date        : 2021 Apr 13
0008 #
0009 # Description : 
0010 
0011 from __future__ import division
0012 import math
0013 import numpy as np
0014 from collections import defaultdict, Counter
0015 from pprint import pprint
0016 
0017 BITSABSCURV=14
0018 BITSPT=13
0019 maxCurv = 0.00855
0020 ptLSB=0.025
0021 ptLUT=[]
0022 pts = []
0023 ptshifts = []
0024 
0025 
0026 BITSTTTANL=16-1
0027 BITSETA=13-1
0028 maxTanL=8.0
0029 etaLSB = math.pi/ (1<<BITSETA)
0030 etaLUT=[]
0031 etas = []
0032 etashifts = []
0033 
0034 def GetPtLUT():
0035     for i in range(1,(1<<BITSABSCURV)):
0036         k = (maxCurv*i)/(1<<BITSABSCURV)
0037         pOB=0.3*3.8*0.01/(k)
0038         pts.append(pOB)
0039         pINT = int(round(pOB/ptLSB))
0040         if pINT<(1<<BITSPT):
0041             ptLUT.append(str(pINT))
0042         else:
0043             ptLUT.append(str((1<<BITSPT)-1))
0044 
0045 
0046 def GetEtaLUT():
0047     for i in range(0,(1<<(BITSTTTANL))):
0048         tanL = (maxTanL*i)/(1<<BITSTTTANL)
0049         lam =math.atan(tanL)
0050         theta =math.pi/2.0-lam
0051         eta = -math.log(math.tan(theta/2.0))
0052         etas.append(eta)
0053         etaINT = int(round(eta*(1<<BITSETA)/math.pi))
0054         if abs(eta<math.pi):
0055             etaLUT.append(str(etaINT))
0056 
0057 def consecutive(data, stepsize=1):
0058     return np.split(data, np.where(np.diff(data) != stepsize)[0]+1)
0059 
0060 def GetLUTwrtLSB(fltLUT, lsb, isPT=False, minshift=0,maxshift=5,
0061                  lowerbound=None, upperbound=None, nbits=None):
0062     length = len(fltLUT)
0063     steps  = np.diff(fltLUT)
0064     if isPT:
0065         steps  = -1* steps
0066     if nbits is None:
0067         nbits = range(minshift, maxshift)
0068     cross = length - np.searchsorted(steps[::-1], [lsb/(2**i) for i in nbits])
0069 
0070     x = []
0071     val = []
0072     shifts = []
0073     bfrom = 0
0074     for i, nb in enumerate(nbits):
0075         shifty_= 0
0076         bitcross = cross[i]
0077         if bitcross == 1:
0078             continue
0079         if nb == nbits[-1]:
0080             bitcross = length
0081             # bitcross = length if isPT else length-1
0082         orgx = np.arange(bfrom, bitcross)
0083         val_ = np.take(fltLUT, orgx)
0084         if upperbound is not None and np.any(val_ >= upperbound):
0085             sel = orgx[val_>=upperbound]
0086             uppershift = -1 if isPT else -2
0087             ## sel[-1]+1 , since we will use < in the final function
0088             sht = [sel[0], sel[-1]+1, uppershift, 0, int(upperbound/lsb)]
0089             shifts.append(sht)
0090             orgx = orgx[val_<=upperbound]
0091             val_ = val_[val_<=upperbound]
0092             if isPT:
0093                 bfrom=orgx[0]
0094             else:
0095                 bitcross=orgx[-1]
0096         if lowerbound is not None and np.any(val_ <= lowerbound):
0097             sel = orgx[val_<=lowerbound]
0098             lowershift = -2 if isPT else -1
0099             sht = [sel[0], sel[-1]+1, lowershift, 0, lowerbound/lsb]
0100             shifts.append(sht)
0101             orgx = orgx[val_>=lowerbound]
0102             val_ = val_[val_>= lowerbound]
0103             bitcross=orgx[-1]
0104 
0105         if nb > 1:
0106             ### Important: We can only shift by nb-1 bit to keep the precision
0107             shiftx_ = ((orgx >> (nb-1)) )
0108             if len(shiftx_) == 0:
0109                 continue
0110             if len(x) > 0:
0111                 shifty_ = int(x[-1] + 1 - shiftx_[0]) ## +1 to make sure it won't overlap
0112                 shiftx_ = shiftx_ + shifty_
0113         else:
0114             shiftx_ = orgx
0115         x_, pickx_ = np.unique(shiftx_, return_index=True)
0116         val_ = np.take(val_, pickx_) 
0117         x = np.append(x, x_)
0118         val = np.append(val, val_)
0119         if nb == 0:
0120             sht = [bfrom, bitcross+1, 0, shifty_, 0 ]
0121         else:
0122             sht = [bfrom, bitcross+1, nb-1, shifty_, 0 ]
0123         shifts.append(sht)
0124         # print("Shifting {nbit} bits with intercept {itsect}, input start from {bfrom} ({ffrom}) to {bto} ({fto}), LUT size {nLUT} ".format(
0125             # nbit=nb, itsect=shifty_, bfrom=bfrom, bto=bitcross, ffrom=fltLUT[bfrom], fto=fltLUT[bitcross],  nLUT =len(val_)))
0126         bfrom = bitcross+1
0127 
0128     return shifts
0129 
0130 def Modification(inval, intINT, config):
0131     for cfg in config:
0132         if inval >= cfg[0] and inval < cfg[1]:
0133             if cfg[2] < 0 and cfg[4]!=0:
0134                 if cfg[2] == -1:
0135                     return cfg[2], -1, cfg[4]
0136                 if cfg[2] == -2:
0137                     return cfg[2], 9999999, cfg[4]
0138             elif cfg[2] < 0 and cfg[4]==0:
0139                 return cfg[2], inval, intINT
0140             else:
0141                 return cfg[2], (inval >> cfg[2] ) + cfg[3], intINT
0142 
0143 def GetLUTModified(orgLUT, shiftMap, isPT=False):
0144     tempmap = defaultdict(list)
0145     x= []
0146     y= []
0147     for i, pINT in enumerate(orgLUT):
0148         ii = i
0149         if isPT:
0150             ii+=1
0151         nshift, newidx, newINT = Modification(ii, pINT, shiftMap)
0152         tempmap[newidx].append(newINT)
0153 
0154     con = consecutive(list(tempmap.keys()))
0155     for k, v in tempmap.items():
0156         if k == -1 or k == 9999999:
0157             continue
0158         setv = set(v)
0159         x.append(k)
0160         if len(setv) == 1:
0161             y.append(v[0])
0162         elif len(setv) == 2 or len(setv) ==3:
0163             contv = Counter(v)
0164             ## The counter was sorted, descending in python3 and asending in python2
0165             ## This will result in slightly different LUT when running 
0166             isallequal = (len(set(contv.values())) == 1)
0167             if isallequal:
0168                 ## Using min for now. To be decided
0169                 y.append(min(contv.keys()))
0170             else:
0171                 y.append(contv.most_common(1)[0][0])
0172         else:
0173             print("----- allow up to 3 values per bins")
0174     return x, y
0175 
0176 def ProducedFinalLUT(LUT, k, isPT=False, bounderidx=False):
0177     k = np.asarray(k).astype(int)
0178     if isPT:
0179         k[:, [0, 1]] +=1
0180         k[k[:, 2] > 0, 3] +=1
0181     x, y = GetLUTModified(LUT, k,isPT)
0182     if x[0] != 0:
0183         for i in k:
0184             if i[2] < 0:
0185                 continue
0186             else:
0187                 i[3] -= x[0]
0188     k = k[k[:, 0].argsort()]
0189     x, y = GetLUTModified(LUT, k, isPT)
0190     if bounderidx:
0191         ### Has 
0192         if np.any(k[:,2] == -1):
0193             y.insert(0, str(k[k[:,2] == -1, 4][0]))
0194             k[k[:,2] == -1, 4] = 0
0195             k[k[:, 2] >= 0, 3] +=1
0196         if np.any(k[:,2] == -2):
0197             y.append(str(k[k[:,2] == -2, 4][0]))
0198             k[k[:,2] == -2, 4] = len(y)-1
0199     return k, x, y
0200 
0201 ### PT
0202 def LookUp(inpt, shiftmap, LUT, bounderidx):
0203     for i in shiftmap:
0204         if inpt >=i[0] and inpt < i[1]:
0205             if i[2] < 0:
0206                 if bounderidx:
0207                     return i[4], LUT[i[4]]
0208                 else:
0209                     return -1, i[4]
0210             else:
0211                 return (inpt >> i[2])+i[3], LUT[(inpt >> i[2])+i[3]]
0212 
0213 def ptChecks(shiftmap, LUT, bounderidx=False):
0214     for i in range(1,(1<<BITSABSCURV)-1):
0215         k = (maxCurv*i)/(1<<BITSABSCURV)
0216         pOB=0.3*3.8*0.01/(k)
0217         idx, pINT = LookUp(i, shiftmap, LUT, bounderidx)
0218         ## We don't need to check beyond the boundary
0219         if pOB > (1<<BITSPT)*0.025 or pOB < 2:
0220             continue
0221         # Allow +-1 1LSB
0222         if (abs(pOB - float(pINT)*0.025) > 0.025 ):
0223             print("pt : ", i, pOB, pts[i-1], ptLUT[i-1], idx, pINT, int(pINT)*0.025)
0224 
0225 def etaChecks(shiftmap, LUT, bounderidx=False):
0226     for i in range(0,(1<<(BITSTTTANL))):
0227         tanL = (maxTanL*i)/(1<<BITSTTTANL)
0228         lam =math.atan(tanL)
0229         theta =math.pi/2.0-lam
0230         eta = -math.log(math.tan(theta/2.0))
0231         ## We don't need to check beyond the boundary
0232         if eta > 2.45:
0233             continue
0234         eINT = int(eta*(1<<BITSETA)/math.pi)
0235         idx, etaINT = LookUp(i, shiftmap, LUT, bounderidx)
0236         if eta < 1.59 and (abs(eta - int(etaINT)*etaLSB) > etaLSB  ):
0237             print("eta : ", i, eta, eINT, idx, etaINT, int(etaINT)*etaLSB)
0238         ## For high eta region, we allow up to 2LSB
0239         if eta >= 1.59 and (abs(eta - int(etaINT)*etaLSB) > etaLSB * 2 ):
0240             print("eta : ", i, eta, eINT, idx, etaINT, int(etaINT)*etaLSB)
0241 
0242 
0243 def PrintPTLUT(k, ptLUT):
0244     shiftout = []
0245     for i in k:
0246         ii = [str(j) for j in i]
0247         temp = ",".join(ii)
0248         shiftout.append("{" + temp +"}")
0249     print("int ptShifts[{nOps}][5]={{".format(nOps=len(k)) + ", ".join(shiftout) + "};")
0250     print("ap_uint<BITSPT> ptLUT[{address}]={{".format(address=len(ptLUT))+', '.join(ptLUT)+'};')
0251 
0252 def PrintEtaLUT(k, etaLUT):
0253     shiftout = []
0254     for i in k:
0255         ii = [str(j) for j in i]
0256         temp = ",".join(ii)
0257         shiftout.append("{" + temp +"}")
0258     print("int etaShifts[{nOps}][5]={{".format(nOps=len(k)) + ", ".join(shiftout) + "};")
0259     print("ap_uint<BITSETA> etaLUT[{address}]={{".format(address=len(etaLUT)) +', '.join(etaLUT)+'};')
0260 
0261 if __name__ == "__main__":
0262     bounderidx=True
0263     GetPtLUT()
0264     k = GetLUTwrtLSB(pts, ptLSB, isPT=True, nbits=[ 1, 2, 3, 4, 5, 6, 7], lowerbound=2, upperbound=((1<<BITSPT)-1)*ptLSB)
0265     k, x, y = ProducedFinalLUT(ptLUT, k, isPT=True, bounderidx=bounderidx)
0266     con = consecutive(x)
0267     if len(con) > 1:
0268         print("index is not continuous: ", con)
0269     # ptChecks(k, y, bounderidx=bounderidx)
0270     # print("Total size of LUT is %d" % len(y))
0271     PrintPTLUT(k, y)
0272 
0273     # # ### Eta
0274     GetEtaLUT()
0275     k =  GetLUTwrtLSB(etas, etaLSB, isPT=False, nbits=[0, 1, 2, 3, 5], upperbound =2.45)
0276     k, x, y = ProducedFinalLUT(etaLUT, k, bounderidx=bounderidx)
0277     con = consecutive(x)
0278     if len(con) > 1:
0279         print("index is not continuous: ", con)
0280     # etaChecks(k, y, bounderidx=bounderidx)
0281     # print("Total size of LUT is %d" % len(y))
0282     PrintEtaLUT(k, y)