diff --git a/pkg/nvcdi/driver-nvml.go b/pkg/nvcdi/driver-nvml.go index 8fb39888..1a9b3c62 100644 --- a/pkg/nvcdi/driver-nvml.go +++ b/pkg/nvcdi/driver-nvml.go @@ -32,6 +32,12 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" ) +type driverVersionDiscoverer struct { + discover.Discover + nvidiaCDIHookPath string + version string +} + // NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation. // The supplied NVML Library is used to query the expected driver version. func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCDIHookPath string, ldconfigPath string, nvmllib nvml.Interface) (discover.Discover, error) { @@ -100,7 +106,11 @@ func NewDriverLibraryDiscoverer(logger logger.Interface, driver *root.Driver, nv hooks, _ := discover.NewLDCacheUpdateHook(logger, libraries, nvidiaCDIHookPath, ldconfigPath) d := discover.Merge( - libraries, + &driverVersionDiscoverer{ + Discover: libraries, + nvidiaCDIHookPath: nvidiaCDIHookPath, + version: version, + }, hooks, ) @@ -220,3 +230,37 @@ func getVersionLibs(logger logger.Interface, driver *root.Driver, version string return relative, nil } + +func (d driverVersionDiscoverer) Hooks() ([]discover.Hook, error) { + mounts, err := d.Discover.Mounts() + if err != nil { + return nil, fmt.Errorf("failed to get library mounts: %v", err) + } + + var links []string + for _, mount := range mounts { + dir, filename := filepath.Split(mount.Path) + // TODO: We should include the other libraries as is done here: + // https://github.com/NVIDIA/nvidia-container-toolkit/blob/79c59aeb7f59dd612793ac80a8d7022c554634bb/internal/platform-support/tegra/symlinks.go#L84-L97 + if d.isDriverLibrary(filename, "libcuda.so") { + // create libcuda.so -> libcuda.so.RM_VERSION symlink + links = append(links, fmt.Sprintf("%s::%s", filename, filepath.Join(dir, "libcuda.so"))) + } + } + + if len(links) == 0 { + return nil, nil + } + + hooks := discover.CreateCreateSymlinkHook(d.nvidiaCDIHookPath, links) + + return hooks.Hooks() + +} + +// isDriverLibrary checks whether the specified filename is a specific driver library. +func (d driverVersionDiscoverer) isDriverLibrary(filename string, libraryName string) bool { + pattern := strings.TrimSuffix(libraryName, ".") + d.version + match, _ := filepath.Match(pattern, filename) + return match +}