File indexing completed on 2024-04-06 12:24:45
0001 #include "RecoEcal/EgammaCoreTools/interface/GraphMap.h"
0002 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0003
0004 #include <iostream>
0005 #include <iomanip>
0006
0007 using namespace reco;
0008
0009 GraphMap::GraphMap(uint nNodes) : nNodes_(nNodes) {
0010
0011 edgesIn_.resize(nNodes);
0012 edgesOut_.resize(nNodes);
0013 }
0014
0015 void GraphMap::addNode(const uint index, const NodeCategory category) {
0016 nodesCategories_[category].push_back(index);
0017 nodesCount_[category] += 1;
0018 }
0019
0020 void GraphMap::addNodes(const std::vector<uint> &indices, const std::vector<NodeCategory> &categories) {
0021 for (size_t i = 0; i < indices.size(); i++) {
0022 addNode(indices[i], categories[i]);
0023 }
0024 }
0025
0026 void GraphMap::addEdge(const uint i, const uint j) {
0027
0028 edgesOut_.at(i).push_back(j);
0029 edgesIn_.at(j).push_back(i);
0030
0031 adjMatrix_[{i, j}] = 1.;
0032 }
0033
0034 void GraphMap::setAdjMatrix(const uint i, const uint j, const float score) { adjMatrix_[{i, j}] = score; };
0035
0036 void GraphMap::setAdjMatrixSym(const uint i, const uint j, const float score) {
0037 adjMatrix_[{i, j}] = score;
0038 adjMatrix_[{j, i}] = score;
0039 };
0040
0041 const std::vector<uint> &GraphMap::getOutEdges(const uint i) const { return edgesOut_.at(i); };
0042
0043 const std::vector<uint> &GraphMap::getInEdges(const uint i) const { return edgesIn_.at(i); };
0044
0045 uint GraphMap::getAdjMatrix(const uint i, const uint j) const { return adjMatrix_.at({i, j}); };
0046
0047 std::vector<float> GraphMap::getAdjMatrixRow(const uint i) const {
0048 std::vector<float> out;
0049 for (const auto &j : getOutEdges(i)) {
0050 out.push_back(adjMatrix_.at({i, j}));
0051 }
0052 return out;
0053 };
0054
0055 std::vector<float> GraphMap::getAdjMatrixCol(const uint j) const {
0056 std::vector<float> out;
0057 for (const auto &i : getInEdges(j)) {
0058 out.push_back(adjMatrix_.at({i, j}));
0059 }
0060 return out;
0061 };
0062
0063
0064
0065 void GraphMap::printGraphMap() {
0066 edm::LogVerbatim("GraphMap") << "OUT edges" << std::endl;
0067 uint seed = 0;
0068 for (const auto &s : edgesOut_) {
0069 edm::LogVerbatim("GraphMap") << "cl: " << seed << " --> ";
0070 for (const auto &e : s) {
0071 edm::LogVerbatim("GraphMap") << e << " (" << adjMatrix_[{seed, e}] << ") ";
0072 }
0073 edm::LogVerbatim("GraphMap") << std::endl;
0074 seed++;
0075 }
0076 edm::LogVerbatim("GraphMap") << std::endl << "IN edges" << std::endl;
0077 seed = 0;
0078 for (const auto &s : edgesIn_) {
0079 edm::LogVerbatim("GraphMap") << "cl: " << seed << " <-- ";
0080 for (const auto &e : s) {
0081 edm::LogVerbatim("GraphMap") << e << " (" << adjMatrix_[{e, seed}] << ") ";
0082 }
0083 edm::LogVerbatim("GraphMap") << std::endl;
0084 seed++;
0085 }
0086 edm::LogVerbatim("GraphMap") << std::endl << "AdjMatrix" << std::endl;
0087 for (const auto &s : nodesCategories_[NodeCategory::kSeed]) {
0088 for (size_t n = 0; n < nNodes_; n++) {
0089 edm::LogVerbatim("GraphMap") << std::setprecision(2) << adjMatrix_[{s, n}] << " ";
0090 }
0091 edm::LogVerbatim("GraphMap") << std::endl;
0092 }
0093 }
0094
0095
0096
0097 void GraphMap::collectNodes(GraphMap::CollectionStrategy strategy, float threshold) {
0098
0099 graphOutput_.clear();
0100
0101 if (strategy == GraphMap::CollectionStrategy::Cascade) {
0102
0103
0104 collectCascading(threshold);
0105 } else if (strategy == GraphMap::CollectionStrategy::CollectAndMerge) {
0106
0107
0108
0109
0110
0111 assignHighestScoreEdge();
0112 const auto &[seedsGraph, simpleNodesMap] = collectSeparately(threshold);
0113 mergeSubGraphs(threshold, seedsGraph, simpleNodesMap);
0114 } else if (strategy == GraphMap::CollectionStrategy::SeedsFirst) {
0115
0116
0117
0118 resolveSuperNodesEdges(threshold);
0119 assignHighestScoreEdge();
0120 collectCascading(threshold);
0121 } else if (strategy == GraphMap::CollectionStrategy::CascadeHighest) {
0122
0123
0124
0125
0126 assignHighestScoreEdge();
0127 collectCascading(threshold);
0128 }
0129 }
0130
0131
0132
0133
0134 void GraphMap::collectCascading(float threshold) {
0135
0136
0137 const auto &seeds = nodesCategories_[NodeCategory::kSeed];
0138
0139 LogDebug("GraphMap") << "Cascading...";
0140 for (const auto &s : seeds) {
0141 LogTrace("GraphMap") << "seed: " << s;
0142 std::vector<uint> collectedNodes;
0143
0144 if (adjMatrix_[{s, s}] < threshold)
0145 continue;
0146
0147 for (const auto &out : edgesOut_[s]) {
0148
0149 if (adjMatrix_[{s, out}] >= threshold) {
0150 LogTrace("GraphMap") << "\tOut edge: " << s << " --> " << out;
0151
0152 collectedNodes.push_back(out);
0153
0154
0155 for (const auto &out_in : edgesIn_[out]) {
0156
0157
0158
0159
0160
0161 if (out != s && out_in == s)
0162 continue;
0163 adjMatrix_[{out_in, out}] = 0.;
0164 LogTrace("GraphMap") << "\t\t Deleted edge: " << out << " <-- " << out_in;
0165 }
0166 }
0167 }
0168 graphOutput_.push_back({s, collectedNodes});
0169 }
0170 }
0171
0172 void GraphMap::assignHighestScoreEdge() {
0173
0174
0175 LogTrace("GraphMap") << "Keep only highest score edge";
0176 for (const auto &cl : nodesCategories_[NodeCategory::kNode]) {
0177 std::pair<uint, float> maxPair{0, 0};
0178 bool found = false;
0179 for (const auto &seed : edgesIn_[cl]) {
0180 float score = adjMatrix_[{seed, cl}];
0181 if (score > maxPair.second) {
0182 maxPair = {seed, score};
0183 found = true;
0184 }
0185 }
0186 if (!found)
0187 continue;
0188 LogTrace("GraphMap") << "cluster: " << cl << " edge from " << maxPair.first;
0189
0190 for (const auto &seed : edgesIn_[cl]) {
0191 if (seed != maxPair.first) {
0192 adjMatrix_[{seed, cl}] = 0.;
0193 }
0194 }
0195 }
0196 }
0197
0198 std::pair<GraphMap::GraphOutput, GraphMap::GraphOutputMap> GraphMap::collectSeparately(float threshold) {
0199
0200 GraphOutput seedsGraph;
0201
0202 GraphOutputMap simpleNodesGraphMap;
0203 LogDebug("GraphMap") << "Collecting separately each seed...";
0204
0205 for (const auto &s : nodesCategories_[NodeCategory::kSeed]) {
0206 LogTrace("GraphMap") << "seed: " << s;
0207 std::vector<uint> collectedNodes;
0208 std::vector<uint> collectedSeeds;
0209
0210 if (adjMatrix_[{s, s}] < threshold)
0211 continue;
0212
0213 for (const auto &out : edgesOut_[s]) {
0214
0215
0216 if (out != s && adjMatrix_[{out, out}] > 0) {
0217
0218 collectedSeeds.push_back(out);
0219
0220
0221 continue;
0222 }
0223
0224 if (adjMatrix_[{s, out}] >= threshold) {
0225 LogTrace("GraphMap") << "\tOut edge: " << s << " --> " << out << " (" << adjMatrix_[{s, out}] << " )";
0226
0227 collectedNodes.push_back(out);
0228
0229
0230
0231
0232
0233
0234 }
0235 }
0236 simpleNodesGraphMap[s] = collectedNodes;
0237 seedsGraph.push_back({s, collectedSeeds});
0238 }
0239 return std::make_pair(seedsGraph, simpleNodesGraphMap);
0240 }
0241
0242 void GraphMap::mergeSubGraphs(float threshold, GraphOutput seedsGraph, GraphOutputMap nodesGraphMap) {
0243
0244
0245
0246 LogTrace("GraphMap") << "Starting merging";
0247 for (const auto &[s, other_seeds] : seedsGraph) {
0248 LogTrace("GraphMap") << "seed: " << s;
0249
0250 if (adjMatrix_[{s, s}] < threshold)
0251 continue;
0252
0253 std::vector<uint> collectedNodes;
0254
0255 const auto &simpleNodes = nodesGraphMap[s];
0256 collectedNodes.insert(std::end(collectedNodes), std::begin(simpleNodes), std::end(simpleNodes));
0257
0258 for (const auto &out_s : other_seeds) {
0259
0260
0261 if (adjMatrix_[{out_s, out_s}] > threshold && adjMatrix_[{s, out_s}] > threshold) {
0262 LogTrace("GraphMap") << "\tMerging nodes from seed: " << out_s;
0263
0264 const auto &otherNodes = nodesGraphMap[out_s];
0265
0266
0267 collectedNodes.insert(std::end(collectedNodes), std::begin(otherNodes), std::end(otherNodes));
0268
0269 adjMatrix_[{out_s, out_s}] = 0.;
0270
0271
0272 }
0273 }
0274
0275
0276 adjMatrix_[{s, s}] = 0;
0277 graphOutput_.push_back({s, collectedNodes});
0278 }
0279 }
0280
0281 void GraphMap::resolveSuperNodesEdges(float threshold) {
0282 LogTrace("GraphMap") << "Resolving seeds";
0283 for (const auto &s : nodesCategories_[NodeCategory::kSeed]) {
0284 LogTrace("GraphMap") << "seed: " << s;
0285
0286 if (adjMatrix_[{s, s}] < threshold)
0287 continue;
0288
0289 for (const auto &out : edgesOut_[s]) {
0290 if (out != s && adjMatrix_[{out, out}] > 0 && adjMatrix_[{s, out}] > threshold) {
0291
0292
0293 LogTrace("GraphMap") << "\tdisable seed: " << out;
0294 adjMatrix_[{out, out}] = 0.;
0295
0296
0297
0298 for (const auto &c : edgesOut_[out]) {
0299 adjMatrix_[{out, c}] = 0.;
0300
0301 }
0302 }
0303 }
0304 }
0305 }