Skip to content

Commit

Permalink
Add support to GCP connections that define keyfile_dict instead of …
Browse files Browse the repository at this point in the history
…`keyfile` (#352)

Add support to Google Cloud Platform connections that define
`keyfile_dict` (actual value) instead of `keyfile` (path).

A design decision for this implementation was to not add `keyfile_json`
to `secret_fields`. This was not done because this property is
originally a JSON. While storing it as an environment variable is
simple, we'd need more significant changes to our profile parsing to
correctly render this to the `profile.yml` generated by Cosmos.

This used to work in Cosmos 0.6.x and stopped working in 0.7.x as part
of a previous profile refactors #271.

Closes: #350

Co-authored-by:  Tatiana Al-Chueyr <[email protected]>

**How to validate this change**

1. Have a GCP BQ service account with the `BigQuery Data Editor` role
2. Create a namespace that can be accessed by the service account (1)
3. Create an Airflow GCP connection that uses `keyfile` (path to the
service account credentials saved locally). Example (replace
`some-namespace` (2) and `key_path`(1)):
```
export AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT='{"conn_type": "google_cloud_platform", "extra": {"key_path": "/home/some-user/key.json", "scope": "https://www.googleapis.com/auth/cloud-platform", "project": "astronomer-dag-authoring", "dataset": "some-namespace" , "num_retries": 5}}'
```
4. Change the `basic_cosmos_dag.py` with the following lines, making
sure it references the Airflow connection created in (3) and the dataset
created in (2):
 ```
    conn_id="google_cloud_default",
    profile_args={
        "dataset": "some-namespace",
    },
```

5. Run the DAG, for instance:
```
PYTHONPATH=`pwd` AIRFLOW_HOME=`pwd`
AIRFLOW__CORE__DAGBAG_IMPORT_TIMEOUT=20000
AIRFLOW__CORE__DAG_FILE_PROCESSOR_TIMEOUT=20000 airflow dags test
basic_cosmos_dag `date -Iseconds`
```

6.  Change the Airflow GCP connection to use `keyfile_dict` (hard-code the `keyfile` content in the Airflow connection, replacing `<your keyfile content here>`)
```
export AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT='{"conn_type":
"google_cloud_platform", "extra": {"keyfile_dict": <your keyfile content
here>, "scope": "https://www.googleapis.com/auth/cloud-platform",
"project": "astronomer-dag-authoring", "dataset": "cosmos" ,
"num_retries": 5}}'
```

7. Run the previously created Cosmos-powered DAG (5) that confirms (6) works

---------

Co-authored-by: Tatiana Al-Chueyr <[email protected]>
  • Loading branch information
JoeSham and tatiana authored Jul 24, 2023
1 parent 1058544 commit f153c49
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 6 deletions.
17 changes: 17 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,25 @@ jobs:
matrix:
python-version: ["3.10"]
airflow-version: ["2.6"]
if: >-
github.event_name == 'push' ||
(
github.event_name == 'pull_request' &&
github.event.pull_request.head.repo.fork == false
) ||
(
github.event_name == 'pull_request_target' &&
contains(github.event.pull_request.labels.*.name, 'safe')
)
steps:
- uses: actions/checkout@v3
if: github.event_name != 'pull_request_target'

- name: Checkout pull/${{ github.event.number }}
uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}
if: github.event_name == 'pull_request_target'

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
1 change: 0 additions & 1 deletion cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def run_command(
output_encoding=self.output_encoding,
cwd=tmp_project_dir,
)

self.exception_handling(result)
self.store_compiled_sql(tmp_project_dir, context)
if self.callback:
Expand Down
2 changes: 2 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .base import BaseProfileMapping
from .bigquery.service_account_file import GoogleCloudServiceAccountFileProfileMapping
from .bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping
from .databricks.token import DatabricksTokenProfileMapping
from .exasol.user_pass import ExasolUserPasswordProfileMapping
from .postgres.user_pass import PostgresUserPasswordProfileMapping
Expand All @@ -21,6 +22,7 @@

profile_mappings: list[Type[BaseProfileMapping]] = [
GoogleCloudServiceAccountFileProfileMapping,
GoogleCloudServiceAccountDictProfileMapping,
DatabricksTokenProfileMapping,
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
Expand Down
9 changes: 5 additions & 4 deletions cosmos/profiles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

from abc import ABC, abstractmethod

from logging import getLogger
from typing import Any

Expand Down Expand Up @@ -42,18 +41,21 @@ def can_claim_connection(self) -> bool:
if self.conn.conn_type != self.airflow_connection_type:
return False

logger.info(dir(self.conn))
logger.info(self.conn.__dict__)

for field in self.required_fields:
try:
if not getattr(self, field):
logger.info(
"Not using mapping %s because %s is not set",
"1 Not using mapping %s because %s is not set",
self.__class__.__name__,
field,
)
return False
except AttributeError:
logger.info(
"Not using mapping %s because %s is not set",
"2 Not using mapping %s because %s is not set",
self.__class__.__name__,
field,
)
Expand Down Expand Up @@ -97,7 +99,6 @@ def get_profile_file_contents(self, profile_name: str, target_name: str = "cosmo
"outputs": {target_name: profile_vars},
}
}

