Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix account credentials #1064

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 25 additions & 62 deletions aws-lambda/src/databricks_cdk/resources/account/credentials.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import logging
from typing import Optional

from databricks.sdk.service.provisioning import CreateCredentialAwsCredentials, CreateCredentialStsRole
from pydantic import BaseModel

from databricks_cdk.utils import (
ACCOUNTS_BASE_URL,
CnfResponse,
delete_request,
get_account_id,
get_request,
post_request,
)
from databricks_cdk.utils import CnfResponse, get_account_client

logger = logging.getLogger(__name__)

Expand All @@ -29,79 +22,49 @@ class CredentialsResponse(CnfResponse):
creation_time: int


def get_credentials_url():
"""Getting url for credentials requests"""
account_id = get_account_id()

# api-endpoint
return f"{ACCOUNTS_BASE_URL}/api/2.0/accounts/{account_id}/credentials"


def get_credentials_by_id(credentials_id: str) -> Optional[dict]:
URL = get_credentials_url()
return get_request(url=f"{URL}/{credentials_id}")


def get_credentials_by_name(credentials_name: str) -> Optional[dict]:
"""Getting credentials based on name"""
URL = get_credentials_url()
get_response = get_request(url=URL)
current = None
for r in get_response:
if r.get("credentials_name") == credentials_name:
current = r
return current


def create_or_update_credentials(
properties: CredentialsProperties,
) -> CredentialsResponse:
"""Create credentials config at databricks"""
url = get_credentials_url()
account_client = get_account_client()
all_credentials = account_client.credentials.list()
current = next((c for c in all_credentials if c.credentials_name == properties.credentials_name), None)

current = get_credentials_by_name(properties.credentials_name)
if current is None:

# Json data
body = {
"credentials_name": properties.credentials_name,
"aws_credentials": {"sts_role": {"role_arn": properties.role_arn}},
}
response = post_request(url, body=body)
external_id = response.get("aws_credentials", {}).get("sts_role", "").get("external_id")
# create credential
credential = account_client.credentials.create(
credentials_name=properties.credentials_name,
aws_credentials=CreateCredentialAwsCredentials(CreateCredentialStsRole(properties.role_arn)),
)
return CredentialsResponse(
physical_resource_id=response["credentials_id"],
physical_resource_id=credential.credentials_id,
credentials_name=properties.credentials_name,
credentials_id=response["credentials_id"],
credentials_id=credential.credentials_id,
role_arn=properties.role_arn,
external_id=external_id,
creation_time=response["creation_time"],
external_id=credential.aws_credentials.sts_role.external_id,
creation_time=credential.creation_time,
)
else:
current_role_arn = current.get("aws_credentials", {}).get("sts_role", "").get("role_arn")
external_id = current.get("aws_credentials", {}).get("sts_role", "").get("external_id")
if current_role_arn != properties.role_arn:
if current.aws_credentials.sts_role.role_arn != properties.role_arn:
raise AttributeError("Role arn can't be changed after deployment")
return CredentialsResponse(
physical_resource_id=current["credentials_id"],
physical_resource_id=current.credentials_id,
credentials_name=properties.credentials_name,
credentials_id=current["credentials_id"],
credentials_id=current.credentials_id,
role_arn=properties.role_arn,
external_id=external_id,
creation_time=current["creation_time"],
external_id=current.aws_credentials.sts_role.external_id,
creation_time=current.creation_time,
)


def delete_credentials(properties: CredentialsProperties, physical_resource_id: str) -> CnfResponse:
def delete_credentials(physical_resource_id: str) -> CnfResponse:
"""Deletes credentials config at databricks"""
URL = get_credentials_url()
account_client = get_account_client()

current = get_credentials_by_id(physical_resource_id)
if current is None:
current = get_credentials_by_name(properties.credentials_name)
current = account_client.credentials.get(credentials_id=physical_resource_id)
if current is not None:
credentials_id = current["credentials_id"]
delete_request(f"{URL}/{credentials_id}")
account_client.credentials.delete(credentials_id=physical_resource_id)
else:
logger.warning("Already removed")
return CnfResponse(physical_resource_id=properties.credentials_name)

