Skip to content

Commit

Permalink
Initial implementation of simplifying config interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
jlaneve committed Jul 25, 2023
1 parent 66af5bb commit 2bea709
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 83 deletions.
10 changes: 10 additions & 0 deletions cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from cosmos.constants import LoadMode, TestBehavior, ExecutionMode
from cosmos.dataset import get_dbt_dataset
from cosmos.operators.lazy_load import MissingPackage
from cosmos.config import (
ProjectConfig,
ProfileConfig,
ExecutionConfig,
RenderConfig,
)

from cosmos.operators.local import (
DbtDepsLocalOperator,
Expand Down Expand Up @@ -79,6 +85,10 @@
)

__all__ = [
"ProjectConfig",
"ProfileConfig",
"ExecutionConfig",
"RenderConfig",
"DbtLSLocalOperator",
"DbtRunOperationLocalOperator",
"DbtRunLocalOperator",
Expand Down
135 changes: 135 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Module that contains all Cosmos config classes."""

from __future__ import annotations

import shutil
from dataclasses import dataclass, field
from pathlib import Path
from logging import getLogger

from cosmos.constants import TestBehavior, ExecutionMode, LoadMode
from cosmos.exceptions import CosmosValueError

logger = getLogger(__name__)


@dataclass
class RenderConfig:
"""
Class for setting general Cosmos config.
:param emit_datasets: If enabled test nodes emit Airflow Datasets for downstream cross-DAG
dependencies
:param test_behavior: The behavior for running tests. Defaults to after each
:param execution_mode: The execution mode for dbt. Defaults to local
:param select: A list of dbt select arguments (e.g. 'config.materialized:incremental')
:param exclude: A list of dbt exclude arguments (e.g. 'tag:nightly')
"""

emit_datasets: bool = True
test_behavior: TestBehavior = TestBehavior.AFTER_EACH
load_method: LoadMode = LoadMode.AUTOMATIC
select: list[str] = field(default_factory=list)
exclude: list[str] = field(default_factory=list)


@dataclass
class ProjectConfig:
"""
Class for setting project config.
:param dbt_project_path: The path to the dbt project directory. Example: /path/to/dbt/project
:param models_dir: The path to the dbt models directory within the project. Defaults to models
:param seeds_dir: The path to the dbt seeds directory within the project. Defaults to seeds
:param snapshots_dir: The path to the dbt snapshots directory within the project. Defaults to
snapshots
:param manifest_path: The path to the dbt manifest file. Defaults to None
"""

dbt_project: str | Path
models: str | Path = "models"
seeds: str | Path = "seeds"
snapshots: str | Path = "snapshots"
manifest: str | Path | None = None

manifest_path: Path | None = None

def __post_init__(self) -> None:
"Converts paths to `Path` objects."
self.dbt_project_path = Path(self.dbt_project)
self.models_path = self.dbt_project_path / Path(self.models)
self.seeds_path = self.dbt_project_path / Path(self.seeds)
self.snapshots_path = self.dbt_project_path / Path(self.snapshots)

Check warning on line 62 in cosmos/config.py

View check run for this annotation

Codecov / codecov/patch

cosmos/config.py#L59-L62

Added lines #L59 - L62 were not covered by tests

if self.manifest:
self.manifest_path = Path(self.manifest)

Check warning on line 65 in cosmos/config.py

View check run for this annotation

Codecov / codecov/patch

cosmos/config.py#L64-L65

Added lines #L64 - L65 were not covered by tests

def validate_project(self) -> None:
"Validates that the project, models, and seeds directories exist."
project_yml_path = self.dbt_project_path / "dbt_project.yml"
if not project_yml_path.exists():
raise CosmosValueError(f"Could not find dbt_project.yml at {project_yml_path}")

Check warning on line 71 in cosmos/config.py

View check run for this annotation

Codecov / codecov/patch

cosmos/config.py#L69-L71

Added lines #L69 - L71 were not covered by tests

if not self.models_path.exists():
raise CosmosValueError(f"Could not find models directory at {self.models_path}")

Check warning on line 74 in cosmos/config.py

View check run for this annotation

Codecov / codecov/patch

cosmos/config.py#L73-L74

Added lines #L73 - L74 were not covered by tests

if self.manifest_path and not self.manifest_path.exists():
raise CosmosValueError(f"Could not find manifest at {self.manifest_path}")

Check warning on line 77 in cosmos/config.py

View check run for this annotation

Codecov / codecov/patch

cosmos/config.py#L76-L77

Added lines #L76 - L77 were not covered by tests

def is_manifest_available(self) -> bool:
"""
Check if the `dbt` project manifest is set and if the file exists.
"""
if not self.manifest_path:
return False

Check warning on line 84 in cosmos/config.py

View check run for this annotation

Codecov / codecov/patch

cosmos/config.py#L83-L84

Added lines #L83 - L84 were not covered by tests

return self.manifest_path.exists()

Check warning on line 86 in cosmos/config.py

View check run for this annotation

Codecov / codecov/patch

cosmos/config.py#L86

Added line #L86 was not covered by tests

@property
def project_name(self) -> str:
"The name of the dbt project."
return self.dbt_project_path.stem

Check warning on line 91 in cosmos/config.py

View check run for this annotation

Codecov / codecov/patch

cosmos/config.py#L91

Added line #L91 was not covered by tests


@dataclass
class ProfileConfig:
"""
Class for setting profile config.
:param profile_name: The name of the dbt profile to use.
:param target_name: The name of the dbt target to use.
:param conn_id: The Airflow connection ID to use.
"""

# should always be set to be explicit
profile_name: str
target_name: str
conn_id: str
profile_args: dict[str, str] = field(default_factory=dict)


@dataclass
class ExecutionConfig:
"""
Contains configuration about how to execute dbt.
:param execution_mode: The execution mode for dbt. Defaults to local
:param dbt_executable_path: The path to the dbt executable. Defaults to dbt-ol or dbt if
available on the path.
:param dbt_cli_flags: A list of extra dbt cli flags to pass to dbt. Defaults to []
:param append_env: If True, append the env dictionary to the existing environment. If False,
replace the existing environment with the env dictionary. Defaults to False
:param cancel_query_on_kill: If True, cancel the query when the dbt process is killed. If False,
do not cancel the query when the dbt process is killed. Defaults to True
:param install_deps: If True, install dbt dependencies before running dbt. Defaults to False
:param skip_exit_code: If the dbt process exits with this exit code, do not raise an exception.
Defaults to None
"""

execution_mode: ExecutionMode = ExecutionMode.LOCAL
dbt_executable_path: str | Path = shutil.which("dbt-ol") or shutil.which("dbt") or "dbt"
dbt_cli_flags: list[str] = field(default_factory=list)
append_env: bool = False
cancel_query_on_kill: bool = True
install_deps: bool = False
skip_exit_code: int | None = None
82 changes: 31 additions & 51 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
import inspect
import logging
from typing import Any, Callable
from pathlib import Path

from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup
from pathlib import Path

from cosmos.airflow.graph import build_airflow_graph
from cosmos.constants import ExecutionMode, LoadMode, TestBehavior
from cosmos.dbt.executable import get_system_dbt
from cosmos.dbt.graph import DbtGraph
from cosmos.dbt.project import DbtProject
from cosmos.dbt.selector import retrieve_by_label
from cosmos.config import ProjectConfig, ExecutionConfig, RenderConfig, ProfileConfig


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,67 +102,48 @@ class DbtToAirflowConverter:
:param dag: Airflow DAG to be populated
:param task_group (optional): Airflow Task Group to be populated
:param dbt_project_name: The name of the dbt project
:param dbt_root_path: The path to the dbt root directory
:param dbt_models_dir: The path to the dbt models directory within the project
:param dbt_seeds_dir: The path to the dbt seeds directory within the project
:param conn_id: The Airflow connection ID to use for the dbt profile
:param profile_args: Arguments to pass to the dbt profile
:param profile_name_override: A name to use for the dbt profile. If not provided, and no profile target is found
in your project's dbt_project.yml, "cosmos_profile" is used.
:param target_name_override: A name to use for the dbt target. If not provided, "cosmos_target" is used.
:param dbt_args: Parameters to pass to the underlying dbt operators, can include dbt_executable_path to utilize venv
:param project_config: The dbt project configuration
:param execution_config: The dbt execution configuration
:param render_config: The dbt render configuration
:param operator_args: Parameters to pass to the underlying operators, can include KubernetesPodOperator
or DockerOperator parameters
:param emit_datasets: If enabled test nodes emit Airflow Datasets for downstream cross-DAG dependencies
:param test_behavior: When to run `dbt` tests. Default is TestBehavior.AFTER_EACH, that runs tests after each model.
:param select: A list of dbt select arguments (e.g. 'config.materialized:incremental')
:param exclude: A list of dbt exclude arguments (e.g. 'tag:nightly')
:param execution_mode: Where Cosmos should run each dbt task (e.g. ExecutionMode.LOCAL, ExecutionMode.KUBERNETES).
Default is ExecutionMode.LOCAL.
:param on_warning_callback: A callback function called on warnings with additional Context variables "test_names"
and "test_results" of type `List`. Each index in "test_names" corresponds to the same index in "test_results".
"""

def __init__(
self,
dbt_project_name: str,
conn_id: str,
project_config: ProjectConfig,
profile_config: ProfileConfig,
execution_config: ExecutionConfig = ExecutionConfig(),
render_config: RenderConfig = RenderConfig(),
dag: DAG | None = None,
task_group: TaskGroup | None = None,
profile_args: dict[str, str] = {},
dbt_args: dict[str, Any] = {},
profile_name_override: str | None = None,
target_name_override: str | None = None,
operator_args: dict[str, Any] = {},
emit_datasets: bool = True,
dbt_root_path: str = "/usr/local/airflow/dags/dbt",
dbt_models_dir: str | None = None,
dbt_seeds_dir: str | None = None,
dbt_snapshots_dir: str | None = None,
test_behavior: str | TestBehavior = TestBehavior.AFTER_EACH,
select: list[str] | None = None,
exclude: list[str] | None = None,
execution_mode: str | ExecutionMode = ExecutionMode.LOCAL,
load_mode: str | LoadMode = LoadMode.AUTOMATIC,
manifest_path: str | Path | None = None,
operator_args: dict[str, Any] | None = None,
on_warning_callback: Callable[..., Any] | None = None,
*args: Any,
**kwargs: Any,
) -> None:
select = select or []
exclude = exclude or []

test_behavior = convert_value_to_enum(test_behavior, TestBehavior)
execution_mode = convert_value_to_enum(execution_mode, ExecutionMode)
load_mode = convert_value_to_enum(load_mode, LoadMode)

test_behavior = convert_value_to_enum(test_behavior, TestBehavior)
execution_mode = convert_value_to_enum(execution_mode, ExecutionMode)
load_mode = convert_value_to_enum(load_mode, LoadMode)

if type(manifest_path) == str:
manifest_path = Path(manifest_path)
conn_id = profile_config.conn_id
profile_args = profile_config.profile_args
profile_name_override = profile_config.profile_name
target_name_override = profile_config.target_name
emit_datasets = render_config.emit_datasets
dbt_root_path = project_config.dbt_project_path.parent
dbt_project_name = project_config.dbt_project_path.name
dbt_models_dir = project_config.models_path
dbt_seeds_dir = project_config.seeds_path
dbt_snapshots_dir = project_config.snapshots_path
test_behavior = render_config.test_behavior
select = render_config.select
exclude = render_config.exclude
execution_mode = execution_config.execution_mode
load_mode = render_config.load_method
manifest_path = project_config.manifest_path
dbt_executable_path = execution_config.dbt_executable_path or get_system_dbt()

Check warning on line 143 in cosmos/converter.py

View check run for this annotation

Codecov / codecov/patch

cosmos/converter.py#L127-L143

Added lines #L127 - L143 were not covered by tests

if not operator_args:
operator_args = {}

Check warning on line 146 in cosmos/converter.py

View check run for this annotation

Codecov / codecov/patch

cosmos/converter.py#L145-L146

Added lines #L145 - L146 were not covered by tests

dbt_project = DbtProject(
name=dbt_project_name,
Expand All @@ -177,12 +158,11 @@ def __init__(
project=dbt_project,
exclude=exclude,
select=select,
dbt_cmd=dbt_args.get("dbt_executable_path", get_system_dbt()),
dbt_cmd=dbt_executable_path,
)
dbt_graph.load(method=load_mode, execution_mode=execution_mode)

task_args = {
**dbt_args,
**operator_args,
"profile_args": profile_args,
"profile_name": profile_name_override,
Expand Down
5 changes: 5 additions & 0 deletions cosmos/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"Contains exceptions that Cosmos uses"


class CosmosValueError(ValueError):
"""Raised when a Cosmos config value is invalid."""
2 changes: 1 addition & 1 deletion dev/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ COPY ./README.rst ${AIRFLOW_HOME}/astronomer_cosmos/
COPY ./cosmos/ ${AIRFLOW_HOME}/astronomer_cosmos/cosmos/

# install the package in editable mode
RUN pip install -e "${AIRFLOW_HOME}/astronomer_cosmos"[dbt-postgres]
RUN pip install -e "${AIRFLOW_HOME}/astronomer_cosmos"[dbt-postgres,dbt-databricks]

# make sure astro user owns the package
RUN chown -R astro:astro ${AIRFLOW_HOME}/astronomer_cosmos
Expand Down
19 changes: 10 additions & 9 deletions dev/dags/basic_cosmos_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,23 @@
from datetime import datetime
from pathlib import Path

from cosmos import DbtDag
from cosmos import DbtDag, ProjectConfig, ProfileConfig

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))

# [START local_example]
basic_cosmos_dag = DbtDag(
# dbt/cosmos-specific parameters
dbt_root_path=DBT_ROOT_PATH,
dbt_project_name="jaffle_shop",
conn_id="airflow_db",
profile_args={
"schema": "public",
},
profile_name_override="airflow",
target_name_override="dev_target",
project_config=ProjectConfig(
DBT_ROOT_PATH / "jaffle_shop",
),
profile_config=ProfileConfig(
profile_name="default",
target_name="dev",
conn_id="airflow_db",
profile_args={"schema": "public"},
),
# normal dag parameters
schedule_interval="@daily",
start_date=datetime(2023, 1, 1),
Expand Down
15 changes: 10 additions & 5 deletions dev/dags/basic_cosmos_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from airflow.decorators import dag
from airflow.operators.empty import EmptyOperator

from cosmos import DbtTaskGroup
from cosmos import DbtTaskGroup, ProjectConfig, ProfileConfig

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))
Expand All @@ -26,10 +26,15 @@ def basic_cosmos_task_group() -> None:
pre_dbt = EmptyOperator(task_id="pre_dbt")

jaffle_shop = DbtTaskGroup(
dbt_root_path=DBT_ROOT_PATH,
dbt_project_name="jaffle_shop",
conn_id="airflow_db",
profile_args={"schema": "public"},
project_config=ProjectConfig(
DBT_ROOT_PATH / "jaffle_shop",
),
profile_config=ProfileConfig(
profile_name="default",
target_name="dev",
conn_id="airflow_db",
profile_args={"schema": "public"},
),
)

post_dbt = EmptyOperator(task_id="post_dbt")
Expand Down
Loading

0 comments on commit 2bea709

Please sign in to comment.