Skip to content

Commit

Permalink
feat: use AWS_DEFAULT_REGION by default
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanholz committed Aug 22, 2024
1 parent c590fbd commit 8c278c2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/gha_runner/clouddeployment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import importlib.resources
import os
from abc import ABC, abstractmethod
from gha_runner.gh import GitHubInstance
from dataclasses import dataclass, field
import importlib.resources
import boto3
from string import Template
from typing import Optional

import boto3

from gha_runner.gh import GitHubInstance


class CloudDeployment(ABC):
Expand Down Expand Up @@ -89,7 +93,6 @@ class AWS(CloudDeployment):
instance_type: str
home_dir: str
repo: str
region_name: str
runner_release: str = ""
tags: list[dict[str, str]] = field(default_factory=list)
gh_runner_tokens: list[str] = field(default_factory=list)
Expand All @@ -98,6 +101,7 @@ class AWS(CloudDeployment):
security_group_id: str = ""
iam_role: str = ""
script: str = ""
region_name: Optional[str] = None

def _build_aws_params(self, user_data_params: dict) -> dict:
"""Build the parameters for the AWS API call.
Expand Down Expand Up @@ -153,6 +157,10 @@ def create_instances(self) -> dict[str, str]:
raise ValueError(
"No instance type provided, cannot create instances."
)
if self.region_name is None and "AWS_DEFAULT_REGION" not in os.environ:
raise ValueError(
"No region name provided, cannot create instances."
)
ec2 = boto3.client("ec2", region_name=self.region_name)
id_dict = {}
for token in self.gh_runner_tokens:
Expand Down Expand Up @@ -180,6 +188,10 @@ def create_instances(self) -> dict[str, str]:
return id_dict

def remove_instances(self, ids: list[str]):
if self.region_name is None and "AWS_DEFAULT_REGION" not in os.environ:
raise ValueError(
"No region name provided, cannot create instances."
)
ec2 = boto3.client("ec2", self.region_name)
params = {
"InstanceIds": ids,
Expand Down
21 changes: 21 additions & 0 deletions tests/test_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ def test_create_instances_missing_instance_type(aws):
aws.create_instances()


def test_create_instances_missing_region(aws):
aws.region_name = None
os.environ.pop("AWS_DEFAULT_REGION")
with pytest.raises(
ValueError, match="No region name provided, cannot create instances."
):
aws.create_instances()


def test_instance_running(aws):
ids = aws.create_instances()
assert len(ids) == 1
Expand Down Expand Up @@ -202,6 +211,18 @@ def test_remove_instances(aws):
assert not aws.instance_running(ids[0])


def test_remove_instances_missing_region(aws):
ids = aws.create_instances()
assert len(ids) == 1
ids = list(ids)
aws.region_name = None
os.environ.pop("AWS_DEFAULT_REGION")
with pytest.raises(
ValueError, match="No region name provided, cannot create instances."
):
aws.remove_instances(ids)


def test_wait_until_removed(aws):
ids = aws.create_instances()
assert len(ids) == 1
Expand Down

0 comments on commit 8c278c2

Please sign in to comment.