return CnfResponse(physical_resource_id=physical_resource_id)
2 changes: 1 addition & 1 deletion aws-lambda/src/databricks_cdk/resources/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def delete_resource(event: DatabricksEvent) -> CnfResponse:
"""Delete a given resource"""
action = event.action()
if action == "credentials":
return delete_credentials(CredentialsProperties(**event.ResourceProperties), event.PhysicalResourceId)
return delete_credentials(event.PhysicalResourceId)
elif action == "storage-configurations":
return delete_storage_configuration(
StorageConfigProperties(**event.ResourceProperties),
Expand Down
16 changes: 15 additions & 1 deletion aws-lambda/src/databricks_cdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, Optional

import boto3
from databricks.sdk import WorkspaceClient
from databricks.sdk import AccountClient, WorkspaceClient
from databricks.sdk.core import Config
from pydantic import BaseModel
from requests import request
Expand Down Expand Up @@ -143,3 +143,17 @@ def get_workspace_client(workspace_url: str, config: Optional[Config] = None) ->
return WorkspaceClient(config=config)

return WorkspaceClient(username=get_deploy_user(), password=get_password(), host=workspace_url)


def get_account_client(
config: Optional[Config] = None, host: str = "https://accounts.cloud.databricks.com"
) -> AccountClient:
"""Get Databricks AccountClient instance, either from config defaulting to ssm params
:param host: Url to account url to, defaults to 'https://accounts.cloud.databricks.com'
:param config: Optional config to use, when provided overwrites workspace_url provided,
defaults to None
"""
if config:
return AccountClient(config=config)

return AccountClient(username=get_deploy_user(), password=get_password(), host=host, account_id=get_account_id())
12 changes: 11 additions & 1 deletion aws-lambda/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest.mock import MagicMock

import pytest
from databricks.sdk import ExperimentsAPI, ModelRegistryAPI, VolumesAPI, WorkspaceClient
from databricks.sdk import AccountClient, CredentialsAPI, ExperimentsAPI, ModelRegistryAPI, VolumesAPI, WorkspaceClient


@pytest.fixture(scope="function", autouse=True)
Expand All @@ -26,3 +26,13 @@ def workspace_client():
workspace_client.volumes = MagicMock(spec=VolumesAPI)

return workspace_client


@pytest.fixture(scope="function")
def account_client():
account_client = MagicMock(spec=AccountClient)

# mock all of the underlying service api's
account_client.credentials = MagicMock(spec=CredentialsAPI)

return account_client
132 changes: 132 additions & 0 deletions aws-lambda/tests/resources/account/test_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from unittest.mock import patch

import pytest
from databricks.sdk.service.provisioning import (
AwsCredentials,
CreateCredentialAwsCredentials,
CreateCredentialStsRole,
Credential,
StsRole,
)

from databricks_cdk.resources.account.credentials import (
CredentialsProperties,
CredentialsResponse,
create_or_update_credentials,
delete_credentials,
)
from databricks_cdk.utils import CnfResponse


@patch("databricks_cdk.resources.account.credentials.get_account_client")
def test_create_credentials_non_existing(patched_get_account_client, account_client):
aws_credentials = AwsCredentials(sts_role=StsRole(role_arn="arn:aws:iam::xxxxxx:role/xxx", external_id="test"))
account_client.credentials.create.return_value = Credential(
credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479",
credentials_name="test",
aws_credentials=aws_credentials,
creation_time=0,
account_id="1234",
)

patched_get_account_client.return_value = account_client

props = CredentialsProperties(
credentials_name="test",
role_arn="arn:aws:iam::xxxxxx:role/xxx",
)

response = create_or_update_credentials(props)
assert response == CredentialsResponse(
physical_resource_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479",
credentials_name="test",
credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479",
role_arn="arn:aws:iam::xxxxxx:role/xxx",
external_id="test",
creation_time=0,
)
account_client.credentials.create.assert_called_once_with(
credentials_name="test",
aws_credentials=CreateCredentialAwsCredentials(
sts_role=CreateCredentialStsRole(role_arn="arn:aws:iam::xxxxxx:role/xxx")
),
)


@patch("databricks_cdk.resources.account.credentials.get_account_client")
def test_create_credentials_exists(patched_get_account_client, account_client):
aws_credentials = AwsCredentials(sts_role=StsRole(role_arn="arn:aws:iam::xxxxxx:role/xxx", external_id="test"))
current = Credential(
credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479",
credentials_name="test",
aws_credentials=aws_credentials,
creation_time=0,
account_id="1234",
)

account_client.credentials.list.return_value = [current]
patched_get_account_client.return_value = account_client

props = CredentialsProperties(
credentials_name="test",
role_arn="arn:aws:iam::xxxxxx:role/xxx",
)

response = create_or_update_credentials(props)
assert response == CredentialsResponse(
physical_resource_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479",
credentials_name="test",
credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479",
role_arn="arn:aws:iam::xxxxxx:role/xxx",
external_id="test",
creation_time=0,
)

# make sure create is not called
assert account_client.credentials.create.call_count == 0


@patch("databricks_cdk.resources.account.credentials.get_account_client")
def test_create_credentials_different_role_arn(patched_get_account_client, account_client):
current = Credential(
credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479",
credentials_name="test",
aws_credentials=AwsCredentials(sts_role=StsRole(role_arn="arn:aws:iam::xxxxxx:role/xxx", external_id="test")),
creation_time=0,
account_id="1234",
)

account_client.credentials.list.return_value = [current]
patched_get_account_client.return_value = account_client

# existing has different role_arn
props = CredentialsProperties(
credentials_name="test",
role_arn="arn:aws:iam::xxxxxx:role/different",
)

with pytest.raises(AttributeError):
create_or_update_credentials(props)

# make sure create is not called
assert account_client.credentials.create.call_count == 0


@patch("databricks_cdk.resources.account.credentials.get_account_client")
def test_delete_credentials(patched_get_account_client, account_client):
current = Credential(
credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479",
credentials_name="test",
aws_credentials=AwsCredentials(sts_role=StsRole(role_arn="arn:aws:iam::xxxxxx:role/xxx", external_id="test")),
creation_time=0,
account_id="1234",
)
account_client.credentials.get.return_value = current
patched_get_account_client.return_value = account_client

response = delete_credentials("2206cf0e-ddeb-4982-a6d1-28bc887c8479")
assert response == CnfResponse(
physical_resource_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479",
)

account_client.credentials.delete.assert_called_once_with(credentials_id="2206cf0e-ddeb-4982-a6d1-28bc887c8479")
30 changes: 30 additions & 0 deletions typescript/tests/resources/account/accountCredentials.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import {Template} from "aws-cdk-lib/assertions";
import * as cdk from "aws-cdk-lib";

import {DatabricksDeployLambda, AccountCredentials} from "../../../src";

describe("Credentials", () => {
test("Credentials Custom Resource synthesizes the way we expect", () => {
const app = new cdk.App();
const databricksStack = new cdk.Stack(app, "DatabricksStack");
const deployLambda = DatabricksDeployLambda.fromServiceToken(databricksStack, "DeployLambda", "some-arn");

new AccountCredentials(databricksStack, "Credentials", {
serviceToken: deployLambda.serviceToken.toString(),
roleArn: "some-role-arn",
credentialsName: "some-credentials-name",
});

const template = Template.fromStack(databricksStack);
template.hasResourceProperties("AWS::CloudFormation::CustomResource",
{
"ServiceToken": "some-arn",
"action": "credentials",
"credentials_name": "some-credentials-name",
"role_arn": "some-role-arn",
}
);

});

});
Loading