Skip to content

Commit

Permalink
Add support for BQ SA keyfile_dict
Browse files Browse the repository at this point in the history
Related issue: astronomer#350
  • Loading branch information
JoeSham committed Jun 28, 2023
1 parent ebc70e3 commit 682c34f
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 13 deletions.
2 changes: 2 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .base import BaseProfileMapping
from .bigquery.service_account_file import GoogleCloudServiceAccountFileProfileMapping
from .bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping
from .databricks.token import DatabricksTokenProfileMapping
from .exasol.user_pass import ExasolUserPasswordProfileMapping
from .postgres.user_pass import PostgresUserPasswordProfileMapping
Expand All @@ -20,6 +21,7 @@

profile_mappings: list[Type[BaseProfileMapping]] = [
GoogleCloudServiceAccountFileProfileMapping,
GoogleCloudServiceAccountDictProfileMapping,
DatabricksTokenProfileMapping,
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
Expand Down
19 changes: 7 additions & 12 deletions cosmos/profiles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class BaseProfileMapping(ABC):
airflow_connection_type: str = "generic"
is_community: bool = False

required_fields: list[str] = []
required_fields: list[str | list[str]] = []
secret_fields: list[str] = []
airflow_param_mapping: dict[str, str | list[str]] = {}

Expand All @@ -42,20 +42,15 @@ def can_claim_connection(self) -> bool:
if self.conn.conn_type != self.airflow_connection_type:
return False

for field in self.required_fields:
try:
if not getattr(self, field):
logger.info(
"Not using mapping %s because %s is not set",
self.__class__.__name__,
field,
)
return False
except AttributeError:
# Check if all required fields exist. If a sublist is provided, then it means 1 of the fields has to exist
for field_list in self.required_fields:
if isinstance(field_list, str):
field_list = [field_list]
if not any([getattr(self, field, False) for field in field_list]):
logger.info(
"Not using mapping %s because %s is not set",
self.__class__.__name__,
field,
field_list,
)
return False

Expand Down
6 changes: 5 additions & 1 deletion cosmos/profiles/bigquery/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"BigQuery Airflow connection -> dbt profile mappings"

from .service_account_file import GoogleCloudServiceAccountFileProfileMapping
from .service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping

__all__ = ["GoogleCloudServiceAccountFileProfileMapping"]
__all__ = [
"GoogleCloudServiceAccountFileProfileMapping",
"GoogleCloudServiceAccountDictProfileMapping",
]
43 changes: 43 additions & 0 deletions cosmos/profiles/bigquery/service_account_keyfile_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"Maps Airflow GCP connections to dbt BigQuery profiles if they use a service account keyfile dict/json."
from __future__ import annotations

from typing import Any

from cosmos.profiles.base import BaseProfileMapping


class GoogleCloudServiceAccountDictProfileMapping(BaseProfileMapping):
"""
Maps Airflow GCP connections to dbt BigQuery profiles if they use a service account keyfile dict/json.
https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup#service-account-file
https://airflow.apache.org/docs/apache-airflow-providers-google/stable/connections/gcp.html
"""

airflow_connection_type: str = "google_cloud_platform"

required_fields = [
"project",
"dataset",
"keyfile_dict",
]

airflow_param_mapping = {
"project": "extra.project",
"dataset": "dataset",
# multiple options for keyfile_dict param name because of older Airflow versions
"keyfile_dict": ["extra.keyfile_dict", "keyfile_dict", "extra__google_cloud_platform__keyfile_dict"],
}

@property
def profile(self) -> dict[str, Any | None]:
"""Generates profile. Defaults `threads` to 1."""
return {
"type": "bigquery",
"method": "service-account-json",
"project": self.project,
"dataset": self.dataset,
"threads": self.profile_args.get("threads") or 1,
"keyfile_json": self.keyfile_dict,
**self.profile_args,
}
8 changes: 8 additions & 0 deletions docs/dbt/connections-profiles.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ Service Account File
:members:


Service Account Dict
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: cosmos.profiles.bigquery.GoogleCloudServiceAccountDictProfileMapping
:undoc-members:
:members:


Databricks
----------

Expand Down

0 comments on commit 682c34f

Please sign in to comment.