Skip to content

Commit

Permalink
Fix unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana committed Sep 30, 2024
1 parent 9ca5e85 commit 99bf7c0
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 41 deletions.
10 changes: 6 additions & 4 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ def validate_profiles_yml(self) -> None:
if self.profiles_yml_filepath and not Path(self.profiles_yml_filepath).exists():
raise CosmosValueError(f"The file {self.profiles_yml_filepath} does not exist.")

def get_profile_type(self):
if self.profile_mapping.dbt_profile_type:
return self.profile_mapping.dbt_profile_type
def get_profile_type(self) -> str:
if self.profile_mapping is not None and hasattr(self.profile_mapping, "dbt_profile_type"):
return str(self.profile_mapping.dbt_profile_type)

profile_path = self._get_profile_path()

Expand All @@ -298,7 +298,9 @@ def get_profile_type(self):

profile = profiles[self.profile_name]
target_type = profile["outputs"][self.target_name]["type"]
return target_type
return str(target_type)

return "undefined"

def _get_profile_path(self, use_mock_values: bool = False) -> Path:
"""
Expand Down
28 changes: 16 additions & 12 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from typing import Any

from airflow.io.path import ObjectStoragePath
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos import settings
from cosmos.exceptions import CosmosValueError
from cosmos.operators.local import (
DbtBuildLocalOperator,
Expand Down Expand Up @@ -43,22 +43,22 @@ class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator):
pass


class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator): #
def __init__(self, *args, full_refresh: bool = False, **kwargs):
class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator): # type: ignore
def __init__(self, *args, full_refresh: bool = False, **kwargs): # type: ignore
# dbt task param
self.profile_config = kwargs.get("profile_config")
self.project_dir = kwargs.get("project_dir")
self.file_path = kwargs.get("extra_context", {}).get("dbt_node_config", {}).get("file_path")
self.profile_type = self.profile_config.get_profile_type()
self.profile_type: str = self.profile_config.get_profile_type() # type: ignore
self.full_refresh = full_refresh

# airflow task param
self.async_op_args = kwargs.pop("async_op_args", {})
self.configuration = {}
self.configuration: dict[str, object] = {}
self.job_id = self.async_op_args.get("job_id", "")
self.impersonation_chain = self.async_op_args.get("impersonation_chain", "")
self.gcp_project = self.async_op_args.get("project_id", "astronomer-dag-authoring")
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore
self.dataset = self.async_op_args.get("dataset", "my_dataset")
self.location = self.async_op_args.get("location", "US")
self.async_op_args["deferrable"] = True
Expand All @@ -67,9 +67,13 @@ def __init__(self, *args, full_refresh: bool = False, **kwargs):
super().__init__(*args, configuration=self.configuration, task_id=kwargs.get("task_id"), **self.async_op_args)

if self.profile_type not in _SUPPORTED_DATABASES:
raise f"Async run are only supported: {_SUPPORTED_DATABASES}"
raise CosmosValueError(f"Async run are only supported: {_SUPPORTED_DATABASES}")

def get_remote_sql(self) -> str:
if not settings.AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.")
from airflow.io.path import ObjectStoragePath

def get_remote_sql(self):
if not self.file_path or not self.project_dir:
raise CosmosValueError("file_path and project_dir are required to be set on the task for async execution")
project_dir_parent = str(self.project_dir.parent)
Expand All @@ -78,10 +82,10 @@ def get_remote_sql(self):

print("remote_model_path: ", remote_model_path)
object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp:
return fp.read()
with object_storage_path.open() as fp: # type: ignore
return fp.read() # type: ignore

def drop_table_sql(self):
def drop_table_sql(self) -> None:
model_name = self.task_id.split(".")[0]
sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};"
hook = BigQueryHook(
Expand Down Expand Up @@ -113,7 +117,7 @@ def execute(self, context: Context) -> Any | None:
"useLegacySql": False,
}
}
super().execute(context)
return super().execute(context)


class DbtTestAirflowAsyncOperator(DbtTestLocalOperator):
Expand Down
32 changes: 8 additions & 24 deletions tests/operators/test_airflow_async.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
from airflow import __version__ as airflow_version
from packaging import version

from cosmos.operators.airflow_async import (
DbtBuildAirflowAsyncOperator,
DbtCompileAirflowAsyncOperator,
DbtDocsAirflowAsyncOperator,
DbtDocsAzureStorageAirflowAsyncOperator,
DbtDocsGCSAirflowAsyncOperator,
DbtDocsS3AirflowAsyncOperator,
DbtLSAirflowAsyncOperator,
DbtRunAirflowAsyncOperator,
DbtRunOperationAirflowAsyncOperator,
Expand All @@ -16,10 +16,6 @@
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtCompileLocalOperator,
DbtDocsAzureStorageLocalOperator,
DbtDocsGCSLocalOperator,
DbtDocsLocalOperator,
DbtDocsS3LocalOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
Expand Down Expand Up @@ -50,6 +46,10 @@ def test_dbt_source_airflow_async_operator_inheritance():
assert issubclass(DbtSourceAirflowAsyncOperator, DbtSourceLocalOperator)


@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.8"),
reason="Cosmos Async operators only work with Airflow 2.8 onwards.",
)
def test_dbt_run_airflow_async_operator_inheritance():
assert issubclass(DbtRunAirflowAsyncOperator, DbtRunLocalOperator)

Expand All @@ -62,21 +62,5 @@ def test_dbt_run_operation_airflow_async_operator_inheritance():
assert issubclass(DbtRunOperationAirflowAsyncOperator, DbtRunOperationLocalOperator)


def test_dbt_docs_airflow_async_operator_inheritance():
assert issubclass(DbtDocsAirflowAsyncOperator, DbtDocsLocalOperator)


def test_dbt_docs_s3_airflow_async_operator_inheritance():
assert issubclass(DbtDocsS3AirflowAsyncOperator, DbtDocsS3LocalOperator)


def test_dbt_docs_azure_storage_airflow_async_operator_inheritance():
assert issubclass(DbtDocsAzureStorageAirflowAsyncOperator, DbtDocsAzureStorageLocalOperator)


def test_dbt_docs_gcs_airflow_async_operator_inheritance():
assert issubclass(DbtDocsGCSAirflowAsyncOperator, DbtDocsGCSLocalOperator)


def test_dbt_compile_airflow_async_operator_inheritance():
assert issubclass(DbtCompileAirflowAsyncOperator, DbtCompileLocalOperator)
2 changes: 1 addition & 1 deletion tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ def test_dbt_compile_local_operator_initialisation():


@patch("cosmos.operators.local.remote_target_path", new="s3://some-bucket/target")
@patch("cosmos.operators.local.AIRFLOW_IO_AVAILABLE", new=False)
@patch("cosmos.settings.AIRFLOW_IO_AVAILABLE", new=False)
def test_configure_remote_target_path_object_storage_unavailable_on_earlier_airflow_versions():
operator = DbtCompileLocalOperator(
task_id="fake-task",
Expand Down

0 comments on commit 99bf7c0

Please sign in to comment.