From aec22855b3895e90e2831d9bf4332624129eba63 Mon Sep 17 00:00:00 2001 From: Attila Afra Date: Fri, 15 Mar 2024 03:26:46 +0200 Subject: [PATCH] CUDA: CURTN: fix context management bugs --- CHANGELOG.md | 1 + devices/cuda/curtn.cpp | 69 +++++++++++++++++++++++++++++++----------- 2 files changed, 53 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 15253ffd..508e682d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ Version History ### Changes in v2.2.2: - Fully fixed GPU memory leak when releasing SYCL, CUDA and HIP device objects +- Fixed CUDA context error in some cases when using the CUDA driver API - Fixed crash on systems with unsupported AMD Vega integrated GPUs ### Changes in v2.2.1: diff --git a/devices/cuda/curtn.cpp b/devices/cuda/curtn.cpp index 9f300594..cb9bd8fa 100644 --- a/devices/cuda/curtn.cpp +++ b/devices/cuda/curtn.cpp @@ -219,9 +219,20 @@ namespace curtn ~Runtime() { - // We must release all primary contexts we've retained - for (const auto& i : contexts) - cuDevicePrimaryCtxRelease(i.first); + // Unload all modules in the primary contexts and release all primary contexts + // We can't clean up other contexts too because we don't know whether those are still alive + for (const auto& primaryContextItem : primaryContexts) + { + const CUcontext context = primaryContextItem.second; + if (cuCtxPushCurrent(context) == CUDA_SUCCESS) + { + for (const auto& moduleItem : contextStates[context].modules) + cuModuleUnload(moduleItem.second); + cuCtxPopCurrent(nullptr); + } + + cuDevicePrimaryCtxRelease(primaryContextItem.first); + } } // Initializes the context for the given device (without setting it on the current thread) @@ -237,14 +248,14 @@ namespace curtn std::lock_guard lock(mutex); - auto contextIter = contexts.find(device); - if (contextIter == contexts.end()) + auto contextIter = primaryContexts.find(device); + if (contextIter == primaryContexts.end()) { result = cuDevicePrimaryCtxRetain(&context, device); if (result != CUDA_SUCCESS) return result; - contexts[device] = context; + primaryContexts[device] = context; } else context = contextIter->second; @@ -263,15 +274,39 @@ namespace curtn if (result != CUDA_SUCCESS) return result; - if (context != nullptr) - return CUDA_SUCCESS; + if (context) + { + // If the current context is a primary context, and we're seeing it the first time, we need + // to retain it. Unfortunately, we can't tell whether it's a primary context so we retain + // the primary context for the device corresponding to the current context in either case. + CUdevice device; + result = cuCtxGetDevice(&device); + if (result != CUDA_SUCCESS) + return result; - // No current context, use device 0 - result = initContext(0, context); - if (result != CUDA_SUCCESS) - return result; + std::lock_guard lock(mutex); + + if (primaryContexts.find(device) == primaryContexts.end()) + { + CUcontext primaryContext; + result = cuDevicePrimaryCtxRetain(&primaryContext, device); + if (result != CUDA_SUCCESS) + return result; - return cuCtxSetCurrent(context); + primaryContexts[device] = primaryContext; + } + + return CUDA_SUCCESS; + } + else + { + // No current context, use device 0 + result = initContext(0, context); + if (result != CUDA_SUCCESS) + return result; + + return cuCtxSetCurrent(context); + } } CUresult initCurrentContext() @@ -368,7 +403,7 @@ namespace curtn CUresult initResult = CUDA_ERROR_NOT_INITIALIZED; // result of CUDA initialization std::unordered_map deviceOrdinals; // device ordinals by device handle - std::unordered_map contexts; // primary contexts by device handle + std::unordered_map primaryContexts; // primary contexts by device handle std::unordered_map contextStates; // per-context states std::unordered_map funcDescs; // function descriptors by symbol std::mutex mutex; @@ -422,9 +457,9 @@ namespace curtn } cudaError_t CUDARTAPI __cudaPopCallConfiguration(dim3* gridDim, - dim3* blockDim, - size_t* sharedMem, - cudaStream_t* stream) + dim3* blockDim, + size_t* sharedMem, + cudaStream_t* stream) { if (Runtime::callConfigs.empty()) return Runtime::setError(cudaErrorUnknown);