File indexing completed on 2022-04-01 23:54:08
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 from __future__ import division
0012 import math
0013 import numpy as np
0014 from collections import defaultdict, Counter
0015 from pprint import pprint
0016
0017 BITSABSCURV=14
0018 BITSPT=13
0019 maxCurv = 0.00855
0020 ptLSB=0.025
0021 ptLUT=[]
0022 pts = []
0023 ptshifts = []
0024
0025
0026 BITSTTTANL=16-1
0027 BITSETA=13-1
0028 maxTanL=8.0
0029 etaLSB = math.pi/ (1<<BITSETA)
0030 etaLUT=[]
0031 etas = []
0032 etashifts = []
0033
0034 def GetPtLUT():
0035 for i in range(1,(1<<BITSABSCURV)):
0036 k = (maxCurv*i)/(1<<BITSABSCURV)
0037 pOB=0.3*3.8*0.01/(k)
0038 pts.append(pOB)
0039 pINT = int(round(pOB/ptLSB))
0040 if pINT<(1<<BITSPT):
0041 ptLUT.append(str(pINT))
0042 else:
0043 ptLUT.append(str((1<<BITSPT)-1))
0044
0045
0046 def GetEtaLUT():
0047 for i in range(0,(1<<(BITSTTTANL))):
0048 tanL = (maxTanL*i)/(1<<BITSTTTANL)
0049 lam =math.atan(tanL)
0050 theta =math.pi/2.0-lam
0051 eta = -math.log(math.tan(theta/2.0))
0052 etas.append(eta)
0053 etaINT = int(round(eta*(1<<BITSETA)/math.pi))
0054 if abs(eta<math.pi):
0055 etaLUT.append(str(etaINT))
0056
0057 def consecutive(data, stepsize=1):
0058 return np.split(data, np.where(np.diff(data) != stepsize)[0]+1)
0059
0060 def GetLUTwrtLSB(fltLUT, lsb, isPT=False, minshift=0,maxshift=5,
0061 lowerbound=None, upperbound=None, nbits=None):
0062 length = len(fltLUT)
0063 steps = np.diff(fltLUT)
0064 if isPT:
0065 steps = -1* steps
0066 if nbits is None:
0067 nbits = range(minshift, maxshift)
0068 cross = length - np.searchsorted(steps[::-1], [lsb/(2**i) for i in nbits])
0069
0070 x = []
0071 val = []
0072 shifts = []
0073 bfrom = 0
0074 for i, nb in enumerate(nbits):
0075 shifty_= 0
0076 bitcross = cross[i]
0077 if bitcross == 1:
0078 continue
0079 if nb == nbits[-1]:
0080 bitcross = length
0081
0082 orgx = np.arange(bfrom, bitcross)
0083 val_ = np.take(fltLUT, orgx)
0084 if upperbound is not None and np.any(val_ >= upperbound):
0085 sel = orgx[val_>=upperbound]
0086 uppershift = -1 if isPT else -2
0087
0088 sht = [sel[0], sel[-1]+1, uppershift, 0, int(upperbound/lsb)]
0089 shifts.append(sht)
0090 orgx = orgx[val_<=upperbound]
0091 val_ = val_[val_<=upperbound]
0092 if isPT:
0093 bfrom=orgx[0]
0094 else:
0095 bitcross=orgx[-1]
0096 if lowerbound is not None and np.any(val_ <= lowerbound):
0097 sel = orgx[val_<=lowerbound]
0098 lowershift = -2 if isPT else -1
0099 sht = [sel[0], sel[-1]+1, lowershift, 0, lowerbound/lsb]
0100 shifts.append(sht)
0101 orgx = orgx[val_>=lowerbound]
0102 val_ = val_[val_>= lowerbound]
0103 bitcross=orgx[-1]
0104
0105 if nb > 1:
0106
0107 shiftx_ = ((orgx >> (nb-1)) )
0108 if len(shiftx_) == 0:
0109 continue
0110 if len(x) > 0:
0111 shifty_ = int(x[-1] + 1 - shiftx_[0])
0112 shiftx_ = shiftx_ + shifty_
0113 else:
0114 shiftx_ = orgx
0115 x_, pickx_ = np.unique(shiftx_, return_index=True)
0116 val_ = np.take(val_, pickx_)
0117 x = np.append(x, x_)
0118 val = np.append(val, val_)
0119 if nb == 0:
0120 sht = [bfrom, bitcross+1, 0, shifty_, 0 ]
0121 else:
0122 sht = [bfrom, bitcross+1, nb-1, shifty_, 0 ]
0123 shifts.append(sht)
0124
0125
0126 bfrom = bitcross+1
0127
0128 return shifts
0129
0130 def Modification(inval, intINT, config):
0131 for cfg in config:
0132 if inval >= cfg[0] and inval < cfg[1]:
0133 if cfg[2] < 0 and cfg[4]!=0:
0134 if cfg[2] == -1:
0135 return cfg[2], -1, cfg[4]
0136 if cfg[2] == -2:
0137 return cfg[2], 9999999, cfg[4]
0138 elif cfg[2] < 0 and cfg[4]==0:
0139 return cfg[2], inval, intINT
0140 else:
0141 return cfg[2], (inval >> cfg[2] ) + cfg[3], intINT
0142
0143 def GetLUTModified(orgLUT, shiftMap, isPT=False):
0144 tempmap = defaultdict(list)
0145 x= []
0146 y= []
0147 for i, pINT in enumerate(orgLUT):
0148 ii = i
0149 if isPT:
0150 ii+=1
0151 nshift, newidx, newINT = Modification(ii, pINT, shiftMap)
0152 tempmap[newidx].append(newINT)
0153
0154 con = consecutive(list(tempmap.keys()))
0155 for k, v in tempmap.items():
0156 if k == -1 or k == 9999999:
0157 continue
0158 setv = set(v)
0159 x.append(k)
0160 if len(setv) == 1:
0161 y.append(v[0])
0162 elif len(setv) == 2 or len(setv) ==3:
0163 contv = Counter(v)
0164
0165
0166 isallequal = (len(set(contv.values())) == 1)
0167 if isallequal:
0168
0169 y.append(min(contv.keys()))
0170 else:
0171 y.append(contv.most_common(1)[0][0])
0172 else:
0173 print("----- allow up to 3 values per bins")
0174 return x, y
0175
0176 def ProducedFinalLUT(LUT, k, isPT=False, bounderidx=False):
0177 k = np.asarray(k).astype(int)
0178 if isPT:
0179 k[:, [0, 1]] +=1
0180 k[k[:, 2] > 0, 3] +=1
0181 x, y = GetLUTModified(LUT, k,isPT)
0182 if x[0] != 0:
0183 for i in k:
0184 if i[2] < 0:
0185 continue
0186 else:
0187 i[3] -= x[0]
0188 k = k[k[:, 0].argsort()]
0189 x, y = GetLUTModified(LUT, k, isPT)
0190 if bounderidx:
0191
0192 if np.any(k[:,2] == -1):
0193 y.insert(0, str(k[k[:,2] == -1, 4][0]))
0194 k[k[:,2] == -1, 4] = 0
0195 k[k[:, 2] >= 0, 3] +=1
0196 if np.any(k[:,2] == -2):
0197 y.append(str(k[k[:,2] == -2, 4][0]))
0198 k[k[:,2] == -2, 4] = len(y)-1
0199 return k, x, y
0200
0201
0202 def LookUp(inpt, shiftmap, LUT, bounderidx):
0203 for i in shiftmap:
0204 if inpt >=i[0] and inpt < i[1]:
0205 if i[2] < 0:
0206 if bounderidx:
0207 return i[4], LUT[i[4]]
0208 else:
0209 return -1, i[4]
0210 else:
0211 return (inpt >> i[2])+i[3], LUT[(inpt >> i[2])+i[3]]
0212
0213 def ptChecks(shiftmap, LUT, bounderidx=False):
0214 for i in range(1,(1<<BITSABSCURV)-1):
0215 k = (maxCurv*i)/(1<<BITSABSCURV)
0216 pOB=0.3*3.8*0.01/(k)
0217 idx, pINT = LookUp(i, shiftmap, LUT, bounderidx)
0218
0219 if pOB > (1<<BITSPT)*0.025 or pOB < 2:
0220 continue
0221
0222 if (abs(pOB - float(pINT)*0.025) > 0.025 ):
0223 print("pt : ", i, pOB, pts[i-1], ptLUT[i-1], idx, pINT, int(pINT)*0.025)
0224
0225 def etaChecks(shiftmap, LUT, bounderidx=False):
0226 for i in range(0,(1<<(BITSTTTANL))):
0227 tanL = (maxTanL*i)/(1<<BITSTTTANL)
0228 lam =math.atan(tanL)
0229 theta =math.pi/2.0-lam
0230 eta = -math.log(math.tan(theta/2.0))
0231
0232 if eta > 2.45:
0233 continue
0234 eINT = int(eta*(1<<BITSETA)/math.pi)
0235 idx, etaINT = LookUp(i, shiftmap, LUT, bounderidx)
0236 if eta < 1.59 and (abs(eta - int(etaINT)*etaLSB) > etaLSB ):
0237 print("eta : ", i, eta, eINT, idx, etaINT, int(etaINT)*etaLSB)
0238
0239 if eta >= 1.59 and (abs(eta - int(etaINT)*etaLSB) > etaLSB * 2 ):
0240 print("eta : ", i, eta, eINT, idx, etaINT, int(etaINT)*etaLSB)
0241
0242
0243 def PrintPTLUT(k, ptLUT):
0244 shiftout = []
0245 for i in k:
0246 ii = [str(j) for j in i]
0247 temp = ",".join(ii)
0248 shiftout.append("{" + temp +"}")
0249 print("int ptShifts[{nOps}][5]={{".format(nOps=len(k)) + ", ".join(shiftout) + "};")
0250 print("ap_uint<BITSPT> ptLUT[{address}]={{".format(address=len(ptLUT))+', '.join(ptLUT)+'};')
0251
0252 def PrintEtaLUT(k, etaLUT):
0253 shiftout = []
0254 for i in k:
0255 ii = [str(j) for j in i]
0256 temp = ",".join(ii)
0257 shiftout.append("{" + temp +"}")
0258 print("int etaShifts[{nOps}][5]={{".format(nOps=len(k)) + ", ".join(shiftout) + "};")
0259 print("ap_uint<BITSETA> etaLUT[{address}]={{".format(address=len(etaLUT)) +', '.join(etaLUT)+'};')
0260
0261 if __name__ == "__main__":
0262 bounderidx=True
0263 GetPtLUT()
0264 k = GetLUTwrtLSB(pts, ptLSB, isPT=True, nbits=[ 1, 2, 3, 4, 5, 6, 7], lowerbound=2, upperbound=((1<<BITSPT)-1)*ptLSB)
0265 k, x, y = ProducedFinalLUT(ptLUT, k, isPT=True, bounderidx=bounderidx)
0266 con = consecutive(x)
0267 if len(con) > 1:
0268 print("index is not continuous: ", con)
0269
0270
0271 PrintPTLUT(k, y)
0272
0273
0274 GetEtaLUT()
0275 k = GetLUTwrtLSB(etas, etaLSB, isPT=False, nbits=[0, 1, 2, 3, 5], upperbound =2.45)
0276 k, x, y = ProducedFinalLUT(etaLUT, k, bounderidx=bounderidx)
0277 con = consecutive(x)
0278 if len(con) > 1:
0279 print("index is not continuous: ", con)
0280
0281
0282 PrintEtaLUT(k, y)