diff --git a/src/gha_runner/clouddeployment.py b/src/gha_runner/clouddeployment.py index edfa945..7f04c39 100644 --- a/src/gha_runner/clouddeployment.py +++ b/src/gha_runner/clouddeployment.py @@ -1,9 +1,13 @@ +import importlib.resources +import os from abc import ABC, abstractmethod -from gha_runner.gh import GitHubInstance from dataclasses import dataclass, field -import importlib.resources -import boto3 from string import Template +from typing import Optional + +import boto3 + +from gha_runner.gh import GitHubInstance class CloudDeployment(ABC): @@ -89,7 +93,6 @@ class AWS(CloudDeployment): instance_type: str home_dir: str repo: str - region_name: str runner_release: str = "" tags: list[dict[str, str]] = field(default_factory=list) gh_runner_tokens: list[str] = field(default_factory=list) @@ -98,6 +101,7 @@ class AWS(CloudDeployment): security_group_id: str = "" iam_role: str = "" script: str = "" + region_name: Optional[str] = None def _build_aws_params(self, user_data_params: dict) -> dict: """Build the parameters for the AWS API call. @@ -153,6 +157,10 @@ def create_instances(self) -> dict[str, str]: raise ValueError( "No instance type provided, cannot create instances." ) + if self.region_name is None and "AWS_DEFAULT_REGION" not in os.environ: + raise ValueError( + "No region name provided, cannot create instances." + ) ec2 = boto3.client("ec2", region_name=self.region_name) id_dict = {} for token in self.gh_runner_tokens: @@ -180,6 +188,10 @@ def create_instances(self) -> dict[str, str]: return id_dict def remove_instances(self, ids: list[str]): + if self.region_name is None and "AWS_DEFAULT_REGION" not in os.environ: + raise ValueError( + "No region name provided, cannot create instances." + ) ec2 = boto3.client("ec2", self.region_name) params = { "InstanceIds": ids, diff --git a/tests/test_aws.py b/tests/test_aws.py index cd8f017..bd867ee 100644 --- a/tests/test_aws.py +++ b/tests/test_aws.py @@ -142,6 +142,15 @@ def test_create_instances_missing_instance_type(aws): aws.create_instances() +def test_create_instances_missing_region(aws): + aws.region_name = None + os.environ.pop("AWS_DEFAULT_REGION") + with pytest.raises( + ValueError, match="No region name provided, cannot create instances." + ): + aws.create_instances() + + def test_instance_running(aws): ids = aws.create_instances() assert len(ids) == 1 @@ -202,6 +211,18 @@ def test_remove_instances(aws): assert not aws.instance_running(ids[0]) +def test_remove_instances_missing_region(aws): + ids = aws.create_instances() + assert len(ids) == 1 + ids = list(ids) + aws.region_name = None + os.environ.pop("AWS_DEFAULT_REGION") + with pytest.raises( + ValueError, match="No region name provided, cannot create instances." + ): + aws.remove_instances(ids) + + def test_wait_until_removed(aws): ids = aws.create_instances() assert len(ids) == 1