Skip to content

Commit

Permalink
moved functionality that is called multiple times to functions and an…
Browse files Browse the repository at this point in the history
…d added code to pre_state_hook
  • Loading branch information
loveeklund-osttra committed Sep 12, 2024
1 parent 81046a4 commit 66d94b3
Showing 1 changed file with 77 additions and 82 deletions.
159 changes: 77 additions & 82 deletions target_bigquery/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def __init__(
self.config["project"],
)
self.client = bigquery_client_factory(self._credentials)
opts = {
self.table_opts = {
"project": self.config["project"],
"dataset": self.config["dataset"],
"jsonschema": self.schema,
Expand All @@ -317,7 +317,7 @@ def __init__(
self.config.get("schema_resolver_version", 1)
),
}
self.table = BigQueryTable(name=self.table_name, **opts)
self.table = BigQueryTable(name=self.table_name, **self.table_opts)
self.create_target(key_properties=key_properties)
self.update_schema()
self.merge_target: Optional[BigQueryTable] = None
Expand All @@ -331,53 +331,37 @@ def __init__(
and self._is_upsert_candidate()
):
self.merge_target = copy(self.table)
self.table = BigQueryTable(
name=f"{self.table_name}__{time.strftime('%Y%m%d%H%M%S')}__{uuid.uuid4()}",
**opts,
)
self.table.create_table(
self.client,
self.apply_transforms,
**{
"table": {
"expires": datetime.datetime.now() + datetime.timedelta(days=1),
},
"dataset": {
"location": self.config.get(
"location",
BigQueryTable.default_dataset_options()["location"],
)
},
},
)
time.sleep(2.5) # Wait for eventual consistency
self._create_overwrite_table()
elif self._is_overwrite_candidate():
self.overwrite_target = copy(self.table)
self.table = BigQueryTable(
name=f"{self.table_name}__{time.strftime('%Y%m%d%H%M%S')}__{uuid.uuid4()}",
**opts,
)
self.table.create_table(
self.client,
self.apply_transforms,
**{
"table": {
"expires": datetime.datetime.now() + datetime.timedelta(days=1),
},
"dataset": {
"location": self.config.get(
"location",
BigQueryTable.default_dataset_options()["location"],
)
},
},
)
time.sleep(2.5) # Wait for eventual consistency
self._create_overwrite_table()

self.global_par_typ = target.par_typ
self.global_queue = target.queue
self.increment_jobs_enqueued = target.increment_jobs_enqueued

def _create_overwrite_table(self) -> None:
self.table = BigQueryTable(
name=f"{self.table_name}__{time.strftime('%Y%m%d%H%M%S')}__{uuid.uuid4()}",
**self.table_opts,
)
self.table.create_table(
self.client,
self.apply_transforms,
**{
"table": {
"expires": datetime.datetime.now() + datetime.timedelta(days=1),
},
"dataset": {
"location": self.config.get(
"location",
BigQueryTable.default_dataset_options()["location"],
)
},
},
)
time.sleep(2.5) # Wait for eventual consistency

def _is_upsert_candidate(self) -> bool:
"""Determine if this stream is an upsert candidate based on user configuration."""
upsert_selection = self.config.get("upsert", False)
Expand Down Expand Up @@ -506,9 +490,21 @@ def update_schema(self) -> None:
"""Update the target schema in BigQuery."""
pass

def _get_bigquery_client(self) -> bigquery.Client:
# If gcs_stage method was used, self.client is probably
# an instance of storage.Client, instead of bigquery.Client
return (
self.client
if isinstance(self.client, bigquery.Client)
else bigquery_client_factory(self._credentials)
)

def pre_state_hook(self) -> None:
"""Called before state is emitted to stdout."""
pass
# if we have a merge_target we need to merge the table before writing out state
# otherwise we might end up with state being moved forward without data being written.
if self.merge_target:
self.merge_table(bigquery_client=self._get_bigquery_client())
self._create_overwrite_table()

@staticmethod
@abstractmethod
Expand All @@ -518,49 +514,48 @@ def worker_cls_factory(
"""Return a worker class for the given parallelization type."""
raise NotImplementedError

def merge_table(self, bigquery_client:bigquery.Client) -> None:
target = self.merge_target.as_table()
date_columns = ["_sdc_extracted_at", "_sdc_received_at"]
tmp, ctas_tmp = None, "SELECT 1 AS _no_op"
if self._is_dedupe_before_upsert_candidate():
# We can't use MERGE with a non-unique key, so we need to dedupe the temp table into
# a _SESSION scoped intermediate table.
tmp = f"{self.merge_target.name}__tmp"
dedupe_query = (
f"SELECT * FROM {self.table.get_escaped_name()} "
f"QUALIFY ROW_NUMBER() OVER (PARTITION BY {', '.join(f'`{p}`' for p in self.key_properties)} "
f"ORDER BY COALESCE({', '.join(date_columns)}) DESC) = 1"
)
ctas_tmp = f"CREATE OR REPLACE TEMP TABLE `{tmp}` AS {dedupe_query}"
merge_clause = (
f"MERGE `{self.merge_target}` AS target USING `{tmp or self.table}` AS source ON "
+ " AND ".join(
f"target.`{f}` = source.`{f}`" for f in self.key_properties
)
)
update_clause = "UPDATE SET " + ", ".join(
f"target.`{f.name}` = source.`{f.name}`" for f in target.schema
)
insert_clause = (
f"INSERT ({', '.join(f'`{f.name}`' for f in target.schema)}) "
f"VALUES ({', '.join(f'source.`{f.name}`' for f in target.schema)})"
)
bigquery_client.query(
f"{ctas_tmp}; {merge_clause} "
f"WHEN MATCHED THEN {update_clause} "
f"WHEN NOT MATCHED THEN {insert_clause}; "
f"DROP TABLE IF EXISTS {self.table.get_escaped_name()};"
).result()

def clean_up(self) -> None:
"""Clean up the target table."""
# If gcs_stage method was used, self.client is probably
# an instance of storage.Client, instead of bigquery.Client
bigquery_client = (
self.client
if isinstance(self.client, bigquery.Client)
else bigquery_client_factory(self._credentials)
)
bigquery_client = self._get_bigquery_client()
if self.merge_target is not None:
# We must merge the temp table into the target table.
target = self.merge_target.as_table()
date_columns = ["_sdc_extracted_at", "_sdc_received_at"]
tmp, ctas_tmp = None, "SELECT 1 AS _no_op"
if self._is_dedupe_before_upsert_candidate():
# We can't use MERGE with a non-unique key, so we need to dedupe the temp table into
# a _SESSION scoped intermediate table.
tmp = f"{self.merge_target.name}__tmp"
dedupe_query = (
f"SELECT * FROM {self.table.get_escaped_name()} "
f"QUALIFY ROW_NUMBER() OVER (PARTITION BY {', '.join(f'`{p}`' for p in self.key_properties)} "
f"ORDER BY COALESCE({', '.join(date_columns)}) DESC) = 1"
)
ctas_tmp = f"CREATE OR REPLACE TEMP TABLE `{tmp}` AS {dedupe_query}"
merge_clause = (
f"MERGE `{self.merge_target}` AS target USING `{tmp or self.table}` AS source ON "
+ " AND ".join(
f"target.`{f}` = source.`{f}`" for f in self.key_properties
)
)
update_clause = "UPDATE SET " + ", ".join(
f"target.`{f.name}` = source.`{f.name}`" for f in target.schema
)
insert_clause = (
f"INSERT ({', '.join(f'`{f.name}`' for f in target.schema)}) "
f"VALUES ({', '.join(f'source.`{f.name}`' for f in target.schema)})"
)
bigquery_client.query(
f"{ctas_tmp}; {merge_clause} "
f"WHEN MATCHED THEN {update_clause} "
f"WHEN NOT MATCHED THEN {insert_clause}; "
f"DROP TABLE IF EXISTS {self.table.get_escaped_name()};"
).result()
self.merge_table(bigquery_client=bigquery_client)
self.table = self.merge_target
self.merge_target = None
elif self.overwrite_target is not None:
Expand All @@ -574,7 +569,7 @@ def clean_up(self) -> None:
f" {self.table.get_escaped_name()};"
).result()
self.table = cast(BigQueryTable, self.merge_target)
self.merge_target = None
self.overwrite_target = None


class Denormalized:
Expand Down

0 comments on commit 66d94b3

Please sign in to comment.