Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into MINOR-add-testCaseR…
Browse files Browse the repository at this point in the history
…esult-ReindexApp
  • Loading branch information
TeddyCr committed Oct 16, 2024
2 parents 647990b + 3049752 commit cb92794
Show file tree
Hide file tree
Showing 63 changed files with 1,116 additions and 311 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
UPDATE test_definition
SET json = JSON_MERGE_PRESERVE(
json,
JSON_OBJECT(
'parameterDefinition',
JSON_ARRAY(
JSON_OBJECT(
'name', 'caseSensitiveColumns',
'dataType', 'BOOLEAN',
'required', false,
'description', 'Use case sensitivity when comparing the columns.',
'displayName', 'Case sensitive columns'
)
)
)
)
WHERE name = 'tableDiff';
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
UPDATE test_definition
SET json = jsonb_set(
json,
'{parameterDefinition}',
(json->'parameterDefinition')::jsonb ||
'{"name": "caseSensitiveColumns", "dataType": "BOOLEAN", "required": false, "description": "Use case sensitivity when comparing the columns.", "displayName": "Case sensitive columns"}'::jsonb
)
WHERE name = 'tableDiff';
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,9 @@ SET json = JSON_REMOVE(json, '$.testCaseResult');
UPDATE installed_apps SET json = JSON_SET(json, '$.supportsInterrupt', true) where name = 'SearchIndexingApplication';
UPDATE apps_marketplace SET json = JSON_SET(json, '$.supportsInterrupt', true) where name = 'SearchIndexingApplication';

ALTER TABLE apps_extension_time_series ADD COLUMN appName VARCHAR(256) GENERATED ALWAYS AS (json ->> '$.appName') STORED NOT NULL;
ALTER TABLE apps_extension_time_series ADD COLUMN appName VARCHAR(256) GENERATED ALWAYS AS (json ->> '$.appName') STORED NOT NULL;

-- Add supportsDataDiff for Athena, BigQuery, Mssql, Mysql, Oracle, Postgres, Redshift, SapHana, Snowflake, Trino
UPDATE dbservice_entity
SET json = JSON_SET(json, '$.connection.config.supportsDataDiff', 'true')
WHERE serviceType IN ('Athena','BigQuery','Mssql','Mysql','Oracle','Postgres','Redshift','SapHana','Snowflake','Trino');
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,9 @@ SET json = jsonb_set(
)
where name = 'SearchIndexingApplication';

ALTER TABLE apps_extension_time_series ADD COLUMN appName VARCHAR(256) GENERATED ALWAYS AS (json ->> 'appName') STORED NOT NULL;
ALTER TABLE apps_extension_time_series ADD COLUMN appName VARCHAR(256) GENERATED ALWAYS AS (json ->> 'appName') STORED NOT NULL;

-- Add supportsDataDiff for Athena, BigQuery, Mssql, Mysql, Oracle, Postgres, Redshift, SapHana, Snowflake, Trino
UPDATE dbservice_entity
SET json = jsonb_set(json::jsonb, '{connection,config,supportsDataDiff}', 'true'::jsonb)
WHERE serviceType IN ('Athena','BigQuery','Mssql','Mysql','Oracle','Postgres','Redshift','SapHana','Snowflake','Trino');
4 changes: 2 additions & 2 deletions docker/run_local_docker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ echo "Running local docker using mode [$mode] database [$database] and skipping
cd ../

echo "Stopping any previous Local Docker Containers"
docker compose -f docker/development/docker-compose-postgres.yml down
docker compose -f docker/development/docker-compose.yml down
docker compose -f docker/development/docker-compose-postgres.yml down --remove-orphans
docker compose -f docker/development/docker-compose.yml down --remove-orphans

if [[ $skipMaven == "false" ]]; then
if [[ $mode == "no-ui" ]]; then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, List, Optional, Type, TypeVar, Union

