Skip to content

Commit

Permalink
Store BQ profile keyfile_dict as an envvar
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana committed Jul 23, 2023
1 parent 19f4a70 commit 9ebda05
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
4 changes: 3 additions & 1 deletion cosmos/profiles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""
from __future__ import annotations

import json
from abc import ABC, abstractmethod

from logging import getLogger
from typing import Any

Expand Down Expand Up @@ -77,6 +77,8 @@ def env_vars(self) -> dict[str, str]:
for field in self.secret_fields:
env_var_name = self.get_env_var_name(field)
value = self.get_dbt_value(field)
if isinstance(value, dict):
env_vars[env_var_name] = json.dumps(value)
if value is not None:
env_vars[env_var_name] = str(value)

Expand Down
6 changes: 5 additions & 1 deletion cosmos/profiles/bigquery/service_account_keyfile_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class GoogleCloudServiceAccountDictProfileMapping(BaseProfileMapping):
"keyfile_dict",
]

secret_fields = [
"keyfile_dict",
]

airflow_param_mapping = {
"project": "extra.project",
# multiple options for dataset because of older Airflow versions
Expand All @@ -39,6 +43,6 @@ def profile(self) -> dict[str, Any | None]:
"project": self.project,
"dataset": self.dataset,
"threads": self.profile_args.get("threads") or 1,
"keyfile_json": self.keyfile_dict,
"keyfile_json": self.get_env_var_format("keyfile_dict"),
**self.profile_args,
}
18 changes: 17 additions & 1 deletion tests/profiles/bigquery/test_bq_service_account_keyfile_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,23 @@ def test_connection_claiming_succeeds(mock_bigquery_conn_with_dict: Connection):


def test_connection_claiming_fails(mock_bigquery_conn_with_dict: Connection):
# Remove the dataset key, which is mandatory
# Remove the `dataset` key, which is mandatory
mock_bigquery_conn_with_dict.extra = json.dumps({"project": "my_project", "keyfile_dict": {"key": "value"}})
profile_mapping = GoogleCloudServiceAccountDictProfileMapping(mock_bigquery_conn_with_dict, {})
assert not profile_mapping.can_claim_connection()


def test_profile_env_vars(
mock_bigquery_conn_with_dict: Connection,
) -> None:
"""
Tests that the environment variables get set correctly.
"""
profile_mapping = get_profile_mapping(
mock_bigquery_conn_with_dict.conn_id,
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_GOOGLE_CLOUD_PLATFORM_KEYFILE_DICT": str(
mock_bigquery_conn_with_dict.extra_dejson["keyfile_dict"]
),
}

0 comments on commit 9ebda05

Please sign in to comment.