diff --git a/task/bq2bq/executor/bumblebee/bigquery_service.py b/task/bq2bq/executor/bumblebee/bigquery_service.py index 8e0725a..d22ee15 100644 --- a/task/bq2bq/executor/bumblebee/bigquery_service.py +++ b/task/bq2bq/executor/bumblebee/bigquery_service.py @@ -4,7 +4,9 @@ from abc import ABC, abstractmethod import google as google +import requests.exceptions from google.api_core.exceptions import BadRequest, Forbidden +from google.api_core.retry import if_exception_type, if_transient_error from google.cloud import bigquery from google.cloud.bigquery.job import QueryJobConfig, CreateDisposition from google.cloud.bigquery.schema import _parse_schema_resource @@ -50,10 +52,14 @@ def delete_table(self, full_table_name): def get_table(self, full_table_name): pass +def if_exception_funcs(fn_origin, fn_additional): + def if_exception_func_predicate(exception): + return fn_origin(exception) or fn_additional(exception) + return if_exception_func_predicate class BigqueryService(BaseBigqueryService): - def __init__(self, client, labels, writer, on_job_finish = None, on_job_register = None): + def __init__(self, client, labels, writer, retry_timeout = None, on_job_finish = None, on_job_register = None): """ :rtype: @@ -61,6 +67,13 @@ def __init__(self, client, labels, writer, on_job_finish = None, on_job_register self.client = client self.labels = labels self.writer = writer + if_additional_transient_error = if_exception_type( + requests.exceptions.Timeout, + requests.exceptions.SSLError, + ) + predicate = if_exception_funcs(if_transient_error, if_additional_transient_error) + retry = bigquery.DEFAULT_RETRY.with_deadline(retry_timeout) if retry_timeout else bigquery.DEFAULT_RETRY + self.retry = retry.with_predicate(predicate) self.on_job_finish = on_job_finish self.on_job_register = on_job_register @@ -74,7 +87,8 @@ def execute_query(self, query): logger.info("executing query") query_job = self.client.query(query=query, - job_config=query_job_config) + job_config=query_job_config, + retry=self.retry) logger.info("Job {} is initially in state {} of {} project".format(query_job.job_id, query_job.state, query_job.project)) @@ -125,7 +139,9 @@ def transform_load(self, query_job_config.destination = table_ref logger.info("transform load") - query_job = self.client.query(query=query, job_config=query_job_config) + query_job = self.client.query(query=query, + job_config=query_job_config, + retry=self.retry) logger.info("Job {} is initially in state {} of {} project".format(query_job.job_id, query_job.state, query_job.project)) @@ -183,7 +199,7 @@ def create_bigquery_service(task_config: TaskConfigFromEnv, labels, writer, on_j default_query_job_config.priority = task_config.query_priority default_query_job_config.allow_field_addition = task_config.allow_field_addition client = bigquery.Client(project=task_config.execution_project, credentials=credentials, default_query_job_config=default_query_job_config) - return BigqueryService(client, labels, writer, on_job_finish=on_job_finish, on_job_register=on_job_register) + return BigqueryService(client, labels, writer, retry_timeout=task_config.retry_timeout, on_job_finish=on_job_finish, on_job_register=on_job_register) def _get_bigquery_credentials(): diff --git a/task/bq2bq/executor/bumblebee/config.py b/task/bq2bq/executor/bumblebee/config.py index 7762399..13ecefd 100644 --- a/task/bq2bq/executor/bumblebee/config.py +++ b/task/bq2bq/executor/bumblebee/config.py @@ -126,6 +126,7 @@ def __init__(self): self._use_spillover = _bool_from_str(get_env_config("USE_SPILLOVER", default="true")) self._concurrency = _validate_greater_than_zero(int(get_env_config("CONCURRENCY", default=1))) self._allow_field_addition = _bool_from_str(get_env_config("ALLOW_FIELD_ADDITION", default="false")) + self._retry_timeout = get_env_config("RETRY_TIMEOUT_IN_SECONDS", default=None) @property def destination_project(self) -> str: @@ -178,6 +179,12 @@ def timezone(self): def concurrency(self) -> int: return self._concurrency + @property + def retry_timeout(self) -> Optional[float]: + if self._retry_timeout: + return float(self._retry_timeout) + return None + def print(self): logger.info("task config:\n{}".format( "\n".join([ @@ -348,6 +355,7 @@ def __init__(self, raw_properties): self._use_spillover = _bool_from_str(self._get_property_or_default("USE_SPILLOVER", "true")) self._concurrency = _validate_greater_than_zero(int(self._get_property_or_default("CONCURRENCY", 1))) self._allow_field_addition = _bool_from_str(self._get_property_or_default("ALLOW_FIELD_ADDITION", "false")) + self._retry_timeout = self._get_property_or_default("RETRY_TIMEOUT_IN_SECONDS", None) @property def sql_type(self) -> str: @@ -412,6 +420,12 @@ def filter_expression(self) -> str: def allow_field_addition(self) -> bool: return self._allow_field_addition + @property + def retry_timeout(self) -> Optional[float]: + if self._retry_timeout: + return float(self._retry_timeout) + return None + def print(self): logger.info("task config:\n{}".format( "\n".join([ diff --git a/task/bq2bq/executor/bumblebee/log.py b/task/bq2bq/executor/bumblebee/log.py index 7eaade3..3158e43 100644 --- a/task/bq2bq/executor/bumblebee/log.py +++ b/task/bq2bq/executor/bumblebee/log.py @@ -1,11 +1,16 @@ import sys import logging +import os +def get_log_level(): + log_level = str(os.environ.get("LOG_LEVEL", default="INFO")).upper() + log_level = log_level if log_level in logging._nameToLevel else "INFO" + return logging._nameToLevel.get(log_level) def get_logger(name: str): logger = logging.getLogger(name) logformat = "[%(asctime)s] %(levelname)s:%(name)s: %(message)s" - logging.basicConfig(level=logging.INFO, stream=sys.stdout, + logging.basicConfig(level=get_log_level(), stream=sys.stdout, format=logformat, datefmt="%Y-%m-%d %H:%M:%S") return logger diff --git a/task/bq2bq/executor/requirements.txt b/task/bq2bq/executor/requirements.txt index 77d99d0..9f8131e 100644 --- a/task/bq2bq/executor/requirements.txt +++ b/task/bq2bq/executor/requirements.txt @@ -2,12 +2,13 @@ cachetools==4.1.1 certifi==2020.6.20 chardet==3.0.4 google==3.0.0 -google-api-core==1.21.0 -google-auth==1.18.0 -google-cloud-bigquery==1.25.0 -google-cloud-core==1.3.0 -google-resumable-media==0.5.1 -googleapis-common-protos==1.52.0 +google-api-core==2.8.0 +google-auth==2.29.0 +google-cloud-bigquery==1.28.3 +google-cloud-core==2.4.1 +google-crc32c==1.5.0 +google-resumable-media==1.3.3 +googleapis-common-protos==1.56.0 idna==2.10 iso8601==0.1.12 protobuf==3.12.2 diff --git a/task/bq2bq/executor/tests/test_config.py b/task/bq2bq/executor/tests/test_config.py index ec87c06..bd3bf5e 100644 --- a/task/bq2bq/executor/tests/test_config.py +++ b/task/bq2bq/executor/tests/test_config.py @@ -156,6 +156,16 @@ def test_concurrency(self): self.assertEqual(config.concurrency, 2) + def test_retry_timeout(self): + self.set_vars_with_default() + config = TaskConfigFromEnv() + self.assertEqual(config.retry_timeout, None) + + self.set_vars_with_default() + os.environ['RETRY_TIMEOUT_IN_SECONDS'] = "120.0" + config = TaskConfigFromEnv() + self.assertEqual(config.retry_timeout, 120.0) + def test_concurrency_should_not_zero_exception(self): self.set_vars_with_default() os.environ['CONCURRENCY'] = "0"