Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-02-26 04:25:33

0001 #include <cstddef>
0002 
0003 // CMSSW imports
0004 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0005 
0006 // Alpaka imports
0007 //#include <alpaka/alpaka.hpp>
0008 #include "HeterogeneousCore/AlpakaInterface/interface/traits.h"
0009 #include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
0010 
0011 // HGCal imports
0012 #include "RecoLocalCalo/HGCalRecAlgos/interface/alpaka/HGCalRecHitCalibrationAlgorithms.h"
0013 #include "DataFormats/HGCalDigi/interface/HGCalRawDataDefinitions.h"
0014 
0015 namespace ALPAKA_ACCELERATOR_NAMESPACE {
0016 
0017   using namespace cms::alpakatools;
0018 
0019   //
0020   struct HGCalRecHitCalibrationKernel_flagRecHits {
0021     ALPAKA_FN_ACC void operator()(Acc1D const& acc,
0022                                   HGCalDigiDevice::View digis,
0023                                   HGCalRecHitDevice::View recHits,
0024                                   HGCalCalibParamDevice::ConstView calibs) const {
0025       for (auto idx : uniform_elements(acc, digis.metadata().size())) {
0026         auto calib = calibs[idx];
0027         bool calibvalid = calib.valid();
0028         auto digi = digis[idx];
0029         auto digiflags = digi.flags();
0030         //recHits[idx].flags() = digiflags;
0031         bool isAvailable((digiflags != hgcal::DIGI_FLAG::Invalid) && (digiflags != hgcal::DIGI_FLAG::NotAvailable) &&
0032                          calibvalid);
0033         bool isToAavailable((digiflags != hgcal::DIGI_FLAG::ZS_ToA) && (digiflags != hgcal::DIGI_FLAG::ZS_ToA_ADCm1));
0034         recHits[idx].flags() = (!isAvailable) * hgcalrechit::HGCalRecHitFlags::EnergyInvalid +
0035                                (!isToAavailable) * hgcalrechit::HGCalRecHitFlags::TimeInvalid;
0036       }
0037     }
0038   };
0039 
0040   //
0041   struct HGCalRecHitCalibrationKernel_adcToCharge {
0042     ALPAKA_FN_ACC void operator()(Acc1D const& acc,
0043                                   HGCalDigiDevice::View digis,
0044                                   HGCalRecHitDevice::View recHits,
0045                                   HGCalCalibParamDevice::ConstView calibs) const {
0046       auto adc_denoise =
0047           [&](uint32_t adc, uint32_t cm, uint32_t adcm1, float adc_ped, float cm_slope, float cm_ped, float bxm1_slope) {
0048             float cmf = cm_slope * (0.5 * float(cm) - cm_ped);
0049             return ((adc - adc_ped) - cmf - bxm1_slope * (adcm1 - adc_ped - cmf));
0050           };
0051 
0052       auto tot_linearization =
0053           [&](uint32_t tot, float tot_lin, float tot2adc, float tot_ped, float tot_p0, float tot_p1, float tot_p2) {
0054             bool isLin(tot > tot_lin);
0055             bool isNotLin(!isLin);
0056             return isLin * (tot2adc * (tot - tot_ped)) + isNotLin * (tot_p0 + tot_p1 * tot + tot_p2 * tot * tot);
0057           };
0058 
0059       for (auto idx : uniform_elements(acc, digis.metadata().size())) {
0060         auto calib = calibs[idx];
0061         bool calibvalid = calib.valid();
0062         auto digi = digis[idx];
0063         auto digiflags = digi.flags();
0064         bool isAvailable((digiflags != hgcal::DIGI_FLAG::Invalid) && (digiflags != hgcal::DIGI_FLAG::NotAvailable) &&
0065                          calibvalid);
0066         bool useTOT((digi.tctp() == 3) && isAvailable);
0067         bool useADC(!useTOT && isAvailable);
0068         recHits[idx].energy() = useADC * adc_denoise(digi.adc(),
0069                                                      digi.cm(),
0070                                                      digi.adcm1(),
0071                                                      calib.ADC_ped(),
0072                                                      calib.CM_slope(),
0073                                                      calib.CM_ped(),
0074                                                      calib.BXm1_slope()) +
0075                                 useTOT * tot_linearization(digi.tot(),
0076                                                            calib.TOT_lin(),
0077                                                            calib.TOTtoADC(),
0078                                                            calib.TOT_ped(),
0079                                                            calib.TOT_P0(),
0080                                                            calib.TOT_P1(),
0081                                                            calib.TOT_P2());
0082 
0083         //after denoising/linearization apply the MIP scale
0084         recHits[idx].energy() *= calib.MIPS_scale();
0085       }
0086     }
0087   };
0088 
0089   //
0090   struct HGCalRecHitCalibrationKernel_toaToTime {
0091     ALPAKA_FN_ACC void operator()(Acc1D const& acc,
0092                                   HGCalDigiDevice::View digis,
0093                                   HGCalRecHitDevice::View recHits,
0094                                   HGCalCalibParamDevice::ConstView calibs) const {
0095       auto toa_inl_corr = [&](uint32_t toa, hgcalrechit::Vector32f ctdc_p, hgcalrechit::Vector8f ftdc_p) {
0096         auto gray = toa / 256;
0097         auto ptdc = toa % 256;
0098         auto ctdc = ptdc / 8;
0099         auto ftdc = ptdc % 8;
0100         auto ctdc_corr = uint(ctdc - ctdc_p[ctdc]) % 32;
0101         auto ftdc_corr = uint(ftdc - ftdc_p[ftdc]) % 8;
0102         return (ftdc_corr + 8 * ctdc_corr + 256 * gray) % 1024;
0103       };
0104 
0105       auto toa_tw_corr = [&](uint32_t toa, float energy, hgcalrechit::Vector3f p) {
0106         return toa - ((energy > p[2]) ? (p[0] + p[1] * std::log(energy - p[2])) : 0.f);
0107       };
0108 
0109       for (auto idx : uniform_elements(acc, digis.metadata().size())) {
0110         auto calib = calibs[idx];
0111         bool calibvalid = calib.valid();
0112         auto digi = digis[idx];
0113         auto digiflags = digi.flags();
0114         bool isAvailable((digiflags != hgcal::DIGI_FLAG::Invalid) && (digiflags != hgcal::DIGI_FLAG::NotAvailable) &&
0115                          calibvalid);
0116         bool isToAavailable((digiflags != hgcal::DIGI_FLAG::ZS_ToA) && (digiflags != hgcal::DIGI_FLAG::ZS_ToA_ADCm1));
0117         bool isGood(isAvailable && isToAavailable);
0118         //INL correction
0119         auto toa = isGood * toa_inl_corr(digi.toa(), calib.TOA_CTDC(), calib.TOA_FTDC());
0120         //timewalk correction
0121         toa = isGood * toa_tw_corr(toa, recHits[idx].energy(), calib.TOA_TW());
0122         //toa to ps
0123         recHits[idx].time() = toa * hgcalrechit::TOAtops;
0124       }
0125     }
0126   };
0127 
0128   struct HGCalRecHitCalibrationKernel_printRecHits {
0129     ALPAKA_FN_ACC void operator()(Acc1D const& acc, HGCalRecHitDevice::ConstView view, int size) const {
0130       for (int i = 0; i < size; ++i) {
0131         auto const& recHit = view[i];
0132         printf("%d\t%f\t%f\t%d\n", i, recHit.energy(), recHit.time(), recHit.flags());
0133       }
0134     }
0135   };
0136 
0137   HGCalRecHitDevice HGCalRecHitCalibrationAlgorithms::calibrate(Queue& queue,
0138                                                                 HGCalDigiHost const& host_digis,
0139                                                                 HGCalCalibParamDevice const& device_calib,
0140                                                                 HGCalConfigParamDevice const& device_config) const {
0141     LogDebug("HGCalRecHitCalibrationAlgorithms") << "\n\nINFO -- Start of calibrate\n\n" << std::endl;
0142 
0143     LogDebug("HGCalRecHitCalibrationAlgorithms") << "\n\nINFO -- Copying the digis to the device\n\n" << std::endl;
0144     HGCalDigiDevice device_digis(host_digis.view().metadata().size(), queue);
0145     alpaka::memcpy(queue, device_digis.buffer(), host_digis.const_buffer());
0146 
0147     LogDebug("HGCalRecHitCalibrationAlgorithms")
0148         << "\n\nINFO -- Allocating rechits buffer and initiating values" << std::endl;
0149     HGCalRecHitDevice device_recHits(device_digis.view().metadata().size(), queue);
0150 
0151     // number of items per group
0152     uint32_t items = n_threads_;
0153     // use as many groups as needed to cover the whole problem
0154     uint32_t groups = divide_up_by(device_digis.view().metadata().size(), items);
0155     // map items to
0156     //   - threads with a single element per thread on a GPU backend
0157     //   - elements within a single thread on a CPU backend
0158     auto grid = make_workdiv<Acc1D>(groups, items);
0159     LogDebug("HGCalRecHitCalibrationAlgorithms") << "N groups: " << groups << "\tN items: " << items << std::endl;
0160 
0161     alpaka::exec<Acc1D>(queue,
0162                         grid,
0163                         HGCalRecHitCalibrationKernel_flagRecHits{},
0164                         device_digis.view(),
0165                         device_recHits.view(),
0166                         device_calib.view());
0167     alpaka::exec<Acc1D>(queue,
0168                         grid,
0169                         HGCalRecHitCalibrationKernel_adcToCharge{},
0170                         device_digis.view(),
0171                         device_recHits.view(),
0172                         device_calib.view());
0173     alpaka::exec<Acc1D>(queue,
0174                         grid,
0175                         HGCalRecHitCalibrationKernel_toaToTime{},
0176                         device_digis.view(),
0177                         device_recHits.view(),
0178                         device_calib.view());
0179 
0180     LogDebug("HGCalRecHitCalibrationAlgorithms") << "Input recHits: " << std::endl;
0181 #ifdef EDM_ML_DEBUG
0182     int n_hits_to_print = 10;
0183     print_recHit_device(queue, *device_recHits, n_hits_to_print);
0184 #endif
0185 
0186     return device_recHits;
0187   }
0188 
0189   void HGCalRecHitCalibrationAlgorithms::print(HGCalDigiHost const& digis, int max) const {
0190     int max_ = max > 0 ? max : digis.view().metadata().size();
0191     for (int i = 0; i < max_; i++) {
0192       LogDebug("HGCalRecHitCalibrationAlgorithms")
0193           << i << digis.view()[i].tot() << "\t" << digis.view()[i].toa() << "\t" << digis.view()[i].cm() << "\t"
0194           << digis.view()[i].flags() << std::endl;
0195     }
0196   }
0197 
0198   void HGCalRecHitCalibrationAlgorithms::print_digi_device(HGCalDigiDevice const& digis, int max) const {
0199     int max_ = max > 0 ? max : digis.view().metadata().size();
0200     for (int i = 0; i < max_; i++) {
0201       LogDebug("HGCalRecHitCalibrationAlgorithms")
0202           << i << digis.view()[i].tot() << "\t" << digis.view()[i].toa() << "\t" << digis.view()[i].cm() << "\t"
0203           << digis.view()[i].flags() << std::endl;
0204     }
0205   }
0206 
0207   void HGCalRecHitCalibrationAlgorithms::print_recHit_device(
0208       Queue& queue, PortableHostCollection<hgcalrechit::HGCalRecHitSoALayout<> >::View const& recHits, int max) const {
0209     auto grid = make_workdiv<Acc1D>(1, 1);
0210     auto size = max > 0 ? max : recHits.metadata().size();
0211     alpaka::exec<Acc1D>(queue, grid, HGCalRecHitCalibrationKernel_printRecHits{}, recHits, size);
0212 
0213     // ensure that the print operations are complete before returning
0214     alpaka::wait(queue);
0215   }
0216 
0217 }  // namespace ALPAKA_ACCELERATOR_NAMESPACE