File indexing completed on 2024-11-27 03:17:56
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 import math
0012 import numpy as np
0013 from collections import defaultdict, Counter
0014 from pprint import pprint
0015
0016
0017
0018
0019 Bfield = 3.81120228767395
0020 SpeedOfLight = 2.99792458
0021
0022 minRinv = 0.006
0023 BITSRinv = 15
0024 BITSABSRinv=14
0025 BITSPT=13
0026 ptLSB=0.03125
0027 ptLUT=[]
0028 pts = []
0029 ptshifts = []
0030 stepRinv = (2. * abs(minRinv)) / (1 << BITSRinv)
0031 drawPT = False
0032
0033
0034 BITSTTTANL=16-1
0035 BITSETA=13-1
0036 maxTanL=8.0
0037 etaLSB = math.pi/ (1<<BITSETA)
0038 etaLUT=[]
0039 etas = []
0040 etashifts = []
0041
0042 def GetPtLUT():
0043 for i in range(1,(1<<BITSABSRinv)):
0044 k = i * stepRinv
0045
0046
0047 pOB=(SpeedOfLight / 10)*Bfield*0.01/(k)
0048 pts.append(pOB)
0049 pINT = int(round(pOB/ptLSB))
0050 if pINT<(1<<BITSPT):
0051 ptLUT.append(str(pINT))
0052 else:
0053 ptLUT.append(str((1<<BITSPT)-1))
0054
0055
0056 def GetEtaLUT():
0057 for i in range(0,(1<<(BITSTTTANL))):
0058 tanL = (maxTanL*i)/(1<<BITSTTTANL)
0059 lam =math.atan(tanL)
0060 theta =math.pi/2.0-lam
0061 eta = -math.log(math.tan(theta/2.0))
0062 etas.append(eta)
0063 etaINT = int(round(eta*(1<<BITSETA)/math.pi))
0064 if abs(eta<math.pi):
0065 etaLUT.append(str(etaINT))
0066
0067 def consecutive(data, stepsize=1):
0068 return np.split(data, np.where(np.diff(data) != stepsize)[0]+1)
0069
0070 def GetLUTwrtLSB(fltLUT, lsb, isPT=False, minshift=0,maxshift=5,
0071 lowerbound=None, upperbound=None, nbits=None):
0072 length = len(fltLUT)
0073 steps = np.diff(fltLUT)
0074 if isPT:
0075 steps = -1* steps
0076 if nbits is None:
0077 nbits = range(minshift, maxshift)
0078 cross = length - np.searchsorted(steps[::-1], [lsb/(2**i) for i in nbits])
0079
0080 x = []
0081 val = []
0082 shifts = []
0083 bfrom = 0
0084 for i, nb in enumerate(nbits):
0085 shifty_= 0
0086 bitcross = cross[i]
0087 if bitcross == 1:
0088 continue
0089 if nb == nbits[-1]:
0090 bitcross = length
0091
0092 orgx = np.arange(bfrom, bitcross)
0093 val_ = np.take(fltLUT, orgx)
0094 if upperbound is not None and np.any(val_ >= upperbound):
0095 sel = orgx[val_>=upperbound]
0096 uppershift = -1 if isPT else -2
0097
0098 sht = [sel[0], sel[-1]+1, uppershift, 0, int(upperbound/lsb)]
0099 shifts.append(sht)
0100 orgx = orgx[val_<=upperbound]
0101 val_ = val_[val_<=upperbound]
0102 if isPT:
0103 bfrom=orgx[0]
0104 else:
0105 bitcross=orgx[-1]
0106 if lowerbound is not None and np.any(val_ <= lowerbound):
0107 sel = orgx[val_<=lowerbound]
0108 lowershift = -2 if isPT else -1
0109 sht = [sel[0], sel[-1]+1, lowershift, 0, lowerbound/lsb]
0110 shifts.append(sht)
0111 orgx = orgx[val_>=lowerbound]
0112 val_ = val_[val_>= lowerbound]
0113 bitcross=orgx[-1]
0114
0115 if nb > 1:
0116
0117 shiftx_ = ((orgx >> (nb-1)) )
0118 if len(shiftx_) == 0:
0119 continue
0120 if len(x) > 0:
0121 shifty_ = int(x[-1] + 1 - shiftx_[0])
0122 shiftx_ = shiftx_ + shifty_
0123 else:
0124 shiftx_ = orgx
0125 x_, pickx_ = np.unique(shiftx_, return_index=True)
0126 val_ = np.take(val_, pickx_)
0127 x = np.append(x, x_)
0128 val = np.append(val, val_)
0129 if nb == 0:
0130 sht = [bfrom, bitcross+1, 0, shifty_, 0 ]
0131 else:
0132 sht = [bfrom, bitcross+1, nb-1, shifty_, 0 ]
0133 shifts.append(sht)
0134
0135
0136 bfrom = bitcross+1
0137
0138 return shifts
0139
0140 def Modification(inval, intINT, config):
0141 for cfg in config:
0142 if inval >= cfg[0] and inval < cfg[1]:
0143 if cfg[2] < 0 and cfg[4]!=0:
0144 if cfg[2] == -1:
0145 return cfg[2], -1, cfg[4]
0146 if cfg[2] == -2:
0147 return cfg[2], 9999999, cfg[4]
0148 elif cfg[2] < 0 and cfg[4]==0:
0149 return cfg[2], inval, intINT
0150 else:
0151 return cfg[2], (inval >> cfg[2] ) + cfg[3], intINT
0152
0153 def GetLUTModified(orgLUT, shiftMap, isPT=False):
0154 tempmap = defaultdict(list)
0155 x= []
0156 y= []
0157 for i, pINT in enumerate(orgLUT):
0158 ii = i
0159 if isPT:
0160 ii+=1
0161 nshift, newidx, newINT = Modification(ii, pINT, shiftMap)
0162 tempmap[newidx].append(newINT)
0163
0164 con = consecutive(list(tempmap.keys()))
0165 for k, v in tempmap.items():
0166 if k == -1 or k == 9999999:
0167 continue
0168 setv = set(v)
0169 x.append(k)
0170 if len(setv) == 1:
0171 y.append(v[0])
0172 elif len(setv) == 2 or len(setv) ==3:
0173 contv = Counter(v)
0174
0175
0176 isallequal = (len(set(contv.values())) == 1)
0177 if isallequal:
0178
0179 y.append(min(contv.keys()))
0180 else:
0181 y.append(contv.most_common(1)[0][0])
0182 else:
0183 print("----- allow up to 3 values per bins")
0184 return x, y
0185
0186 def ProducedFinalLUT(LUT, k, isPT=False, bounderidx=False):
0187 k = np.asarray(k).astype(int)
0188 if isPT:
0189 k[:, [0, 1]] +=1
0190 k[k[:, 2] > 0, 3] +=1
0191 x, y = GetLUTModified(LUT, k,isPT)
0192 if x[0] != 0:
0193 for i in k:
0194 if i[2] < 0:
0195 continue
0196 else:
0197 i[3] -= x[0]
0198 k = k[k[:, 0].argsort()]
0199 x, y = GetLUTModified(LUT, k, isPT)
0200 if bounderidx:
0201
0202 if np.any(k[:,2] == -1):
0203 y.insert(0, str(k[k[:,2] == -1, 4][0]))
0204 k[k[:,2] == -1, 4] = 0
0205 k[k[:, 2] >= 0, 3] +=1
0206 if np.any(k[:,2] == -2):
0207 y.append(str(k[k[:,2] == -2, 4][0]))
0208 k[k[:,2] == -2, 4] = len(y)-1
0209 return k, x, y
0210
0211
0212 def LookUp(inpt, shiftmap, LUT, bounderidx):
0213 for i in shiftmap:
0214 if inpt >=i[0] and inpt < i[1]:
0215 if i[2] < 0:
0216 if bounderidx:
0217 return i[4], LUT[i[4]]
0218 else:
0219 return -1, i[4]
0220 else:
0221 return (inpt >> i[2])+i[3], LUT[(inpt >> i[2])+i[3]]
0222
0223 def ptChecks(shiftmap, LUT, bounderidx=False):
0224 l_ptOBs = []
0225 l_lkpts = []
0226 l_ptINTs = []
0227 for i in range(1,(1<<BITSABSRinv)-1):
0228 k = i * stepRinv
0229 pOB=(SpeedOfLight / 10)*Bfield*0.01/(k)
0230 idx, pINT = LookUp(i, shiftmap, LUT, bounderidx)
0231
0232 if pOB > (1<<BITSPT)*ptLSB or pOB < 2:
0233 continue
0234 l_ptOBs.append(pOB)
0235 l_ptINTs.append(int(round(pOB/ptLSB))*ptLSB)
0236 l_lkpts.append(float(pINT)*ptLSB)
0237
0238 if (abs(pOB - float(pINT)*ptLSB) > ptLSB ):
0239 print("pt : ", i, pOB, pts[i-1], ptLUT[i-1], idx, pINT, int(pINT) *ptLSB)
0240
0241 if drawPT is True:
0242 import matplotlib.pyplot as plt
0243 fig, (ax1, ax2) = plt.subplots(2, 1, layout='constrained')
0244 ax1.plot(l_ptOBs, l_ptINTs, 'bo', label="Round up")
0245 ax1.plot(l_ptOBs, l_lkpts, 'r+', label="Look up")
0246 ax1.set_xscale('log')
0247 ax1.legend()
0248 ax1.set_xlabel("input pt")
0249 ax1.set_ylabel("convert pt")
0250
0251 diff = (np.array(l_ptINTs) - np.array(l_lkpts))/ptLSB
0252 ax2.plot(l_ptOBs, diff, 'bo', label="Difference")
0253 ax2.set_xscale('log')
0254 ax2.legend()
0255 ax2.set_xlabel("input pt")
0256 ax2.set_ylabel("pt LSB")
0257 ax2.set_ylim([-2.5, 2.5])
0258 plt.savefig("pt_check.png")
0259
0260 def etaChecks(shiftmap, LUT, bounderidx=False):
0261 for i in range(0,(1<<(BITSTTTANL))):
0262 tanL = (maxTanL*i)/(1<<BITSTTTANL)
0263 lam =math.atan(tanL)
0264 theta =math.pi/2.0-lam
0265 eta = -math.log(math.tan(theta/2.0))
0266
0267 if eta > 2.45:
0268 continue
0269 eINT = int(eta*(1<<BITSETA)/math.pi)
0270 idx, etaINT = LookUp(i, shiftmap, LUT, bounderidx)
0271 if eta < 1.59 and (abs(eta - int(etaINT)*etaLSB) > etaLSB ):
0272 print("eta : ", i, eta, eINT, idx, etaINT, int(etaINT)*etaLSB)
0273
0274 if eta >= 1.59 and (abs(eta - int(etaINT)*etaLSB) > etaLSB * 2 ):
0275 print("eta : ", i, eta, eINT, idx, etaINT, int(etaINT)*etaLSB)
0276
0277
0278 def PrintPTLUT(k, ptLUT):
0279 shiftout = []
0280 for i in k:
0281 ii = [str(j) for j in i]
0282 temp = ",".join(ii)
0283 shiftout.append("{" + temp +"}")
0284 print("int ptShifts[{nOps}][5]={{".format(nOps=len(k)) + ", ".join(shiftout) + "};")
0285 print("ap_uint<BITSPT> ptLUT[{address}]={{".format(address=len(ptLUT))+', '.join(ptLUT)+'};')
0286
0287 def PrintEtaLUT(k, etaLUT):
0288 shiftout = []
0289 for i in k:
0290 ii = [str(j) for j in i]
0291 temp = ",".join(ii)
0292 shiftout.append("{" + temp +"}")
0293 print("int etaShifts[{nOps}][5]={{".format(nOps=len(k)) + ", ".join(shiftout) + "};")
0294 print("ap_uint<BITSETA> etaLUT[{address}]={{".format(address=len(etaLUT)) +', '.join(etaLUT)+'};')
0295
0296
0297
0298
0299 if __name__ == "__main__":
0300 bounderidx=True
0301 GetPtLUT()
0302 k = GetLUTwrtLSB(pts, ptLSB, isPT=True, nbits=[1, 2, 3, 4, 5, 6, 7], lowerbound=1.9, upperbound=((1<<BITSPT)-1)*ptLSB)
0303 k, x, y = ProducedFinalLUT(ptLUT, k, isPT=True, bounderidx=bounderidx)
0304
0305
0306
0307 con = consecutive(x)
0308 if len(con) > 1:
0309 print("index is not continuous: ", con)
0310
0311
0312 PrintPTLUT(k, y)
0313
0314
0315 GetEtaLUT()
0316 k = GetLUTwrtLSB(etas, etaLSB, isPT=False, nbits=[0, 1, 2, 3, 5], upperbound =2.45)
0317 k, x, y = ProducedFinalLUT(etaLUT, k, bounderidx=bounderidx)
0318 con = consecutive(x)
0319 if len(con) > 1:
0320 print("index is not continuous: ", con)
0321
0322
0323 PrintEtaLUT(k, y)