Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-08-21 04:46:42

0001 /** Computation of input features for superclustering DNN. Used by plugins/TracksterLinkingBySuperClustering.cc and plugins/SuperclusteringSampleDumper.cc */
0002 // Author: Theo Cuisset - theo.cuisset@cern.ch
0003 // Date: 11/2023
0004 
0005 #ifndef __RecoHGCal_TICL_SuperclusteringDNNInputs_H__
0006 #define __RecoHGCal_TICL_SuperclusteringDNNInputs_H__
0007 
0008 #include <vector>
0009 #include <string>
0010 #include <memory>
0011 
0012 namespace ticl {
0013   class Trackster;
0014 
0015   // Abstract base class for DNN input preparation.
0016   class AbstractSuperclusteringDNNInput {
0017   public:
0018     virtual ~AbstractSuperclusteringDNNInput() = default;
0019 
0020     virtual unsigned int featureCount() const { return featureNames().size(); };
0021 
0022     /** Get name of features. Used for SuperclusteringSampleDumper branch names (inference does not use the names, only the indices) 
0023      * The default implementation is meant to be overriden by inheriting classes
0024     */
0025     virtual std::vector<std::string> featureNames() const {
0026       std::vector<std::string> defaultNames;
0027       defaultNames.reserve(featureCount());
0028       for (unsigned int i = 1; i <= featureCount(); i++) {
0029         defaultNames.push_back(std::string("nb_") + std::to_string(i));
0030       }
0031       return defaultNames;
0032     }
0033 
0034     /** Compute feature for seed and candidate pair */
0035     virtual std::vector<float> computeVector(ticl::Trackster const& ts_base, ticl::Trackster const& ts_toCluster) = 0;
0036   };
0037 
0038   /* First version of DNN by Alessandro Tarabini. Meant as a DNN equivalent of Mustache algorithm (superclustering algo in ECAL)
0039   Uses features : ['DeltaEta', 'DeltaPhi', 'multi_en', 'multi_eta', 'multi_pt', 'seedEta','seedPhi','seedEn', 'seedPt']
0040   */
0041   class SuperclusteringDNNInputV1 : public AbstractSuperclusteringDNNInput {
0042   public:
0043     unsigned int featureCount() const override { return 9; }
0044 
0045     std::vector<float> computeVector(ticl::Trackster const& ts_base, ticl::Trackster const& ts_toCluster) override;
0046 
0047     std::vector<std::string> featureNames() const override {
0048       return {"DeltaEtaBaryc",
0049               "DeltaPhiBaryc",
0050               "multi_en",
0051               "multi_eta",
0052               "multi_pt",
0053               "seedEta",
0054               "seedPhi",
0055               "seedEn",
0056               "seedPt"};
0057     }
0058   };
0059 
0060   /* Second version of DNN by Alessandro Tarabini, making use of HGCAL-specific features.
0061   Uses features : ['DeltaEta', 'DeltaPhi', 'multi_en', 'multi_eta', 'multi_pt', 'seedEta','seedPhi','seedEn', 'seedPt', 'theta', 'theta_xz_seedFrame', 'theta_yz_seedFrame', 'theta_xy_cmsFrame', 'theta_yz_cmsFrame', 'theta_xz_cmsFrame', 'explVar', 'explVarRatio']
0062   */
0063   class SuperclusteringDNNInputV2 : public AbstractSuperclusteringDNNInput {
0064   public:
0065     unsigned int featureCount() const override { return 17; }
0066 
0067     std::vector<float> computeVector(ticl::Trackster const& ts_base, ticl::Trackster const& ts_toCluster) override;
0068 
0069     std::vector<std::string> featureNames() const override {
0070       return {"DeltaEtaBaryc",
0071               "DeltaPhiBaryc",
0072               "multi_en",
0073               "multi_eta",
0074               "multi_pt",
0075               "seedEta",
0076               "seedPhi",
0077               "seedEn",
0078               "seedPt",
0079               "theta",
0080               "theta_xz_seedFrame",
0081               "theta_yz_seedFrame",
0082               "theta_xy_cmsFrame",
0083               "theta_yz_cmsFrame",
0084               "theta_xz_cmsFrame",
0085               "explVar",
0086               "explVarRatio"};
0087     }
0088   };
0089 
0090   std::unique_ptr<AbstractSuperclusteringDNNInput> makeSuperclusteringDNNInputFromString(std::string dnnVersion);
0091 }  // namespace ticl
0092 
0093 #endif