Skip to content

Commit

Permalink
Rebase changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek Mohan committed Jul 24, 2023
1 parent 299e889 commit 858b298
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 36 deletions.
4 changes: 0 additions & 4 deletions cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
from cosmos.airflow.task_group import DbtTaskGroup
from cosmos.constants import LoadMode, TestBehavior, ExecutionMode
from cosmos.dataset import get_dbt_dataset

# re-export the dag and task group
from cosmos.airflow.dag import DbtDag
from cosmos.airflow.task_group import DbtTaskGroup
from cosmos.operators.lazy_load import MissingPackage

from cosmos.operators.local import (
Expand Down
4 changes: 2 additions & 2 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[st
return leaves


def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dict) -> TaskMetadata:
def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any]) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.
Expand Down Expand Up @@ -88,7 +88,7 @@ def create_test_task_metadata(
test_task_name: str,
execution_mode: ExecutionMode,
task_args: dict[str, Any],
on_warning_callback: callable,
on_warning_callback: Callable[..., Any] | None = None,
model_name: str | None = None,
) -> TaskMetadata:
"""
Expand Down
7 changes: 5 additions & 2 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# mypy: ignore-errors
# ignoring enum Mypy errors

from __future__ import annotations
from enum import Enum

import inspect
import logging
import sys
from typing import Any, Callable

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -167,7 +170,7 @@ def __init__(
models_dir=Path(dbt_models_dir) if dbt_models_dir else None,
seeds_dir=Path(dbt_seeds_dir) if dbt_seeds_dir else None,
snapshots_dir=Path(dbt_snapshots_dir) if dbt_snapshots_dir else None,
manifest_path=manifest_path, # type: ignore[arg-type]
manifest_path=manifest_path,
)

dbt_graph = DbtGraph(
Expand Down
10 changes: 8 additions & 2 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from subprocess import Popen, PIPE
from typing import Any

Expand Down Expand Up @@ -34,7 +35,7 @@ class DbtNode:

name: str
unique_id: str
resource_type: str
resource_type: DbtResourceType
depends_on: list[str]
file_path: Path
tags: list[str] = field(default_factory=lambda: [])
Expand Down Expand Up @@ -130,7 +131,12 @@ def load_via_dbt_ls(self) -> None:
logger.info(f"Running command: {command}")
try:
process = Popen(
command, stdout=PIPE, stderr=PIPE, cwd=self.project.dir, universal_newlines=True, env=os.environ # type: ignore[arg-type]
command, # type: ignore[arg-type]
stdout=PIPE,
stderr=PIPE,
cwd=self.project.dir,
universal_newlines=True,
env=os.environ,
)
except FileNotFoundError as exception:
raise CosmosLoadDbtException(f"Unable to run the command due to the error:\n{exception}")
Expand Down
15 changes: 8 additions & 7 deletions cosmos/dbt/parser/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,18 @@ def extract_python_file_upstream_requirements(code: str) -> list[str]:
source_code = ast.parse(code)

upstream_entities = []
model_function = ""
model_function = None
for node in source_code.body:
if isinstance(node, ast.FunctionDef) and node.name == DBT_PY_MODEL_METHOD_NAME:
model_function = node
break

for item in ast.walk(model_function):
if isinstance(item, ast.Call) and item.func.attr == DBT_PY_DEP_METHOD_NAME:
upstream_entity_id = hasattr(item.args[-1], "value") and item.args[-1].value
if upstream_entity_id:
upstream_entities.append(upstream_entity_id)
if model_function:
for item in ast.walk(model_function):
if isinstance(item, ast.Call) and item.func.attr == DBT_PY_DEP_METHOD_NAME: # type: ignore[attr-defined]
upstream_entity_id = hasattr(item.args[-1], "value") and item.args[-1].value
if upstream_entity_id:
upstream_entities.append(upstream_entity_id)

return upstream_entities

Expand Down Expand Up @@ -153,7 +154,7 @@ def __post_init__(self) -> None:
code = code.split("{%")[0]

elif self.type == DbtModelType.DBT_SEED:
code = None
code = ""

if self.path.suffix == PYTHON_FILE_SUFFIX:
config.upstream_models = config.upstream_models.union(set(extract_python_file_upstream_requirements(code)))
Expand Down
8 changes: 6 additions & 2 deletions cosmos/hooks/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
import contextlib
import os
import signal
from collections import namedtuple
from typing import NamedTuple
from subprocess import PIPE, STDOUT, Popen
from tempfile import TemporaryDirectory, gettempdir

from airflow.hooks.base import BaseHook

FullOutputSubprocessResult = namedtuple("FullOutputSubprocessResult", ["exit_code", "output", "full_output"])

class FullOutputSubprocessResult(NamedTuple):
exit_code: int
output: str
full_output: list[str]


class FullOutputSubprocessHook(BaseHook): # type: ignore[misc] # ignores subclass MyPy error
Expand Down
21 changes: 11 additions & 10 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se
session.add(rtif)

def run_subprocess(self, *args: Tuple[Any], **kwargs: Any) -> FullOutputSubprocessResult:
return self.subprocess_hook.run_command(*args, **kwargs) # type: ignore[no-any-return]
subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(*args, **kwargs)
return subprocess_result

def get_profile_name(self, project_dir: str) -> str:
"""
Expand Down Expand Up @@ -244,7 +245,7 @@ def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None

