Skip to content

Commit

Permalink
Merge pull request #1533 from jemrobinson/tidy-azure-api
Browse files Browse the repository at this point in the history
Fix typing of external APIs
  • Loading branch information
jemrobinson authored Aug 1, 2023
2 parents d8b1667 + 01ba734 commit 188e491
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 41 deletions.
85 changes: 57 additions & 28 deletions data_safe_haven/external/api/azure_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
from collections.abc import Sequence
from contextlib import suppress
from typing import Any
from typing import Any, cast

from azure.core.exceptions import (
HttpResponseError,
Expand All @@ -21,26 +21,34 @@
from azure.mgmt.automation.models import (
DscCompilationJobCreateParameters,
DscConfigurationAssociationProperty,
Module,
)
from azure.mgmt.compute import ComputeManagementClient
from azure.mgmt.compute.models import RunCommandInput, RunCommandInputParameter
from azure.mgmt.dns import DnsManagementClient
from azure.mgmt.dns.models import RecordSet, TxtRecord
from azure.mgmt.keyvault import KeyVaultManagementClient
from azure.mgmt.keyvault.models import (
from azure.mgmt.compute.v2021_07_01 import ComputeManagementClient
from azure.mgmt.compute.v2021_07_01.models import (
ResourceSkuCapabilities,
RunCommandInput,
RunCommandInputParameter,
RunCommandResult,
)
from azure.mgmt.dns.v2018_05_01 import DnsManagementClient
from azure.mgmt.dns.v2018_05_01.models import RecordSet, TxtRecord
from azure.mgmt.keyvault.v2021_06_01_preview import KeyVaultManagementClient
from azure.mgmt.keyvault.v2021_06_01_preview.models import (
AccessPolicyEntry,
Permissions,
Sku as KeyVaultSku,
Vault,
VaultCreateOrUpdateParameters,
VaultProperties,
)
from azure.mgmt.msi import ManagedServiceIdentityClient
from azure.mgmt.msi.models import Identity
from azure.mgmt.resource import ResourceManagementClient, SubscriptionClient
from azure.mgmt.resource.resources.models import ResourceGroup
from azure.mgmt.storage import StorageManagementClient
from azure.mgmt.storage.models import (
from azure.mgmt.msi.v2022_01_31_preview import ManagedServiceIdentityClient
from azure.mgmt.msi.v2022_01_31_preview.models import Identity
from azure.mgmt.resource.resources.v2021_04_01 import ResourceManagementClient
from azure.mgmt.resource.resources.v2021_04_01.models import ResourceGroup
from azure.mgmt.resource.subscriptions import SubscriptionClient
from azure.mgmt.resource.subscriptions.models import Location
from azure.mgmt.storage.v2021_08_01 import StorageManagementClient
from azure.mgmt.storage.v2021_08_01.models import (
BlobContainer,
Kind as StorageAccountKind,
PublicAccess,
Expand Down Expand Up @@ -86,8 +94,12 @@ def compile_desired_state(
automation_client = AutomationClient(self.credential, self.subscription_id)
# Wait until all modules are available
while True:
available_modules = automation_client.module.list_by_automation_account(
resource_group_name, automation_account_name
# Cast to correct spurious type hint in Azure libraries
available_modules = cast(
list[Module],
automation_client.module.list_by_automation_account(
resource_group_name, automation_account_name
),
)
available_module_names = [
module.name
Expand Down Expand Up @@ -224,7 +236,6 @@ def ensure_keyvault(
) -> Vault:
"""Ensure that a KeyVault exists
Raises:
DataSafeHavenAzureError if the existence of the KeyVault could not be verified
"""
Expand Down Expand Up @@ -276,8 +287,11 @@ def ensure_keyvault(
),
),
)
# Cast to correct spurious type hint in Azure libraries
key_vaults = [
kv for kv in key_vault_client.vaults.list() if kv.name == key_vault_name
kv
for kv in cast(list[Vault], key_vault_client.vaults.list())
if kv.name == key_vault_name
]
self.logger.info(
f"Ensured that key vault [green]{key_vaults[0].name}[/] exists.",
Expand Down Expand Up @@ -429,7 +443,6 @@ def ensure_managed_identity(
msi_client = ManagedServiceIdentityClient(
self.credential, self.subscription_id
)
# mypy erroneously thinks that create_or_update returns Any rather than Identity
managed_identity = msi_client.user_assigned_identities.create_or_update(
resource_group_name,
identity_name,
Expand Down Expand Up @@ -468,9 +481,12 @@ def ensure_resource_group(
resource_group_name,
ResourceGroup(location=location, tags=tags),
)
# Cast to correct spurious type hint in Azure libraries
resource_groups = [
rg
for rg in resource_client.resource_groups.list()
for rg in cast(
list[ResourceGroup], resource_client.resource_groups.list()
)
if rg.name == resource_group_name
]
self.logger.info(
Expand Down Expand Up @@ -615,9 +631,12 @@ def get_locations(self) -> list[str]:
try:
subscription_client = SubscriptionClient(self.credential)
return [
location.name
for location in subscription_client.subscriptions.list_locations(
subscription_id=self.subscription_id
str(location.name)
for location in cast(
list[Location],
subscription_client.subscriptions.list_locations(
subscription_id=self.subscription_id
),
)
]
except Exception as exc:
Expand Down Expand Up @@ -672,7 +691,10 @@ def get_vm_sku_details(self, sku: str) -> tuple[str, str, str]:
for resource_sku in compute_client.resource_skus.list():
if resource_sku.name == sku:
if resource_sku.capabilities:
for capability in resource_sku.capabilities:
# Cast to correct spurious type hint in Azure libraries
for capability in cast(
list[ResourceSkuCapabilities], resource_sku.capabilities
):
if capability.name == "vCPUs":
cpus = capability.value
if capability.name == "GPUs":
Expand Down Expand Up @@ -746,7 +768,10 @@ def list_available_vm_skus(self, location: str) -> dict[str, dict[str, Any]]:
"GPUs": 0
} # default to 0 GPUs, overriding if appropriate
if resource_sku.capabilities:
for capability in resource_sku.capabilities:
# Cast to correct spurious type hint in Azure libraries
for capability in cast(
list[ResourceSkuCapabilities], resource_sku.capabilities
):
skus[resource_sku.name][capability.name] = capability.value
return skus
except Exception as exc:
Expand Down Expand Up @@ -889,9 +914,12 @@ def remove_resource_group(self, resource_group_name: str) -> None:
)
while not poller.done():
poller.wait(10)
# Cast to correct spurious type hint in Azure libraries
resource_groups = [
rg
for rg in resource_client.resource_groups.list()
for rg in cast(
list[ResourceGroup], resource_client.resource_groups.list()
)
if rg.name == resource_group_name
]
if resource_groups:
Expand Down Expand Up @@ -971,9 +999,10 @@ def run_remote_script(
poller = compute_client.virtual_machines.begin_run_command(
resource_group_name, vm_name, run_command_parameters
)
result = poller.result()
# Return stdout/stderr from the command
return str(result.value[0].message)
# Cast to correct spurious type hint in Azure libraries
result = cast(RunCommandResult, poller.result())
# Return any stdout/stderr from the command
return str(result.value[0].message) if result.value else ""
except Exception as exc:
msg = f"Failed to run command on '{vm_name}'.\n{exc}"
raise DataSafeHavenAzureError(msg) from exc
Expand Down
2 changes: 1 addition & 1 deletion data_safe_haven/external/api/graph_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ def verify_custom_domain(
# Check whether all expected nameservers are active
with suppress(resolver.NXDOMAIN):
active_nameservers = [
str(ns) for ns in resolver.resolve(domain_name, "NS")
str(ns) for ns in iter(resolver.resolve(domain_name, "NS"))
]
self.logger.info("Checking domain verification status.")
if all(
Expand Down
9 changes: 7 additions & 2 deletions data_safe_haven/external/interface/azure_authenticator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Standalone utility class for anything that needs to authenticate against Azure"""
from typing import cast

from azure.core.exceptions import ClientAuthenticationError
from azure.identity import DefaultAzureCredential
from azure.mgmt.resource import SubscriptionClient
from azure.mgmt.resource.subscriptions import SubscriptionClient
from azure.mgmt.resource.subscriptions.models import Subscription

from data_safe_haven.exceptions import (
DataSafeHavenAzureError,
Expand Down Expand Up @@ -53,7 +56,9 @@ def login(self) -> None:

# Check that the Azure credentials are valid
try:
for subscription in list(subscription_client.subscriptions.list()):
for subscription in cast(
list[Subscription], subscription_client.subscriptions.list()
):
if subscription.display_name == self.subscription_name:
self.subscription_id_ = subscription.subscription_id
self.tenant_id_ = subscription.tenant_id
Expand Down
2 changes: 0 additions & 2 deletions data_safe_haven/pulumi/components/automation_dsc_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Register a VM as an Azure Automation DSC node"""
import pathlib
import time
from collections.abc import Sequence

Expand Down Expand Up @@ -57,7 +56,6 @@ def __init__(
):
super().__init__("dsh:common:AutomationDscNode", name, {}, opts)
child_opts = ResourceOptions.merge(ResourceOptions(parent=self), opts)
pathlib.Path(__file__).parent.parent.parent / "resources"

# Upload the primary domain controller DSC
dsc = automation.DscConfiguration(
Expand Down
2 changes: 1 addition & 1 deletion data_safe_haven/pulumi/components/shm_bastion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(
# self.automation_account_name = automation_account_name
self.location = location
self.resource_group_name = resource_group_name
self.subnet_id = Output.from_input(subnet).apply(lambda s: s.id)
self.subnet_id = Output.from_input(subnet).apply(lambda s: s.id if s.id else "")


class SHMBastionComponent(ComponentResource):
Expand Down
14 changes: 7 additions & 7 deletions data_safe_haven/pulumi/components/shm_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ def __init__(
sku=automation.SkuArgs(name=automation.SkuNameEnum.FREE),
opts=child_opts,
)
automation_keys = automation.list_key_by_automation_account(
automation_account.name, resource_group_name=resource_group.name
)
automation_keys = Output.all(
automation_account_name=automation_account.name,
resource_group_name=resource_group.name,
).apply(lambda kwargs: automation.list_key_by_automation_account(**kwargs))

# List of modules as 'name: (version, SHA256 hash)'
# Note that we exclude ComputerManagementDsc which is already present (https://docs.microsoft.com/en-us/azure/automation/shared-resources/modules#default-modules)
Expand Down Expand Up @@ -156,10 +157,9 @@ def __init__(
workspace_name=f"{stack_name}-log",
opts=child_opts,
)
log_analytics_keys = operationalinsights.get_shared_keys(
resource_group_name=resource_group.name,
workspace_name=log_analytics.name,
)
log_analytics_keys = Output.all(
resource_group_name=resource_group.name, workspace_name=log_analytics.name
).apply(lambda kwargs: operationalinsights.get_shared_keys(**kwargs))

# Set up a private linkscope and endpoint for the log analytics workspace
log_analytics_private_link_scope = insights.PrivateLinkScope(
Expand Down

0 comments on commit 188e491

Please sign in to comment.