Skip to content

Commit

Permalink
2118 [CUDAX] Change the RAII device swapper to use driver API and add…
Browse files Browse the repository at this point in the history
… it in places where it was missing (#2192)

* Change __scoped_device to use driver API

* Switch to use driver API based dev setter

* Remove constexpr from operator device()

* Fix comments and includes

* Fallback to non-versioned get entry point pre 12.5
We need to use versioned version to get correct cuStreamGetCtx.
There is v2 version of it in 12.5, fortunatelly the versioned
get entry point is available there too

* Fix unused local variable

* Fix warnings in ensure_current_device test

* Move ensure current device out of detail

* Add LIBCUDACXX_ENABLE_EXCEPTIONS to tests cmake
  • Loading branch information
pciolkosz authored Aug 6, 2024
1 parent d1e7c1c commit 75929cb
Show file tree
Hide file tree
Showing 18 changed files with 412 additions and 100 deletions.
32 changes: 30 additions & 2 deletions cudax/include/cuda/experimental/__device/device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
# pragma system_header
#endif // no system header

#include <cuda.h>

#include <cuda/experimental/__device/device_ref.cuh>
#include <cuda/experimental/__utility/driver_api.cuh>

#include <cassert>
#include <mutex>

namespace cuda::experimental
{
Expand All @@ -33,7 +39,7 @@ struct __emplace_device
{
int __id_;

_CCCL_NODISCARD constexpr operator device() const noexcept;
_CCCL_NODISCARD operator device() const noexcept;

_CCCL_NODISCARD constexpr const __emplace_device* operator->() const noexcept;
};
Expand All @@ -56,13 +62,35 @@ public:
# endif
#endif

CUcontext primary_context() const
{
::std::call_once(__init_once, [this]() {
__device = detail::driver::deviceGet(__id_);
__primary_ctx = detail::driver::primaryCtxRetain(__device);
});
assert(__primary_ctx != nullptr);
return __primary_ctx;
}

~device()
{
if (__primary_ctx)
{
detail::driver::primaryCtxRelease(__device);
}
}

private:
// TODO: put a mutable thread-safe (or thread_local) cache of device
// properties here.

friend class device_ref;
friend struct detail::__emplace_device;

mutable CUcontext __primary_ctx = nullptr;
mutable CUdevice __device{};
mutable ::std::once_flag __init_once;

explicit constexpr device(int __id) noexcept
: device_ref(__id)
{}
Expand All @@ -76,7 +104,7 @@ private:

namespace detail
{
_CCCL_NODISCARD inline constexpr __emplace_device::operator device() const noexcept
_CCCL_NODISCARD inline __emplace_device::operator device() const noexcept
{
return device(__id_);
}
Expand Down
64 changes: 0 additions & 64 deletions cudax/include/cuda/experimental/__device/device_ref.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#endif // no system header

#include <cuda/std/__cuda/api_wrapper.h>
#include <cuda/std/__type_traits/decay.h>

namespace cuda::experimental
{
Expand Down Expand Up @@ -103,69 +102,6 @@ public:
}
};

#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document

//! @brief RAII helper which saves the current device and switches to the
//! specified device on construction and switches to the saved device on
//! destruction.
//!
struct __scoped_device
{
private:
// The original device ordinal, or -1 if the device was not changed.
int const __old_device;

//! @brief Returns the current device ordinal.
//!
//! @throws cuda_error if the device query fails.
static int __current_device()
{
int device = -1;
_CCCL_TRY_CUDA_API(cudaGetDevice, "failed to get the current device", &device);
return device;
}

explicit __scoped_device(int new_device, int old_device) noexcept
: __old_device(new_device == old_device ? -1 : old_device)
{}

public:
//! @brief Construct a new `__scoped_device` object and switch to the specified
//! device.
//!
//! @param new_device The device to switch to
//!
//! @throws cuda_error if the device switch fails
explicit __scoped_device(device_ref new_device)
: __scoped_device(new_device.get(), __current_device())
{
if (__old_device != -1)
{
_CCCL_TRY_CUDA_API(cudaSetDevice, "failed to set the current device", new_device.get());
}
}

__scoped_device(__scoped_device&&) = delete;
__scoped_device(__scoped_device const&) = delete;
__scoped_device& operator=(__scoped_device&&) = delete;
__scoped_device& operator=(__scoped_device const&) = delete;

//! @brief Destroy the `__scoped_device` object and switch back to the original
//! device.
//!
//! @throws cuda_error if the device switch fails. If the destructor is called
//! during stack unwinding, the program is automatically terminated.
~__scoped_device() noexcept(false)
{
if (__old_device != -1)
{
_CCCL_TRY_CUDA_API(cudaSetDevice, "failed to restore the current device", __old_device);
}
}
};

#endif // DOXYGEN_SHOULD_SKIP_THIS

} // namespace cuda::experimental

#endif // _CUDAX__DEVICE_DEVICE_REF
10 changes: 7 additions & 3 deletions cudax/include/cuda/experimental/__event/event.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include <cuda/experimental/__detail/utility.cuh>
#include <cuda/experimental/__event/event_ref.cuh>
#include <cuda/experimental/__utility/ensure_current_device.cuh>

namespace cuda::experimental
{
Expand All @@ -54,7 +55,7 @@ public:
//!
//! @throws cuda_error if the event creation fails.
explicit event(stream_ref __stream, flags __flags = flags::none)
: event(static_cast<unsigned int>(__flags) | cudaEventDisableTiming)
: event(__stream, static_cast<unsigned int>(__flags) | cudaEventDisableTiming)
{
record(__stream);
}
Expand Down Expand Up @@ -85,7 +86,9 @@ public:
{
if (__event_ != nullptr)
{
[[maybe_unused]] auto __status = ::cudaEventDestroy(__event_);
// Needs to call driver API in case current device is not set, runtime version would set dev 0 current
// Alternative would be to store the device and push/pop here
[[maybe_unused]] auto __status = detail::driver::eventDestroy(__event_);
}
}

Expand Down Expand Up @@ -144,9 +147,10 @@ private:
: event_ref(__evnt)
{}

explicit event(unsigned int __flags)
explicit event(stream_ref __stream, unsigned int __flags)
: event_ref(::cudaEvent_t{})
{
[[maybe_unused]] __ensure_current_device __dev_setter(__stream);
_CCCL_TRY_CUDA_API(
::cudaEventCreateWithFlags, "Failed to create CUDA event", &__event_, static_cast<unsigned int>(__flags));
}
Expand Down
5 changes: 4 additions & 1 deletion cudax/include/cuda/experimental/__event/event_ref.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <cuda/std/utility>
#include <cuda/stream_ref>

#include <cuda/experimental/__utility/driver_api.cuh>

namespace cuda::experimental
{
class event;
Expand Down Expand Up @@ -74,7 +76,8 @@ public:
{
assert(__event_ != nullptr);
assert(__stream.get() != nullptr);
_CCCL_TRY_CUDA_API(::cudaEventRecord, "Failed to record CUDA event", __event_, __stream.get());
// Need to use driver API, cudaEventRecord will push dev 0 if stack is empty
detail::driver::eventRecord(__event_, __stream.get());
}

//! @brief Waits until all the work in the stream prior to the record of the
Expand Down
2 changes: 1 addition & 1 deletion cudax/include/cuda/experimental/__event/timed_event.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public:
//!
//! @throws cuda_error if the event creation fails.
explicit timed_event(stream_ref __stream, flags __flags = flags::none)
: event(static_cast<unsigned int>(__flags))
: event(__stream, static_cast<unsigned int>(__flags))
{
record(__stream);
}
Expand Down
7 changes: 7 additions & 0 deletions cudax/include/cuda/experimental/__launch/launch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cuda/stream_ref>

#include <cuda/experimental/__launch/configuration.cuh>
#include <cuda/experimental/__utility/ensure_current_device.cuh>

#if _CCCL_STD_VER >= 2017
namespace cuda::experimental
Expand Down Expand Up @@ -119,6 +120,7 @@ template <typename... Args, typename... Config, typename Dimensions, typename Ke
void launch(
::cuda::stream_ref stream, const kernel_config<Dimensions, Config...>& conf, const Kernel& kernel, Args... args)
{
[[maybe_unused]] __ensure_current_device __dev_setter(stream);
cudaError_t status;
if constexpr (::cuda::std::is_invocable_v<Kernel, kernel_config<Dimensions, Config...>, Args...>)
{
Expand Down Expand Up @@ -181,6 +183,7 @@ void launch(
template <typename... Args, typename... Levels, typename Kernel>
void launch(::cuda::stream_ref stream, const hierarchy_dimensions<Levels...>& dims, const Kernel& kernel, Args... args)
{
[[maybe_unused]] __ensure_current_device __dev_setter(stream);
cudaError_t status;
if constexpr (::cuda::std::is_invocable_v<Kernel, hierarchy_dimensions<Levels...>, Args...>)
{
Expand Down Expand Up @@ -245,6 +248,7 @@ void launch(::cuda::stream_ref stream,
void (*kernel)(kernel_config<Dimensions, Config...>, ExpArgs...),
ActArgs&&... args)
{
[[maybe_unused]] __ensure_current_device __dev_setter(stream);
cudaError_t status = [&](ExpArgs... args) {
return detail::launch_impl(stream, conf, kernel, conf, args...);
}(std::forward<ActArgs>(args)...);
Expand Down Expand Up @@ -299,6 +303,7 @@ void launch(::cuda::stream_ref stream,
void (*kernel)(hierarchy_dimensions<Levels...>, ExpArgs...),
ActArgs&&... args)
{
[[maybe_unused]] __ensure_current_device __dev_setter(stream);
cudaError_t status = [&](ExpArgs... args) {
return detail::launch_impl(stream, kernel_config(dims), kernel, dims, args...);
}(std::forward<ActArgs>(args)...);
Expand Down Expand Up @@ -354,6 +359,7 @@ void launch(::cuda::stream_ref stream,
void (*kernel)(ExpArgs...),
ActArgs&&... args)
{
[[maybe_unused]] __ensure_current_device __dev_setter(stream);
cudaError_t status = [&](ExpArgs... args) {
return detail::launch_impl(stream, conf, kernel, args...);
}(std::forward<ActArgs>(args)...);
Expand Down Expand Up @@ -406,6 +412,7 @@ template <typename... ExpArgs, typename... ActArgs, typename... Levels>
void launch(
::cuda::stream_ref stream, const hierarchy_dimensions<Levels...>& dims, void (*kernel)(ExpArgs...), ActArgs&&... args)
{
[[maybe_unused]] __ensure_current_device __dev_setter(stream);
cudaError_t status = [&](ExpArgs... args) {
return detail::launch_impl(stream, kernel_config(dims), kernel, args...);
}(std::forward<ActArgs>(args)...);
Expand Down
17 changes: 11 additions & 6 deletions cudax/include/cuda/experimental/__stream/stream.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <cuda/experimental/__device/device_ref.cuh>
#include <cuda/experimental/__event/timed_event.cuh>
#include <cuda/experimental/__utility/ensure_current_device.cuh>

namespace cuda::experimental
{
Expand All @@ -51,7 +52,7 @@ struct stream : stream_ref
//! @throws cuda_error if stream creation fails
explicit stream(device_ref __dev, int __priority = default_priority)
{
__scoped_device dev_setter(__dev);
[[maybe_unused]] __ensure_current_device __dev_setter(__dev);
_CCCL_TRY_CUDA_API(
::cudaStreamCreateWithPriority, "Failed to create a stream", &__stream, cudaStreamDefault, __priority);
}
Expand Down Expand Up @@ -89,7 +90,9 @@ struct stream : stream_ref
{
if (__stream != detail::invalid_stream)
{
[[maybe_unused]] auto status = ::cudaStreamDestroy(__stream);
// Needs to call driver API in case current device is not set, runtime version would set dev 0 current
// Alternative would be to store the device and push/pop here
[[maybe_unused]] auto status = detail::driver::streamDestroy(__stream);
}
}

Expand Down Expand Up @@ -139,18 +142,20 @@ struct stream : stream_ref
void wait(event_ref __ev) const
{
assert(__ev.get() != nullptr);
_CCCL_TRY_CUDA_API(::cudaStreamWaitEvent, "Failed to make a stream wait for an event", get(), __ev.get());
// Need to use driver API, cudaStreamWaitEvent would push dev 0 if stack was empty
detail::driver::streamWaitEvent(get(), __ev.get());
}

//! @brief Make all future work submitted into this stream depend on completion of all work from the specified stream
//! @brief Make all future work submitted into this stream depend on completion of all work from the specified
//! stream
//!
//! @param __other Stream that this stream should wait for
//!
//! @throws cuda_error if inserting the dependency fails
void wait(stream_ref __other) const
{
// TODO consider an optimization to not create an event every time and instead have one persistent event or one per
// stream
// TODO consider an optimization to not create an event every time and instead have one persistent event or one
// per stream
assert(__stream != detail::invalid_stream);
event __tmp(__other);
wait(__tmp);
Expand Down
Loading

0 comments on commit 75929cb

Please sign in to comment.