Skip to content

Commit

Permalink
In Google Batch, install GPU drivers for GPU VMs
Browse files Browse the repository at this point in the history
Also clean up some old logic for container options when using GPUs.
These are now automatically handled by Google Cloud.

Fixes nextflow-io#5372.

Signed-off-by: Siddhartha Bagaria <[email protected]>
  • Loading branch information
Sid Bagaria authored and siddharthab committed Oct 17, 2024
1 parent 0c9b333 commit 2860b09
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ class GoogleBatchMachineTypeSelector {

private static final List<String> DEFAULT_FAMILIES = ['n1-*', 'n2-*', 'n2d-*', 'c2-*', 'c2d-*', 'm1-*', 'm2-*', 'm3-*', 'e2-*']

/*
* Accelerator optimized families. See: https://cloud.google.com/compute/docs/accelerator-optimized-machines
* LAST UPDATE 2024-10-16
*/
private static final List<String> ACCELERATOR_OPTIMIZED_FAMILIES = ['a2-*', 'a3-*', 'g2-*']

@Immutable
static class MachineType {
String type
Expand All @@ -86,6 +92,7 @@ class GoogleBatchMachineTypeSelector {
float onDemandPrice
int cpusPerVm
int memPerVm
int gpusPerVm
PriceModel priceModel
}

Expand All @@ -97,7 +104,7 @@ class GoogleBatchMachineTypeSelector {
if (families.size() == 1) {
final familyOrType = families.get(0)
if (familyOrType.contains("custom-"))
return new MachineType(type: familyOrType, family: 'custom', cpusPerVm: cpus, memPerVm: memoryMB, location: region, priceModel: spot ? PriceModel.spot : PriceModel.standard)
return new MachineType(type: familyOrType, family: 'custom', cpusPerVm: cpus, memPerVm: memoryMB, gpusPerVm: 0, location: region, priceModel: spot ? PriceModel.spot : PriceModel.standard)

final machineType = getAvailableMachineTypes(region, spot).find { it.type == familyOrType }
if( machineType )
Expand Down Expand Up @@ -156,6 +163,7 @@ class GoogleBatchMachineTypeSelector {
onDemandPrice: it.onDemandPrice as float,
cpusPerVm: it.cpusPerVm as int,
memPerVm: it.memPerVm as int,
gpusPerVm: it.gpusPerVm as int,
location: region,
priceModel: priceModel
)
Expand Down Expand Up @@ -249,4 +257,19 @@ class GoogleBatchMachineTypeSelector {
return new MemoryUnit( numberOfDisks * 375L * (1<<30) )
}

/**
* Determine whether GPU drivers should be installed.
*
* @param machineType Machine type
* @return Boolean value indicating if GPU drivers should be installed.
*/
protected boolean installGpuDrivers(MachineType machineType) {
if ( machineType.gpusPerVm > 0 ) {
return true
}
// Cloud Info service currently does not currently return gpusPerVm values (or the user
// could have disabled use of the service) so also check against a known set of families.
return ACCELERATOR_OPTIMIZED_FAMILIES.find { matchType(it, machineType.type) }
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,8 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask {
.addAllCommands( cmd )
.addAllVolumes( launcher.getContainerMounts() )

final accel = task.config.getAccelerator()
// add nvidia specific driver paths
// see https://cloud.google.com/batch/docs/create-run-job#create-job-gpu
if( accel && accel.type.toLowerCase().startsWith('nvidia-') ) {
container
.addVolumes('/var/lib/nvidia/lib64:/usr/local/nvidia/lib64')
.addVolumes('/var/lib/nvidia/bin:/usr/local/nvidia/bin')
}

def containerOptions = task.config.getContainerOptions() ?: ''
// accelerator requires privileged option
// https://cloud.google.com/batch/docs/create-run-job#create-job-gpu
if( task.config.getAccelerator() || fusionEnabled() ) {
if( fusionEnabled() ) {
if( containerOptions ) containerOptions += ' '
containerOptions += '--privileged'
}
Expand Down Expand Up @@ -324,17 +313,6 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask {
else {
final instancePolicy = AllocationPolicy.InstancePolicy.newBuilder()

if( task.config.getAccelerator() ) {
final accelerator = AllocationPolicy.Accelerator.newBuilder()
.setCount( task.config.getAccelerator().getRequest() )

if( task.config.getAccelerator().getType() )
accelerator.setType( task.config.getAccelerator().getType() )

instancePolicy.addAccelerators(accelerator)
instancePolicyOrTemplate.setInstallGpuDrivers(true)
}

if( executor.config.getBootDiskImage() )
instancePolicy.setBootDisk( AllocationPolicy.Disk.newBuilder().setImage( executor.config.getBootDiskImage() ) )

Expand All @@ -347,13 +325,27 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask {

if( machineType ) {
instancePolicy.setMachineType(machineType.type)
instancePolicyOrTemplate.setInstallGpuDrivers(
GoogleBatchMachineTypeSelector.INSTANCE.installGpuDrivers(machineType)
)
machineInfo = new CloudMachineInfo(
type: machineType.type,
zone: machineType.location,
priceModel: machineType.priceModel
)
}

if( task.config.getAccelerator() ) {
final accelerator = AllocationPolicy.Accelerator.newBuilder()
.setCount( task.config.getAccelerator().getRequest() )

if( task.config.getAccelerator().getType() )
accelerator.setType( task.config.getAccelerator().getType() )

instancePolicy.addAccelerators(accelerator)
instancePolicyOrTemplate.setInstallGpuDrivers(true)
}

// When using local SSD not all the disk sizes are valid and depends on the machine type
if( disk?.type == 'local-ssd' && machineType ) {
final validSize = GoogleBatchMachineTypeSelector.INSTANCE.findValidLocalSSDSize(disk.request, machineType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,18 @@ class GoogleBatchMachineTypeSelectorTest extends Specification {
'50 GB' | 'c2d-highmem-56' | 'c2d' | 56 | '1500 GB'
'750 GB' | 'm3-megamem-64' | 'm3' | 64 | '1500 GB'
}

def 'should know when to install GPU drivers'() {
expect:
final machineType = new MachineType(type: TYPE, gpusPerVm: GPUS)
GoogleBatchMachineTypeSelector.INSTANCE.installGpuDrivers(machineType) == EXPECTED

where:
TYPE | GPUS | EXPECTED
'n2-standard-4' | 0 | false
'n2-standard-4' | 1 | true
'a2-highgpu-1g' | 0 | true
'a3-highgpu-1g' | 0 | true
'g2-standard-4' | 0 | true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class GoogleBatchTaskHandlerTest extends Specification {
and:
!instancePolicyOrTemplate.getInstanceTemplate()
and:
!instancePolicyOrTemplate.getInstallGpuDrivers()
instancePolicy.getAcceleratorsCount() == 0
instancePolicy.getDisksCount() == 0
!instancePolicy.getMachineType()
Expand Down Expand Up @@ -196,11 +197,9 @@ class GoogleBatchTaskHandlerTest extends Specification {
and:
runnable.getContainer().getCommandsList().join(' ') == '/bin/bash -o pipefail -c bash .command.run'
runnable.getContainer().getImageUri() == CONTAINER_IMAGE
runnable.getContainer().getOptions() == '--this --that --privileged'
runnable.getContainer().getOptions() == '--this --that'
runnable.getContainer().getVolumesList() == [
'/mnt/disks/foo/scratch:/mnt/disks/foo/scratch:rw',
'/var/lib/nvidia/lib64:/usr/local/nvidia/lib64',
'/var/lib/nvidia/bin:/usr/local/nvidia/bin'
]
and:
runnable.getEnvironment().getVariablesMap() == [:]
Expand Down

0 comments on commit 2860b09

Please sign in to comment.