Skip to content

Commit

Permalink
PR revisions.
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek Mohan committed Jul 24, 2023
1 parent 0ff015a commit 299e889
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 57 deletions.
5 changes: 2 additions & 3 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import inspect
import logging
import sys
from enum import Enum
from typing import Any, Callable, Optional
from typing import Any, Callable

from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
Expand Down Expand Up @@ -144,7 +143,7 @@ def __init__(
execution_mode: str | ExecutionMode = ExecutionMode.LOCAL,
load_mode: str | LoadMode = LoadMode.AUTOMATIC,
manifest_path: str | Path | None = None,
on_warning_callback: Optional[Callable[..., Any]] = None,
on_warning_callback: Callable[..., Any] | None = None,
*args: Any,
**kwargs: Any,
) -> None:
Expand Down
3 changes: 1 addition & 2 deletions cosmos/core/airflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib
import logging
from typing import Optional

from airflow.models import BaseOperator
from airflow.models.dag import DAG
Expand All @@ -11,7 +10,7 @@
logger = logging.getLogger(__name__)


def get_airflow_task(task: Task, dag: DAG, task_group: Optional[TaskGroup] = None) -> BaseOperator:
def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None) -> BaseOperator:
"""
Get the Airflow Operator class for a Task.
Expand Down
2 changes: 1 addition & 1 deletion cosmos/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, id: str, *args: Tuple[Any], **kwargs: str):
logger.warning("Datasets are not supported in Airflow < 2.5.0")

def __eq__(self, other: "Dataset") -> bool:
return self.id == other.id # type: ignore[no-any-return]
return bool(self.id == other.id)


def get_dbt_dataset(connection_id: str, project_name: str, model_name: str) -> Dataset:
Expand Down
6 changes: 3 additions & 3 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from dataclasses import dataclass, field
from subprocess import Popen, PIPE
from typing import Any, Optional
from typing import Any

from cosmos.constants import DbtResourceType, ExecutionMode, LoadMode
from cosmos.dbt.executable import get_system_dbt
Expand Down Expand Up @@ -65,8 +65,8 @@ class DbtGraph:
def __init__(
self,
project: DbtProject,
exclude: Optional[list[str]] = None,
select: Optional[list[str]] = None,
exclude: list[str] | None = None,
select: list[str] | None = None,
dbt_cmd: str = get_system_dbt(),
):
self.project = project
Expand Down
12 changes: 6 additions & 6 deletions cosmos/dbt/parser/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional, Set
from typing import Any, ClassVar, Dict, List, Set

import jinja2
import yaml
Expand Down Expand Up @@ -63,7 +63,7 @@ def _config_selector_ooo(
self,
sql_configs: Set[str],
properties_configs: Set[str],
prefixes: Optional[List[str]] = None,
prefixes: List[str] | None = None,
) -> Set[str]:
"""
this will force values from the sql files to override whatever is in the properties.yml. So ooo:
Expand Down Expand Up @@ -221,10 +221,10 @@ class DbtProject:
project_name: str

# optional, user-specified instance variables
dbt_root_path: Optional[str] = None
dbt_models_dir: Optional[str] = None
dbt_snapshots_dir: Optional[str] = None
dbt_seeds_dir: Optional[str] = None
dbt_root_path: str | None = None
dbt_models_dir: str | None = None
dbt_snapshots_dir: str | None = None
dbt_seeds_dir: str | None = None

# private instance variables for managing state
models: Dict[str, DbtModel] = field(default_factory=dict)
Expand Down
28 changes: 13 additions & 15 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
import shutil
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Sequence, Tuple

import yaml
from airflow.models.baseoperator import BaseOperator
Expand Down Expand Up @@ -78,26 +78,26 @@ def __init__(
self,
project_dir: str,
conn_id: str,
base_cmd: Optional[str | list[str]] = None,
select: Optional[str] = None,
exclude: Optional[str] = None,
selector: Optional[str] = None,
vars: Optional[dict[str, str]] = None,
models: Optional[str] = None,
base_cmd: list[str] | None = None,
select: str | None = None,
exclude: str | None = None,
selector: str | None = None,
vars: dict[str, str] | None = None,
models: str | None = None,
cache_selected_only: bool = False,
no_version_check: bool = False,
fail_fast: bool = False,
quiet: bool = False,
warn_error: bool = False,
db_name: Optional[str] = None,
schema: Optional[str] = None,
env: Optional[dict[str, Any]] = None,
db_name: str | None = None,
schema: str | None = None,
env: dict[str, Any] | None = None,
append_env: bool = False,
output_encoding: str = "utf-8",
skip_exit_code: int = 99,
cancel_query_on_kill: bool = True,
dbt_executable_path: str = "dbt",
dbt_cmd_flags: Optional[list[str]] = None,
dbt_cmd_flags: list[str] | None = None,
**kwargs: str,
) -> None:
self.project_dir = project_dir
Expand Down Expand Up @@ -205,12 +205,10 @@ def build_cmd(
self,
context: Context,
cmd_flags: list[str] | None = None,
) -> Tuple[list[Optional[str]], dict[str, str | bytes | os.PathLike[Any]]]:
) -> Tuple[list[str | None], dict[str, str | bytes | os.PathLike[Any]]]:
dbt_cmd = [self.dbt_executable_path]

