diff --git a/cosmos/__init__.py b/cosmos/__init__.py index 82ed9ef3f..e034cb31d 100644 --- a/cosmos/__init__.py +++ b/cosmos/__init__.py @@ -12,6 +12,12 @@ from cosmos.constants import LoadMode, TestBehavior, ExecutionMode from cosmos.dataset import get_dbt_dataset from cosmos.operators.lazy_load import MissingPackage +from cosmos.config import ( + ProjectConfig, + ProfileConfig, + ExecutionConfig, + RenderConfig, +) from cosmos.operators.local import ( DbtDepsLocalOperator, @@ -79,6 +85,10 @@ ) __all__ = [ + "ProjectConfig", + "ProfileConfig", + "ExecutionConfig", + "RenderConfig", "DbtLSLocalOperator", "DbtRunOperationLocalOperator", "DbtRunLocalOperator", diff --git a/cosmos/config.py b/cosmos/config.py new file mode 100644 index 000000000..a6e78c071 --- /dev/null +++ b/cosmos/config.py @@ -0,0 +1,135 @@ +"""Module that contains all Cosmos config classes.""" + +from __future__ import annotations + +import shutil +from dataclasses import dataclass, field +from pathlib import Path +from logging import getLogger + +from cosmos.constants import TestBehavior, ExecutionMode, LoadMode +from cosmos.exceptions import CosmosValueError + +logger = getLogger(__name__) + + +@dataclass +class RenderConfig: + """ + Class for setting general Cosmos config. + + :param emit_datasets: If enabled test nodes emit Airflow Datasets for downstream cross-DAG + dependencies + :param test_behavior: The behavior for running tests. Defaults to after each + :param execution_mode: The execution mode for dbt. Defaults to local + :param select: A list of dbt select arguments (e.g. 'config.materialized:incremental') + :param exclude: A list of dbt exclude arguments (e.g. 'tag:nightly') + """ + + emit_datasets: bool = True + test_behavior: TestBehavior = TestBehavior.AFTER_EACH + load_method: LoadMode = LoadMode.AUTOMATIC + select: list[str] = field(default_factory=list) + exclude: list[str] = field(default_factory=list) + + +@dataclass +class ProjectConfig: + """ + Class for setting project config. + + :param dbt_project_path: The path to the dbt project directory. Example: /path/to/dbt/project + :param models_dir: The path to the dbt models directory within the project. Defaults to models + :param seeds_dir: The path to the dbt seeds directory within the project. Defaults to seeds + :param snapshots_dir: The path to the dbt snapshots directory within the project. Defaults to + snapshots + :param manifest_path: The path to the dbt manifest file. Defaults to None + """ + + dbt_project: str | Path + models: str | Path = "models" + seeds: str | Path = "seeds" + snapshots: str | Path = "snapshots" + manifest: str | Path | None = None + + manifest_path: Path | None = None + + def __post_init__(self) -> None: + "Converts paths to `Path` objects." + self.dbt_project_path = Path(self.dbt_project) + self.models_path = self.dbt_project_path / Path(self.models) + self.seeds_path = self.dbt_project_path / Path(self.seeds) + self.snapshots_path = self.dbt_project_path / Path(self.snapshots) + + if self.manifest: + self.manifest_path = Path(self.manifest) + + def validate_project(self) -> None: + "Validates that the project, models, and seeds directories exist." + project_yml_path = self.dbt_project_path / "dbt_project.yml" + if not project_yml_path.exists(): + raise CosmosValueError(f"Could not find dbt_project.yml at {project_yml_path}") + + if not self.models_path.exists(): + raise CosmosValueError(f"Could not find models directory at {self.models_path}") + + if self.manifest_path and not self.manifest_path.exists(): + raise CosmosValueError(f"Could not find manifest at {self.manifest_path}") + + def is_manifest_available(self) -> bool: + """ + Check if the `dbt` project manifest is set and if the file exists. + """ + if not self.manifest_path: + return False + + return self.manifest_path.exists() + + @property + def project_name(self) -> str: + "The name of the dbt project." + return self.dbt_project_path.stem + + +@dataclass +class ProfileConfig: + """ + Class for setting profile config. + + :param profile_name: The name of the dbt profile to use. + :param target_name: The name of the dbt target to use. + :param conn_id: The Airflow connection ID to use. + """ + + # should always be set to be explicit + profile_name: str + target_name: str + conn_id: str + profile_args: dict[str, str] = field(default_factory=dict) + + +@dataclass +class ExecutionConfig: + """ + Contains configuration about how to execute dbt. + + :param execution_mode: The execution mode for dbt. Defaults to local + :param dbt_executable_path: The path to the dbt executable. Defaults to dbt-ol or dbt if + available on the path. + :param dbt_cli_flags: A list of extra dbt cli flags to pass to dbt. Defaults to [] + :param append_env: If True, append the env dictionary to the existing environment. If False, + replace the existing environment with the env dictionary. Defaults to False + :param cancel_query_on_kill: If True, cancel the query when the dbt process is killed. If False, + do not cancel the query when the dbt process is killed. Defaults to True + :param install_deps: If True, install dbt dependencies before running dbt. Defaults to False + :param skip_exit_code: If the dbt process exits with this exit code, do not raise an exception. + Defaults to None + """ + + execution_mode: ExecutionMode = ExecutionMode.LOCAL + dbt_executable_path: str | Path = shutil.which("dbt-ol") or shutil.which("dbt") or "dbt" + dbt_cli_flags: list[str] = field(default_factory=list) + append_env: bool = False + cancel_query_on_kill: bool = True + install_deps: bool = False + skip_exit_code: int | None = None diff --git a/cosmos/converter.py b/cosmos/converter.py index c8efa3cd3..c66552720 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -7,18 +7,18 @@ import inspect import logging from typing import Any, Callable +from pathlib import Path from airflow.exceptions import AirflowException from airflow.models.dag import DAG from airflow.utils.task_group import TaskGroup -from pathlib import Path from cosmos.airflow.graph import build_airflow_graph -from cosmos.constants import ExecutionMode, LoadMode, TestBehavior from cosmos.dbt.executable import get_system_dbt from cosmos.dbt.graph import DbtGraph from cosmos.dbt.project import DbtProject from cosmos.dbt.selector import retrieve_by_label +from cosmos.config import ProjectConfig, ExecutionConfig, RenderConfig, ProfileConfig logger = logging.getLogger(__name__) @@ -102,67 +102,48 @@ class DbtToAirflowConverter: :param dag: Airflow DAG to be populated :param task_group (optional): Airflow Task Group to be populated - :param dbt_project_name: The name of the dbt project - :param dbt_root_path: The path to the dbt root directory - :param dbt_models_dir: The path to the dbt models directory within the project - :param dbt_seeds_dir: The path to the dbt seeds directory within the project - :param conn_id: The Airflow connection ID to use for the dbt profile - :param profile_args: Arguments to pass to the dbt profile - :param profile_name_override: A name to use for the dbt profile. If not provided, and no profile target is found - in your project's dbt_project.yml, "cosmos_profile" is used. - :param target_name_override: A name to use for the dbt target. If not provided, "cosmos_target" is used. - :param dbt_args: Parameters to pass to the underlying dbt operators, can include dbt_executable_path to utilize venv + :param project_config: The dbt project configuration + :param execution_config: The dbt execution configuration + :param render_config: The dbt render configuration :param operator_args: Parameters to pass to the underlying operators, can include KubernetesPodOperator or DockerOperator parameters - :param emit_datasets: If enabled test nodes emit Airflow Datasets for downstream cross-DAG dependencies - :param test_behavior: When to run `dbt` tests. Default is TestBehavior.AFTER_EACH, that runs tests after each model. - :param select: A list of dbt select arguments (e.g. 'config.materialized:incremental') - :param exclude: A list of dbt exclude arguments (e.g. 'tag:nightly') - :param execution_mode: Where Cosmos should run each dbt task (e.g. ExecutionMode.LOCAL, ExecutionMode.KUBERNETES). - Default is ExecutionMode.LOCAL. :param on_warning_callback: A callback function called on warnings with additional Context variables "test_names" and "test_results" of type `List`. Each index in "test_names" corresponds to the same index in "test_results". """ def __init__( self, - dbt_project_name: str, - conn_id: str, + project_config: ProjectConfig, + profile_config: ProfileConfig, + execution_config: ExecutionConfig = ExecutionConfig(), + render_config: RenderConfig = RenderConfig(), dag: DAG | None = None, task_group: TaskGroup | None = None, - profile_args: dict[str, str] = {}, - dbt_args: dict[str, Any] = {}, - profile_name_override: str | None = None, - target_name_override: str | None = None, - operator_args: dict[str, Any] = {}, - emit_datasets: bool = True, - dbt_root_path: str = "/usr/local/airflow/dags/dbt", - dbt_models_dir: str | None = None, - dbt_seeds_dir: str | None = None, - dbt_snapshots_dir: str | None = None, - test_behavior: str | TestBehavior = TestBehavior.AFTER_EACH, - select: list[str] | None = None, - exclude: list[str] | None = None, - execution_mode: str | ExecutionMode = ExecutionMode.LOCAL, - load_mode: str | LoadMode = LoadMode.AUTOMATIC, - manifest_path: str | Path | None = None, + operator_args: dict[str, Any] | None = None, on_warning_callback: Callable[..., Any] | None = None, *args: Any, **kwargs: Any, ) -> None: - select = select or [] - exclude = exclude or [] - - test_behavior = convert_value_to_enum(test_behavior, TestBehavior) - execution_mode = convert_value_to_enum(execution_mode, ExecutionMode) - load_mode = convert_value_to_enum(load_mode, LoadMode) - - test_behavior = convert_value_to_enum(test_behavior, TestBehavior) - execution_mode = convert_value_to_enum(execution_mode, ExecutionMode) - load_mode = convert_value_to_enum(load_mode, LoadMode) - - if type(manifest_path) == str: - manifest_path = Path(manifest_path) + conn_id = profile_config.conn_id + profile_args = profile_config.profile_args + profile_name_override = profile_config.profile_name + target_name_override = profile_config.target_name + emit_datasets = render_config.emit_datasets + dbt_root_path = project_config.dbt_project_path.parent + dbt_project_name = project_config.dbt_project_path.name + dbt_models_dir = project_config.models_path + dbt_seeds_dir = project_config.seeds_path + dbt_snapshots_dir = project_config.snapshots_path + test_behavior = render_config.test_behavior + select = render_config.select + exclude = render_config.exclude + execution_mode = execution_config.execution_mode + load_mode = render_config.load_method + manifest_path = project_config.manifest_path + dbt_executable_path = execution_config.dbt_executable_path or get_system_dbt() + + if not operator_args: + operator_args = {} dbt_project = DbtProject( name=dbt_project_name, @@ -177,12 +158,11 @@ def __init__( project=dbt_project, exclude=exclude, select=select, - dbt_cmd=dbt_args.get("dbt_executable_path", get_system_dbt()), + dbt_cmd=dbt_executable_path, ) dbt_graph.load(method=load_mode, execution_mode=execution_mode) task_args = { - **dbt_args, **operator_args, "profile_args": profile_args, "profile_name": profile_name_override, diff --git a/cosmos/exceptions.py b/cosmos/exceptions.py new file mode 100644 index 000000000..74091f4a1 --- /dev/null +++ b/cosmos/exceptions.py @@ -0,0 +1,5 @@ +"Contains exceptions that Cosmos uses" + + +class CosmosValueError(ValueError): + """Raised when a Cosmos config value is invalid.""" diff --git a/dev/Dockerfile b/dev/Dockerfile index 64dece432..90c49ed6c 100644 --- a/dev/Dockerfile +++ b/dev/Dockerfile @@ -7,7 +7,7 @@ COPY ./README.rst ${AIRFLOW_HOME}/astronomer_cosmos/ COPY ./cosmos/ ${AIRFLOW_HOME}/astronomer_cosmos/cosmos/ # install the package in editable mode -RUN pip install -e "${AIRFLOW_HOME}/astronomer_cosmos"[dbt-postgres] +RUN pip install -e "${AIRFLOW_HOME}/astronomer_cosmos"[dbt-postgres,dbt-databricks] # make sure astro user owns the package RUN chown -R astro:astro ${AIRFLOW_HOME}/astronomer_cosmos diff --git a/dev/dags/basic_cosmos_dag.py b/dev/dags/basic_cosmos_dag.py index 6702d6ec9..f372454ba 100644 --- a/dev/dags/basic_cosmos_dag.py +++ b/dev/dags/basic_cosmos_dag.py @@ -6,7 +6,7 @@ from datetime import datetime from pathlib import Path -from cosmos import DbtDag +from cosmos import DbtDag, ProjectConfig, ProfileConfig DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) @@ -14,14 +14,15 @@ # [START local_example] basic_cosmos_dag = DbtDag( # dbt/cosmos-specific parameters - dbt_root_path=DBT_ROOT_PATH, - dbt_project_name="jaffle_shop", - conn_id="airflow_db", - profile_args={ - "schema": "public", - }, - profile_name_override="airflow", - target_name_override="dev_target", + project_config=ProjectConfig( + DBT_ROOT_PATH / "jaffle_shop", + ), + profile_config=ProfileConfig( + profile_name="default", + target_name="dev", + conn_id="airflow_db", + profile_args={"schema": "public"}, + ), # normal dag parameters schedule_interval="@daily", start_date=datetime(2023, 1, 1), diff --git a/dev/dags/basic_cosmos_task_group.py b/dev/dags/basic_cosmos_task_group.py index 5d2581d3e..d6d95ebf7 100644 --- a/dev/dags/basic_cosmos_task_group.py +++ b/dev/dags/basic_cosmos_task_group.py @@ -8,7 +8,7 @@ from airflow.decorators import dag from airflow.operators.empty import EmptyOperator -from cosmos import DbtTaskGroup +from cosmos import DbtTaskGroup, ProjectConfig, ProfileConfig DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) @@ -26,10 +26,15 @@ def basic_cosmos_task_group() -> None: pre_dbt = EmptyOperator(task_id="pre_dbt") jaffle_shop = DbtTaskGroup( - dbt_root_path=DBT_ROOT_PATH, - dbt_project_name="jaffle_shop", - conn_id="airflow_db", - profile_args={"schema": "public"}, + project_config=ProjectConfig( + DBT_ROOT_PATH / "jaffle_shop", + ), + profile_config=ProfileConfig( + profile_name="default", + target_name="dev", + conn_id="airflow_db", + profile_args={"schema": "public"}, + ), ) post_dbt = EmptyOperator(task_id="post_dbt") diff --git a/dev/dags/example_cosmos_python_models.py b/dev/dags/example_cosmos_python_models.py index 2876e38c2..b829cee37 100644 --- a/dev/dags/example_cosmos_python_models.py +++ b/dev/dags/example_cosmos_python_models.py @@ -17,7 +17,7 @@ from datetime import datetime from pathlib import Path -from cosmos import DbtDag +from cosmos import DbtDag, ProjectConfig, ProfileConfig, ExecutionConfig DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) @@ -26,15 +26,18 @@ # [START example_cosmos_python_models] example_cosmos_python_models = DbtDag( # dbt/cosmos-specific parameters - dbt_root_path=DBT_ROOT_PATH, - dbt_project_name="jaffle_shop_python", - conn_id="databricks_default", - profile_args={ - "schema": SCHEMA, - }, - operator_args={"append_env": True}, - profile_name_override="airflow", - target_name_override="dev_target", + project_config=ProjectConfig( + DBT_ROOT_PATH / "jaffle_shop_python", + ), + profile_config=ProfileConfig( + profile_name="default", + target_name="dev", + conn_id="databricks_default", + profile_args={"schema": SCHEMA}, + ), + execution_config=ExecutionConfig( + append_env=True, + ), # normal dag parameters schedule_interval="@daily", start_date=datetime(2023, 1, 1), diff --git a/dev/dags/example_virtualenv.py b/dev/dags/example_virtualenv.py index 7ad287022..83c56112d 100644 --- a/dev/dags/example_virtualenv.py +++ b/dev/dags/example_virtualenv.py @@ -5,7 +5,7 @@ from datetime import datetime from pathlib import Path -from cosmos import DbtDag, ExecutionMode +from cosmos import DbtDag, ExecutionMode, ExecutionConfig, ProjectConfig, ProfileConfig DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) @@ -15,13 +15,19 @@ # [START virtualenv_example] example_virtualenv = DbtDag( # dbt/cosmos-specific parameters - dbt_root_path=DBT_ROOT_PATH, - dbt_project_name=PROJECT_NAME, - conn_id=CONNECTION_ID, - dbt_args={"schema": "public"}, - execution_mode=ExecutionMode.VIRTUALENV, + project_config=ProjectConfig( + DBT_ROOT_PATH / "jaffle_shop", + ), + profile_config=ProfileConfig( + profile_name="default", + target_name="dev", + conn_id="airflow_db", + profile_args={"schema": "public"}, + ), + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.VIRTUALENV, + ), operator_args={ - "project_dir": DBT_ROOT_PATH / PROJECT_NAME, "py_system_site_packages": False, "py_requirements": ["dbt-postgres==1.6.0b1"], },