Macros

Line Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
#ifndef HeterogeneousCore_AlpakaInterface_interface_getDeviceCachingAllocator_h
#define HeterogeneousCore_AlpakaInterface_interface_getDeviceCachingAllocator_h

#include <cassert>
#include <memory>

#include <alpaka/alpaka.hpp>

#include "FWCore/Utilities/interface/thread_safety_macros.h"
#include "HeterogeneousCore/AlpakaInterface/interface/AllocatorConfig.h"
#include "HeterogeneousCore/AlpakaInterface/interface/CachingAllocator.h"
#include "HeterogeneousCore/AlpakaInterface/interface/devices.h"

namespace cms::alpakatools {

  namespace detail {

    template <typename TDev,
              typename TQueue,
              typename = std::enable_if_t<alpaka::isDevice<TDev> and alpaka::isQueue<TQueue>>>
    auto allocate_device_allocators(AllocatorConfig const& config, bool debug) {
      using Allocator = CachingAllocator<TDev, TQueue>;
      auto const& devices = cms::alpakatools::devices<alpaka::Platform<TDev>>();
      ssize_t const size = devices.size();

      // allocate the storage for the objects
      auto ptr = std::allocator<Allocator>().allocate(size);

      // construct the objects in the storage
      ptrdiff_t index = 0;
      try {
        for (; index < size; ++index) {
#if __cplusplus >= 202002L
          std::construct_at(
#else
          std::allocator<Allocator>().construct(
#endif
              ptr + index,
              devices[index],
              config,
              true,  // reuseSameQueueAllocations
              debug);
        }
      } catch (...) {
        --index;
        // destroy any object that had been succesfully constructed
        while (index >= 0) {
          std::destroy_at(ptr + index);
          --index;
        }
        // deallocate the storage
        std::allocator<Allocator>().deallocate(ptr, size);
        // rethrow the exception
        throw;
      }

      // use a custom deleter to destroy all objects and deallocate the memory
      auto deleter = [size](Allocator* allocators) {
        for (size_t i = size; i > 0; --i) {
          std::destroy_at(allocators + i - 1);
        }
        std::allocator<Allocator>().deallocate(allocators, size);
      };

      return std::unique_ptr<Allocator[], decltype(deleter)>(ptr, deleter);
    }

  }  // namespace detail

  template <typename TDev,
            typename TQueue,
            typename = std::enable_if_t<alpaka::isDevice<TDev> and alpaka::isQueue<TQueue>>>
  inline CachingAllocator<TDev, TQueue>& getDeviceCachingAllocator(TDev const& device,
                                                                   AllocatorConfig const& config = AllocatorConfig{},
                                                                   bool debug = false) {
    // initialise all allocators, one per device
    CMS_THREAD_SAFE static auto allocators = detail::allocate_device_allocators<TDev, TQueue>(config, debug);

    size_t const index = alpaka::getNativeHandle(device);
    assert(index < cms::alpakatools::devices<alpaka::Platform<TDev>>().size());

    // the public interface is thread safe
    return allocators[index];
  }

}  // namespace cms::alpakatools

#endif  // HeterogeneousCore_AlpakaInterface_interface_getDeviceCachingAllocator_h