from metadata.data_quality.validations import utils
from metadata.data_quality.validations.runtime_param_setter.param_setter import (
RuntimeParameterSetter,
)
Expand Down Expand Up @@ -65,37 +66,18 @@ def run_validation(self) -> TestCaseResult:
"""
raise NotImplementedError

@staticmethod
def get_test_case_param_value(
self,
test_case_param_vals: list[TestCaseParameterValue],
test_case_param_vals: List[TestCaseParameterValue],
name: str,
type_: T,
default: Optional[R] = None,
pre_processor: Optional[Callable] = None,
) -> Optional[Union[R, T]]:
"""Give a column and a type return the value with the appropriate type casting for the
test case definition.
Args:
test_case: the test case
type_ (Union[float, int, str]): type for the value
name (str): column name
default (_type_, optional): Default value to return if column is not found
pre_processor: pre processor function/type to use against the value before casting to type_
"""
value = next(
(param.value for param in test_case_param_vals if param.name == name), None
return utils.get_test_case_param_value(
test_case_param_vals, name, type_, default, pre_processor
)

if not value:
return default if default is not None else None

if not pre_processor:
return type_(value)

pre_processed_value = pre_processor(value)
return type_(pre_processed_value)

def get_test_case_result_object( # pylint: disable=too-many-arguments
self,
execution_date: Timestamp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from sqlalchemy import Column

from metadata.data_quality.validations import utils
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.generated.schema.tests.basic import (
TestCaseResult,
Expand Down Expand Up @@ -50,11 +51,8 @@ def run_validation(self) -> TestCaseResult:
literal_eval,
)

match_enum = self.get_test_case_param_value(
self.test_case.parameterValues, # type: ignore
"matchEnum",
bool,
default=False,
match_enum = utils.get_bool_test_case_param(
self.test_case.parameterValues, "matchEnum"
)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import List, Optional
from urllib.parse import urlparse

from metadata.data_quality.validations import utils
from metadata.data_quality.validations.models import (
Column,
TableDiffRuntimeParameters,
Expand All @@ -27,6 +28,7 @@
from metadata.ingestion.source.connections import get_connection
from metadata.profiler.orm.registry import Dialects
from metadata.utils import fqn
from metadata.utils.collections import CaseInsensitiveList


class TableDiffParamsSetter(RuntimeParameterSetter):
Expand Down Expand Up @@ -58,6 +60,9 @@ def get_parameters(self, test_case) -> TableDiffRuntimeParameters:
DatabaseService, self.table_entity.service.id, nullable=False
)
table2_fqn = self.get_parameter(test_case, "table2")
case_sensitive_columns: bool = utils.get_bool_test_case_param(
test_case.parameterValues, "caseSensitiveColumns"
)
if table2_fqn is None:
raise ValueError("table2 not set")
table2: Table = self.ometa_client.get_by_name(
Expand All @@ -82,7 +87,10 @@ def get_parameters(self, test_case) -> TableDiffRuntimeParameters:
override_url=service1_url,
),
columns=self.filter_relevant_columns(
self.table_entity.columns, key_columns, extra_columns
self.table_entity.columns,
key_columns,
extra_columns,
case_sensitive=case_sensitive_columns,
),
),
table2=TableParameter(
Expand All @@ -94,7 +102,10 @@ def get_parameters(self, test_case) -> TableDiffRuntimeParameters:
or service2_url,
),
columns=self.filter_relevant_columns(
table2.columns, key_columns, extra_columns
table2.columns,
key_columns,
extra_columns,
case_sensitive=case_sensitive_columns,
),
),
keyColumns=key_columns,
Expand Down Expand Up @@ -156,9 +167,17 @@ def get_key_columns(self, test_case) -> List[str]:

@staticmethod
def filter_relevant_columns(
columns: List[Column], key_columns: List[str], extra_columns: List[str]
columns: List[Column],
key_columns: List[str],
extra_columns: List[str],
case_sensitive: bool,
) -> List[Column]:
return [c for c in columns if c.name.root in [*key_columns, *extra_columns]]
validated_columns = (
[*key_columns, *extra_columns]
if case_sensitive
else CaseInsensitiveList([*key_columns, *extra_columns])
)
return [c for c in columns if c.name.root in validated_columns]

@staticmethod
def get_parameter(test_case: TestCase, key: str, default=None):
Expand Down Expand Up @@ -195,7 +214,7 @@ def get_data_diff_url(
if hasattr(db_service.connection.config, "supportsDatabase"):
kwargs["path"] = f"/{database}"
if kwargs["scheme"] in {Dialects.MSSQL, Dialects.Snowflake}:
kwargs["path"] += f"/{schema}"
kwargs["path"] = f"/{database}/{schema}"
return url._replace(**kwargs).geturl()

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@
# pylint: disable=missing-module-docstring
import logging
import traceback
from decimal import Decimal
from itertools import islice
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple, cast
from urllib.parse import urlparse

import data_diff
import sqlalchemy.types
from data_diff.diff_tables import DiffResultWrapper
from data_diff.errors import DataDiffMismatchingKeyTypesError
from data_diff.utils import ArithAlphanumeric
from data_diff.utils import ArithAlphanumeric, CaseInsensitiveDict
from sqlalchemy import Column as SAColumn

from metadata.data_quality.validations import utils
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.data_quality.validations.mixins.sqa_validator_mixin import (
SQAValidatorMixin,
Expand Down Expand Up @@ -75,6 +77,18 @@ def masked(s: str, mask: bool = True) -> str:
return "***" if mask else s


def is_numeric(t: type) -> bool:
"""Check if a type is numeric.
Args:
t: type to check
Returns:
True if the type is numeric otherwise False
"""
return t in [int, float, Decimal]


class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
"""
Compare two tables and fail if the number of differences exceeds a threshold
Expand Down Expand Up @@ -167,12 +181,14 @@ def get_incomparable_columns(self) -> List[str]:
self.runtime_params.table1.path,
self.runtime_params.keyColumns,
extra_columns=self.runtime_params.extraColumns,
case_sensitive=self.get_case_sensitive(),
).with_schema()
table2 = data_diff.connect_to_table(
self.runtime_params.table2.serviceUrl,
self.runtime_params.table2.path,
self.runtime_params.keyColumns,
extra_columns=self.runtime_params.extraColumns,
case_sensitive=self.get_case_sensitive(),
).with_schema()
result = []
for column in table1.key_columns + table1.extra_columns:
Expand All @@ -185,7 +201,8 @@ def get_incomparable_columns(self) -> List[str]:
col2_type = self._get_column_python_type(
table2._schema[column] # pylint: disable=protected-access
)

