Skip to content

Commit

Permalink
Add support for BQ SA keyfile_dict
Browse files Browse the repository at this point in the history
Related issue: astronomer#350
  • Loading branch information
JoeSham committed Jun 28, 2023
1 parent ebc70e3 commit 4ee1b33
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
19 changes: 7 additions & 12 deletions cosmos/profiles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class BaseProfileMapping(ABC):
airflow_connection_type: str = "generic"
is_community: bool = False

required_fields: list[str] = []
required_fields: list[str | list[str]] = []
secret_fields: list[str] = []
airflow_param_mapping: dict[str, str | list[str]] = {}

Expand All @@ -42,20 +42,15 @@ def can_claim_connection(self) -> bool:
if self.conn.conn_type != self.airflow_connection_type:
return False

for field in self.required_fields:
try:
if not getattr(self, field):
logger.info(
"Not using mapping %s because %s is not set",
self.__class__.__name__,
field,
)
return False
except AttributeError:
# Check if all required fields exist. If a sublist is provided, then it means 1 of the fields has to exist
for field_list in self.required_fields:
if isinstance(field_list, str):
field_list = [field_list]
if not any([getattr(self, field, False) for field in field_list]):
logger.info(
"Not using mapping %s because %s is not set",
self.__class__.__name__,
field,
field_list,
)
return False

Expand Down
27 changes: 21 additions & 6 deletions cosmos/profiles/bigquery/service_account_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,39 @@ class GoogleCloudServiceAccountFileProfileMapping(BaseProfileMapping):
required_fields = [
"project",
"dataset",
"keyfile",
# plus one of:
["keyfile", "keyfile_dict"],
]

airflow_param_mapping = {
"project": "extra.project",
"dataset": "dataset",
"keyfile": "extra.key_path",
# multiple options for keyfile/keyfile_dict param name because of older Airflow versions
"keyfile": ["key_path", "extra__google_cloud_platform__key_path", "extra.key_path"],
"keyfile_dict": ["keyfile_dict", "extra__google_cloud_platform__keyfile_dict", "extra.keyfile_dict"],
}

@property
def profile(self) -> dict[str, Any | None]:
"Generates profile. Defaults `threads` to 1."
return {
"""
Generates profile. Defaults `threads` to 1.
Profile can either use keyfile as path to json file, or keyfile_dict as directly the json dict
"""

profile_dict = {
"type": "bigquery",
"method": "service-account",
"project": self.project,
"dataset": self.dataset,
"threads": self.profile_args.get("threads") or 1,
"keyfile": self.keyfile,
**self.profile_args,
}

# use keyfile_dict if it exists, otherwise use keyfile
try:
profile_dict["keyfile_json"] = self.keyfile_dict
profile_dict["method"] = "service-account-json"
except AttributeError:
profile_dict["keyfile"] = self.keyfile
profile_dict["method"] = "service-account"

return profile_dict

0 comments on commit 4ee1b33

Please sign in to comment.