Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-11-27 03:17:56

0001 #!/usr/bin/env python
0002 # encoding: utf-8
0003 
0004 # File        : makeTrackConversionLUTs.py
0005 # Author      : Zhenbin Wu
0006 # Contact     : zhenbin.wu@gmail.com
0007 # Date        : 2021 Apr 13
0008 #
0009 # Description : 
0010 
0011 import math
0012 import numpy as np
0013 from collections import defaultdict, Counter
0014 from pprint import pprint
0015 
0016 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Constant for PT ~~~~~
0017 ## Constant from L1 Track Trigger
0018 #https://github.com/cms-sw/cmssw/blob/master/L1Trigger/TrackTrigger/python/ProducerSetup_cfi.py#L98
0019 Bfield = 3.81120228767395  # in T
0020 SpeedOfLight = 2.99792458  # in e8 m/s
0021 # https://github.com/cms-sw/cmssw/blob/master/DataFormats/L1TrackTrigger/interface/TTTrack_TrackWord.h#L87
0022 minRinv = 0.006
0023 BITSRinv = 15
0024 BITSABSRinv=14
0025 BITSPT=13
0026 ptLSB=0.03125
0027 ptLUT=[]
0028 pts = []
0029 ptshifts = []
0030 stepRinv = (2. * abs(minRinv)) / (1 << BITSRinv)
0031 drawPT = False
0032 
0033 ## Constant for Eta
0034 BITSTTTANL=16-1
0035 BITSETA=13-1
0036 maxTanL=8.0
0037 etaLSB = math.pi/ (1<<BITSETA)
0038 etaLUT=[]
0039 etas = []
0040 etashifts = []
0041 
0042 def GetPtLUT():
0043     for i in range(1,(1<<BITSABSRinv)):
0044         k = i * stepRinv
0045         #Floating Pt
0046         #https://github.com/cms-sw/cmssw/blob/master/DataFormats/L1TrackTrigger/interface/TTTrack.h#L226
0047         pOB=(SpeedOfLight / 10)*Bfield*0.01/(k)
0048         pts.append(pOB)
0049         pINT = int(round(pOB/ptLSB))
0050         if pINT<(1<<BITSPT):
0051             ptLUT.append(str(pINT))
0052         else:
0053             ptLUT.append(str((1<<BITSPT)-1))
0054 
0055 
0056 def GetEtaLUT():
0057     for i in range(0,(1<<(BITSTTTANL))):
0058         tanL = (maxTanL*i)/(1<<BITSTTTANL)
0059         lam =math.atan(tanL)
0060         theta =math.pi/2.0-lam
0061         eta = -math.log(math.tan(theta/2.0))
0062         etas.append(eta)
0063         etaINT = int(round(eta*(1<<BITSETA)/math.pi))
0064         if abs(eta<math.pi):
0065             etaLUT.append(str(etaINT))
0066 
0067 def consecutive(data, stepsize=1):
0068     return np.split(data, np.where(np.diff(data) != stepsize)[0]+1)
0069 
0070 def GetLUTwrtLSB(fltLUT, lsb, isPT=False, minshift=0,maxshift=5,
0071                  lowerbound=None, upperbound=None, nbits=None):
0072     length = len(fltLUT)
0073     steps  = np.diff(fltLUT)
0074     if isPT:
0075         steps  = -1* steps
0076     if nbits is None:
0077         nbits = range(minshift, maxshift)
0078     cross = length - np.searchsorted(steps[::-1], [lsb/(2**i) for i in nbits])
0079 
0080     x = []
0081     val = []
0082     shifts = []
0083     bfrom = 0
0084     for i, nb in enumerate(nbits):
0085         shifty_= 0
0086         bitcross = cross[i]
0087         if bitcross == 1:
0088             continue
0089         if nb == nbits[-1]:
0090             bitcross = length
0091             # bitcross = length if isPT else length-1
0092         orgx = np.arange(bfrom, bitcross)
0093         val_ = np.take(fltLUT, orgx)
0094         if upperbound is not None and np.any(val_ >= upperbound):
0095             sel = orgx[val_>=upperbound]
0096             uppershift = -1 if isPT else -2
0097             ## sel[-1]+1 , since we will use < in the final function
0098             sht = [sel[0], sel[-1]+1, uppershift, 0, int(upperbound/lsb)]
0099             shifts.append(sht)
0100             orgx = orgx[val_<=upperbound]
0101             val_ = val_[val_<=upperbound]
0102             if isPT:
0103                 bfrom=orgx[0]
0104             else:
0105                 bitcross=orgx[-1]
0106         if lowerbound is not None and np.any(val_ <= lowerbound):
0107             sel = orgx[val_<=lowerbound]
0108             lowershift = -2 if isPT else -1
0109             sht = [sel[0], sel[-1]+1, lowershift, 0, lowerbound/lsb]
0110             shifts.append(sht)
0111             orgx = orgx[val_>=lowerbound]
0112             val_ = val_[val_>= lowerbound]
0113             bitcross=orgx[-1]
0114 
0115         if nb > 1:
0116             ### Important: We can only shift by nb-1 bit to keep the precision
0117             shiftx_ = ((orgx >> (nb-1)) )
0118             if len(shiftx_) == 0:
0119                 continue
0120             if len(x) > 0:
0121                 shifty_ = int(x[-1] + 1 - shiftx_[0]) ## +1 to make sure it won't overlap
0122                 shiftx_ = shiftx_ + shifty_
0123         else:
0124             shiftx_ = orgx
0125         x_, pickx_ = np.unique(shiftx_, return_index=True)
0126         val_ = np.take(val_, pickx_) 
0127         x = np.append(x, x_)
0128         val = np.append(val, val_)
0129         if nb == 0:
0130             sht = [bfrom, bitcross+1, 0, shifty_, 0 ]
0131         else:
0132             sht = [bfrom, bitcross+1, nb-1, shifty_, 0 ]
0133         shifts.append(sht)
0134         # print("Shifting {nbit} bits with intercept {itsect}, input start from {bfrom} ({ffrom}) to {bto} ({fto}), LUT size {nLUT} ".format(
0135             # nbit=nb, itsect=shifty_, bfrom=bfrom, bto=bitcross, ffrom=fltLUT[bfrom], fto=fltLUT[bitcross],  nLUT =len(val_)))
0136         bfrom = bitcross+1
0137 
0138     return shifts
0139 
0140 def Modification(inval, intINT, config):
0141     for cfg in config:
0142         if inval >= cfg[0] and inval < cfg[1]:
0143             if cfg[2] < 0 and cfg[4]!=0:
0144                 if cfg[2] == -1:
0145                     return cfg[2], -1, cfg[4]
0146                 if cfg[2] == -2:
0147                     return cfg[2], 9999999, cfg[4]
0148             elif cfg[2] < 0 and cfg[4]==0:
0149                 return cfg[2], inval, intINT
0150             else:
0151                 return cfg[2], (inval >> cfg[2] ) + cfg[3], intINT
0152 
0153 def GetLUTModified(orgLUT, shiftMap, isPT=False):
0154     tempmap = defaultdict(list)
0155     x= []
0156     y= []
0157     for i, pINT in enumerate(orgLUT):
0158         ii = i
0159         if isPT:
0160             ii+=1
0161         nshift, newidx, newINT = Modification(ii, pINT, shiftMap)
0162         tempmap[newidx].append(newINT)
0163 
0164     con = consecutive(list(tempmap.keys()))
0165     for k, v in tempmap.items():
0166         if k == -1 or k == 9999999:
0167             continue
0168         setv = set(v)
0169         x.append(k)
0170         if len(setv) == 1:
0171             y.append(v[0])
0172         elif len(setv) == 2 or len(setv) ==3:
0173             contv = Counter(v)
0174             ## The counter was sorted, descending in python3 and asending in python2
0175             ## This will result in slightly different LUT when running 
0176             isallequal = (len(set(contv.values())) == 1)
0177             if isallequal:
0178                 ## Using min for now. To be decided
0179                 y.append(min(contv.keys()))
0180             else:
0181                 y.append(contv.most_common(1)[0][0])
0182         else:
0183             print("----- allow up to 3 values per bins")
0184     return x, y
0185 
0186 def ProducedFinalLUT(LUT, k, isPT=False, bounderidx=False):
0187     k = np.asarray(k).astype(int)
0188     if isPT:
0189         k[:, [0, 1]] +=1
0190         k[k[:, 2] > 0, 3] +=1
0191     x, y = GetLUTModified(LUT, k,isPT)
0192     if x[0] != 0:
0193         for i in k:
0194             if i[2] < 0:
0195                 continue
0196             else:
0197                 i[3] -= x[0]
0198     k = k[k[:, 0].argsort()]
0199     x, y = GetLUTModified(LUT, k, isPT)
0200     if bounderidx:
0201         ### Has 
0202         if np.any(k[:,2] == -1):
0203             y.insert(0, str(k[k[:,2] == -1, 4][0]))
0204             k[k[:,2] == -1, 4] = 0
0205             k[k[:, 2] >= 0, 3] +=1
0206         if np.any(k[:,2] == -2):
0207             y.append(str(k[k[:,2] == -2, 4][0]))
0208             k[k[:,2] == -2, 4] = len(y)-1
0209     return k, x, y
0210 
0211 ### PT
0212 def LookUp(inpt, shiftmap, LUT, bounderidx):
0213     for i in shiftmap:
0214         if inpt >=i[0] and inpt < i[1]:
0215             if i[2] < 0:
0216                 if bounderidx:
0217                     return i[4], LUT[i[4]]
0218                 else:
0219                     return -1, i[4]
0220             else:
0221                 return (inpt >> i[2])+i[3], LUT[(inpt >> i[2])+i[3]]
0222 
0223 def ptChecks(shiftmap, LUT, bounderidx=False):
0224     l_ptOBs = []
0225     l_lkpts = []
0226     l_ptINTs = []
0227     for i in range(1,(1<<BITSABSRinv)-1):
0228         k = i * stepRinv
0229         pOB=(SpeedOfLight / 10)*Bfield*0.01/(k)
0230         idx, pINT = LookUp(i, shiftmap, LUT, bounderidx)
0231         ## We don't need to check beyond the boundary
0232         if pOB > (1<<BITSPT)*ptLSB or pOB < 2:
0233             continue
0234         l_ptOBs.append(pOB)
0235         l_ptINTs.append(int(round(pOB/ptLSB))*ptLSB)
0236         l_lkpts.append(float(pINT)*ptLSB)
0237         # Allow +-1 1LSB
0238         if (abs(pOB - float(pINT)*ptLSB) > ptLSB ):
0239             print("pt : ", i, pOB, pts[i-1], ptLUT[i-1], idx, pINT, int(pINT) *ptLSB)
0240 
0241     if drawPT is True:
0242         import matplotlib.pyplot as plt
0243         fig, (ax1, ax2) = plt.subplots(2, 1, layout='constrained')
0244         ax1.plot(l_ptOBs, l_ptINTs, 'bo', label="Round up")
0245         ax1.plot(l_ptOBs, l_lkpts, 'r+', label="Look up")
0246         ax1.set_xscale('log')
0247         ax1.legend()
0248         ax1.set_xlabel("input pt")
0249         ax1.set_ylabel("convert pt")
0250 
0251         diff = (np.array(l_ptINTs) - np.array(l_lkpts))/ptLSB
0252         ax2.plot(l_ptOBs, diff, 'bo', label="Difference")
0253         ax2.set_xscale('log')
0254         ax2.legend()
0255         ax2.set_xlabel("input pt")
0256         ax2.set_ylabel("pt LSB")
0257         ax2.set_ylim([-2.5, 2.5])
0258         plt.savefig("pt_check.png")
0259 
0260 def etaChecks(shiftmap, LUT, bounderidx=False):
0261     for i in range(0,(1<<(BITSTTTANL))):
0262         tanL = (maxTanL*i)/(1<<BITSTTTANL)
0263         lam =math.atan(tanL)
0264         theta =math.pi/2.0-lam
0265         eta = -math.log(math.tan(theta/2.0))
0266         ## We don't need to check beyond the boundary
0267         if eta > 2.45:
0268             continue
0269         eINT = int(eta*(1<<BITSETA)/math.pi)
0270         idx, etaINT = LookUp(i, shiftmap, LUT, bounderidx)
0271         if eta < 1.59 and (abs(eta - int(etaINT)*etaLSB) > etaLSB  ):
0272             print("eta : ", i, eta, eINT, idx, etaINT, int(etaINT)*etaLSB)
0273         ## For high eta region, we allow up to 2LSB
0274         if eta >= 1.59 and (abs(eta - int(etaINT)*etaLSB) > etaLSB * 2 ):
0275             print("eta : ", i, eta, eINT, idx, etaINT, int(etaINT)*etaLSB)
0276 
0277 
0278 def PrintPTLUT(k, ptLUT):
0279     shiftout = []
0280     for i in k:
0281         ii = [str(j) for j in i]
0282         temp = ",".join(ii)
0283         shiftout.append("{" + temp +"}")
0284     print("int ptShifts[{nOps}][5]={{".format(nOps=len(k)) + ", ".join(shiftout) + "};")
0285     print("ap_uint<BITSPT> ptLUT[{address}]={{".format(address=len(ptLUT))+', '.join(ptLUT)+'};')
0286 
0287 def PrintEtaLUT(k, etaLUT):
0288     shiftout = []
0289     for i in k:
0290         ii = [str(j) for j in i]
0291         temp = ",".join(ii)
0292         shiftout.append("{" + temp +"}")
0293     print("int etaShifts[{nOps}][5]={{".format(nOps=len(k)) + ", ".join(shiftout) + "};")
0294     print("ap_uint<BITSETA> etaLUT[{address}]={{".format(address=len(etaLUT)) +', '.join(etaLUT)+'};')
0295 
0296 
0297 
0298 
0299 if __name__ == "__main__":
0300     bounderidx=True
0301     GetPtLUT()
0302     k = GetLUTwrtLSB(pts, ptLSB, isPT=True, nbits=[1, 2, 3, 4, 5, 6, 7], lowerbound=1.9, upperbound=((1<<BITSPT)-1)*ptLSB)
0303     k, x, y = ProducedFinalLUT(ptLUT, k, isPT=True, bounderidx=bounderidx)
0304     ## K is the shift map
0305     ## X is the index to the LUT
0306     ## Y is the LUT
0307     con = consecutive(x)
0308     if len(con) > 1:
0309         print("index is not continuous: ", con)
0310     # ptChecks(k, y, bounderidx=bounderidx)
0311     # print("Total size of LUT is %d" % len(y))
0312     PrintPTLUT(k, y)
0313 
0314     # ### Eta
0315     GetEtaLUT()
0316     k =  GetLUTwrtLSB(etas, etaLSB, isPT=False, nbits=[0, 1, 2, 3, 5], upperbound =2.45)
0317     k, x, y = ProducedFinalLUT(etaLUT, k, bounderidx=bounderidx)
0318     con = consecutive(x)
0319     if len(con) > 1:
0320         print("index is not continuous: ", con)
0321     # etaChecks(k, y, bounderidx=bounderidx)
0322     # print("Total size of LUT is %d" % len(y))
0323     PrintEtaLUT(k, y)