Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-07-16 02:43:05

0001 #!/usr/bin/env python
0002 # encoding: utf-8
0003 
0004 # File        : makeTrackConversionLUTs.py
0005 # Author      : Zhenbin Wu
0006 # Contact     : zhenbin.wu@gmail.com
0007 # Date        : 2021 Apr 13
0008 #
0009 # Description : 
0010 
0011 from __future__ import division
0012 import math
0013 import numpy as np
0014 from collections import defaultdict, Counter
0015 from pprint import pprint
0016 
0017 #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Constant for PT ~~~~~
0018 ## Constant from L1 Track Trigger
0019 #https://github.com/cms-sw/cmssw/blob/master/L1Trigger/TrackTrigger/python/ProducerSetup_cfi.py#L98
0020 Bfield = 3.81120228767395  # in T
0021 SpeedOfLight = 2.99792458  # in e8 m/s
0022 # https://github.com/cms-sw/cmssw/blob/master/DataFormats/L1TrackTrigger/interface/TTTrack_TrackWord.h#L87
0023 minRinv = 0.006
0024 BITSRinv = 15
0025 BITSABSRinv=14
0026 BITSPT=13
0027 ptLSB=0.03125
0028 ptLUT=[]
0029 pts = []
0030 ptshifts = []
0031 stepRinv = (2. * abs(minRinv)) / (1 << BITSRinv)
0032 drawPT = False
0033 
0034 ## Constant for Eta
0035 BITSTTTANL=16-1
0036 BITSETA=13-1
0037 maxTanL=8.0
0038 etaLSB = math.pi/ (1<<BITSETA)
0039 etaLUT=[]
0040 etas = []
0041 etashifts = []
0042 
0043 def GetPtLUT():
0044     for i in range(1,(1<<BITSABSRinv)):
0045         k = i * stepRinv
0046         #Floating Pt
0047         #https://github.com/cms-sw/cmssw/blob/master/DataFormats/L1TrackTrigger/interface/TTTrack.h#L226
0048         pOB=(SpeedOfLight / 10)*Bfield*0.01/(k)
0049         pts.append(pOB)
0050         pINT = int(round(pOB/ptLSB))
0051         if pINT<(1<<BITSPT):
0052             ptLUT.append(str(pINT))
0053         else:
0054             ptLUT.append(str((1<<BITSPT)-1))
0055 
0056 
0057 def GetEtaLUT():
0058     for i in range(0,(1<<(BITSTTTANL))):
0059         tanL = (maxTanL*i)/(1<<BITSTTTANL)
0060         lam =math.atan(tanL)
0061         theta =math.pi/2.0-lam
0062         eta = -math.log(math.tan(theta/2.0))
0063         etas.append(eta)
0064         etaINT = int(round(eta*(1<<BITSETA)/math.pi))
0065         if abs(eta<math.pi):
0066             etaLUT.append(str(etaINT))
0067 
0068 def consecutive(data, stepsize=1):
0069     return np.split(data, np.where(np.diff(data) != stepsize)[0]+1)
0070 
0071 def GetLUTwrtLSB(fltLUT, lsb, isPT=False, minshift=0,maxshift=5,
0072                  lowerbound=None, upperbound=None, nbits=None):
0073     length = len(fltLUT)
0074     steps  = np.diff(fltLUT)
0075     if isPT:
0076         steps  = -1* steps
0077     if nbits is None:
0078         nbits = range(minshift, maxshift)
0079     cross = length - np.searchsorted(steps[::-1], [lsb/(2**i) for i in nbits])
0080 
0081     x = []
0082     val = []
0083     shifts = []
0084     bfrom = 0
0085     for i, nb in enumerate(nbits):
0086         shifty_= 0
0087         bitcross = cross[i]
0088         if bitcross == 1:
0089             continue
0090         if nb == nbits[-1]:
0091             bitcross = length
0092             # bitcross = length if isPT else length-1
0093         orgx = np.arange(bfrom, bitcross)
0094         val_ = np.take(fltLUT, orgx)
0095         if upperbound is not None and np.any(val_ >= upperbound):
0096             sel = orgx[val_>=upperbound]
0097             uppershift = -1 if isPT else -2
0098             ## sel[-1]+1 , since we will use < in the final function
0099             sht = [sel[0], sel[-1]+1, uppershift, 0, int(upperbound/lsb)]
0100             shifts.append(sht)
0101             orgx = orgx[val_<=upperbound]
0102             val_ = val_[val_<=upperbound]
0103             if isPT:
0104                 bfrom=orgx[0]
0105             else:
0106                 bitcross=orgx[-1]
0107         if lowerbound is not None and np.any(val_ <= lowerbound):
0108             sel = orgx[val_<=lowerbound]
0109             lowershift = -2 if isPT else -1
0110             sht = [sel[0], sel[-1]+1, lowershift, 0, lowerbound/lsb]
0111             shifts.append(sht)
0112             orgx = orgx[val_>=lowerbound]
0113             val_ = val_[val_>= lowerbound]
0114             bitcross=orgx[-1]
0115 
0116         if nb > 1:
0117             ### Important: We can only shift by nb-1 bit to keep the precision
0118             shiftx_ = ((orgx >> (nb-1)) )
0119             if len(shiftx_) == 0:
0120                 continue
0121             if len(x) > 0:
0122                 shifty_ = int(x[-1] + 1 - shiftx_[0]) ## +1 to make sure it won't overlap
0123                 shiftx_ = shiftx_ + shifty_
0124         else:
0125             shiftx_ = orgx
0126         x_, pickx_ = np.unique(shiftx_, return_index=True)
0127         val_ = np.take(val_, pickx_) 
0128         x = np.append(x, x_)
0129         val = np.append(val, val_)
0130         if nb == 0:
0131             sht = [bfrom, bitcross+1, 0, shifty_, 0 ]
0132         else:
0133             sht = [bfrom, bitcross+1, nb-1, shifty_, 0 ]
0134         shifts.append(sht)
0135         # print("Shifting {nbit} bits with intercept {itsect}, input start from {bfrom} ({ffrom}) to {bto} ({fto}), LUT size {nLUT} ".format(
0136             # nbit=nb, itsect=shifty_, bfrom=bfrom, bto=bitcross, ffrom=fltLUT[bfrom], fto=fltLUT[bitcross],  nLUT =len(val_)))
0137         bfrom = bitcross+1
0138 
0139     return shifts
0140 
0141 def Modification(inval, intINT, config):
0142     for cfg in config:
0143         if inval >= cfg[0] and inval < cfg[1]:
0144             if cfg[2] < 0 and cfg[4]!=0:
0145                 if cfg[2] == -1:
0146                     return cfg[2], -1, cfg[4]
0147                 if cfg[2] == -2:
0148                     return cfg[2], 9999999, cfg[4]
0149             elif cfg[2] < 0 and cfg[4]==0:
0150                 return cfg[2], inval, intINT
0151             else:
0152                 return cfg[2], (inval >> cfg[2] ) + cfg[3], intINT
0153 
0154 def GetLUTModified(orgLUT, shiftMap, isPT=False):
0155     tempmap = defaultdict(list)
0156     x= []
0157     y= []
0158     for i, pINT in enumerate(orgLUT):
0159         ii = i
0160         if isPT:
0161             ii+=1
0162         nshift, newidx, newINT = Modification(ii, pINT, shiftMap)
0163         tempmap[newidx].append(newINT)
0164 
0165     con = consecutive(list(tempmap.keys()))
0166     for k, v in tempmap.items():
0167         if k == -1 or k == 9999999:
0168             continue
0169         setv = set(v)
0170         x.append(k)
0171         if len(setv) == 1:
0172             y.append(v[0])
0173         elif len(setv) == 2 or len(setv) ==3:
0174             contv = Counter(v)
0175             ## The counter was sorted, descending in python3 and asending in python2
0176             ## This will result in slightly different LUT when running 
0177             isallequal = (len(set(contv.values())) == 1)
0178             if isallequal:
0179                 ## Using min for now. To be decided
0180                 y.append(min(contv.keys()))
0181             else:
0182                 y.append(contv.most_common(1)[0][0])
0183         else:
0184             print("----- allow up to 3 values per bins")
0185     return x, y
0186 
0187 def ProducedFinalLUT(LUT, k, isPT=False, bounderidx=False):
0188     k = np.asarray(k).astype(int)
0189     if isPT:
0190         k[:, [0, 1]] +=1
0191         k[k[:, 2] > 0, 3] +=1
0192     x, y = GetLUTModified(LUT, k,isPT)
0193     if x[0] != 0:
0194         for i in k:
0195             if i[2] < 0:
0196                 continue
0197             else:
0198                 i[3] -= x[0]
0199     k = k[k[:, 0].argsort()]
0200     x, y = GetLUTModified(LUT, k, isPT)
0201     if bounderidx:
0202         ### Has 
0203         if np.any(k[:,2] == -1):
0204             y.insert(0, str(k[k[:,2] == -1, 4][0]))
0205             k[k[:,2] == -1, 4] = 0
0206             k[k[:, 2] >= 0, 3] +=1
0207         if np.any(k[:,2] == -2):
0208             y.append(str(k[k[:,2] == -2, 4][0]))
0209             k[k[:,2] == -2, 4] = len(y)-1
0210     return k, x, y
0211 
0212 ### PT
0213 def LookUp(inpt, shiftmap, LUT, bounderidx):
0214     for i in shiftmap:
0215         if inpt >=i[0] and inpt < i[1]:
0216             if i[2] < 0:
0217                 if bounderidx:
0218                     return i[4], LUT[i[4]]
0219                 else:
0220                     return -1, i[4]
0221             else:
0222                 return (inpt >> i[2])+i[3], LUT[(inpt >> i[2])+i[3]]
0223 
0224 def ptChecks(shiftmap, LUT, bounderidx=False):
0225     l_ptOBs = []
0226     l_lkpts = []
0227     l_ptINTs = []
0228     for i in range(1,(1<<BITSABSRinv)-1):
0229         k = i * stepRinv
0230         pOB=(SpeedOfLight / 10)*Bfield*0.01/(k)
0231         idx, pINT = LookUp(i, shiftmap, LUT, bounderidx)
0232         ## We don't need to check beyond the boundary
0233         if pOB > (1<<BITSPT)*ptLSB or pOB < 2:
0234             continue
0235         l_ptOBs.append(pOB)
0236         l_ptINTs.append(int(round(pOB/ptLSB))*ptLSB)
0237         l_lkpts.append(float(pINT)*ptLSB)
0238         # Allow +-1 1LSB
0239         if (abs(pOB - float(pINT)*ptLSB) > ptLSB ):
0240             print("pt : ", i, pOB, pts[i-1], ptLUT[i-1], idx, pINT, int(pINT) *ptLSB)
0241 
0242     if drawPT is True:
0243         import matplotlib.pyplot as plt
0244         fig, (ax1, ax2) = plt.subplots(2, 1, layout='constrained')
0245         ax1.plot(l_ptOBs, l_ptINTs, 'bo', label="Round up")
0246         ax1.plot(l_ptOBs, l_lkpts, 'r+', label="Look up")
0247         ax1.set_xscale('log')
0248         ax1.legend()
0249         ax1.set_xlabel("input pt")
0250         ax1.set_ylabel("convert pt")
0251 
0252         diff = (np.array(l_ptINTs) - np.array(l_lkpts))/ptLSB
0253         ax2.plot(l_ptOBs, diff, 'bo', label="Difference")
0254         ax2.set_xscale('log')
0255         ax2.legend()
0256         ax2.set_xlabel("input pt")
0257         ax2.set_ylabel("pt LSB")
0258         ax2.set_ylim([-2.5, 2.5])
0259         plt.savefig("pt_check.png")
0260 
0261 def etaChecks(shiftmap, LUT, bounderidx=False):
0262     for i in range(0,(1<<(BITSTTTANL))):
0263         tanL = (maxTanL*i)/(1<<BITSTTTANL)
0264         lam =math.atan(tanL)
0265         theta =math.pi/2.0-lam
0266         eta = -math.log(math.tan(theta/2.0))
0267         ## We don't need to check beyond the boundary
0268         if eta > 2.45:
0269             continue
0270         eINT = int(eta*(1<<BITSETA)/math.pi)
0271         idx, etaINT = LookUp(i, shiftmap, LUT, bounderidx)
0272         if eta < 1.59 and (abs(eta - int(etaINT)*etaLSB) > etaLSB  ):
0273             print("eta : ", i, eta, eINT, idx, etaINT, int(etaINT)*etaLSB)
0274         ## For high eta region, we allow up to 2LSB
0275         if eta >= 1.59 and (abs(eta - int(etaINT)*etaLSB) > etaLSB * 2 ):
0276             print("eta : ", i, eta, eINT, idx, etaINT, int(etaINT)*etaLSB)
0277 
0278 
0279 def PrintPTLUT(k, ptLUT):
0280     shiftout = []
0281     for i in k:
0282         ii = [str(j) for j in i]
0283         temp = ",".join(ii)
0284         shiftout.append("{" + temp +"}")
0285     print("int ptShifts[{nOps}][5]={{".format(nOps=len(k)) + ", ".join(shiftout) + "};")
0286     print("ap_uint<BITSPT> ptLUT[{address}]={{".format(address=len(ptLUT))+', '.join(ptLUT)+'};')
0287 
0288 def PrintEtaLUT(k, etaLUT):
0289     shiftout = []
0290     for i in k:
0291         ii = [str(j) for j in i]
0292         temp = ",".join(ii)
0293         shiftout.append("{" + temp +"}")
0294     print("int etaShifts[{nOps}][5]={{".format(nOps=len(k)) + ", ".join(shiftout) + "};")
0295     print("ap_uint<BITSETA> etaLUT[{address}]={{".format(address=len(etaLUT)) +', '.join(etaLUT)+'};')
0296 
0297 
0298 
0299 
0300 if __name__ == "__main__":
0301     bounderidx=True
0302     GetPtLUT()
0303     k = GetLUTwrtLSB(pts, ptLSB, isPT=True, nbits=[1, 2, 3, 4, 5, 6, 7], lowerbound=1.9, upperbound=((1<<BITSPT)-1)*ptLSB)
0304     k, x, y = ProducedFinalLUT(ptLUT, k, isPT=True, bounderidx=bounderidx)
0305     ## K is the shift map
0306     ## X is the index to the LUT
0307     ## Y is the LUT
0308     con = consecutive(x)
0309     if len(con) > 1:
0310         print("index is not continuous: ", con)
0311     # ptChecks(k, y, bounderidx=bounderidx)
0312     # print("Total size of LUT is %d" % len(y))
0313     PrintPTLUT(k, y)
0314 
0315     # ### Eta
0316     GetEtaLUT()
0317     k =  GetLUTwrtLSB(etas, etaLSB, isPT=False, nbits=[0, 1, 2, 3, 5], upperbound =2.45)
0318     k, x, y = ProducedFinalLUT(etaLUT, k, bounderidx=bounderidx)
0319     con = consecutive(x)
0320     if len(con) > 1:
0321         print("index is not continuous: ", con)
0322     # etaChecks(k, y, bounderidx=bounderidx)
0323     # print("Total size of LUT is %d" % len(y))
0324     PrintEtaLUT(k, y)