return str(yaml.dump(profile_contents, indent=4))

def get_dbt_value(self, name: str) -> Any:
Expand Down
6 changes: 5 additions & 1 deletion cosmos/profiles/bigquery/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"BigQuery Airflow connection -> dbt profile mappings"

from .service_account_file import GoogleCloudServiceAccountFileProfileMapping
from .service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping

__all__ = ["GoogleCloudServiceAccountFileProfileMapping"]
__all__ = [
"GoogleCloudServiceAccountFileProfileMapping",
"GoogleCloudServiceAccountDictProfileMapping",
]
48 changes: 48 additions & 0 deletions cosmos/profiles/bigquery/service_account_keyfile_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"Maps Airflow GCP connections to dbt BigQuery profiles if they use a service account keyfile dict/json."
from __future__ import annotations

from typing import Any

from cosmos.profiles.base import BaseProfileMapping


class GoogleCloudServiceAccountDictProfileMapping(BaseProfileMapping):
"""
Maps Airflow GCP connections to dbt BigQuery profiles if they use a service account keyfile dict/json.
https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup#service-account-file
https://airflow.apache.org/docs/apache-airflow-providers-google/stable/connections/gcp.html
"""

airflow_connection_type: str = "google_cloud_platform"

required_fields = [
"project",
"dataset",
"keyfile_dict",
]

airflow_param_mapping = {
"project": "extra.project",
# multiple options for dataset because of older Airflow versions
"dataset": ["extra.dataset", "dataset"],
# multiple options for keyfile_dict param name because of older Airflow versions
"keyfile_dict": ["extra.keyfile_dict", "keyfile_dict", "extra__google_cloud_platform__keyfile_dict"],
}

@property
def profile(self) -> dict[str, Any | None]:
"""
Generates a GCP profile.
Even though the Airflow connection contains hard-coded Service account credentials,
we generate a temporary file and the DBT profile uses it.
"""
return {
"type": "bigquery",
"method": "service-account-json",
"project": self.project,
"dataset": self.dataset,
"threads": self.profile_args.get("threads") or 1,
"keyfile_json": self.keyfile_dict,
**self.profile_args,
}
8 changes: 8 additions & 0 deletions docs/dbt/connections-profiles.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ Service Account File
:members:


Service Account Dict
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: cosmos.profiles.bigquery.GoogleCloudServiceAccountDictProfileMapping
:undoc-members:
:members:


Databricks
----------

Expand Down
61 changes: 61 additions & 0 deletions tests/profiles/bigquery/test_bq_service_account_keyfile_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_profile_mapping
from cosmos.profiles.bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping


@pytest.fixture()
def mock_bigquery_conn_with_dict(): # type: ignore
"""
Mocks and returns an Airflow BigQuery connection.
"""
extra = {
"project": "my_project",
"dataset": "my_dataset",
"keyfile_dict": {"key": "value"},
}
conn = Connection(
conn_id="my_bigquery_connection",
conn_type="google_cloud_platform",
extra=json.dumps(extra),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_bigquery_mapping_selected(mock_bigquery_conn_with_dict: Connection):
profile_mapping = get_profile_mapping(
mock_bigquery_conn_with_dict.conn_id,
{"dataset": "my_dataset"},
)
assert isinstance(profile_mapping, GoogleCloudServiceAccountDictProfileMapping)


def test_connection_claiming_succeeds(mock_bigquery_conn_with_dict: Connection):
profile_mapping = GoogleCloudServiceAccountDictProfileMapping(mock_bigquery_conn_with_dict, {})
assert profile_mapping.can_claim_connection()


def test_connection_claiming_fails(mock_bigquery_conn_with_dict: Connection):
# Remove the `dataset` key, which is mandatory
mock_bigquery_conn_with_dict.extra = json.dumps({"project": "my_project", "keyfile_dict": {"key": "value"}})
profile_mapping = GoogleCloudServiceAccountDictProfileMapping(mock_bigquery_conn_with_dict, {})
assert not profile_mapping.can_claim_connection()


def test_profile(mock_bigquery_conn_with_dict: Connection):
profile_mapping = GoogleCloudServiceAccountDictProfileMapping(mock_bigquery_conn_with_dict, {})
expected = {
"type": "bigquery",
"method": "service-account-json",
"project": "my_project",
"dataset": "my_dataset",
"threads": 1,
"keyfile_json": {"key": "value"},
}
assert profile_mapping.profile == expected

0 comments on commit f153c49

Please sign in to comment.