if isinstance(self.base_cmd, str):
dbt_cmd.append(self.base_cmd)
elif isinstance(self.base_cmd, list):
if self.base_cmd:
dbt_cmd.extend(self.base_cmd)

dbt_cmd.extend(self.add_global_flags())
Expand Down
16 changes: 8 additions & 8 deletions cosmos/operators/docker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import Any, Callable, Optional, Sequence
from typing import Any, Callable, Sequence

import yaml
from airflow.utils.context import Context
Expand Down Expand Up @@ -62,7 +62,7 @@ class DbtLSDockerOperator(DbtDockerBaseOperator):

def __init__(self, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = "ls"
self.base_cmd = ["ls"]

def execute(self, context: Context) -> Any:
return self.build_and_run_cmd(context=context)
Expand All @@ -80,7 +80,7 @@ class DbtSeedDockerOperator(DbtDockerBaseOperator):
def __init__(self, full_refresh: bool = False, **kwargs: str) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)
self.base_cmd = "seed"
self.base_cmd = ["seed"]

def add_cmd_flags(self) -> list[str]:
flags = []
Expand All @@ -104,7 +104,7 @@ class DbtSnapshotDockerOperator(DbtDockerBaseOperator):

def __init__(self, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = "snapshot"
self.base_cmd = ["snapshot"]

def execute(self, context: Context) -> Any:
return self.build_and_run_cmd(context=context)
Expand All @@ -120,7 +120,7 @@ class DbtRunDockerOperator(DbtDockerBaseOperator):

def __init__(self, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = "run"
self.base_cmd = ["run"]

def execute(self, context: Context) -> Any:
return self.build_and_run_cmd(context=context)
Expand All @@ -133,9 +133,9 @@ class DbtTestDockerOperator(DbtDockerBaseOperator):

ui_color = "#8194E0"

def __init__(self, on_warning_callback: Optional[Callable[..., Any]] = None, **kwargs: str) -> None:
def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = "test"
self.base_cmd = ["test"]
# as of now, on_warning_callback in docker executor does nothing
self.on_warning_callback = on_warning_callback

Expand All @@ -155,7 +155,7 @@ class DbtRunOperationDockerOperator(DbtDockerBaseOperator):
ui_color = "#8194E0"
template_fields: Sequence[str] = ("args",)

def __init__(self, macro_name: str, args: Optional[dict[str, Any]] = None, **kwargs: str) -> None:
def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: str) -> None:
self.macro_name = macro_name
self.args = args
super().__init__(**kwargs)
Expand Down
16 changes: 8 additions & 8 deletions cosmos/operators/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from os import PathLike
from typing import Any, Callable, Optional, Sequence
from typing import Any, Callable, Sequence

import yaml
from airflow.utils.context import Context
Expand Down Expand Up @@ -70,7 +70,7 @@ class DbtLSKubernetesOperator(DbtKubernetesBaseOperator):

def __init__(self, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = "ls"
self.base_cmd = ["ls"]

def execute(self, context: Context) -> Any:
return self.build_and_run_cmd(context=context)
Expand All @@ -88,7 +88,7 @@ class DbtSeedKubernetesOperator(DbtKubernetesBaseOperator):
def __init__(self, full_refresh: bool = False, **kwargs: str) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)
self.base_cmd = "seed"
self.base_cmd = ["seed"]

def add_cmd_flags(self) -> list[str]:
flags = []
Expand All @@ -112,7 +112,7 @@ class DbtSnapshotKubernetesOperator(DbtKubernetesBaseOperator):

def __init__(self, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = "snapshot"
self.base_cmd = ["snapshot"]

def execute(self, context: Context) -> Any:
return self.build_and_run_cmd(context=context)
Expand All @@ -128,7 +128,7 @@ class DbtRunKubernetesOperator(DbtKubernetesBaseOperator):

def __init__(self, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = "run"
self.base_cmd = ["run"]

def execute(self, context: Context) -> Any:
return self.build_and_run_cmd(context=context)
Expand All @@ -141,9 +141,9 @@ class DbtTestKubernetesOperator(DbtKubernetesBaseOperator):

ui_color = "#8194E0"

def __init__(self, on_warning_callback: Optional[Callable[..., Any]] = None, **kwargs: str) -> None:
def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = "test"
self.base_cmd = ["test"]
# as of now, on_warning_callback in kubernetes executor does nothing
self.on_warning_callback = on_warning_callback

Expand All @@ -163,7 +163,7 @@ class DbtRunOperationKubernetesOperator(DbtKubernetesBaseOperator):
ui_color = "#8194E0"
template_fields: Sequence[str] = ("args",)

def __init__(self, macro_name: str, args: Optional[dict[str, Any]] = None, **kwargs: str) -> None:
def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: str) -> None:
self.macro_name = macro_name
self.args = args
super().__init__(**kwargs)
Expand Down
20 changes: 10 additions & 10 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import signal
import tempfile
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Tuple
from typing import Any, Callable, Sequence, Tuple

import yaml
from airflow.compat.functools import cached_property
Expand Down Expand Up @@ -53,7 +53,7 @@ class DbtLocalBaseOperator(DbtBaseOperator):
def __init__(
self,
install_deps: bool = False,
callback: Optional[Callable[[str], None]] = None,
callback: Callable[[str], None] | None = None,
profile_args: dict[str, str] = {},
profile_name: str | None = None,
target_name: str | None = None,
Expand Down Expand Up @@ -174,7 +174,7 @@ def get_target_name(self) -> str:

def run_command(
self,
cmd: list[Optional[str]],
cmd: list[str | None],
env: dict[str, str | bytes | os.PathLike[Any]],
context: Context,
) -> FullOutputSubprocessResult:
Expand Down Expand Up @@ -264,7 +264,7 @@ class DbtLSLocalOperator(DbtLocalBaseOperator):

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.base_cmd = "ls"
self.base_cmd = ["ls"]

def execute(self, context: Context) -> str:
result = self.build_and_run_cmd(context=context)
Expand All @@ -283,7 +283,7 @@ class DbtSeedLocalOperator(DbtLocalBaseOperator):
def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)
self.base_cmd = "seed"
self.base_cmd = ["seed"]

def add_cmd_flags(self) -> list[str]:
flags = []
Expand All @@ -308,7 +308,7 @@ class DbtSnapshotLocalOperator(DbtLocalBaseOperator):

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.base_cmd = "snapshot"
self.base_cmd = ["snapshot"]

def execute(self, context: Context) -> str:
result = self.build_and_run_cmd(context=context)
Expand All @@ -325,7 +325,7 @@ class DbtRunLocalOperator(DbtLocalBaseOperator):

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.base_cmd = "run"
self.base_cmd = ["run"]

def execute(self, context: Context) -> str:
result = self.build_and_run_cmd(context=context)
Expand All @@ -343,11 +343,11 @@ class DbtTestLocalOperator(DbtLocalBaseOperator):

def __init__(
self,
on_warning_callback: Optional[Callable[..., Any]] = None,
on_warning_callback: Callable[..., Any] | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.base_cmd = "test"
self.base_cmd = ["test"]
self.on_warning_callback = on_warning_callback

def _should_run_tests(
Expand Down Expand Up @@ -406,7 +406,7 @@ class DbtRunOperationLocalOperator(DbtLocalBaseOperator):
ui_color = "#8194E0"
template_fields: Sequence[str] = ("args",)

def __init__(self, macro_name: str, args: Optional[dict[str, Any]] = None, **kwargs: Any) -> None:
def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: Any) -> None:
self.macro_name = macro_name
self.args = args
super().__init__(**kwargs)
Expand Down
1 change: 1 addition & 0 deletions dev/dags/dbt/jaffle_shop/.user.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
id: 3ff83655-3dec-45dc-b611-05c39b864ceb
2 changes: 1 addition & 1 deletion tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_dbt_base_operator_add_user_supplied_flags() -> None:
conn_id="my_airflow_connection",
task_id="my-task",
project_dir="my/dir",
base_cmd="run",
base_cmd=["run"],
dbt_cmd_flags=["--full-refresh"],
)

Expand Down

0 comments on commit 299e889

Please sign in to comment.