diff --git a/aws-lambda/src/databricks_cdk/resources/account/credentials.py b/aws-lambda/src/databricks_cdk/resources/account/credentials.py index 8b561936..f357967e 100644 --- a/aws-lambda/src/databricks_cdk/resources/account/credentials.py +++ b/aws-lambda/src/databricks_cdk/resources/account/credentials.py @@ -1,16 +1,9 @@ import logging -from typing import Optional +from databricks.sdk.service.provisioning import CreateCredentialAwsCredentials, CreateCredentialStsRole from pydantic import BaseModel -from databricks_cdk.utils import ( - ACCOUNTS_BASE_URL, - CnfResponse, - delete_request, - get_account_id, - get_request, - post_request, -) +from databricks_cdk.utils import CnfResponse, get_account_client logger = logging.getLogger(__name__) @@ -29,79 +22,49 @@ class CredentialsResponse(CnfResponse): creation_time: int -def get_credentials_url(): - """Getting url for credentials requests""" - account_id = get_account_id() - - # api-endpoint - return f"{ACCOUNTS_BASE_URL}/api/2.0/accounts/{account_id}/credentials" - - -def get_credentials_by_id(credentials_id: str) -> Optional[dict]: - URL = get_credentials_url() - return get_request(url=f"{URL}/{credentials_id}") - - -def get_credentials_by_name(credentials_name: str) -> Optional[dict]: - """Getting credentials based on name""" - URL = get_credentials_url() - get_response = get_request(url=URL) - current = None - for r in get_response: - if r.get("credentials_name") == credentials_name: - current = r - return current - - def create_or_update_credentials( properties: CredentialsProperties, ) -> CredentialsResponse: """Create credentials config at databricks""" - url = get_credentials_url() + account_client = get_account_client() + all_credentials = account_client.credentials.list() + current = next((c for c in all_credentials if c.credentials_name == properties.credentials_name), None) - current = get_credentials_by_name(properties.credentials_name) if current is None: - - # Json data - body = { - "credentials_name": properties.credentials_name, - "aws_credentials": {"sts_role": {"role_arn": properties.role_arn}}, - } - response = post_request(url, body=body) - external_id = response.get("aws_credentials", {}).get("sts_role", "").get("external_id") + # create credential + credential = account_client.credentials.create( + credentials_name=properties.credentials_name, + aws_credentials=CreateCredentialAwsCredentials(CreateCredentialStsRole(properties.role_arn)), + ) return CredentialsResponse( - physical_resource_id=response["credentials_id"], + physical_resource_id=credential.credentials_id, credentials_name=properties.credentials_name, - credentials_id=response["credentials_id"], + credentials_id=credential.credentials_id, role_arn=properties.role_arn, - external_id=external_id, - creation_time=response["creation_time"], + external_id=credential.aws_credentials.sts_role.external_id, + creation_time=credential.creation_time, ) else: - current_role_arn = current.get("aws_credentials", {}).get("sts_role", "").get("role_arn") - external_id = current.get("aws_credentials", {}).get("sts_role", "").get("external_id") - if current_role_arn != properties.role_arn: + if current.aws_credentials.sts_role.role_arn != properties.role_arn: raise AttributeError("Role arn can't be changed after deployment") return CredentialsResponse( - physical_resource_id=current["credentials_id"], + physical_resource_id=current.credentials_id, credentials_name=properties.credentials_name, - credentials_id=current["credentials_id"], + credentials_id=current.credentials_id, role_arn=properties.role_arn, - external_id=external_id, - creation_time=current["creation_time"], + external_id=current.aws_credentials.sts_role.external_id, + creation_time=current.creation_time, ) -def delete_credentials(properties: CredentialsProperties, physical_resource_id: str) -> CnfResponse: +def delete_credentials(physical_resource_id: str) -> CnfResponse: """Deletes credentials config at databricks""" - URL = get_credentials_url() + account_client = get_account_client() - current = get_credentials_by_id(physical_resource_id) - if current is None: - current = get_credentials_by_name(properties.credentials_name) + current = account_client.credentials.get(credentials_id=physical_resource_id) if current is not None: - credentials_id = current["credentials_id"] - delete_request(f"{URL}/{credentials_id}") + account_client.credentials.delete(credentials_id=physical_resource_id) else: logger.warning("Already removed") - return CnfResponse(physical_resource_id=properties.credentials_name) + + return CnfResponse(physical_resource_id=physical_resource_id) diff --git a/aws-lambda/src/databricks_cdk/resources/handler.py b/aws-lambda/src/databricks_cdk/resources/handler.py index 07f687b6..480604ce 100644 --- a/aws-lambda/src/databricks_cdk/resources/handler.py +++ b/aws-lambda/src/databricks_cdk/resources/handler.py @@ -229,7 +229,7 @@ def delete_resource(event: DatabricksEvent) -> CnfResponse: """Delete a given resource""" action = event.action() if action == "credentials": - return delete_credentials(CredentialsProperties(**event.ResourceProperties), event.PhysicalResourceId) + return delete_credentials(event.PhysicalResourceId) elif action == "storage-configurations": return delete_storage_configuration( StorageConfigProperties(**event.ResourceProperties), diff --git a/aws-lambda/src/databricks_cdk/utils.py b/aws-lambda/src/databricks_cdk/utils.py index 9d78d78c..d9a9bec9 100644 --- a/aws-lambda/src/databricks_cdk/utils.py +++ b/aws-lambda/src/databricks_cdk/utils.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Optional import boto3 -from databricks.sdk import WorkspaceClient +from databricks.sdk import AccountClient, WorkspaceClient from databricks.sdk.core import Config from pydantic import BaseModel from requests import request @@ -143,3 +143,17 @@ def get_workspace_client(workspace_url: str, config: Optional[Config] = None) -> return WorkspaceClient(config=config) return WorkspaceClient(username=get_deploy_user(), password=get_password(), host=workspace_url) + + +def get_account_client( + config: Optional[Config] = None, host: str = "https://accounts.cloud.databricks.com" +) -> AccountClient: + """Get Databricks AccountClient instance, either from config defaulting to ssm params + :param host: Url to account url to, defaults to 'https://accounts.cloud.databricks.com' + :param config: Optional config to use, when provided overwrites workspace_url provided, + defaults to None + """ + if config: + return AccountClient(config=config) + + return AccountClient(username=get_deploy_user(), password=get_password(), host=host, account_id=get_account_id()) diff --git a/aws-lambda/tests/conftest.py b/aws-lambda/tests/conftest.py index 7dd573c4..297f4f8d 100644 --- a/aws-lambda/tests/conftest.py +++ b/aws-lambda/tests/conftest.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock import pytest -from databricks.sdk import ExperimentsAPI, ModelRegistryAPI, VolumesAPI, WorkspaceClient +from databricks.sdk import AccountClient, CredentialsAPI, ExperimentsAPI, ModelRegistryAPI, VolumesAPI, WorkspaceClient @pytest.fixture(scope="function", autouse=True) @@ -26,3 +26,13 @@ def workspace_client(): workspace_client.volumes = MagicMock(spec=VolumesAPI) return workspace_client + + +@pytest.fixture(scope="function") +def account_client(): + account_client = MagicMock(spec=AccountClient) + + # mock all of the underlying service api's + account_client.credentials = MagicMock(spec=CredentialsAPI) + + return account_client diff --git a/aws-lambda/tests/resources/account/test_credentials.py b/aws-lambda/tests/resources/account/test_credentials.py new file mode 100644 index 00000000..f419027d --- /dev/null +++ b/aws-lambda/tests/resources/account/test_credentials.py @@ -0,0 +1,132 @@ +from unittest.mock import patch + +import pytest +from databricks.sdk.service.provisioning import ( + AwsCredentials, + CreateCredentialAwsCredentials, + CreateCredentialStsRole, + Credential, + StsRole, +) + +from databricks_cdk.resources.account.credentials import ( + CredentialsProperties, + CredentialsResponse, + create_or_update_credentials, + delete_credentials, +) +from databricks_cdk.utils import CnfResponse + + +@patch("databricks_cdk.resources.account.credentials.get_account_client") +def test_create_credentials_non_existing(patched_get_account_client, account_client): + aws_credentials = AwsCredentials(sts_role=StsRole(role_arn="arn:aws:iam::xxxxxx:role/xxx", external_id="test")) + account_client.credentials.create.return_value = Credential( + credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479", + credentials_name="test", + aws_credentials=aws_credentials, + creation_time=0, + account_id="1234", + ) + + patched_get_account_client.return_value = account_client + + props = CredentialsProperties( + credentials_name="test", + role_arn="arn:aws:iam::xxxxxx:role/xxx", + ) + + response = create_or_update_credentials(props) + assert response == CredentialsResponse( + physical_resource_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479", + credentials_name="test", + credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479", + role_arn="arn:aws:iam::xxxxxx:role/xxx", + external_id="test", + creation_time=0, + ) + account_client.credentials.create.assert_called_once_with( + credentials_name="test", + aws_credentials=CreateCredentialAwsCredentials( + sts_role=CreateCredentialStsRole(role_arn="arn:aws:iam::xxxxxx:role/xxx") + ), + ) + + +@patch("databricks_cdk.resources.account.credentials.get_account_client") +def test_create_credentials_exists(patched_get_account_client, account_client): + aws_credentials = AwsCredentials(sts_role=StsRole(role_arn="arn:aws:iam::xxxxxx:role/xxx", external_id="test")) + current = Credential( + credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479", + credentials_name="test", + aws_credentials=aws_credentials, + creation_time=0, + account_id="1234", + ) + + account_client.credentials.list.return_value = [current] + patched_get_account_client.return_value = account_client + + props = CredentialsProperties( + credentials_name="test", + role_arn="arn:aws:iam::xxxxxx:role/xxx", + ) + + response = create_or_update_credentials(props) + assert response == CredentialsResponse( + physical_resource_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479", + credentials_name="test", + credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479", + role_arn="arn:aws:iam::xxxxxx:role/xxx", + external_id="test", + creation_time=0, + ) + + # make sure create is not called + assert account_client.credentials.create.call_count == 0 + + +@patch("databricks_cdk.resources.account.credentials.get_account_client") +def test_create_credentials_different_role_arn(patched_get_account_client, account_client): + current = Credential( + credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479", + credentials_name="test", + aws_credentials=AwsCredentials(sts_role=StsRole(role_arn="arn:aws:iam::xxxxxx:role/xxx", external_id="test")), + creation_time=0, + account_id="1234", + ) + + account_client.credentials.list.return_value = [current] + patched_get_account_client.return_value = account_client + + # existing has different role_arn + props = CredentialsProperties( + credentials_name="test", + role_arn="arn:aws:iam::xxxxxx:role/different", + ) + + with pytest.raises(AttributeError): + create_or_update_credentials(props) + + # make sure create is not called + assert account_client.credentials.create.call_count == 0 + + +@patch("databricks_cdk.resources.account.credentials.get_account_client") +def test_delete_credentials(patched_get_account_client, account_client): + current = Credential( + credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479", + credentials_name="test", + aws_credentials=AwsCredentials(sts_role=StsRole(role_arn="arn:aws:iam::xxxxxx:role/xxx", external_id="test")), + creation_time=0, + account_id="1234", + ) + account_client.credentials.get.return_value = current + patched_get_account_client.return_value = account_client + + response = delete_credentials("2206cf0e-ddeb-4982-a6d1-28bc887c8479") + assert response == CnfResponse( + physical_resource_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479", + ) + + account_client.credentials.delete.assert_called_once_with(credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479") diff --git a/typescript/tests/resources/account/accountCredentials.test.ts b/typescript/tests/resources/account/accountCredentials.test.ts new file mode 100644 index 00000000..7f3d5223 --- /dev/null +++ b/typescript/tests/resources/account/accountCredentials.test.ts @@ -0,0 +1,30 @@ +import {Template} from "aws-cdk-lib/assertions"; +import * as cdk from "aws-cdk-lib"; + +import {DatabricksDeployLambda, AccountCredentials} from "../../../src"; + +describe("Credentials", () => { + test("Credentials Custom Resource synthesizes the way we expect", () => { + const app = new cdk.App(); + const databricksStack = new cdk.Stack(app, "DatabricksStack"); + const deployLambda = DatabricksDeployLambda.fromServiceToken(databricksStack, "DeployLambda", "some-arn"); + + new AccountCredentials(databricksStack, "Credentials", { + serviceToken: deployLambda.serviceToken.toString(), + roleArn: "some-role-arn", + credentialsName: "some-credentials-name", + }); + + const template = Template.fromStack(databricksStack); + template.hasResourceProperties("AWS::CloudFormation::CustomResource", + { + "ServiceToken": "some-arn", + "action": "credentials", + "credentials_name": "some-credentials-name", + "role_arn": "some-role-arn", + } + ); + + }); + +}); \ No newline at end of file