Skip to content

Commit

Permalink
CUDA: CURTN: fix context management bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
atafra committed Mar 15, 2024
1 parent d548568 commit aec2285
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Version History
### Changes in v2.2.2:

- Fully fixed GPU memory leak when releasing SYCL, CUDA and HIP device objects
- Fixed CUDA context error in some cases when using the CUDA driver API
- Fixed crash on systems with unsupported AMD Vega integrated GPUs

### Changes in v2.2.1:
Expand Down
69 changes: 52 additions & 17 deletions devices/cuda/curtn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,20 @@ namespace curtn

~Runtime()
{
// We must release all primary contexts we've retained
for (const auto& i : contexts)
cuDevicePrimaryCtxRelease(i.first);
// Unload all modules in the primary contexts and release all primary contexts
// We can't clean up other contexts too because we don't know whether those are still alive
for (const auto& primaryContextItem : primaryContexts)
{
const CUcontext context = primaryContextItem.second;
if (cuCtxPushCurrent(context) == CUDA_SUCCESS)
{
for (const auto& moduleItem : contextStates[context].modules)
cuModuleUnload(moduleItem.second);
cuCtxPopCurrent(nullptr);
}

cuDevicePrimaryCtxRelease(primaryContextItem.first);
}
}

// Initializes the context for the given device (without setting it on the current thread)
Expand All @@ -237,14 +248,14 @@ namespace curtn

std::lock_guard<std::mutex> lock(mutex);

auto contextIter = contexts.find(device);
if (contextIter == contexts.end())
auto contextIter = primaryContexts.find(device);
if (contextIter == primaryContexts.end())
{
result = cuDevicePrimaryCtxRetain(&context, device);
if (result != CUDA_SUCCESS)
return result;

contexts[device] = context;
primaryContexts[device] = context;
}
else
context = contextIter->second;
Expand All @@ -263,15 +274,39 @@ namespace curtn
if (result != CUDA_SUCCESS)
return result;

if (context != nullptr)
return CUDA_SUCCESS;
if (context)
{
// If the current context is a primary context, and we're seeing it the first time, we need
// to retain it. Unfortunately, we can't tell whether it's a primary context so we retain
// the primary context for the device corresponding to the current context in either case.
CUdevice device;
result = cuCtxGetDevice(&device);
if (result != CUDA_SUCCESS)
return result;

// No current context, use device 0
result = initContext(0, context);
if (result != CUDA_SUCCESS)
return result;
std::lock_guard<std::mutex> lock(mutex);

if (primaryContexts.find(device) == primaryContexts.end())
{
CUcontext primaryContext;
result = cuDevicePrimaryCtxRetain(&primaryContext, device);
if (result != CUDA_SUCCESS)
return result;

return cuCtxSetCurrent(context);
primaryContexts[device] = primaryContext;
}

return CUDA_SUCCESS;
}
else
{
// No current context, use device 0
result = initContext(0, context);
if (result != CUDA_SUCCESS)
return result;

return cuCtxSetCurrent(context);
}
}

CUresult initCurrentContext()
Expand Down Expand Up @@ -368,7 +403,7 @@ namespace curtn

CUresult initResult = CUDA_ERROR_NOT_INITIALIZED; // result of CUDA initialization
std::unordered_map<CUdevice, int> deviceOrdinals; // device ordinals by device handle
std::unordered_map<CUdevice, CUcontext> contexts; // primary contexts by device handle
std::unordered_map<CUdevice, CUcontext> primaryContexts; // primary contexts by device handle
std::unordered_map<CUcontext, ContextState> contextStates; // per-context states
std::unordered_map<const void*, FunctionDesc> funcDescs; // function descriptors by symbol
std::mutex mutex;
Expand Down Expand Up @@ -422,9 +457,9 @@ namespace curtn
}

cudaError_t CUDARTAPI __cudaPopCallConfiguration(dim3* gridDim,
dim3* blockDim,
size_t* sharedMem,
cudaStream_t* stream)
dim3* blockDim,
size_t* sharedMem,
cudaStream_t* stream)
{
if (Runtime::callConfigs.empty())
return Runtime::setError(cudaErrorUnknown);
Expand Down

0 comments on commit aec2285

Please sign in to comment.