Skip to content

Commit

Permalink
Fix parsing test nodes when using the custom load method (LoadMethod.…
Browse files Browse the repository at this point in the history
…CUSTOM) (#563)

This PR fixes parsing test nodes when using LoadMethod.CUSTOM, since it didn't return any test nodes.
This issue surfaced after #543.

Closes: #561
  • Loading branch information
raffifu authored Sep 28, 2023
1 parent fb36be5 commit 0123c3e
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 10 deletions.
4 changes: 3 additions & 1 deletion cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ def load_via_custom_parser(self) -> None:
operator_args=self.operator_args,
)
nodes = {}
models = itertools.chain(project.models.items(), project.snapshots.items(), project.seeds.items())
models = itertools.chain(
project.models.items(), project.snapshots.items(), project.seeds.items(), project.tests.items()
)
for model_name, model in models:
config = {item.split(":")[0]: item.split(":")[-1] for item in model.config.config_selectors}
node = DbtNode(
Expand Down
39 changes: 32 additions & 7 deletions cosmos/dbt/parser/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@

class DbtModelType(Enum):
"""
Represents type of dbt unit (model, snapshot, seed)
Represents type of dbt unit (model, snapshot, seed, test)
"""

DBT_MODEL = "model"
DBT_SNAPSHOT = "snapshot"
DBT_SEED = "seed"
DBT_TEST = "test"


@dataclass
Expand Down Expand Up @@ -155,8 +156,8 @@ def __post_init__(self) -> None:
code = code.split("%}")[1]
code = code.split("{%")[0]

elif self.type == DbtModelType.DBT_SEED:
code = ""
elif self.type == DbtModelType.DBT_SEED or self.type == DbtModelType.DBT_TEST:
return

if self.path.suffix == PYTHON_FILE_SUFFIX:
config.upstream_models = config.upstream_models.union(set(extract_python_file_upstream_requirements(code)))
Expand Down Expand Up @@ -250,6 +251,7 @@ class DbtProject:
models: Dict[str, DbtModel] = field(default_factory=dict)
snapshots: Dict[str, DbtModel] = field(default_factory=dict)
seeds: Dict[str, DbtModel] = field(default_factory=dict)
tests: Dict[str, DbtModel] = field(default_factory=dict)
project_dir: Path = field(init=False)
models_dir: Path = field(init=False)
snapshots_dir: Path = field(init=False)
Expand Down Expand Up @@ -349,19 +351,42 @@ def _handle_config_file(self, path: Path) -> None:
config_dict = yaml.safe_load(path.read_text())

# iterate over the models in the config
if not (config_dict and config_dict.get("models")):
if not config_dict:
return

for model in config_dict["models"]:
for model in config_dict.get("models", []):
model_name = model.get("name")

# if the model doesn't exist, we can't do anything
if model_name not in self.models:
if not model_name:
continue

# tests
for column in model.get("columns", []):
for test in column.get("tests", []):
if not column.get("name"):
continue

# Get the test name
if not isinstance(test, str):
test = list(test.keys())[0]

test_model = DbtModel(
name=f"{test}_{column['name']}_{model_name}",
type=DbtModelType.DBT_TEST,
path=path,
operator_args=self.operator_args,
config=DbtModelConfig(upstream_models=set({model_name})),
)

self.tests[test_model.name] = test_model

# config_selectors
if model_name not in self.models:
continue

config_selectors = []
for selector in self.models[model_name].config.config_types:
for selector in DbtModelConfig.config_types:
config_value = model.get("config", {}).get(selector)
if config_value:
if isinstance(config_value, str):
Expand Down
17 changes: 17 additions & 0 deletions tests/dbt/parser/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SAMPLE_CSV_PATH = DBT_PROJECT_PATH / "jaffle_shop/seeds/raw_customers.csv"
SAMPLE_MODEL_SQL_PATH = DBT_PROJECT_PATH / "jaffle_shop/models/customers.sql"
SAMPLE_SNAPSHOT_SQL_PATH = DBT_PROJECT_PATH / "jaffle_shop/models/orders.sql"
SAMPLE_YML_PATH = DBT_PROJECT_PATH / "jaffle_shop/models/schema.yml"


def test_dbtproject__handle_csv_file():
Expand Down Expand Up @@ -64,6 +65,22 @@ def test_dbtproject__handle_sql_file_snapshot():
assert raw_customers.path == SAMPLE_SNAPSHOT_SQL_PATH


def test_dbtproject__handle_config_file():
dbt_project = DbtProject(
project_name="jaffle_shop",
dbt_root_path=DBT_PROJECT_PATH,
)
dbt_project.tests = {}

dbt_project._handle_config_file(SAMPLE_YML_PATH)

assert len(dbt_project.tests) == 12
assert "not_null_customer_id_customers" in dbt_project.tests
sample_test = dbt_project.tests["not_null_customer_id_customers"]
assert sample_test.type == DbtModelType.DBT_TEST
assert sample_test.path == SAMPLE_YML_PATH


def test_dbtproject__handle_config_file_empty_file():
with NamedTemporaryFile("w") as tmp_fp:
tmp_fp.flush()
Expand Down
3 changes: 1 addition & 2 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,7 @@ def test_load_via_load_via_custom_parser(pipeline_name):
dbt_graph.load_via_custom_parser()

assert dbt_graph.nodes == dbt_graph.filtered_nodes
# the custom parser does not add dbt test nodes
assert len(dbt_graph.nodes) == 8
assert len(dbt_graph.nodes) == 28


@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency", return_value=None)
Expand Down

0 comments on commit 0123c3e

Please sign in to comment.