Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-12-08 23:44:23

0001 #ifndef SimDataFormats_Associations_TICLAssociationMap_h
0002 #define SimDataFormats_Associations_TICLAssociationMap_h
0003 
0004 #include <vector>
0005 #include <utility>
0006 #include <algorithm>
0007 #include <stdexcept>
0008 #include <type_traits>
0009 #include <iostream>
0010 #include <cassert>
0011 
0012 #include <limits>
0013 
0014 // CMSSW specific includes
0015 #include "DataFormats/Common/interface/Ref.h"
0016 #include "DataFormats/Common/interface/RefProd.h"
0017 #include "FWCore/Framework/interface/Event.h"
0018 
0019 namespace ticl {
0020 
0021   // Define wrapper types to differentiate between fraction and shared energy
0022   struct FractionType {
0023     float value;
0024     FractionType(float v = 0.0f) : value(v) {}
0025     FractionType& operator+=(float v) {
0026       value += v;
0027       return *this;
0028     }
0029   };
0030 
0031   struct SharedEnergyType {
0032     float value;
0033     SharedEnergyType(float v = 0.0f) : value(v) {}
0034     SharedEnergyType& operator+=(float v) {
0035       value += v;
0036       return *this;
0037     }
0038   };
0039 
0040   // AssociationElement class to store index and value, and provide methods directly
0041   template <typename V>
0042   class AssociationElement {
0043   public:
0044     using value_type = V;
0045     AssociationElement() : index_(std::numeric_limits<unsigned int>::max()) {
0046       if constexpr (std::is_same_v<V, FractionType> || std::is_same_v<V, SharedEnergyType>) {
0047         value_.value = -1.f;
0048       } else if constexpr (std::is_same_v<V, std::pair<FractionType, float>> ||
0049                            std::is_same_v<V, std::pair<SharedEnergyType, float>>) {
0050         value_.first.value = -1.f;
0051       }
0052     }
0053     AssociationElement(unsigned int index, const V& value) : index_(index), value_(value) {}
0054 
0055     unsigned int index() const { return index_; }
0056 
0057     bool isValid() const {
0058       if constexpr (std::is_same_v<V, FractionType> || std::is_same_v<V, SharedEnergyType>) {
0059         return value_.value >= 0.f;
0060       } else if constexpr (std::is_same_v<V, std::pair<FractionType, float>> ||
0061                            std::is_same_v<V, std::pair<SharedEnergyType, float>>) {
0062         return value_.first.value >= 0.f;
0063       }
0064     }
0065 
0066     // Enable fraction() if ValueType is FractionType
0067     template <typename T = V, typename std::enable_if_t<std::is_same_v<T, FractionType>, int> = 0>
0068     float fraction() const {
0069       return value_.value;
0070     }
0071 
0072     // Enable sharedEnergy() if ValueType is SharedEnergyType
0073     template <typename T = V, typename std::enable_if_t<std::is_same_v<T, SharedEnergyType>, int> = 0>
0074     float sharedEnergy() const {
0075       return value_.value;
0076     }
0077 
0078     // Enable fraction() and score() if ValueType is std::pair<FractionType, float>
0079     template <typename T = V, typename std::enable_if_t<std::is_same_v<T, std::pair<FractionType, float>>, int> = 0>
0080     float fraction() const {
0081       return value_.first.value;
0082     }
0083     template <typename T = V, typename std::enable_if_t<std::is_same_v<T, std::pair<FractionType, float>>, int> = 0>
0084     float score() const {
0085       return value_.second;
0086     }
0087 
0088     // Enable sharedEnergy() and score() if ValueType is std::pair<SharedEnergyType, float>
0089     template <typename T = V, typename std::enable_if_t<std::is_same_v<T, std::pair<SharedEnergyType, float>>, int> = 0>
0090     float sharedEnergy() const {
0091       return value_.first.value;
0092     }
0093     template <typename T = V, typename std::enable_if_t<std::is_same_v<T, std::pair<SharedEnergyType, float>>, int> = 0>
0094     float score() const {
0095       return value_.second;
0096     }
0097 
0098     // Method to accumulate values
0099     void accumulate(const V& other_value) {
0100       if constexpr (std::is_same_v<V, FractionType> || std::is_same_v<V, SharedEnergyType>) {
0101         value_.value += other_value.value;
0102       } else if constexpr (std::is_same_v<V, std::pair<FractionType, float>> ||
0103                            std::is_same_v<V, std::pair<SharedEnergyType, float>>) {
0104         value_.first.value += other_value.first.value;
0105         value_.second += other_value.second;
0106       }
0107     }
0108     bool operator==(const AssociationElement& other) const {
0109       return index_ == other.index_ && value_.value == other.value_.value;
0110     }
0111 
0112     bool operator!=(const AssociationElement& other) const { return !(*this == other); }
0113 
0114   private:
0115     unsigned int index_;
0116     V value_;
0117   };
0118 
0119   // Type traits to differentiate between one-to-one and one-to-many maps
0120   template <typename T>
0121   struct MapTraits;
0122 
0123   template <typename V>
0124   struct MapTraits<std::vector<AssociationElement<V>>> {
0125     static constexpr bool is_one_to_one = true;
0126     using AssociationElementType = AssociationElement<V>;
0127     using ValueType = V;
0128   };
0129 
0130   template <typename V>
0131   struct MapTraits<std::vector<std::vector<AssociationElement<V>>>> {
0132     static constexpr bool is_one_to_one = false;
0133     using AssociationElementType = AssociationElement<V>;
0134     using ValueType = V;
0135   };
0136 
0137   // Trait to check if V is a std::pair (i.e., has a score)
0138   template <typename T>
0139   struct IsValueTypeWithScore : std::false_type {};
0140 
0141   template <typename First>
0142   struct IsValueTypeWithScore<std::pair<First, float>> : std::true_type {};
0143 
0144   // Define map types using AssociationElement and container types
0145   using mapWithFraction = std::vector<std::vector<AssociationElement<FractionType>>>;
0146   using mapWithFractionAndScore = std::vector<std::vector<AssociationElement<std::pair<FractionType, float>>>>;
0147   using oneToOneMapWithFraction = std::vector<AssociationElement<FractionType>>;
0148   using oneToOneMapWithFractionAndScore = std::vector<AssociationElement<std::pair<FractionType, float>>>;
0149 
0150   using mapWithSharedEnergy = std::vector<std::vector<AssociationElement<SharedEnergyType>>>;
0151   using mapWithSharedEnergyAndScore = std::vector<std::vector<AssociationElement<std::pair<SharedEnergyType, float>>>>;
0152   using oneToOneMapWithSharedEnergy = std::vector<AssociationElement<SharedEnergyType>>;
0153   using oneToOneMapWithSharedEnergyAndScore = std::vector<AssociationElement<std::pair<SharedEnergyType, float>>>;
0154 
0155   // AssociationMap class templated on MapType
0156   template <typename MapType, typename Collection1 = void, typename Collection2 = void>
0157   class AssociationMap {
0158   private:
0159     MapType map_;
0160 
0161     // Type alias for conditionally including collectionRefProds
0162     using CollectionRefProdType =
0163         typename std::conditional_t<std::is_void_v<Collection1> || std::is_void_v<Collection2>,
0164                                     std::monostate,
0165                                     std::pair<edm::RefProd<Collection1>, edm::RefProd<Collection2>>>;
0166 
0167     CollectionRefProdType collectionRefProds;
0168 
0169     // Traits to deduce AssociationElementType and ValueType
0170     using Traits = MapTraits<MapType>;
0171     using AssociationElementType = typename Traits::AssociationElementType;
0172     using V = typename Traits::ValueType;
0173     static constexpr bool is_one_to_one = Traits::is_one_to_one;
0174 
0175   public:
0176     AssociationMap() : collectionRefProds() {}
0177 
0178     // Constructor for generic use
0179     template <typename C1 = Collection1,
0180               typename C2 = Collection2,
0181               typename std::enable_if_t<std::is_void_v<C1> && std::is_void_v<C2>, int> = 0>
0182     AssociationMap(const unsigned int size1 = 0) {
0183       map_.resize(size1);
0184     }
0185 
0186     // Constructor for CMSSW-specific use
0187     template <typename C1 = Collection1,
0188               typename C2 = Collection2,
0189               typename std::enable_if_t<!std::is_void_v<C1> && !std::is_void_v<C2>, int> = 0>
0190     AssociationMap(const edm::RefProd<C1>& id1, const edm::RefProd<C2>& id2, const edm::Event& event)
0191         : collectionRefProds(std::make_pair(id1, id2)) {
0192       resize(event);
0193     }
0194 
0195     // Constructor for CMSSW-specific use
0196     template <typename C1 = Collection1,
0197               typename C2 = Collection2,
0198               typename std::enable_if_t<!std::is_void_v<C1> && !std::is_void_v<C2>, int> = 0>
0199     AssociationMap(const edm::Handle<C1>& handle1, const edm::Handle<C2>& handle2, const edm::Event& event)
0200         : collectionRefProds(std::make_pair(edm::RefProd<C1>(handle1), edm::RefProd<C2>(handle2))) {
0201       resize(event);
0202     }
0203 
0204     MapType& getMap() { return map_; }
0205 
0206     const MapType& getMap() const { return map_; }
0207 
0208     const auto size() const { return map_.size(); }
0209 
0210     // CMSSW-specific method to get references
0211     template <typename C1 = Collection1,
0212               typename C2 = Collection2,
0213               typename std::enable_if_t<!std::is_void_v<C1> && !std::is_void_v<C2>, int> = 0>
0214     edm::Ref<C1> getRefFirst(unsigned int index) const {
0215       return edm::Ref<C1>(collectionRefProds.first, index);
0216     }
0217 
0218     template <typename C1 = Collection1,
0219               typename C2 = Collection2,
0220               typename std::enable_if_t<!std::is_void_v<C1> && !std::is_void_v<C2>, int> = 0>
0221     edm::Ref<C2> getRefSecond(unsigned int index) const {
0222       return edm::Ref<C2>(collectionRefProds.second, index);
0223     }
0224 
0225     // Method to get collection IDs for CMSSW-specific use
0226     template <typename C1 = Collection1,
0227               typename C2 = Collection2,
0228               typename std::enable_if_t<!std::is_void_v<C1> && !std::is_void_v<C2>, int> = 0>
0229     std::pair<const edm::RefProd<C1>, const edm::RefProd<C2>> getCollectionIDs() const {
0230       return collectionRefProds;
0231     }
0232 
0233     void insert(unsigned int index1, unsigned int index2, float fraction_or_energy, float score = 0.0f) {
0234       assert(index1 < map_.size());
0235       V value;
0236       if constexpr (IsValueTypeWithScore<V>::value) {
0237         using FirstType = typename V::first_type;
0238         value = V(FirstType(fraction_or_energy), score);
0239       } else {
0240         value = V(fraction_or_energy);
0241       }
0242       AssociationElementType element(index2, value);
0243 
0244       if constexpr (is_one_to_one) {
0245         map_[index1] = element;
0246       } else {
0247         auto& vec = map_[index1];
0248         auto it =
0249             std::find_if(vec.begin(), vec.end(), [&](const AssociationElementType& e) { return e.index() == index2; });
0250         if (it != vec.end()) {
0251           // Accumulate value
0252           it->accumulate(value);
0253         } else {
0254           vec.push_back(element);
0255         }
0256       }
0257     }
0258 
0259     // Overload of insert for CMSSW-specific use
0260     template <typename C1 = Collection1,
0261               typename C2 = Collection2,
0262               typename std::enable_if_t<!std::is_void_v<C1> && !std::is_void_v<C2>, int> = 0>
0263     void insert(const edm::Ref<C1>& ref1, const edm::Ref<C2>& ref2, float fraction_or_energy, float score = 0.0f) {
0264       insert(ref1.key(), ref2.key(), fraction_or_energy, score);
0265     }
0266 
0267     void sort(bool byScore = false) {
0268       if constexpr (is_one_to_one) {
0269         // Sorting not applicable for one-to-one maps
0270       } else {
0271         for (auto& vec : map_) {
0272           if (byScore && IsValueTypeWithScore<V>::value) {
0273             std::sort(vec.begin(), vec.end(), [](const auto& a, const auto& b) {
0274               if (a.score() != b.score()) {
0275                 return a.score() > b.score();
0276               } else {
0277                 return a.index() < b.index();
0278               }
0279             });
0280           } else {
0281             if constexpr (std::is_same_v<V, FractionType> || std::is_same_v<V, std::pair<FractionType, float>>) {
0282               std::sort(vec.begin(), vec.end(), [](const auto& a, const auto& b) {
0283                 if (a.fraction() != b.fraction()) {
0284                   return a.fraction() > b.fraction();
0285                 } else {
0286                   return a.index() < b.index();
0287                 }
0288               });
0289             } else {
0290               std::sort(vec.begin(), vec.end(), [](const auto& a, const auto& b) {
0291                 if (a.sharedEnergy() != b.sharedEnergy()) {
0292                   return a.sharedEnergy() > b.sharedEnergy();
0293                 } else {
0294                   return a.index() < b.index();
0295                 }
0296               });
0297             }
0298           }
0299         }
0300       }
0301     }
0302 
0303     // Overload of sort() that accepts a custom comparator
0304     template <typename Compare>
0305     void sort(Compare comp) {
0306       if constexpr (is_one_to_one) {
0307         // Sorting not applicable for one-to-one maps
0308       } else {
0309         for (auto& vec : map_) {
0310           std::sort(vec.begin(), vec.end(), comp);
0311         }
0312       }
0313     }
0314 
0315     // Access methods
0316     const auto& operator[](unsigned int index1) const { return map_[index1]; }
0317 
0318     auto& operator[](unsigned int index1) { return map_[index1]; }
0319 
0320     const auto& at(unsigned int index1) const {
0321       const auto& elem = map_.at(index1);
0322       if (!elem.isValid()) {
0323         throw std::out_of_range("Attempted to access an unset element in AssociationMap. Element index: " +
0324                                 std::to_string(index1));
0325       }
0326       return elem;
0327     }
0328 
0329     auto& at(unsigned int index1) {
0330       auto& elem = map_.at(index1);
0331       if (!elem.isValid()) {
0332         throw std::out_of_range("Attempted to access an unset element in AssociationMap. Element index: " +
0333                                 std::to_string(index1));
0334       }
0335       return elem;
0336     }
0337 
0338     // CMSSW-specific resize method
0339     template <typename C1 = Collection1,
0340               typename C2 = Collection2,
0341               typename std::enable_if_t<!std::is_void_v<C1> && !std::is_void_v<C2>, int> = 0>
0342     void resize(const edm::Event& event) {
0343       map_.resize(collectionRefProds.first->size());
0344     }
0345 
0346     // Generic resize method
0347     template <typename C1 = Collection1,
0348               typename C2 = Collection2,
0349               typename std::enable_if_t<std::is_void_v<C1> && std::is_void_v<C2>, int> = 0>
0350     void resize(const unsigned int size1) {
0351       map_.resize(size1);
0352     }
0353 
0354     // Method to print the entire map
0355     void print(std::ostream& os) const {
0356       for (size_t i = 0; i < map_.size(); ++i) {
0357         if constexpr (is_one_to_one) {
0358           const auto& elem = map_[i];
0359           if (!elem.isValid()) {
0360             continue;
0361           }
0362           os << "Index " << i << ":\n";
0363 
0364           os << "  (" << elem.index() << ", ";
0365           if constexpr (IsValueTypeWithScore<V>::value) {
0366             if constexpr (std::is_same_v<typename V::first_type, FractionType>) {
0367               os << "Fraction: " << elem.fraction() << ", Score: " << elem.score();
0368             } else {
0369               os << "SharedEnergy: " << elem.sharedEnergy() << ", Score: " << elem.score();
0370             }
0371           } else {
0372             if constexpr (std::is_same_v<V, FractionType>) {
0373               os << "Fraction: " << elem.fraction();
0374             } else if constexpr (std::is_same_v<V, SharedEnergyType>) {
0375               os << "SharedEnergy: " << elem.sharedEnergy();
0376             }
0377           }
0378           os << ")\n";
0379         } else {
0380           os << "Index " << i << ":\n";
0381           for (const auto& elem : map_[i]) {
0382             os << "  (" << elem.index() << ", ";
0383             if constexpr (IsValueTypeWithScore<V>::value) {
0384               if constexpr (std::is_same_v<typename V::first_type, FractionType>) {
0385                 os << "Fraction: " << elem.fraction() << ", Score: " << elem.score();
0386               } else {
0387                 os << "SharedEnergy: " << elem.sharedEnergy() << ", Score: " << elem.score();
0388               }
0389             } else {
0390               if constexpr (std::is_same_v<V, FractionType>) {
0391                 os << "Fraction: " << elem.fraction();
0392               } else if constexpr (std::is_same_v<V, SharedEnergyType>) {
0393                 os << "SharedEnergy: " << elem.sharedEnergy();
0394               }
0395             }
0396             os << ")\n";
0397           }
0398         }
0399       }
0400     }
0401   };
0402 
0403 }  // namespace ticl
0404 
0405 #endif