File indexing completed on 2024-07-16 02:43:05
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 from __future__ import division
0012 import math
0013 import numpy as np
0014 from collections import defaultdict, Counter
0015 from pprint import pprint
0016
0017
0018
0019
0020 Bfield = 3.81120228767395
0021 SpeedOfLight = 2.99792458
0022
0023 minRinv = 0.006
0024 BITSRinv = 15
0025 BITSABSRinv=14
0026 BITSPT=13
0027 ptLSB=0.03125
0028 ptLUT=[]
0029 pts = []
0030 ptshifts = []
0031 stepRinv = (2. * abs(minRinv)) / (1 << BITSRinv)
0032 drawPT = False
0033
0034
0035 BITSTTTANL=16-1
0036 BITSETA=13-1
0037 maxTanL=8.0
0038 etaLSB = math.pi/ (1<<BITSETA)
0039 etaLUT=[]
0040 etas = []
0041 etashifts = []
0042
0043 def GetPtLUT():
0044 for i in range(1,(1<<BITSABSRinv)):
0045 k = i * stepRinv
0046
0047
0048 pOB=(SpeedOfLight / 10)*Bfield*0.01/(k)
0049 pts.append(pOB)
0050 pINT = int(round(pOB/ptLSB))
0051 if pINT<(1<<BITSPT):
0052 ptLUT.append(str(pINT))
0053 else:
0054 ptLUT.append(str((1<<BITSPT)-1))
0055
0056
0057 def GetEtaLUT():
0058 for i in range(0,(1<<(BITSTTTANL))):
0059 tanL = (maxTanL*i)/(1<<BITSTTTANL)
0060 lam =math.atan(tanL)
0061 theta =math.pi/2.0-lam
0062 eta = -math.log(math.tan(theta/2.0))
0063 etas.append(eta)
0064 etaINT = int(round(eta*(1<<BITSETA)/math.pi))
0065 if abs(eta<math.pi):
0066 etaLUT.append(str(etaINT))
0067
0068 def consecutive(data, stepsize=1):
0069 return np.split(data, np.where(np.diff(data) != stepsize)[0]+1)
0070
0071 def GetLUTwrtLSB(fltLUT, lsb, isPT=False, minshift=0,maxshift=5,
0072 lowerbound=None, upperbound=None, nbits=None):
0073 length = len(fltLUT)
0074 steps = np.diff(fltLUT)
0075 if isPT:
0076 steps = -1* steps
0077 if nbits is None:
0078 nbits = range(minshift, maxshift)
0079 cross = length - np.searchsorted(steps[::-1], [lsb/(2**i) for i in nbits])
0080
0081 x = []
0082 val = []
0083 shifts = []
0084 bfrom = 0
0085 for i, nb in enumerate(nbits):
0086 shifty_= 0
0087 bitcross = cross[i]
0088 if bitcross == 1:
0089 continue
0090 if nb == nbits[-1]:
0091 bitcross = length
0092
0093 orgx = np.arange(bfrom, bitcross)
0094 val_ = np.take(fltLUT, orgx)
0095 if upperbound is not None and np.any(val_ >= upperbound):
0096 sel = orgx[val_>=upperbound]
0097 uppershift = -1 if isPT else -2
0098
0099 sht = [sel[0], sel[-1]+1, uppershift, 0, int(upperbound/lsb)]
0100 shifts.append(sht)
0101 orgx = orgx[val_<=upperbound]
0102 val_ = val_[val_<=upperbound]
0103 if isPT:
0104 bfrom=orgx[0]
0105 else:
0106 bitcross=orgx[-1]
0107 if lowerbound is not None and np.any(val_ <= lowerbound):
0108 sel = orgx[val_<=lowerbound]
0109 lowershift = -2 if isPT else -1
0110 sht = [sel[0], sel[-1]+1, lowershift, 0, lowerbound/lsb]
0111 shifts.append(sht)
0112 orgx = orgx[val_>=lowerbound]
0113 val_ = val_[val_>= lowerbound]
0114 bitcross=orgx[-1]
0115
0116 if nb > 1:
0117
0118 shiftx_ = ((orgx >> (nb-1)) )
0119 if len(shiftx_) == 0:
0120 continue
0121 if len(x) > 0:
0122 shifty_ = int(x[-1] + 1 - shiftx_[0])
0123 shiftx_ = shiftx_ + shifty_
0124 else:
0125 shiftx_ = orgx
0126 x_, pickx_ = np.unique(shiftx_, return_index=True)
0127 val_ = np.take(val_, pickx_)
0128 x = np.append(x, x_)
0129 val = np.append(val, val_)
0130 if nb == 0:
0131 sht = [bfrom, bitcross+1, 0, shifty_, 0 ]
0132 else:
0133 sht = [bfrom, bitcross+1, nb-1, shifty_, 0 ]
0134 shifts.append(sht)
0135
0136
0137 bfrom = bitcross+1
0138
0139 return shifts
0140
0141 def Modification(inval, intINT, config):
0142 for cfg in config:
0143 if inval >= cfg[0] and inval < cfg[1]:
0144 if cfg[2] < 0 and cfg[4]!=0:
0145 if cfg[2] == -1:
0146 return cfg[2], -1, cfg[4]
0147 if cfg[2] == -2:
0148 return cfg[2], 9999999, cfg[4]
0149 elif cfg[2] < 0 and cfg[4]==0:
0150 return cfg[2], inval, intINT
0151 else:
0152 return cfg[2], (inval >> cfg[2] ) + cfg[3], intINT
0153
0154 def GetLUTModified(orgLUT, shiftMap, isPT=False):
0155 tempmap = defaultdict(list)
0156 x= []
0157 y= []
0158 for i, pINT in enumerate(orgLUT):
0159 ii = i
0160 if isPT:
0161 ii+=1
0162 nshift, newidx, newINT = Modification(ii, pINT, shiftMap)
0163 tempmap[newidx].append(newINT)
0164
0165 con = consecutive(list(tempmap.keys()))
0166 for k, v in tempmap.items():
0167 if k == -1 or k == 9999999:
0168 continue
0169 setv = set(v)
0170 x.append(k)
0171 if len(setv) == 1:
0172 y.append(v[0])
0173 elif len(setv) == 2 or len(setv) ==3:
0174 contv = Counter(v)
0175
0176
0177 isallequal = (len(set(contv.values())) == 1)
0178 if isallequal:
0179
0180 y.append(min(contv.keys()))
0181 else:
0182 y.append(contv.most_common(1)[0][0])
0183 else:
0184 print("----- allow up to 3 values per bins")
0185 return x, y
0186
0187 def ProducedFinalLUT(LUT, k, isPT=False, bounderidx=False):
0188 k = np.asarray(k).astype(int)
0189 if isPT:
0190 k[:, [0, 1]] +=1
0191 k[k[:, 2] > 0, 3] +=1
0192 x, y = GetLUTModified(LUT, k,isPT)
0193 if x[0] != 0:
0194 for i in k:
0195 if i[2] < 0:
0196 continue
0197 else:
0198 i[3] -= x[0]
0199 k = k[k[:, 0].argsort()]
0200 x, y = GetLUTModified(LUT, k, isPT)
0201 if bounderidx:
0202
0203 if np.any(k[:,2] == -1):
0204 y.insert(0, str(k[k[:,2] == -1, 4][0]))
0205 k[k[:,2] == -1, 4] = 0
0206 k[k[:, 2] >= 0, 3] +=1
0207 if np.any(k[:,2] == -2):
0208 y.append(str(k[k[:,2] == -2, 4][0]))
0209 k[k[:,2] == -2, 4] = len(y)-1
0210 return k, x, y
0211
0212
0213 def LookUp(inpt, shiftmap, LUT, bounderidx):
0214 for i in shiftmap:
0215 if inpt >=i[0] and inpt < i[1]:
0216 if i[2] < 0:
0217 if bounderidx:
0218 return i[4], LUT[i[4]]
0219 else:
0220 return -1, i[4]
0221 else:
0222 return (inpt >> i[2])+i[3], LUT[(inpt >> i[2])+i[3]]
0223
0224 def ptChecks(shiftmap, LUT, bounderidx=False):
0225 l_ptOBs = []
0226 l_lkpts = []
0227 l_ptINTs = []
0228 for i in range(1,(1<<BITSABSRinv)-1):
0229 k = i * stepRinv
0230 pOB=(SpeedOfLight / 10)*Bfield*0.01/(k)
0231 idx, pINT = LookUp(i, shiftmap, LUT, bounderidx)
0232
0233 if pOB > (1<<BITSPT)*ptLSB or pOB < 2:
0234 continue
0235 l_ptOBs.append(pOB)
0236 l_ptINTs.append(int(round(pOB/ptLSB))*ptLSB)
0237 l_lkpts.append(float(pINT)*ptLSB)
0238
0239 if (abs(pOB - float(pINT)*ptLSB) > ptLSB ):
0240 print("pt : ", i, pOB, pts[i-1], ptLUT[i-1], idx, pINT, int(pINT) *ptLSB)
0241
0242 if drawPT is True:
0243 import matplotlib.pyplot as plt
0244 fig, (ax1, ax2) = plt.subplots(2, 1, layout='constrained')
0245 ax1.plot(l_ptOBs, l_ptINTs, 'bo', label="Round up")
0246 ax1.plot(l_ptOBs, l_lkpts, 'r+', label="Look up")
0247 ax1.set_xscale('log')
0248 ax1.legend()
0249 ax1.set_xlabel("input pt")
0250 ax1.set_ylabel("convert pt")
0251
0252 diff = (np.array(l_ptINTs) - np.array(l_lkpts))/ptLSB
0253 ax2.plot(l_ptOBs, diff, 'bo', label="Difference")
0254 ax2.set_xscale('log')
0255 ax2.legend()
0256 ax2.set_xlabel("input pt")
0257 ax2.set_ylabel("pt LSB")
0258 ax2.set_ylim([-2.5, 2.5])
0259 plt.savefig("pt_check.png")
0260
0261 def etaChecks(shiftmap, LUT, bounderidx=False):
0262 for i in range(0,(1<<(BITSTTTANL))):
0263 tanL = (maxTanL*i)/(1<<BITSTTTANL)
0264 lam =math.atan(tanL)
0265 theta =math.pi/2.0-lam
0266 eta = -math.log(math.tan(theta/2.0))
0267
0268 if eta > 2.45:
0269 continue
0270 eINT = int(eta*(1<<BITSETA)/math.pi)
0271 idx, etaINT = LookUp(i, shiftmap, LUT, bounderidx)
0272 if eta < 1.59 and (abs(eta - int(etaINT)*etaLSB) > etaLSB ):
0273 print("eta : ", i, eta, eINT, idx, etaINT, int(etaINT)*etaLSB)
0274
0275 if eta >= 1.59 and (abs(eta - int(etaINT)*etaLSB) > etaLSB * 2 ):
0276 print("eta : ", i, eta, eINT, idx, etaINT, int(etaINT)*etaLSB)
0277
0278
0279 def PrintPTLUT(k, ptLUT):
0280 shiftout = []
0281 for i in k:
0282 ii = [str(j) for j in i]
0283 temp = ",".join(ii)
0284 shiftout.append("{" + temp +"}")
0285 print("int ptShifts[{nOps}][5]={{".format(nOps=len(k)) + ", ".join(shiftout) + "};")
0286 print("ap_uint<BITSPT> ptLUT[{address}]={{".format(address=len(ptLUT))+', '.join(ptLUT)+'};')
0287
0288 def PrintEtaLUT(k, etaLUT):
0289 shiftout = []
0290 for i in k:
0291 ii = [str(j) for j in i]
0292 temp = ",".join(ii)
0293 shiftout.append("{" + temp +"}")
0294 print("int etaShifts[{nOps}][5]={{".format(nOps=len(k)) + ", ".join(shiftout) + "};")
0295 print("ap_uint<BITSETA> etaLUT[{address}]={{".format(address=len(etaLUT)) +', '.join(etaLUT)+'};')
0296
0297
0298
0299
0300 if __name__ == "__main__":
0301 bounderidx=True
0302 GetPtLUT()
0303 k = GetLUTwrtLSB(pts, ptLSB, isPT=True, nbits=[1, 2, 3, 4, 5, 6, 7], lowerbound=1.9, upperbound=((1<<BITSPT)-1)*ptLSB)
0304 k, x, y = ProducedFinalLUT(ptLUT, k, isPT=True, bounderidx=bounderidx)
0305
0306
0307
0308 con = consecutive(x)
0309 if len(con) > 1:
0310 print("index is not continuous: ", con)
0311
0312
0313 PrintPTLUT(k, y)
0314
0315
0316 GetEtaLUT()
0317 k = GetLUTwrtLSB(etas, etaLSB, isPT=False, nbits=[0, 1, 2, 3, 5], upperbound =2.45)
0318 k, x, y = ProducedFinalLUT(etaLUT, k, bounderidx=bounderidx)
0319 con = consecutive(x)
0320 if len(con) > 1:
0321 print("index is not continuous: ", con)
0322
0323
0324 PrintEtaLUT(k, y)