Iterator

MultiVectorManager

Macros

Line Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
// Author: Felice Pantaleo (CERN), 2023, felice.pantaleo@cern.ch
#ifndef MultiVectorManager_h
#define MultiVectorManager_h

#include <vector>
#include <cassert>
#include <algorithm>
#include <span>

template <typename T>
class MultiVectorManager {
public:
  void addVector(std::span<const T> vec) {
    vectors.emplace_back(vec);
    offsets.push_back(totalSize);
    totalSize += vec.size();
  }

  T& operator[](size_t globalIndex) {
    return const_cast<T&>(static_cast<const MultiVectorManager*>(this)->operator[](globalIndex));
  }

  const T& operator[](size_t globalIndex) const {
    assert(globalIndex < totalSize && "Global index out of range");

    auto it = std::upper_bound(offsets.begin(), offsets.end(), globalIndex);
    size_t vectorIndex = std::distance(offsets.begin(), it) - 1;
    size_t localIndex = globalIndex - offsets[vectorIndex];

    return vectors[vectorIndex][localIndex];
  }

  size_t getGlobalIndex(size_t vectorIndex, size_t localIndex) const {
    assert(vectorIndex < vectors.size() && "Vector index out of range");

    const auto& vec = vectors[vectorIndex];
    assert(localIndex < vec.size() && "Local index out of range");

    return offsets[vectorIndex] + localIndex;
  }

  std::pair<size_t, size_t> getVectorAndLocalIndex(size_t globalIndex) const {
    assert(globalIndex < totalSize && "Global index out of range");

    auto it = std::upper_bound(offsets.begin(), offsets.end(), globalIndex);
    size_t vectorIndex = std::distance(offsets.begin(), it) - 1;
    size_t localIndex = globalIndex - offsets[vectorIndex];

    return {vectorIndex, localIndex};
  }

  size_t size() const { return totalSize; }

  class Iterator {
  public:
    using iterator_category = std::forward_iterator_tag;
    using difference_type = std::ptrdiff_t;
    using value_type = T;
    using pointer = T*;
    using reference = T&;

    Iterator(const MultiVectorManager& manager, size_t index) : manager(manager), currentIndex(index) {}

    bool operator!=(const Iterator& other) const { return currentIndex != other.currentIndex; }

    T& operator*() const { return const_cast<T&>(manager[currentIndex]); }

    void operator++() { ++currentIndex; }

  private:
    const MultiVectorManager& manager;
    size_t currentIndex;
  };

  Iterator begin() const { return Iterator(*this, 0); }

  Iterator end() const { return Iterator(*this, totalSize); }

private:
  std::vector<std::span<const T>> vectors;
  std::vector<size_t> offsets;
  size_t totalSize = 0;
};

#endif