Skip to content

Commit

Permalink
Add tests for BQ SA keyfile_dict support
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana committed Jul 23, 2023
1 parent c966026 commit 19f4a70
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
3 changes: 2 additions & 1 deletion cosmos/profiles/bigquery/service_account_keyfile_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class GoogleCloudServiceAccountDictProfileMapping(BaseProfileMapping):

airflow_param_mapping = {
"project": "extra.project",
"dataset": "dataset",
# multiple options for dataset because of older Airflow versions
"dataset": ["extra.dataset", "dataset"],
# multiple options for keyfile_dict param name because of older Airflow versions
"keyfile_dict": ["extra.keyfile_dict", "keyfile_dict", "extra__google_cloud_platform__keyfile_dict"],
}
Expand Down
48 changes: 48 additions & 0 deletions tests/profiles/bigquery/test_bq_service_account_keyfile_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_profile_mapping
from cosmos.profiles.bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping


@pytest.fixture()
def mock_bigquery_conn_with_dict(): # type: ignore
"""
Mocks and returns an Airflow BigQuery connection.
"""
extra = {
"project": "my_project",
"dataset": "my_dataset",
"keyfile_dict": {"key": "value"},
}
conn = Connection(
conn_id="my_bigquery_connection",
conn_type="google_cloud_platform",
extra=json.dumps(extra),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_bigquery_mapping_selected(mock_bigquery_conn_with_dict: Connection):
profile_mapping = get_profile_mapping(
mock_bigquery_conn_with_dict.conn_id,
{"dataset": "my_dataset"},
)
assert isinstance(profile_mapping, GoogleCloudServiceAccountDictProfileMapping)


def test_connection_claiming_succeeds(mock_bigquery_conn_with_dict: Connection):
profile_mapping = GoogleCloudServiceAccountDictProfileMapping(mock_bigquery_conn_with_dict, {})
assert profile_mapping.can_claim_connection()


def test_connection_claiming_fails(mock_bigquery_conn_with_dict: Connection):
# Remove the dataset key, which is mandatory
mock_bigquery_conn_with_dict.extra = json.dumps({"project": "my_project", "keyfile_dict": {"key": "value"}})
profile_mapping = GoogleCloudServiceAccountDictProfileMapping(mock_bigquery_conn_with_dict, {})
assert not profile_mapping.can_claim_connection()

0 comments on commit 19f4a70

Please sign in to comment.