def execute(self, context: Context) -> str:
# TODO is this going to put loads of unnecessary stuff in to xcom?
return self.build_and_run_cmd(context=context).output # type: ignore[no-any-return]
return self.build_and_run_cmd(context=context).output

def on_kill(self) -> None:
if self.cancel_query_on_kill:
Expand All @@ -268,7 +269,7 @@ def __init__(self, **kwargs: Any) -> None:

def execute(self, context: Context) -> str:
result = self.build_and_run_cmd(context=context)
return result.output # type: ignore[no-any-return]
return result.output


class DbtSeedLocalOperator(DbtLocalBaseOperator):
Expand All @@ -295,7 +296,7 @@ def add_cmd_flags(self) -> list[str]:
def execute(self, context: Context) -> str:
cmd_flags = self.add_cmd_flags()
result = self.build_and_run_cmd(context=context, cmd_flags=cmd_flags)
return result.output # type: ignore[no-any-return]
return result.output


class DbtSnapshotLocalOperator(DbtLocalBaseOperator):
Expand All @@ -312,7 +313,7 @@ def __init__(self, **kwargs: Any) -> None:

def execute(self, context: Context) -> str:
result = self.build_and_run_cmd(context=context)
return result.output # type: ignore[no-any-return]
return result.output


class DbtRunLocalOperator(DbtLocalBaseOperator):
Expand All @@ -329,7 +330,7 @@ def __init__(self, **kwargs: Any) -> None:

def execute(self, context: Context) -> str:
result = self.build_and_run_cmd(context=context)
return result.output # type: ignore[no-any-return]
return result.output


class DbtTestLocalOperator(DbtLocalBaseOperator):
Expand Down Expand Up @@ -385,13 +386,13 @@ def execute(self, context: Context) -> str:
result = self.build_and_run_cmd(context=context)

if not self._should_run_tests(result):
return result.output # type: ignore[no-any-return]
return result.output

warnings = parse_output(result, "WARN")
if warnings > 0:
self._handle_warnings(result, context)

return result.output # type: ignore[no-any-return]
return result.output


class DbtRunOperationLocalOperator(DbtLocalBaseOperator):
Expand Down Expand Up @@ -422,7 +423,7 @@ def add_cmd_flags(self) -> list[str]:
def execute(self, context: Context) -> str:
cmd_flags = self.add_cmd_flags()
result = self.build_and_run_cmd(context=context, cmd_flags=cmd_flags)
return result.output # type: ignore[no-any-return]
return result.output


class DbtDocsLocalOperator(DbtLocalBaseOperator):
Expand All @@ -441,7 +442,7 @@ def __init__(self, **kwargs: Any) -> None:

def execute(self, context: Context) -> str:
result = self.build_and_run_cmd(context=context)
return result.output # type: ignore[no-any-return]
return result.output


class DbtDocsS3LocalOperator(DbtDocsLocalOperator):
Expand Down
7 changes: 2 additions & 5 deletions cosmos/operators/virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,8 @@ def run_subprocess( # type: ignore[override]
if self.py_requirements:
command[0] = self.venv_dbt_path

return self.subprocess_hook.run_command( # type: ignore[no-any-return]
command,
*args,
**kwargs,
)
subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(command, *args, **kwargs)
return subprocess_result

def execute(self, context: Context) -> str:
output = super().execute(context)
Expand Down
1 change: 1 addition & 0 deletions dev/dags/dbt/jaffle_shop_python/.user.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
id: 215ac1e7-601b-45dd-9867-587a3e282bce
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_dag(
dag.clear(
start_date=execution_date,
end_date=execution_date,
dag_run_state=False, # type: ignore
dag_run_state=False,
session=session,
)
dag.log.debug("Getting dagrun for dag %s", dag.dag_id)
Expand Down Expand Up @@ -164,7 +164,7 @@ def _get_or_create_dagrun(
run_id=run_id,
start_date=start_date or execution_date,
session=session,
conf=conf, # type: ignore
conf=conf,
)
log.info("created dagrun %s", str(dr))
return dr

0 comments on commit 858b298

Please sign in to comment.