diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 573b61c48..83cd7fb5f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -58,12 +58,13 @@ repos: - id: blacken-docs alias: black additional_dependencies: [black>=22.10.0] - #- repo: https://github.com/pre-commit/mirrors-mypy - # rev: 'v1.3.0' - # hooks: - # - id: mypy - # name: mypy-python-sdk - # additional_dependencies: [types-PyYAML, types-attrs, attrs, types-requests, types-python-dateutil] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v1.3.0' + hooks: + - id: mypy + name: mypy-python-sdk + additional_dependencies: [types-PyYAML, types-attrs, attrs, types-requests, types-python-dateutil] + files: ^cosmos ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks diff --git a/cosmos/__init__.py b/cosmos/__init__.py index a4508fc2a..82ed9ef3f 100644 --- a/cosmos/__init__.py +++ b/cosmos/__init__.py @@ -1,16 +1,17 @@ +# type: ignore # ignores "Cannot assign to a type" MyPy error + """ Astronomer Cosmos is a library for rendering dbt workflows in Airflow. Contains dags, task groups, and operators. """ - __version__ = "0.7.5" from cosmos.airflow.dag import DbtDag from cosmos.airflow.task_group import DbtTaskGroup from cosmos.constants import LoadMode, TestBehavior, ExecutionMode from cosmos.dataset import get_dbt_dataset - +from cosmos.operators.lazy_load import MissingPackage from cosmos.operators.local import ( DbtDepsLocalOperator, @@ -32,8 +33,6 @@ DbtTestDockerOperator, ) except ImportError: - from cosmos.operators.lazy_load import MissingPackage - DbtLSDockerOperator = MissingPackage("cosmos.operators.docker.DbtLSDockerOperator", "docker") DbtRunDockerOperator = MissingPackage("cosmos.operators.docker.DbtRunDockerOperator", "docker") DbtRunOperationDockerOperator = MissingPackage( @@ -54,8 +53,6 @@ DbtTestKubernetesOperator, ) except ImportError: - from cosmos.operators.lazy_load import MissingPackage - DbtLSKubernetesOperator = MissingPackage( "cosmos.operators.kubernetes.DbtLSKubernetesOperator", "kubernetes", diff --git a/cosmos/airflow/dag.py b/cosmos/airflow/dag.py index 8ab955a1f..948a3558b 100644 --- a/cosmos/airflow/dag.py +++ b/cosmos/airflow/dag.py @@ -10,7 +10,7 @@ from cosmos.converter import airflow_kwargs, specific_kwargs, DbtToAirflowConverter -class DbtDag(DAG, DbtToAirflowConverter): +class DbtDag(DAG, DbtToAirflowConverter): # type: ignore[misc] # ignores subclass MyPy error """ Render a dbt project as an Airflow DAG. """ @@ -21,4 +21,5 @@ def __init__( **kwargs: Any, ) -> None: DAG.__init__(self, *args, **airflow_kwargs(**kwargs)) - DbtToAirflowConverter.__init__(self, *args, dag=self, **specific_kwargs(**kwargs)) + kwargs["dag"] = self + DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs)) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index d9942c1e9..d6a67c663 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Callable +from typing import Any, Callable from airflow.models.dag import DAG from airflow.utils.task_group import TaskGroup @@ -42,7 +42,8 @@ def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[st parents = [] leaves = [] materialized_nodes = [node for node in nodes.values() if node.unique_id in tasks_ids] - [parents.extend(node.depends_on) for node in materialized_nodes] + for node in materialized_nodes: + parents.extend(node.depends_on) parents_ids = set(parents) for node in materialized_nodes: if node.unique_id not in parents_ids: @@ -50,7 +51,7 @@ def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[st return leaves -def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dict) -> TaskMetadata: +def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any]) -> TaskMetadata | None: """ Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node. @@ -80,13 +81,14 @@ def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dic return task_metadata else: logger.error(f"Unsupported resource type {node.resource_type} (node {node.unique_id}).") + return None def create_test_task_metadata( test_task_name: str, execution_mode: ExecutionMode, - task_args: dict, - on_warning_callback: callable, + task_args: dict[str, Any], + on_warning_callback: Callable[..., Any] | None = None, model_name: str | None = None, ) -> TaskMetadata: """ @@ -118,12 +120,12 @@ def build_airflow_graph( nodes: dict[str, DbtNode], dag: DAG, # Airflow-specific - parent DAG where to associate tasks and (optional) task groups execution_mode: ExecutionMode, # Cosmos-specific - decide what which class to use - task_args: dict[str, str], # Cosmos/DBT - used to instantiate tasks + task_args: dict[str, Any], # Cosmos/DBT - used to instantiate tasks test_behavior: TestBehavior, # Cosmos-specific: how to inject tests to Airflow DAG dbt_project_name: str, # DBT / Cosmos - used to name test task if mode is after_all, conn_id: str, # Cosmos, dataset URI task_group: TaskGroup | None = None, - on_warning_callback: Callable | None = None, # argument specific to the DBT test command + on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command emit_datasets: bool = True, # Cosmos ) -> None: """ @@ -191,7 +193,7 @@ def build_airflow_graph( f"{dbt_project_name}_test", execution_mode, task_args=task_args, on_warning_callback=on_warning_callback ) test_task = create_airflow_task(test_meta, dag, task_group=task_group) - leaves_ids = calculate_leaves(tasks_ids=tasks_map.keys(), nodes=nodes) + leaves_ids = calculate_leaves(tasks_ids=list(tasks_map.keys()), nodes=nodes) for leaf_node_id in leaves_ids: tasks_map[leaf_node_id] >> test_task diff --git a/cosmos/airflow/task_group.py b/cosmos/airflow/task_group.py index 67746a9bc..0a3234744 100644 --- a/cosmos/airflow/task_group.py +++ b/cosmos/airflow/task_group.py @@ -9,7 +9,7 @@ from cosmos.converter import airflow_kwargs, specific_kwargs, DbtToAirflowConverter -class DbtTaskGroup(TaskGroup, DbtToAirflowConverter): +class DbtTaskGroup(TaskGroup, DbtToAirflowConverter): # type: ignore[misc] # ignores subclass MyPy error """ Render a dbt project as an Airflow Task Group. """ @@ -21,4 +21,5 @@ def __init__( ) -> None: group_id = kwargs.get("group_id", kwargs.get("dbt_project_name", "dbt_task_group")) TaskGroup.__init__(self, group_id, *args, **airflow_kwargs(**kwargs)) - DbtToAirflowConverter.__init__(self, *args, task_group=self, **specific_kwargs(**kwargs)) + kwargs["task_group"] = self + DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs)) diff --git a/cosmos/converter.py b/cosmos/converter.py index d6a0a9a34..c8efa3cd3 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -1,14 +1,17 @@ +# mypy: ignore-errors +# ignoring enum Mypy errors + from __future__ import annotations +from enum import Enum import inspect import logging -import pathlib -from enum import Enum -from typing import Any, Callable, Optional +from typing import Any, Callable from airflow.exceptions import AirflowException from airflow.models.dag import DAG from airflow.utils.task_group import TaskGroup +from pathlib import Path from cosmos.airflow.graph import build_airflow_graph from cosmos.constants import ExecutionMode, LoadMode, TestBehavior @@ -142,8 +145,8 @@ def __init__( exclude: list[str] | None = None, execution_mode: str | ExecutionMode = ExecutionMode.LOCAL, load_mode: str | LoadMode = LoadMode.AUTOMATIC, - manifest_path: str | pathlib.Path | None = None, - on_warning_callback: Optional[Callable] = None, + manifest_path: str | Path | None = None, + on_warning_callback: Callable[..., Any] | None = None, *args: Any, **kwargs: Any, ) -> None: @@ -154,12 +157,19 @@ def __init__( execution_mode = convert_value_to_enum(execution_mode, ExecutionMode) load_mode = convert_value_to_enum(load_mode, LoadMode) + test_behavior = convert_value_to_enum(test_behavior, TestBehavior) + execution_mode = convert_value_to_enum(execution_mode, ExecutionMode) + load_mode = convert_value_to_enum(load_mode, LoadMode) + + if type(manifest_path) == str: + manifest_path = Path(manifest_path) + dbt_project = DbtProject( name=dbt_project_name, - root_dir=dbt_root_path, - models_dir=dbt_models_dir, - seeds_dir=dbt_seeds_dir, - snapshots_dir=dbt_snapshots_dir, + root_dir=Path(dbt_root_path), + models_dir=Path(dbt_models_dir) if dbt_models_dir else None, + seeds_dir=Path(dbt_seeds_dir) if dbt_seeds_dir else None, + snapshots_dir=Path(dbt_snapshots_dir) if dbt_snapshots_dir else None, manifest_path=manifest_path, ) diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index f88139ff0..4eaeb2df5 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -1,6 +1,5 @@ import importlib import logging -from typing import Optional from airflow.models import BaseOperator from airflow.models.dag import DAG @@ -11,7 +10,7 @@ logger = logging.getLogger(__name__) -def get_airflow_task(task: Task, dag: DAG, task_group: Optional[TaskGroup] = None) -> BaseOperator: +def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None) -> BaseOperator: """ Get the Airflow Operator class for a Task. diff --git a/cosmos/dataset.py b/cosmos/dataset.py index 0e68be7bd..5927de319 100644 --- a/cosmos/dataset.py +++ b/cosmos/dataset.py @@ -1,22 +1,25 @@ +from typing import Any, Tuple + + try: from airflow.datasets import Dataset -except ImportError: +except (ImportError, ModuleNotFoundError): from logging import getLogger logger = getLogger(__name__) - class Dataset: + class Dataset: # type: ignore[no-redef] cosmos_override = True - def __init__(self, id: str, *args, **kwargs): + def __init__(self, id: str, *args: Tuple[Any], **kwargs: str): self.id = id logger.warning("Datasets are not supported in Airflow < 2.5.0") - def __eq__(self, other) -> bool: - return self.id == other.id + def __eq__(self, other: "Dataset") -> bool: + return bool(self.id == other.id) -def get_dbt_dataset(connection_id: str, project_name: str, model_name: str): +def get_dbt_dataset(connection_id: str, project_name: str, model_name: str) -> Dataset: return Dataset(f"DBT://{connection_id.upper()}/{project_name.upper()}/{model_name.upper()}") diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 29af4f731..7ba4a52e3 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -4,6 +4,7 @@ import logging import os from dataclasses import dataclass, field +from pathlib import Path from subprocess import Popen, PIPE from typing import Any @@ -34,9 +35,9 @@ class DbtNode: name: str unique_id: str - resource_type: str + resource_type: DbtResourceType depends_on: list[str] - file_path: str + file_path: Path tags: list[str] = field(default_factory=lambda: []) config: dict[str, Any] = field(default_factory=lambda: {}) @@ -66,7 +67,7 @@ def __init__( self, project: DbtProject, exclude: list[str] | None = None, - select: list[str] = None, + select: list[str] | None = None, dbt_cmd: str = get_system_dbt(), ): self.project = project @@ -112,7 +113,7 @@ def load(self, method: LoadMode = LoadMode.AUTOMATIC, execution_mode: ExecutionM load_method[method]() - def load_via_dbt_ls(self): + def load_via_dbt_ls(self) -> None: """ This is the most accurate way of loading `dbt` projects and filtering them out, since it uses the `dbt` command line for both parsing and filtering the nodes. @@ -130,7 +131,12 @@ def load_via_dbt_ls(self): logger.info(f"Running command: {command}") try: process = Popen( - command, stdout=PIPE, stderr=PIPE, cwd=self.project.dir, universal_newlines=True, env=os.environ + command, # type: ignore[arg-type] + stdout=PIPE, + stderr=PIPE, + cwd=self.project.dir, + universal_newlines=True, + env=os.environ, ) except FileNotFoundError as exception: raise CosmosLoadDbtException(f"Unable to run the command due to the error:\n{exception}") @@ -164,7 +170,7 @@ def load_via_dbt_ls(self): self.nodes = nodes self.filtered_nodes = nodes - def load_via_custom_parser(self): + def load_via_custom_parser(self) -> None: """ This is the least accurate way of loading `dbt` projects and filtering them out, since it uses custom Cosmos logic, which is usually a subset of what is available in `dbt`. @@ -177,11 +183,11 @@ def load_via_custom_parser(self): * self.filtered_nodes """ logger.info("Trying to parse the dbt project using a custom Cosmos method...") + project = LegacyDbtProject( - dbt_root_path=self.project.root_dir, - dbt_models_dir=self.project.models_dir.stem, - dbt_snapshots_dir=self.project.snapshots_dir.stem, - dbt_seeds_dir=self.project.seeds_dir.stem, + dbt_root_path=str(self.project.root_dir), + dbt_models_dir=self.project.models_dir.stem if self.project.models_dir else None, + dbt_seeds_dir=self.project.seeds_dir.stem if self.project.seeds_dir else None, project_name=self.project.name, ) nodes = {} @@ -192,7 +198,7 @@ def load_via_custom_parser(self): name=model_name, unique_id=model_name, resource_type=DbtResourceType(model.type.value), - depends_on=model.config.upstream_models, + depends_on=list(model.config.upstream_models), file_path=model.path, tags=[], config=config, @@ -204,7 +210,7 @@ def load_via_custom_parser(self): project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude ) - def load_from_dbt_manifest(self): + def load_from_dbt_manifest(self) -> None: """ This approach accurately loads `dbt` projects using the `manifest.yml` file. @@ -217,7 +223,7 @@ def load_from_dbt_manifest(self): """ logger.info("Trying to parse the dbt project using a dbt manifest...") nodes = {} - with open(self.project.manifest_path) as fp: + with open(self.project.manifest_path) as fp: # type: ignore[arg-type] manifest = json.load(fp) for unique_id, node_dict in manifest.get("nodes", {}).items(): diff --git a/cosmos/dbt/parser/project.py b/cosmos/dbt/parser/project.py index a390308e2..3eff9935c 100644 --- a/cosmos/dbt/parser/project.py +++ b/cosmos/dbt/parser/project.py @@ -10,10 +10,10 @@ from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import ClassVar, Dict, List, Set +from typing import Any, ClassVar, Dict, List, Set import jinja2 -import yaml # type: ignore +import yaml logger = logging.getLogger(__name__) @@ -63,7 +63,7 @@ def _config_selector_ooo( self, sql_configs: Set[str], properties_configs: Set[str], - prefixes: List[str] = None, + prefixes: List[str] | None = None, ) -> Set[str]: """ this will force values from the sql files to override whatever is in the properties.yml. So ooo: @@ -103,17 +103,18 @@ def extract_python_file_upstream_requirements(code: str) -> list[str]: source_code = ast.parse(code) upstream_entities = [] - model_function = "" + model_function = None for node in source_code.body: if isinstance(node, ast.FunctionDef) and node.name == DBT_PY_MODEL_METHOD_NAME: model_function = node break - for item in ast.walk(model_function): - if isinstance(item, ast.Call) and item.func.attr == DBT_PY_DEP_METHOD_NAME: - upstream_entity_id = hasattr(item.args[-1], "value") and item.args[-1].value - if upstream_entity_id: - upstream_entities.append(upstream_entity_id) + if model_function: + for item in ast.walk(model_function): + if isinstance(item, ast.Call) and item.func.attr == DBT_PY_DEP_METHOD_NAME: # type: ignore[attr-defined] + upstream_entity_id = hasattr(item.args[-1], "value") and item.args[-1].value + if upstream_entity_id: + upstream_entities.append(upstream_entity_id) return upstream_entities @@ -153,7 +154,7 @@ def __post_init__(self) -> None: code = code.split("{%")[0] elif self.type == DbtModelType.DBT_SEED: - code = None + code = "" if self.path.suffix == PYTHON_FILE_SUFFIX: config.upstream_models = config.upstream_models.union(set(extract_python_file_upstream_requirements(code))) @@ -186,7 +187,7 @@ def __post_init__(self) -> None: self.config = config # TODO following needs coverage: - def _extract_config(self, kwarg, config_name: str): + def _extract_config(self, kwarg: Any, config_name: str) -> Any: if hasattr(kwarg, "key") and kwarg.key == config_name: try: # try to convert it to a constant and get the value @@ -221,10 +222,10 @@ class DbtProject: project_name: str # optional, user-specified instance variables - dbt_root_path: str = "/usr/local/airflow/dags/dbt" - dbt_models_dir: str = "models" - dbt_snapshots_dir: str = "snapshots" - dbt_seeds_dir: str = "seeds" + dbt_root_path: str | None = None + dbt_models_dir: str | None = None + dbt_snapshots_dir: str | None = None + dbt_seeds_dir: str | None = None # private instance variables for managing state models: Dict[str, DbtModel] = field(default_factory=dict) @@ -239,6 +240,15 @@ def __post_init__(self) -> None: """ Initializes the parser. """ + if self.dbt_root_path is None: + self.dbt_root_path = "/usr/local/airflow/dags/dbt" + if self.dbt_models_dir is None: + self.dbt_models_dir = "models" + if self.dbt_snapshots_dir is None: + self.dbt_snapshots_dir = "snapshots" + if self.dbt_seeds_dir is None: + self.dbt_seeds_dir = "seeds" + # set the project and model dirs self.project_dir = Path(os.path.join(self.dbt_root_path, self.project_name)) self.models_dir = self.project_dir / self.dbt_models_dir @@ -333,7 +343,9 @@ def _handle_config_file(self, path: Path) -> None: if isinstance(config_value, str): config_selectors.append(f"{selector}:{config_value}") else: - [config_selectors.append(f"{selector}:{item}") for item in config_value if item] + for item in config_value: + if item: + config_selectors.append(f"{selector}:{item}") # dbt default ensures "materialized:view" is set for all models if nothing is specified so that it will # work in a select/exclude list diff --git a/cosmos/dbt/project.py b/cosmos/dbt/project.py index df3a8bcaa..fe60f5751 100644 --- a/cosmos/dbt/project.py +++ b/cosmos/dbt/project.py @@ -17,7 +17,7 @@ class DbtProject: profile_path: Path | None = None _cosmos_created_profile_file: bool = False - def __post_init__(self): + def __post_init__(self) -> None: if self.models_dir is None: self.models_dir = self.dir / "models" if self.seeds_dir is None: @@ -38,10 +38,10 @@ def is_manifest_available(self) -> bool: """ Check if the `dbt` project manifest is set and if the file exists. """ - return self.manifest_path and Path(self.manifest_path).exists() + return self.manifest_path is not None and Path(self.manifest_path).exists() def is_profile_yml_available(self) -> bool: """ Check if the `dbt` profiles.yml file exists. """ - return Path(self.profile_path).exists() + return Path(self.profile_path).exists() if self.profile_path else False diff --git a/cosmos/dbt/selector.py b/cosmos/dbt/selector.py index e0098acea..ec1cd6ae5 100644 --- a/cosmos/dbt/selector.py +++ b/cosmos/dbt/selector.py @@ -2,6 +2,11 @@ import logging from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from cosmos.dbt.graph import DbtNode + SUPPORTED_CONFIG = ["materialized", "schema", "tags"] PATH_SELECTOR = "path:" @@ -30,13 +35,13 @@ def __init__(self, project_dir: Path, statement: str): https://docs.getdbt.com/reference/node-selection/yaml-selectors """ self.project_dir = project_dir - self.paths: list[str] = [] + self.paths: list[Path] = [] self.tags: list[str] = [] self.config: dict[str, str] = {} self.other: list[str] = [] self.load_from_statement(statement) - def load_from_statement(self, statement: str): + def load_from_statement(self, statement: str) -> None: """ Load in-place select parameters. Raises an exception if they are not yet implemented in Cosmos. @@ -65,7 +70,7 @@ def load_from_statement(self, statement: str): logger.warning("Unsupported select statement: %s", item) -def select_nodes_ids_by_intersection(nodes: dict, config: SelectorConfig) -> list[str]: +def select_nodes_ids_by_intersection(nodes: dict[str, DbtNode], config: SelectorConfig) -> set[str]: """ Return a list of node ids which matches the configuration defined in config. @@ -93,7 +98,7 @@ def select_nodes_ids_by_intersection(nodes: dict, config: SelectorConfig) -> lis return selected_nodes -def retrieve_by_label(statement_list: list[str], label: str) -> set: +def retrieve_by_label(statement_list: list[str], label: str) -> set[str]: """ Return a set of values associated with a label. @@ -102,17 +107,18 @@ def retrieve_by_label(statement_list: list[str], label: str) -> set: >>> values {"a", "b"} """ - label_values = set() + label_values: set[str] = set() for statement in statement_list: config = SelectorConfig(Path(), statement) item_values = getattr(config, label) label_values = label_values.union(item_values) + return label_values def select_nodes( - project_dir: Path, nodes: dict[str, str], select: list[str] | None = None, exclude: list[str] | None = None -) -> dict[str, str]: + project_dir: Path, nodes: dict[str, DbtNode], select: list[str] | None = None, exclude: list[str] | None = None +) -> dict[str, DbtNode]: """ Given a group of nodes within a project, apply select and exclude filters using dbt node selection. @@ -126,7 +132,7 @@ def select_nodes( if not select and not exclude: return nodes - subset_ids = set() + subset_ids: set[str] = set() for statement in select: config = SelectorConfig(project_dir, statement) diff --git a/cosmos/hooks/subprocess.py b/cosmos/hooks/subprocess.py index 027b46e85..f0bf39ce3 100644 --- a/cosmos/hooks/subprocess.py +++ b/cosmos/hooks/subprocess.py @@ -7,16 +7,20 @@ import contextlib import os import signal -from collections import namedtuple +from typing import NamedTuple from subprocess import PIPE, STDOUT, Popen from tempfile import TemporaryDirectory, gettempdir from airflow.hooks.base import BaseHook -FullOutputSubprocessResult = namedtuple("FullOutputSubprocessResult", ["exit_code", "output", "full_output"]) +class FullOutputSubprocessResult(NamedTuple): + exit_code: int + output: str + full_output: list[str] -class FullOutputSubprocessHook(BaseHook): + +class FullOutputSubprocessHook(BaseHook): # type: ignore[misc] # ignores subclass MyPy error """Hook for running processes with the ``subprocess`` module.""" def __init__(self) -> None: @@ -56,7 +60,7 @@ def run_command( if cwd is None: cwd = stack.enter_context(TemporaryDirectory(prefix="airflowtmp")) - def pre_exec(): + def pre_exec() -> None: # Restore default signal disposition and invoke setsid for sig in ("SIGPIPE", "SIGXFZ", "SIGXFSZ"): if hasattr(signal, sig): @@ -93,7 +97,7 @@ def pre_exec(): return FullOutputSubprocessResult(exit_code=return_code, output=line, full_output=log_lines) - def send_sigterm(self): + def send_sigterm(self) -> None: """Sends SIGTERM signal to ``self.sub_process`` if one exists.""" self.log.info("Sending SIGTERM signal to process group") if self.sub_process and hasattr(self.sub_process, "pid"): diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 777e727d6..bc8444798 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) -class DbtBaseOperator(BaseOperator): +class DbtBaseOperator(BaseOperator): # type: ignore[misc] # ignores subclass MyPy error """ Executes a dbt core cli command. @@ -78,27 +78,27 @@ def __init__( self, project_dir: str, conn_id: str, - base_cmd: str | list[str] = None, - select: str = None, - exclude: str = None, - selector: str = None, - vars: dict = None, - models: str = None, + base_cmd: list[str] | None = None, + select: str | None = None, + exclude: str | None = None, + selector: str | None = None, + vars: dict[str, str] | None = None, + models: str | None = None, cache_selected_only: bool = False, no_version_check: bool = False, fail_fast: bool = False, quiet: bool = False, warn_error: bool = False, - db_name: str = None, - schema: str = None, - env: dict = None, + db_name: str | None = None, + schema: str | None = None, + env: dict[str, Any] | None = None, append_env: bool = False, output_encoding: str = "utf-8", skip_exit_code: int = 99, cancel_query_on_kill: bool = True, dbt_executable_path: str = "dbt", - dbt_cmd_flags: list[str] = None, - **kwargs, + dbt_cmd_flags: list[str] | None = None, + **kwargs: str, ) -> None: self.project_dir = project_dir self.conn_id = conn_id @@ -132,7 +132,7 @@ def __init__( self.dbt_cmd_flags = dbt_cmd_flags super().__init__(**kwargs) - def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike]: + def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]: """ Builds the set of environment variables to be exposed for the bash command. @@ -159,7 +159,7 @@ def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike]: # filter out invalid types and give a warning when a value is removed accepted_types = (str, bytes, os.PathLike) - filtered_env: dict[str, str | bytes | os.PathLike] = {} + filtered_env: dict[str, str | bytes | os.PathLike[Any]] = {} for key, val in env.items(): if isinstance(key, accepted_types) and isinstance(val, accepted_types): @@ -205,12 +205,10 @@ def build_cmd( self, context: Context, cmd_flags: list[str] | None = None, - ) -> Tuple[list[str], dict]: + ) -> Tuple[list[str | None], dict[str, str | bytes | os.PathLike[Any]]]: dbt_cmd = [self.dbt_executable_path] - if isinstance(self.base_cmd, str): - dbt_cmd.append(self.base_cmd) - else: + if self.base_cmd: dbt_cmd.extend(self.base_cmd) dbt_cmd.extend(self.add_global_flags()) diff --git a/cosmos/operators/docker.py b/cosmos/operators/docker.py index 5b8844f1c..a5839c715 100644 --- a/cosmos/operators/docker.py +++ b/cosmos/operators/docker.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Callable, Optional, Sequence +from typing import Any, Callable, Sequence import yaml from airflow.utils.context import Context @@ -20,7 +20,7 @@ ) -class DbtDockerBaseOperator(DockerOperator, DbtBaseOperator): +class DbtDockerBaseOperator(DockerOperator, DbtBaseOperator): # type: ignore[misc] # ignores subclass MyPy error """ Executes a dbt core cli command in a Docker container. @@ -33,23 +33,23 @@ class DbtDockerBaseOperator(DockerOperator, DbtBaseOperator): def __init__( self, image: str, # Make image a required argument since it's required by DockerOperator - **kwargs, + **kwargs: Any, ) -> None: super().__init__(image=image, **kwargs) - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None): + def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any: self.build_command(cmd_flags, context) - self.log.info(f"Running command: {self.command}") + self.log.info(f"Running command: {self.command}") # type: ignore[has-type] return super().execute(context) - def build_command(self, cmd_flags, context): + def build_command(self, context: Context, cmd_flags: list[str] | None = None) -> None: # For the first round, we're going to assume that the command is dbt # This means that we don't have openlineage support, but we will create a ticket # to add that in the future self.dbt_executable_path = "dbt" dbt_cmd, env_vars = self.build_cmd(context=context, cmd_flags=cmd_flags) # set env vars - self.environment = {**env_vars, **self.environment} + self.environment = {**env_vars, **self.environment} # type: ignore[has-type] self.command = dbt_cmd @@ -60,11 +60,11 @@ class DbtLSDockerOperator(DbtDockerBaseOperator): ui_color = "#DBCDF6" - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: str) -> None: super().__init__(**kwargs) - self.base_cmd = "ls" + self.base_cmd = ["ls"] - def execute(self, context: Context): + def execute(self, context: Context) -> Any: return self.build_and_run_cmd(context=context) @@ -77,19 +77,19 @@ class DbtSeedDockerOperator(DbtDockerBaseOperator): ui_color = "#F58D7E" - def __init__(self, full_refresh: bool = False, **kwargs) -> None: + def __init__(self, full_refresh: bool = False, **kwargs: str) -> None: self.full_refresh = full_refresh super().__init__(**kwargs) - self.base_cmd = "seed" + self.base_cmd = ["seed"] - def add_cmd_flags(self): + def add_cmd_flags(self) -> list[str]: flags = [] if self.full_refresh is True: flags.append("--full-refresh") return flags - def execute(self, context: Context): + def execute(self, context: Context) -> Any: cmd_flags = self.add_cmd_flags() return self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) @@ -102,11 +102,11 @@ class DbtSnapshotDockerOperator(DbtDockerBaseOperator): ui_color = "#964B00" - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: str) -> None: super().__init__(**kwargs) - self.base_cmd = "snapshot" + self.base_cmd = ["snapshot"] - def execute(self, context: Context): + def execute(self, context: Context) -> Any: return self.build_and_run_cmd(context=context) @@ -118,11 +118,11 @@ class DbtRunDockerOperator(DbtDockerBaseOperator): ui_color = "#7352BA" ui_fgcolor = "#F4F2FC" - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: str) -> None: super().__init__(**kwargs) - self.base_cmd = "run" + self.base_cmd = ["run"] - def execute(self, context: Context): + def execute(self, context: Context) -> Any: return self.build_and_run_cmd(context=context) @@ -133,13 +133,13 @@ class DbtTestDockerOperator(DbtDockerBaseOperator): ui_color = "#8194E0" - def __init__(self, on_warning_callback: Optional[Callable] = None, **kwargs) -> None: + def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: str) -> None: super().__init__(**kwargs) - self.base_cmd = "test" + self.base_cmd = ["test"] # as of now, on_warning_callback in docker executor does nothing self.on_warning_callback = on_warning_callback - def execute(self, context: Context): + def execute(self, context: Context) -> Any: return self.build_and_run_cmd(context=context) @@ -153,21 +153,21 @@ class DbtRunOperationDockerOperator(DbtDockerBaseOperator): """ ui_color = "#8194E0" - template_fields: Sequence[str] = "args" + template_fields: Sequence[str] = ("args",) - def __init__(self, macro_name: str, args: dict = None, **kwargs) -> None: + def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: str) -> None: self.macro_name = macro_name self.args = args super().__init__(**kwargs) self.base_cmd = ["run-operation", macro_name] - def add_cmd_flags(self): + def add_cmd_flags(self) -> list[str]: flags = [] if self.args is not None: flags.append("--args") flags.append(yaml.dump(self.args)) return flags - def execute(self, context: Context): + def execute(self, context: Context) -> Any: cmd_flags = self.add_cmd_flags() return self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) diff --git a/cosmos/operators/kubernetes.py b/cosmos/operators/kubernetes.py index a85559780..995aa70a7 100644 --- a/cosmos/operators/kubernetes.py +++ b/cosmos/operators/kubernetes.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging -from typing import Callable, Optional, Sequence +from os import PathLike +from typing import Any, Callable, Sequence import yaml from airflow.utils.context import Context @@ -16,7 +17,6 @@ convert_env_vars, ) from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator - from kubernetes.client import models as k8s except ImportError: raise ImportError( "Could not import KubernetesPodOperator. Ensure you've installed the Kubernetes provider " @@ -24,7 +24,7 @@ ) -class DbtKubernetesBaseOperator(KubernetesPodOperator, DbtBaseOperator): +class DbtKubernetesBaseOperator(KubernetesPodOperator, DbtBaseOperator): # type: ignore[misc] """ Executes a dbt core cli command in a Kubernetes Pod. @@ -34,25 +34,23 @@ class DbtKubernetesBaseOperator(KubernetesPodOperator, DbtBaseOperator): intercept_flag = False - def __init__( - self, - **kwargs, - ) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def build_env_args(self, env: dict) -> list[k8s.V1EnvVar]: - env_vars_dict = {} - for env_var in self.env_vars: + def build_env_args(self, env: dict[str, str | bytes | PathLike[Any]]) -> None: + env_vars_dict = dict() + + for env_var in self.env_vars: # type: ignore[has-type] env_vars_dict[env_var.name] = env_var.value self.env_vars = convert_env_vars({**env, **env_vars_dict}) - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None): + def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any: self.build_kube_args(cmd_flags, context) self.log.info(f"Running command: {self.arguments}") return super().execute(context) - def build_kube_args(self, cmd_flags, context): + def build_kube_args(self, context: Context, cmd_flags: list[str] | None = None) -> None: # For the first round, we're going to assume that the command is dbt # This means that we don't have openlineage support, but we will create a ticket # to add that in the future @@ -70,11 +68,11 @@ class DbtLSKubernetesOperator(DbtKubernetesBaseOperator): ui_color = "#DBCDF6" - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: str) -> None: super().__init__(**kwargs) - self.base_cmd = "ls" + self.base_cmd = ["ls"] - def execute(self, context: Context): + def execute(self, context: Context) -> Any: return self.build_and_run_cmd(context=context) @@ -87,19 +85,19 @@ class DbtSeedKubernetesOperator(DbtKubernetesBaseOperator): ui_color = "#F58D7E" - def __init__(self, full_refresh: bool = False, **kwargs) -> None: + def __init__(self, full_refresh: bool = False, **kwargs: str) -> None: self.full_refresh = full_refresh super().__init__(**kwargs) - self.base_cmd = "seed" + self.base_cmd = ["seed"] - def add_cmd_flags(self): + def add_cmd_flags(self) -> list[str]: flags = [] if self.full_refresh is True: flags.append("--full-refresh") return flags - def execute(self, context: Context): + def execute(self, context: Context) -> Any: cmd_flags = self.add_cmd_flags() return self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) @@ -112,11 +110,11 @@ class DbtSnapshotKubernetesOperator(DbtKubernetesBaseOperator): ui_color = "#964B00" - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: str) -> None: super().__init__(**kwargs) - self.base_cmd = "snapshot" + self.base_cmd = ["snapshot"] - def execute(self, context: Context): + def execute(self, context: Context) -> Any: return self.build_and_run_cmd(context=context) @@ -128,11 +126,11 @@ class DbtRunKubernetesOperator(DbtKubernetesBaseOperator): ui_color = "#7352BA" ui_fgcolor = "#F4F2FC" - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: str) -> None: super().__init__(**kwargs) - self.base_cmd = "run" + self.base_cmd = ["run"] - def execute(self, context: Context): + def execute(self, context: Context) -> Any: return self.build_and_run_cmd(context=context) @@ -143,13 +141,13 @@ class DbtTestKubernetesOperator(DbtKubernetesBaseOperator): ui_color = "#8194E0" - def __init__(self, on_warning_callback: Optional[Callable] = None, **kwargs) -> None: + def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: str) -> None: super().__init__(**kwargs) - self.base_cmd = "test" + self.base_cmd = ["test"] # as of now, on_warning_callback in kubernetes executor does nothing self.on_warning_callback = on_warning_callback - def execute(self, context: Context): + def execute(self, context: Context) -> Any: return self.build_and_run_cmd(context=context) @@ -163,21 +161,21 @@ class DbtRunOperationKubernetesOperator(DbtKubernetesBaseOperator): """ ui_color = "#8194E0" - template_fields: Sequence[str] = "args" + template_fields: Sequence[str] = ("args",) - def __init__(self, macro_name: str, args: dict = None, **kwargs) -> None: + def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: str) -> None: self.macro_name = macro_name self.args = args super().__init__(**kwargs) self.base_cmd = ["run-operation", macro_name] - def add_cmd_flags(self): + def add_cmd_flags(self) -> list[str]: flags = [] if self.args is not None: flags.append("--args") flags.append(yaml.dump(self.args)) return flags - def execute(self, context: Context): + def execute(self, context: Context) -> Any: cmd_flags = self.add_cmd_flags() return self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) diff --git a/cosmos/operators/lazy_load.py b/cosmos/operators/lazy_load.py index eba2ecdd6..2ebc084f2 100644 --- a/cosmos/operators/lazy_load.py +++ b/cosmos/operators/lazy_load.py @@ -1,5 +1,8 @@ -def MissingPackage(module_name, optional_dependency_name): - def raise_error(**kwargs): +from typing import Any + + +def MissingPackage(module_name: str, optional_dependency_name: str) -> Any: + def raise_error(**kwargs: Any) -> None: raise RuntimeError( f"Error loading the module {module_name}," f" please make sure the right optional dependencies are installed." diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 605fd4ce3..beb1858bd 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -6,7 +6,7 @@ import signal import tempfile from pathlib import Path -from typing import Callable, Optional, Sequence +from typing import Any, Callable, Sequence, Tuple import yaml from airflow.compat.functools import cached_property @@ -45,7 +45,7 @@ class DbtLocalBaseOperator(DbtBaseOperator): :param should_store_compiled_sql: If true, store the compiled SQL in the compiled_sql rendered template. """ - template_fields: Sequence[str] = DbtBaseOperator.template_fields + ("compiled_sql",) + template_fields: Sequence[str] = DbtBaseOperator.template_fields + ("compiled_sql",) # type: ignore[operator] template_fields_renderers = { "compiled_sql": "sql", } @@ -53,12 +53,12 @@ class DbtLocalBaseOperator(DbtBaseOperator): def __init__( self, install_deps: bool = False, - callback: Optional[Callable[[str], None]] = None, + callback: Callable[[str], None] | None = None, profile_args: dict[str, str] = {}, profile_name: str | None = None, target_name: str | None = None, should_store_compiled_sql: bool = True, - **kwargs, + **kwargs: Any, ) -> None: self.install_deps = install_deps self.profile_args = profile_args @@ -69,12 +69,12 @@ def __init__( self.should_store_compiled_sql = should_store_compiled_sql super().__init__(**kwargs) - @cached_property - def subprocess_hook(self): + @cached_property # type: ignore[misc] # ignores internal untyped decorator + def subprocess_hook(self) -> FullOutputSubprocessHook: """Returns hook for running the bash command.""" return FullOutputSubprocessHook() - def exception_handling(self, result: FullOutputSubprocessResult): + def exception_handling(self, result: FullOutputSubprocessResult) -> None: if self.skip_exit_code is not None and result.exit_code == self.skip_exit_code: raise AirflowSkipException(f"dbt command returned exit code {self.skip_exit_code}. Skipping.") elif result.exit_code != 0: @@ -83,7 +83,7 @@ def exception_handling(self, result: FullOutputSubprocessResult): *result.full_output, ) - @provide_session + @provide_session # type: ignore[misc] # ignores internal untyped decorator def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None: """ Takes the compiled SQL files from the dbt run and stores them in the compiled_sql rendered template. @@ -126,8 +126,9 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se ).delete() session.add(rtif) - def run_subprocess(self, *args, **kwargs): - return self.subprocess_hook.run_command(*args, **kwargs) + def run_subprocess(self, *args: Tuple[Any], **kwargs: Any) -> FullOutputSubprocessResult: + subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(*args, **kwargs) + return subprocess_result def get_profile_name(self, project_dir: str) -> str: """ @@ -174,8 +175,8 @@ def get_target_name(self) -> str: def run_command( self, - cmd: list[str], - env: dict[str, str], + cmd: list[str | None], + env: dict[str, str | bytes | os.PathLike[Any]], context: Context, ) -> FullOutputSubprocessResult: """ @@ -262,11 +263,11 @@ class DbtLSLocalOperator(DbtLocalBaseOperator): ui_color = "#DBCDF6" - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - self.base_cmd = "ls" + self.base_cmd = ["ls"] - def execute(self, context: Context): + def execute(self, context: Context) -> str: result = self.build_and_run_cmd(context=context) return result.output @@ -280,19 +281,19 @@ class DbtSeedLocalOperator(DbtLocalBaseOperator): ui_color = "#F58D7E" - def __init__(self, full_refresh: bool = False, **kwargs) -> None: + def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: self.full_refresh = full_refresh super().__init__(**kwargs) - self.base_cmd = "seed" + self.base_cmd = ["seed"] - def add_cmd_flags(self): + def add_cmd_flags(self) -> list[str]: flags = [] if self.full_refresh is True: flags.append("--full-refresh") return flags - def execute(self, context: Context): + def execute(self, context: Context) -> str: cmd_flags = self.add_cmd_flags() result = self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) return result.output @@ -306,11 +307,11 @@ class DbtSnapshotLocalOperator(DbtLocalBaseOperator): ui_color = "#964B00" - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - self.base_cmd = "snapshot" + self.base_cmd = ["snapshot"] - def execute(self, context: Context): + def execute(self, context: Context) -> str: result = self.build_and_run_cmd(context=context) return result.output @@ -323,11 +324,11 @@ class DbtRunLocalOperator(DbtLocalBaseOperator): ui_color = "#7352BA" ui_fgcolor = "#F4F2FC" - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - self.base_cmd = "run" + self.base_cmd = ["run"] - def execute(self, context: Context): + def execute(self, context: Context) -> str: result = self.build_and_run_cmd(context=context) return result.output @@ -343,11 +344,11 @@ class DbtTestLocalOperator(DbtLocalBaseOperator): def __init__( self, - on_warning_callback: Optional[Callable] = None, - **kwargs, + on_warning_callback: Callable[..., Any] | None = None, + **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.base_cmd = "test" + self.base_cmd = ["test"] self.on_warning_callback = on_warning_callback def _should_run_tests( @@ -362,7 +363,7 @@ def _should_run_tests( :param result: The output from the build and run command. """ - return self.on_warning_callback and no_tests_message not in result.output + return self.on_warning_callback is not None and no_tests_message not in result.output def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) -> None: """ @@ -378,9 +379,10 @@ def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) warning_context["test_names"] = test_names warning_context["test_results"] = test_results - self.on_warning_callback(warning_context) + if self.on_warning_callback: + self.on_warning_callback(warning_context) - def execute(self, context: Context): + def execute(self, context: Context) -> str: result = self.build_and_run_cmd(context=context) if not self._should_run_tests(result): @@ -403,22 +405,22 @@ class DbtRunOperationLocalOperator(DbtLocalBaseOperator): """ ui_color = "#8194E0" - template_fields: Sequence[str] = "args" + template_fields: Sequence[str] = ("args",) - def __init__(self, macro_name: str, args: dict = None, **kwargs) -> None: + def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: Any) -> None: self.macro_name = macro_name self.args = args super().__init__(**kwargs) self.base_cmd = ["run-operation", macro_name] - def add_cmd_flags(self): + def add_cmd_flags(self) -> list[str]: flags = [] if self.args is not None: flags.append("--args") flags.append(yaml.dump(self.args)) return flags - def execute(self, context: Context): + def execute(self, context: Context) -> str: cmd_flags = self.add_cmd_flags() result = self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) return result.output @@ -434,11 +436,11 @@ class DbtDocsLocalOperator(DbtLocalBaseOperator): required_files = ["index.html", "manifest.json", "graph.gpickle", "catalog.json"] - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.base_cmd = ["docs", "generate"] - def execute(self, context: Context): + def execute(self, context: Context) -> str: result = self.build_and_run_cmd(context=context) return result.output @@ -460,7 +462,7 @@ def __init__( aws_conn_id: str, bucket_name: str, folder_dir: str | None = None, - **kwargs, + **kwargs: str, ) -> None: "Initializes the operator." self.aws_conn_id = aws_conn_id @@ -520,7 +522,7 @@ def __init__( azure_conn_id: str, container_name: str, folder_dir: str | None = None, - **kwargs, + **kwargs: str, ) -> None: "Initializes the operator." self.azure_conn_id = azure_conn_id @@ -571,7 +573,7 @@ class DbtDepsLocalOperator(DbtLocalBaseOperator): ui_color = "#8194E0" - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: str) -> None: raise DeprecationWarning( "The DbtDepsOperator has been deprecated. " "Please use the `install_deps` flag in dbt_args instead." ) diff --git a/cosmos/operators/virtualenv.py b/cosmos/operators/virtualenv.py index 07728cd2f..2566097d7 100644 --- a/cosmos/operators/virtualenv.py +++ b/cosmos/operators/virtualenv.py @@ -3,11 +3,12 @@ import logging from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Tuple from airflow.compat.functools import cached_property from airflow.utils.python_virtualenv import prepare_virtualenv +from cosmos.hooks.subprocess import FullOutputSubprocessResult from cosmos.operators.local import ( DbtDocsLocalOperator, @@ -45,14 +46,14 @@ def __init__( self, py_requirements: list[str] | None = None, py_system_site_packages: bool = False, - **kwargs, + **kwargs: Any, ) -> None: self.py_requirements = py_requirements or [] self.py_system_site_packages = py_system_site_packages super().__init__(**kwargs) - self._venv_tmp_dir = "" + self._venv_tmp_dir = TemporaryDirectory() - @cached_property + @cached_property # type: ignore[misc] # ignores internal untyped decorator def venv_dbt_path( self, ) -> str: @@ -84,15 +85,14 @@ def venv_dbt_path( self.log.info("Using dbt version %s available at %s", dbt_version, dbt_binary) return str(dbt_binary) - def run_subprocess(self, command, *args, **kwargs): + def run_subprocess( # type: ignore[override] + self, *args: Tuple[Any], command: list[str], **kwargs: Any + ) -> FullOutputSubprocessResult: if self.py_requirements: command[0] = self.venv_dbt_path - return self.subprocess_hook.run_command( - command, - *args, - **kwargs, - ) + subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(command, *args, **kwargs) + return subprocess_result def execute(self, context: Context) -> str: output = super().execute(context) diff --git a/cosmos/profiles/trino/jwt.py b/cosmos/profiles/trino/jwt.py index 7cd6dc967..f94b787fc 100644 --- a/cosmos/profiles/trino/jwt.py +++ b/cosmos/profiles/trino/jwt.py @@ -28,7 +28,7 @@ class TrinoJWTProfileMapping(TrinoBaseProfileMapping): @property def profile(self) -> dict[str, Any | None]: "Gets profile." - common_profile_vars = super().profile + common_profile_vars: dict[str, Any] = super().profile # need to remove jwt from profile_args because it will be set as an environment variable profile_args = self.profile_args.copy() diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index f5e049ae8..baacfef43 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -34,7 +34,7 @@ def test_dbt_base_operator_add_user_supplied_flags() -> None: conn_id="my_airflow_connection", task_id="my-task", project_dir="my/dir", - base_cmd="run", + base_cmd=["run"], dbt_cmd_flags=["--full-refresh"], ) diff --git a/tests/utils.py b/tests/utils.py index 99804b540..37f7a3223 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -57,7 +57,7 @@ def test_dag( dag.clear( start_date=execution_date, end_date=execution_date, - dag_run_state=False, # type: ignore + dag_run_state=False, session=session, ) dag.log.debug("Getting dagrun for dag %s", dag.dag_id) @@ -164,7 +164,7 @@ def _get_or_create_dagrun( run_id=run_id, start_date=start_date or execution_date, session=session, - conf=conf, # type: ignore + conf=conf, ) log.info("created dagrun %s", str(dr)) return dr