Skip to content

Commit

Permalink
Rebase changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek Mohan committed Jul 24, 2023
1 parent 299e889 commit cfb2123
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 16 deletions.
4 changes: 0 additions & 4 deletions cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
from cosmos.airflow.task_group import DbtTaskGroup
from cosmos.constants import LoadMode, TestBehavior, ExecutionMode
from cosmos.dataset import get_dbt_dataset

# re-export the dag and task group
from cosmos.airflow.dag import DbtDag
from cosmos.airflow.task_group import DbtTaskGroup
from cosmos.operators.lazy_load import MissingPackage

from cosmos.operators.local import (
Expand Down
4 changes: 2 additions & 2 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[st
return leaves


def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dict) -> TaskMetadata:
def create_task_metadata(node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any]) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.
Expand Down Expand Up @@ -88,7 +88,7 @@ def create_test_task_metadata(
test_task_name: str,
execution_mode: ExecutionMode,
task_args: dict[str, Any],
on_warning_callback: callable,
on_warning_callback: Callable[..., Any] | None = None,
model_name: str | None = None,
) -> TaskMetadata:
"""
Expand Down
5 changes: 4 additions & 1 deletion cosmos/converter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# mypy: ignore-errors
# ignoring enum Mypy errors

from __future__ import annotations
from enum import Enum

import inspect
import logging
import sys
from typing import Any, Callable

from airflow.exceptions import AirflowException
Expand Down
10 changes: 8 additions & 2 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from subprocess import Popen, PIPE
from typing import Any

Expand Down Expand Up @@ -34,7 +35,7 @@ class DbtNode:

name: str
unique_id: str
resource_type: str
resource_type: DbtResourceType
depends_on: list[str]
file_path: Path
tags: list[str] = field(default_factory=lambda: [])
Expand Down Expand Up @@ -130,7 +131,12 @@ def load_via_dbt_ls(self) -> None:
logger.info(f"Running command: {command}")
try:
process = Popen(
command, stdout=PIPE, stderr=PIPE, cwd=self.project.dir, universal_newlines=True, env=os.environ # type: ignore[arg-type]
command, # type: ignore[arg-type]
stdout=PIPE,
stderr=PIPE,
cwd=self.project.dir,
universal_newlines=True,
env=os.environ,
)
except FileNotFoundError as exception:
raise CosmosLoadDbtException(f"Unable to run the command due to the error:\n{exception}")
Expand Down
15 changes: 8 additions & 7 deletions cosmos/dbt/parser/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,18 @@ def extract_python_file_upstream_requirements(code: str) -> list[str]:
source_code = ast.parse(code)

upstream_entities = []
model_function = ""
model_function = None
for node in source_code.body:
if isinstance(node, ast.FunctionDef) and node.name == DBT_PY_MODEL_METHOD_NAME:
model_function = node
break

for item in ast.walk(model_function):
if isinstance(item, ast.Call) and item.func.attr == DBT_PY_DEP_METHOD_NAME:
upstream_entity_id = hasattr(item.args[-1], "value") and item.args[-1].value
if upstream_entity_id:
upstream_entities.append(upstream_entity_id)
if model_function:
for item in ast.walk(model_function):
if isinstance(item, ast.Call) and item.func.attr == DBT_PY_DEP_METHOD_NAME: # type: ignore[attr-defined]
upstream_entity_id = hasattr(item.args[-1], "value") and item.args[-1].value
if upstream_entity_id:
upstream_entities.append(upstream_entity_id)

return upstream_entities

Expand Down Expand Up @@ -153,7 +154,7 @@ def __post_init__(self) -> None:
code = code.split("{%")[0]

elif self.type == DbtModelType.DBT_SEED:
code = None
code = ""

if self.path.suffix == PYTHON_FILE_SUFFIX:
config.upstream_models = config.upstream_models.union(set(extract_python_file_upstream_requirements(code)))
Expand Down
1 change: 1 addition & 0 deletions dev/dags/dbt/jaffle_shop_python/.user.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
id: 215ac1e7-601b-45dd-9867-587a3e282bce

0 comments on commit cfb2123

Please sign in to comment.