diff --git a/CHANGELOG.md b/CHANGELOG.md index f4a9e134..b8a25c24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # NVIDIA Container Toolkit Library and CLI Changelog +* Use D3DKMTEnumAdapters3 to enumerate adpaters on WSL2 if available. + ## 1.15.0~rc.2 * Added detection of libnvdxgdmal.so.1 on WSL2 diff --git a/src/dxcore.c b/src/dxcore.c index e65225f1..4d42204d 100644 --- a/src/dxcore.c +++ b/src/dxcore.c @@ -29,14 +29,17 @@ static const char * const dxcore_nvidia_driver_store_components[] = { */ struct dxcore_enumAdapters2; +struct dxcore_enumAdapters3; struct dxcore_queryAdapterInfo; typedef int(*pfnDxcoreEnumAdapters2)(struct dxcore_enumAdapters2* pParams); +typedef int(*pfnDxcoreEnumAdapters3)(struct dxcore_enumAdapters3* pParams); typedef int(*pfnDxcoreQueryAdapterInfo)(struct dxcore_queryAdapterInfo* pParams); struct dxcore_lib { void* hDxcoreLib; pfnDxcoreEnumAdapters2 pDxcoreEnumAdapters2; + pfnDxcoreEnumAdapters3 pDxcoreEnumAdapters3; pfnDxcoreQueryAdapterInfo pDxcoreQueryAdapterInfo; }; @@ -54,6 +57,15 @@ struct dxcore_enumAdapters2 struct dxcore_adapterInfo *pAdapters; }; +#define ENUMADAPTER3_FILTER_COMPUTE_ONLY (0x0000000000000001) + +struct dxcore_enumAdapters3 +{ + unsigned long long Filter; + unsigned int NumAdapters; + struct dxcore_adapterInfo *pAdapters; +}; + enum dxcore_kmtqueryAdapterInfoType { DXCORE_QUERYDRIVERVERSION = 13, @@ -270,7 +282,37 @@ static void dxcore_add_adapter(struct dxcore_context* pCtx, struct dxcore_lib* p log_infof("Adding new adapter via dxcore hAdapter:%x luid:%llx wddm version:%d", pAdapterInfo->hAdapter, *((unsigned long long*)&pAdapterInfo->AdapterLuid), wddmVersion); } -static void dxcore_enum_adapters(struct dxcore_context* pCtx, struct dxcore_lib* pLib) +static int dxcore_enum_adapters3(struct dxcore_context* pCtx, struct dxcore_lib* pLib) +{ + struct dxcore_enumAdapters3 params = {0}; + unsigned int adapterIndex = 0; + + // Include compute-only in addition to display+compute adapters + params.Filter = ENUMADAPTER3_FILTER_COMPUTE_ONLY; + params.NumAdapters = 0; + params.pAdapters = NULL; + + if (pLib->pDxcoreEnumAdapters3(¶ms)) { + log_err("Failed to enumerate adapters via enumAdapers3"); + return 1; + } + + params.pAdapters = malloc(sizeof(struct dxcore_adapterInfo) * params.NumAdapters); + if (pLib->pDxcoreEnumAdapters3(¶ms)) { + free(params.pAdapters); + log_err("Failed to enumerate adapters via enumAdapers3"); + return 1; + } + + for (adapterIndex = 0; adapterIndex < params.NumAdapters; adapterIndex++) { + dxcore_add_adapter(pCtx, pLib, ¶ms.pAdapters[adapterIndex]); + } + + free(params.pAdapters); + return 0; +} + +static int dxcore_enum_adapters2(struct dxcore_context* pCtx, struct dxcore_lib* pLib) { struct dxcore_enumAdapters2 params = {0}; unsigned int adapterIndex = 0; @@ -279,15 +321,15 @@ static void dxcore_enum_adapters(struct dxcore_context* pCtx, struct dxcore_lib* params.pAdapters = NULL; if (pLib->pDxcoreEnumAdapters2(¶ms)) { - log_err("Failed to enumerate adapters via dxcore"); - return; + log_err("Failed to enumerate adapters via enumAdapters2"); + return 1; } params.pAdapters = malloc(sizeof(struct dxcore_adapterInfo) * params.NumAdapters); if (pLib->pDxcoreEnumAdapters2(¶ms)) { free(params.pAdapters); - log_err("Failed to enumerate adapters via dxcore"); - return; + log_err("Failed to enumerate adapters via enumAdapters2"); + return 1; } for (adapterIndex = 0; adapterIndex < params.NumAdapters; adapterIndex++) { @@ -295,6 +337,27 @@ static void dxcore_enum_adapters(struct dxcore_context* pCtx, struct dxcore_lib* } free(params.pAdapters); + return 0; +} + +static void dxcore_enum_adapters(struct dxcore_context* pCtx, struct dxcore_lib* pLib) +{ + int status; + if (pLib->pDxcoreEnumAdapters3) { + status = dxcore_enum_adapters3(pCtx, pLib); + if (status == 0) { + return; + } + } + + // Fall back to EnumAdapters2 if the OS doesn't support EnumAdapters3 + if (pLib->pDxcoreEnumAdapters2) { + status = dxcore_enum_adapters2(pCtx, pLib); + if (status == 0) { + return; + } + } + log_err("Failed to enumerate adapters via dxcore"); } int dxcore_init_context(struct dxcore_context* pCtx) @@ -311,8 +374,9 @@ int dxcore_init_context(struct dxcore_context* pCtx) } lib.pDxcoreEnumAdapters2 = (pfnDxcoreEnumAdapters2)dlsym(lib.hDxcoreLib, "D3DKMTEnumAdapters2"); - if (!lib.pDxcoreEnumAdapters2) { - log_err("dxcore library is present but the symbol D3DKMTEnumAdapters2 is missing"); + lib.pDxcoreEnumAdapters3 = (pfnDxcoreEnumAdapters3)dlsym(lib.hDxcoreLib, "D3DKMTEnumAdapters3"); + if (!lib.pDxcoreEnumAdapters2 && !lib.pDxcoreEnumAdapters3) { + log_err("dxcore library is present but the symbols D3DKMTEnumAdapters2 and D3DKMTEnumAdapters3 are missing"); goto error; }