File indexing completed on 2024-08-21 04:46:42
0001
0002
0003
0004
0005 #ifndef __RecoHGCal_TICL_SuperclusteringDNNInputs_H__
0006 #define __RecoHGCal_TICL_SuperclusteringDNNInputs_H__
0007
0008 #include <vector>
0009 #include <string>
0010 #include <memory>
0011
0012 namespace ticl {
0013 class Trackster;
0014
0015
0016 class AbstractSuperclusteringDNNInput {
0017 public:
0018 virtual ~AbstractSuperclusteringDNNInput() = default;
0019
0020 virtual unsigned int featureCount() const { return featureNames().size(); };
0021
0022
0023
0024
0025 virtual std::vector<std::string> featureNames() const {
0026 std::vector<std::string> defaultNames;
0027 defaultNames.reserve(featureCount());
0028 for (unsigned int i = 1; i <= featureCount(); i++) {
0029 defaultNames.push_back(std::string("nb_") + std::to_string(i));
0030 }
0031 return defaultNames;
0032 }
0033
0034
0035 virtual std::vector<float> computeVector(ticl::Trackster const& ts_base, ticl::Trackster const& ts_toCluster) = 0;
0036 };
0037
0038
0039
0040
0041 class SuperclusteringDNNInputV1 : public AbstractSuperclusteringDNNInput {
0042 public:
0043 unsigned int featureCount() const override { return 9; }
0044
0045 std::vector<float> computeVector(ticl::Trackster const& ts_base, ticl::Trackster const& ts_toCluster) override;
0046
0047 std::vector<std::string> featureNames() const override {
0048 return {"DeltaEtaBaryc",
0049 "DeltaPhiBaryc",
0050 "multi_en",
0051 "multi_eta",
0052 "multi_pt",
0053 "seedEta",
0054 "seedPhi",
0055 "seedEn",
0056 "seedPt"};
0057 }
0058 };
0059
0060
0061
0062
0063 class SuperclusteringDNNInputV2 : public AbstractSuperclusteringDNNInput {
0064 public:
0065 unsigned int featureCount() const override { return 17; }
0066
0067 std::vector<float> computeVector(ticl::Trackster const& ts_base, ticl::Trackster const& ts_toCluster) override;
0068
0069 std::vector<std::string> featureNames() const override {
0070 return {"DeltaEtaBaryc",
0071 "DeltaPhiBaryc",
0072 "multi_en",
0073 "multi_eta",
0074 "multi_pt",
0075 "seedEta",
0076 "seedPhi",
0077 "seedEn",
0078 "seedPt",
0079 "theta",
0080 "theta_xz_seedFrame",
0081 "theta_yz_seedFrame",
0082 "theta_xy_cmsFrame",
0083 "theta_yz_cmsFrame",
0084 "theta_xz_cmsFrame",
0085 "explVar",
0086 "explVarRatio"};
0087 }
0088 };
0089
0090 std::unique_ptr<AbstractSuperclusteringDNNInput> makeSuperclusteringDNNInputFromString(std::string dnnVersion);
0091 }
0092
0093 #endif