Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-05-27 01:56:29

0001 #include <cstddef>
0002 
0003 // CMSSW imports
0004 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0005 
0006 // Alpaka imports
0007 //#include <alpaka/alpaka.hpp>
0008 #include "HeterogeneousCore/AlpakaInterface/interface/traits.h"
0009 #include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
0010 
0011 // HGCal imports
0012 #include "RecoLocalCalo/HGCalRecAlgos/interface/alpaka/HGCalRecHitCalibrationAlgorithms.h"
0013 #include "DataFormats/HGCalDigi/interface/HGCalRawDataDefinitions.h"
0014 
0015 namespace ALPAKA_ACCELERATOR_NAMESPACE {
0016 
0017   using namespace cms::alpakatools;
0018 
0019   //
0020   struct HGCalRecHitCalibrationKernel_flagRecHits {
0021     ALPAKA_FN_ACC void operator()(Acc1D const& acc,
0022                                   HGCalDigiDevice::View digis,
0023                                   HGCalRecHitDevice::View recHits,
0024                                   HGCalCalibParamDevice::ConstView calibs) const {
0025       for (auto idx : uniform_elements(acc, digis.metadata().size())) {
0026         auto calib = calibs[idx];
0027         bool calibvalid = calib.valid();
0028         auto digi = digis[idx];
0029         auto digiflags = digi.flags();
0030         //recHits[idx].flags() = digiflags;
0031         bool isAvailable((digiflags != ::hgcal::DIGI_FLAG::Invalid) &&
0032                          (digiflags != ::hgcal::DIGI_FLAG::NotAvailable) && calibvalid);
0033         bool isToAavailable((digiflags != ::hgcal::DIGI_FLAG::ZS_ToA) &&
0034                             (digiflags != ::hgcal::DIGI_FLAG::ZS_ToA_ADCm1));
0035         recHits[idx].flags() = (!isAvailable) * hgcalrechit::HGCalRecHitFlags::EnergyInvalid +
0036                                (!isToAavailable) * hgcalrechit::HGCalRecHitFlags::TimeInvalid;
0037       }
0038     }
0039   };
0040 
0041   //
0042   struct HGCalRecHitCalibrationKernel_adcToCharge {
0043     ALPAKA_FN_ACC void operator()(Acc1D const& acc,
0044                                   HGCalDigiDevice::View digis,
0045                                   HGCalRecHitDevice::View recHits,
0046                                   HGCalCalibParamDevice::ConstView calibs) const {
0047       auto adc_denoise =
0048           [&](uint32_t adc, uint32_t cm, uint32_t adcm1, float adc_ped, float cm_slope, float cm_ped, float bxm1_slope) {
0049             float cmf = cm_slope * (0.5 * float(cm) - cm_ped);
0050             return ((adc - adc_ped) - cmf - bxm1_slope * (adcm1 - adc_ped - cmf));
0051           };
0052 
0053       auto tot_linearization =
0054           [&](uint32_t tot, float tot_lin, float tot2adc, float tot_ped, float tot_p0, float tot_p1, float tot_p2) {
0055             bool isLin(tot > tot_lin);
0056             bool isNotLin(!isLin);
0057             return isLin * (tot2adc * (tot - tot_ped)) + isNotLin * (tot_p0 + tot_p1 * tot + tot_p2 * tot * tot);
0058           };
0059 
0060       for (auto idx : uniform_elements(acc, digis.metadata().size())) {
0061         auto calib = calibs[idx];
0062         bool calibvalid = calib.valid();
0063         auto digi = digis[idx];
0064         auto digiflags = digi.flags();
0065         bool isAvailable((digiflags != ::hgcal::DIGI_FLAG::Invalid) &&
0066                          (digiflags != ::hgcal::DIGI_FLAG::NotAvailable) && calibvalid);
0067         bool useTOT((digi.tctp() == 3) && isAvailable);
0068         bool useADC(!useTOT && isAvailable);
0069         recHits[idx].energy() = useADC * adc_denoise(digi.adc(),
0070                                                      digi.cm(),
0071                                                      digi.adcm1(),
0072                                                      calib.ADC_ped(),
0073                                                      calib.CM_slope(),
0074                                                      calib.CM_ped(),
0075                                                      calib.BXm1_slope()) +
0076                                 useTOT * tot_linearization(digi.tot(),
0077                                                            calib.TOT_lin(),
0078                                                            calib.TOTtoADC(),
0079                                                            calib.TOT_ped(),
0080                                                            calib.TOT_P0(),
0081                                                            calib.TOT_P1(),
0082                                                            calib.TOT_P2());
0083 
0084         //after denoising/linearization apply the MIP scale
0085         recHits[idx].energy() *= calib.MIPS_scale();
0086       }
0087     }
0088   };
0089 
0090   //
0091   struct HGCalRecHitCalibrationKernel_toaToTime {
0092     ALPAKA_FN_ACC void operator()(Acc1D const& acc,
0093                                   HGCalDigiDevice::View digis,
0094                                   HGCalRecHitDevice::View recHits,
0095                                   HGCalCalibParamDevice::ConstView calibs) const {
0096       auto toa_inl_corr = [&](uint32_t toa, hgcalrechit::Vector32f ctdc_p, hgcalrechit::Vector8f ftdc_p) {
0097         auto gray = toa / 256;
0098         auto ptdc = toa % 256;
0099         auto ctdc = ptdc / 8;
0100         auto ftdc = ptdc % 8;
0101         auto ctdc_corr = uint(ctdc - ctdc_p[ctdc]) % 32;
0102         auto ftdc_corr = uint(ftdc - ftdc_p[ftdc]) % 8;
0103         return (ftdc_corr + 8 * ctdc_corr + 256 * gray) % 1024;
0104       };
0105 
0106       auto toa_tw_corr = [&](uint32_t toa, float energy, hgcalrechit::Vector3f p) {
0107         return toa - ((energy > p[2]) ? (p[0] + p[1] * std::log(energy - p[2])) : 0.f);
0108       };
0109 
0110       for (auto idx : uniform_elements(acc, digis.metadata().size())) {
0111         auto calib = calibs[idx];
0112         bool calibvalid = calib.valid();
0113         auto digi = digis[idx];
0114         auto digiflags = digi.flags();
0115         bool isAvailable((digiflags != ::hgcal::DIGI_FLAG::Invalid) &&
0116                          (digiflags != ::hgcal::DIGI_FLAG::NotAvailable) && calibvalid);
0117         bool isToAavailable((digiflags != ::hgcal::DIGI_FLAG::ZS_ToA) &&
0118                             (digiflags != ::hgcal::DIGI_FLAG::ZS_ToA_ADCm1));
0119         bool isGood(isAvailable && isToAavailable);
0120         //INL correction
0121         auto toa = isGood * toa_inl_corr(digi.toa(), calib.TOA_CTDC(), calib.TOA_FTDC());
0122         //timewalk correction
0123         toa = isGood * toa_tw_corr(toa, recHits[idx].energy(), calib.TOA_TW());
0124         //toa to ps
0125         recHits[idx].time() = toa * hgcalrechit::TOAtops;
0126       }
0127     }
0128   };
0129 
0130   struct HGCalRecHitCalibrationKernel_handleCalibCell {
0131     ALPAKA_FN_ACC void operator()(Acc1D const& acc,
0132                                   HGCalDigiDevice::View digis,
0133                                   HGCalRecHitDevice::View recHits,
0134                                   HGCalCalibParamDevice::ConstView calibs,
0135                                   HGCalMappingCellParamDevice::ConstView maps,
0136                                   HGCalDenseIndexInfoDevice::ConstView index) const {
0137       auto time_average = [&](float time_surr, float time_calib, float energy_surr, float energy_calib) {
0138         bool is_time_surr(time_surr > 0);
0139         bool is_time_calib(time_calib > 0);
0140         float totalEn = (is_time_surr * energy_surr + is_time_calib * energy_calib);
0141         float weighted_average =
0142             (totalEn > 0)
0143                 ? (is_time_surr * energy_surr * time_surr + is_time_calib * energy_calib * time_calib) / totalEn
0144                 : 0.0f;
0145         return weighted_average;
0146       };
0147 
0148       for (auto idx : uniform_elements(acc, digis.metadata().size())) {
0149         auto calib = calibs[idx];
0150         bool calibvalid = calib.valid();
0151         auto digi = digis[idx];
0152         auto digiflags = digi.flags();
0153         bool isAvailable((digiflags != ::hgcal::DIGI_FLAG::Invalid) &&
0154                          (digiflags != ::hgcal::DIGI_FLAG::NotAvailable) && calibvalid);
0155         bool isToAavailable((digiflags != ::hgcal::DIGI_FLAG::ZS_ToA) &&
0156                             (digiflags != ::hgcal::DIGI_FLAG::ZS_ToA_ADCm1));
0157 
0158         auto cellIndex = index[idx].cellInfoIdx();
0159         bool isCalibCell(maps[cellIndex].iscalib());
0160         int offset = maps[cellIndex].caliboffset();  //Calibration-to-surrounding cell offset
0161         bool is_surr_cell((offset != 0) && isAvailable && isCalibCell);
0162 
0163         //Effectively operate only on the cell that surrounds the calibration cells
0164         if (!is_surr_cell) {
0165           continue;
0166         }
0167 
0168         recHits[idx + offset].flags() = hgcalrechit::HGCalRecHitFlags::Normal;
0169 
0170         recHits[idx + offset].time() = isToAavailable * time_average(recHits[idx + offset].time(),
0171                                                                      recHits[idx].time(),
0172                                                                      recHits[idx + offset].energy(),
0173                                                                      recHits[idx].energy());
0174 
0175         bool is_negative_surr_energy(recHits[idx + offset].energy() < 0);
0176         auto negative_energy_correction = (-1.0 * recHits[idx + offset].energy()) * is_negative_surr_energy;
0177 
0178         recHits[idx + offset].energy() += (negative_energy_correction + recHits[idx].energy());
0179       }
0180     }
0181   };
0182 
0183   struct HGCalRecHitCalibrationKernel_printRecHits {
0184     ALPAKA_FN_ACC void operator()(Acc1D const& acc, HGCalRecHitDevice::ConstView view, int size) const {
0185       for (int i = 0; i < size; ++i) {
0186         auto const& recHit = view[i];
0187         printf("%d\t%f\t%f\t%d\n", i, recHit.energy(), recHit.time(), recHit.flags());
0188       }
0189     }
0190   };
0191 
0192   HGCalRecHitDevice HGCalRecHitCalibrationAlgorithms::calibrate(Queue& queue,
0193                                                                 HGCalDigiHost const& host_digis,
0194                                                                 HGCalCalibParamDevice const& device_calib,
0195                                                                 HGCalMappingCellParamDevice const& device_mapping,
0196                                                                 HGCalDenseIndexInfoDevice const& device_index) const {
0197     LogDebug("HGCalRecHitCalibrationAlgorithms") << "\n\nINFO -- Start of calibrate\n\n" << std::endl;
0198 
0199     LogDebug("HGCalRecHitCalibrationAlgorithms") << "\n\nINFO -- Copying the digis to the device\n\n" << std::endl;
0200     HGCalDigiDevice device_digis(host_digis.view().metadata().size(), queue);
0201     alpaka::memcpy(queue, device_digis.buffer(), host_digis.const_buffer());
0202 
0203     LogDebug("HGCalRecHitCalibrationAlgorithms")
0204         << "\n\nINFO -- Allocating rechits buffer and initiating values" << std::endl;
0205     HGCalRecHitDevice device_recHits(device_digis.view().metadata().size(), queue);
0206 
0207     // number of items per group
0208     uint32_t items = n_threads_;
0209     // use as many groups as needed to cover the whole problem
0210     uint32_t groups = divide_up_by(device_digis.view().metadata().size(), items);
0211     // map items to
0212     //   - threads with a single element per thread on a GPU backend
0213     //   - elements within a single thread on a CPU backend
0214     auto grid = make_workdiv<Acc1D>(groups, items);
0215     LogDebug("HGCalRecHitCalibrationAlgorithms") << "N groups: " << groups << "\tN items: " << items << std::endl;
0216 
0217     alpaka::exec<Acc1D>(queue,
0218                         grid,
0219                         HGCalRecHitCalibrationKernel_flagRecHits{},
0220                         device_digis.view(),
0221                         device_recHits.view(),
0222                         device_calib.view());
0223     alpaka::exec<Acc1D>(queue,
0224                         grid,
0225                         HGCalRecHitCalibrationKernel_adcToCharge{},
0226                         device_digis.view(),
0227                         device_recHits.view(),
0228                         device_calib.view());
0229     alpaka::exec<Acc1D>(queue,
0230                         grid,
0231                         HGCalRecHitCalibrationKernel_toaToTime{},
0232                         device_digis.view(),
0233                         device_recHits.view(),
0234                         device_calib.view());
0235     alpaka::exec<Acc1D>(queue,
0236                         grid,
0237                         HGCalRecHitCalibrationKernel_handleCalibCell{},
0238                         device_digis.view(),
0239                         device_recHits.view(),
0240                         device_calib.view(),
0241                         device_mapping.view(),
0242                         device_index.view());
0243 
0244     LogDebug("HGCalRecHitCalibrationAlgorithms") << "Input recHits: " << std::endl;
0245 #ifdef EDM_ML_DEBUG
0246     int n_hits_to_print = 10;
0247     print_recHit_device(queue, *device_recHits, n_hits_to_print);
0248 #endif
0249 
0250     return device_recHits;
0251   }
0252 
0253   void HGCalRecHitCalibrationAlgorithms::print(HGCalDigiHost const& digis, int max) const {
0254     int max_ = max > 0 ? max : digis.view().metadata().size();
0255     for (int i = 0; i < max_; i++) {
0256       LogDebug("HGCalRecHitCalibrationAlgorithms")
0257           << i << digis.view()[i].tot() << "\t" << digis.view()[i].toa() << "\t" << digis.view()[i].cm() << "\t"
0258           << digis.view()[i].flags() << std::endl;
0259     }
0260   }
0261 
0262   void HGCalRecHitCalibrationAlgorithms::print_digi_device(HGCalDigiDevice const& digis, int max) const {
0263     int max_ = max > 0 ? max : digis.view().metadata().size();
0264     for (int i = 0; i < max_; i++) {
0265       LogDebug("HGCalRecHitCalibrationAlgorithms")
0266           << i << digis.view()[i].tot() << "\t" << digis.view()[i].toa() << "\t" << digis.view()[i].cm() << "\t"
0267           << digis.view()[i].flags() << std::endl;
0268     }
0269   }
0270 
0271   void HGCalRecHitCalibrationAlgorithms::print_recHit_device(
0272       Queue& queue, PortableHostCollection<hgcalrechit::HGCalRecHitSoALayout<> >::View const& recHits, int max) const {
0273     auto grid = make_workdiv<Acc1D>(1, 1);
0274     auto size = max > 0 ? max : recHits.metadata().size();
0275     alpaka::exec<Acc1D>(queue, grid, HGCalRecHitCalibrationKernel_printRecHits{}, recHits, size);
0276 
0277     // ensure that the print operations are complete before returning
0278     alpaka::wait(queue);
0279   }
0280 
0281 }  // namespace ALPAKA_ACCELERATOR_NAMESPACE