if is_numeric(col1_type) and is_numeric(col2_type):
continue
if col1_type != col2_type:
result.append(column)
return result
Expand Down Expand Up @@ -228,11 +245,13 @@ def get_table_diff(self) -> DiffResultWrapper:
self.runtime_params.table1.serviceUrl,
self.runtime_params.table1.path,
self.runtime_params.keyColumns, # type: ignore
case_sensitive=self.get_case_sensitive(),
)
table2 = data_diff.connect_to_table(
self.runtime_params.table2.serviceUrl,
self.runtime_params.table2.path,
self.runtime_params.keyColumns, # type: ignore
case_sensitive=self.get_case_sensitive(),
)
data_diff_kwargs = {
"key_columns": self.runtime_params.keyColumns,
Expand Down Expand Up @@ -308,7 +327,9 @@ def _validate_dialects(self):
def get_column_diff(self) -> Optional[TestCaseResult]:
"""Get the column diff between the two tables. If there are no differences, return None."""
removed, added = self.get_changed_added_columns(
self.runtime_params.table1.columns, self.runtime_params.table2.columns
self.runtime_params.table1.columns,
self.runtime_params.table2.columns,
self.get_case_sensitive(),
)
changed = self.get_incomparable_columns()
if removed or added or changed:
Expand All @@ -321,7 +342,7 @@ def get_column_diff(self) -> Optional[TestCaseResult]:

@staticmethod
def get_changed_added_columns(
left: List[Column], right: List[Column]
left: List[Column], right: List[Column], case_sensitive: bool
) -> Optional[Tuple[List[str], List[str]]]:
"""Given a list of columns from two tables, return the columns that are removed and added.
Expand All @@ -335,6 +356,10 @@ def get_changed_added_columns(
removed: List[str] = []
added: List[str] = []
right_columns_dict: Dict[str, Column] = {c.name.root: c for c in right}
if not case_sensitive:
right_columns_dict = cast(
Dict[str, Column], CaseInsensitiveDict(right_columns_dict)
)
for column in left:
table2_column = right_columns_dict.get(column.name.root)
if table2_column is None:
Expand All @@ -345,7 +370,10 @@ def get_changed_added_columns(
return removed, added

def column_validation_result(
self, removed: List[str], added: List[str], changed: List[str]
self,
removed: List[str],
added: List[str],
changed: List[str],
) -> TestCaseResult:
"""Build the result for a column validation result. Messages will only be added
for non-empty categories. Values will be populated reported for all categories.
Expand All @@ -367,13 +395,18 @@ def column_validation_result(
message += f"\n Added columns: {','.join(added)}\n"
if changed:
message += "\n Changed columns:"
table1_columns = {
c.name.root: c for c in self.runtime_params.table1.columns
}
table2_columns = {
c.name.root: c for c in self.runtime_params.table2.columns
}
if not self.get_case_sensitive():
table1_columns = CaseInsensitiveDict(table1_columns)
table2_columns = CaseInsensitiveDict(table2_columns)
for col in changed:
col1 = next(
c for c in self.runtime_params.table1.columns if c.name.root == col
)
col2 = next(
c for c in self.runtime_params.table2.columns if c.name.root == col
)
col1 = table1_columns[col]
col2 = table2_columns[col]
message += (
f"\n {col}: {col1.dataType.value} -> {col2.dataType.value}"
)
Expand Down Expand Up @@ -432,3 +465,8 @@ def safe_table_diff_iterator(self) -> DiffResultWrapper:
if str(ex) == "2":
# This is a known issue in data_diff where the diff object is closed
pass

def get_case_sensitive(self):
return utils.get_bool_test_case_param(
self.test_case.parameterValues, "caseSensitiveColumns"
)
Loading

0 comments on commit cb92794

Please sign in to comment.