diff --git a/cosmos/__init__.py b/cosmos/__init__.py index 2113b4522..82ed9ef3f 100644 --- a/cosmos/__init__.py +++ b/cosmos/__init__.py @@ -11,10 +11,6 @@ from cosmos.airflow.task_group import DbtTaskGroup from cosmos.constants import LoadMode, TestBehavior, ExecutionMode from cosmos.dataset import get_dbt_dataset - -# re-export the dag and task group -from cosmos.airflow.dag import DbtDag -from cosmos.airflow.task_group import DbtTaskGroup from cosmos.operators.lazy_load import MissingPackage from cosmos.operators.local import ( diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 5fb24b5fe..d6a67c663 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -51,7 +51,7 @@ def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[st return leaves -def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dict) -> TaskMetadata: +def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any]) -> TaskMetadata | None: """ Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node. @@ -88,7 +88,7 @@ def create_test_task_metadata( test_task_name: str, execution_mode: ExecutionMode, task_args: dict[str, Any], - on_warning_callback: callable, + on_warning_callback: Callable[..., Any] | None = None, model_name: str | None = None, ) -> TaskMetadata: """ diff --git a/cosmos/converter.py b/cosmos/converter.py index 7ea065780..c8efa3cd3 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -1,8 +1,11 @@ +# mypy: ignore-errors +# ignoring enum Mypy errors + from __future__ import annotations +from enum import Enum import inspect import logging -import sys from typing import Any, Callable from airflow.exceptions import AirflowException @@ -167,7 +170,7 @@ def __init__( models_dir=Path(dbt_models_dir) if dbt_models_dir else None, seeds_dir=Path(dbt_seeds_dir) if dbt_seeds_dir else None, snapshots_dir=Path(dbt_snapshots_dir) if dbt_snapshots_dir else None, - manifest_path=manifest_path, # type: ignore[arg-type] + manifest_path=manifest_path, ) dbt_graph = DbtGraph( diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 527723b3a..7ba4a52e3 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -4,6 +4,7 @@ import logging import os from dataclasses import dataclass, field +from pathlib import Path from subprocess import Popen, PIPE from typing import Any @@ -34,7 +35,7 @@ class DbtNode: name: str unique_id: str - resource_type: str + resource_type: DbtResourceType depends_on: list[str] file_path: Path tags: list[str] = field(default_factory=lambda: []) @@ -130,7 +131,12 @@ def load_via_dbt_ls(self) -> None: logger.info(f"Running command: {command}") try: process = Popen( - command, stdout=PIPE, stderr=PIPE, cwd=self.project.dir, universal_newlines=True, env=os.environ # type: ignore[arg-type] + command, # type: ignore[arg-type] + stdout=PIPE, + stderr=PIPE, + cwd=self.project.dir, + universal_newlines=True, + env=os.environ, ) except FileNotFoundError as exception: raise CosmosLoadDbtException(f"Unable to run the command due to the error:\n{exception}") diff --git a/cosmos/dbt/parser/project.py b/cosmos/dbt/parser/project.py index ec8966903..3eff9935c 100644 --- a/cosmos/dbt/parser/project.py +++ b/cosmos/dbt/parser/project.py @@ -103,17 +103,18 @@ def extract_python_file_upstream_requirements(code: str) -> list[str]: source_code = ast.parse(code) upstream_entities = [] - model_function = "" + model_function = None for node in source_code.body: if isinstance(node, ast.FunctionDef) and node.name == DBT_PY_MODEL_METHOD_NAME: model_function = node break - for item in ast.walk(model_function): - if isinstance(item, ast.Call) and item.func.attr == DBT_PY_DEP_METHOD_NAME: - upstream_entity_id = hasattr(item.args[-1], "value") and item.args[-1].value - if upstream_entity_id: - upstream_entities.append(upstream_entity_id) + if model_function: + for item in ast.walk(model_function): + if isinstance(item, ast.Call) and item.func.attr == DBT_PY_DEP_METHOD_NAME: # type: ignore[attr-defined] + upstream_entity_id = hasattr(item.args[-1], "value") and item.args[-1].value + if upstream_entity_id: + upstream_entities.append(upstream_entity_id) return upstream_entities @@ -153,7 +154,7 @@ def __post_init__(self) -> None: code = code.split("{%")[0] elif self.type == DbtModelType.DBT_SEED: - code = None + code = "" if self.path.suffix == PYTHON_FILE_SUFFIX: config.upstream_models = config.upstream_models.union(set(extract_python_file_upstream_requirements(code))) diff --git a/cosmos/hooks/subprocess.py b/cosmos/hooks/subprocess.py index f12bcef29..f0bf39ce3 100644 --- a/cosmos/hooks/subprocess.py +++ b/cosmos/hooks/subprocess.py @@ -7,13 +7,17 @@ import contextlib import os import signal -from collections import namedtuple +from typing import NamedTuple from subprocess import PIPE, STDOUT, Popen from tempfile import TemporaryDirectory, gettempdir from airflow.hooks.base import BaseHook -FullOutputSubprocessResult = namedtuple("FullOutputSubprocessResult", ["exit_code", "output", "full_output"]) + +class FullOutputSubprocessResult(NamedTuple): + exit_code: int + output: str + full_output: list[str] class FullOutputSubprocessHook(BaseHook): # type: ignore[misc] # ignores subclass MyPy error diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 1fbce1d4c..beb1858bd 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -127,7 +127,8 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se session.add(rtif) def run_subprocess(self, *args: Tuple[Any], **kwargs: Any) -> FullOutputSubprocessResult: - return self.subprocess_hook.run_command(*args, **kwargs) # type: ignore[no-any-return] + subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(*args, **kwargs) + return subprocess_result def get_profile_name(self, project_dir: str) -> str: """ @@ -244,7 +245,7 @@ def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None def execute(self, context: Context) -> str: # TODO is this going to put loads of unnecessary stuff in to xcom? - return self.build_and_run_cmd(context=context).output # type: ignore[no-any-return] + return self.build_and_run_cmd(context=context).output def on_kill(self) -> None: if self.cancel_query_on_kill: @@ -268,7 +269,7 @@ def __init__(self, **kwargs: Any) -> None: def execute(self, context: Context) -> str: result = self.build_and_run_cmd(context=context) - return result.output # type: ignore[no-any-return] + return result.output class DbtSeedLocalOperator(DbtLocalBaseOperator): @@ -295,7 +296,7 @@ def add_cmd_flags(self) -> list[str]: def execute(self, context: Context) -> str: cmd_flags = self.add_cmd_flags() result = self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) - return result.output # type: ignore[no-any-return] + return result.output class DbtSnapshotLocalOperator(DbtLocalBaseOperator): @@ -312,7 +313,7 @@ def __init__(self, **kwargs: Any) -> None: def execute(self, context: Context) -> str: result = self.build_and_run_cmd(context=context) - return result.output # type: ignore[no-any-return] + return result.output class DbtRunLocalOperator(DbtLocalBaseOperator): @@ -329,7 +330,7 @@ def __init__(self, **kwargs: Any) -> None: def execute(self, context: Context) -> str: result = self.build_and_run_cmd(context=context) - return result.output # type: ignore[no-any-return] + return result.output class DbtTestLocalOperator(DbtLocalBaseOperator): @@ -385,13 +386,13 @@ def execute(self, context: Context) -> str: result = self.build_and_run_cmd(context=context) if not self._should_run_tests(result): - return result.output # type: ignore[no-any-return] + return result.output warnings = parse_output(result, "WARN") if warnings > 0: self._handle_warnings(result, context) - return result.output # type: ignore[no-any-return] + return result.output class DbtRunOperationLocalOperator(DbtLocalBaseOperator): @@ -422,7 +423,7 @@ def add_cmd_flags(self) -> list[str]: def execute(self, context: Context) -> str: cmd_flags = self.add_cmd_flags() result = self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) - return result.output # type: ignore[no-any-return] + return result.output class DbtDocsLocalOperator(DbtLocalBaseOperator): @@ -441,7 +442,7 @@ def __init__(self, **kwargs: Any) -> None: def execute(self, context: Context) -> str: result = self.build_and_run_cmd(context=context) - return result.output # type: ignore[no-any-return] + return result.output class DbtDocsS3LocalOperator(DbtDocsLocalOperator): diff --git a/cosmos/operators/virtualenv.py b/cosmos/operators/virtualenv.py index 22507c472..2566097d7 100644 --- a/cosmos/operators/virtualenv.py +++ b/cosmos/operators/virtualenv.py @@ -91,11 +91,8 @@ def run_subprocess( # type: ignore[override] if self.py_requirements: command[0] = self.venv_dbt_path - return self.subprocess_hook.run_command( # type: ignore[no-any-return] - command, - *args, - **kwargs, - ) + subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(command, *args, **kwargs) + return subprocess_result def execute(self, context: Context) -> str: output = super().execute(context) diff --git a/dev/dags/dbt/jaffle_shop_python/.user.yml b/dev/dags/dbt/jaffle_shop_python/.user.yml new file mode 100644 index 000000000..99944580e --- /dev/null +++ b/dev/dags/dbt/jaffle_shop_python/.user.yml @@ -0,0 +1 @@ +id: 215ac1e7-601b-45dd-9867-587a3e282bce diff --git a/tests/utils.py b/tests/utils.py index 99804b540..37f7a3223 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -57,7 +57,7 @@ def test_dag( dag.clear( start_date=execution_date, end_date=execution_date, - dag_run_state=False, # type: ignore + dag_run_state=False, session=session, ) dag.log.debug("Getting dagrun for dag %s", dag.dag_id) @@ -164,7 +164,7 @@ def _get_or_create_dagrun( run_id=run_id, start_date=start_date or execution_date, session=session, - conf=conf, # type: ignore + conf=conf, ) log.info("created dagrun %s", str(dr)) return dr