diff --git a/task/bq2bq/executor/bumblebee/bigquery_service.py b/task/bq2bq/executor/bumblebee/bigquery_service.py index 8e0725a..d169888 100644 --- a/task/bq2bq/executor/bumblebee/bigquery_service.py +++ b/task/bq2bq/executor/bumblebee/bigquery_service.py @@ -53,7 +53,7 @@ def get_table(self, full_table_name): class BigqueryService(BaseBigqueryService): - def __init__(self, client, labels, writer, on_job_finish = None, on_job_register = None): + def __init__(self, client, labels, writer, retry_timeout = None, on_job_finish = None, on_job_register = None): """ :rtype: @@ -61,6 +61,7 @@ def __init__(self, client, labels, writer, on_job_finish = None, on_job_register self.client = client self.labels = labels self.writer = writer + self.retry = bigquery.DEFAULT_RETRY.with_deadline(retry_timeout) if retry_timeout else bigquery.DEFAULT_RETRY self.on_job_finish = on_job_finish self.on_job_register = on_job_register @@ -74,7 +75,8 @@ def execute_query(self, query): logger.info("executing query") query_job = self.client.query(query=query, - job_config=query_job_config) + job_config=query_job_config, + retry=self.retry) logger.info("Job {} is initially in state {} of {} project".format(query_job.job_id, query_job.state, query_job.project)) @@ -125,7 +127,9 @@ def transform_load(self, query_job_config.destination = table_ref logger.info("transform load") - query_job = self.client.query(query=query, job_config=query_job_config) + query_job = self.client.query(query=query, + job_config=query_job_config, + retry=self.retry) logger.info("Job {} is initially in state {} of {} project".format(query_job.job_id, query_job.state, query_job.project)) @@ -183,7 +187,7 @@ def create_bigquery_service(task_config: TaskConfigFromEnv, labels, writer, on_j default_query_job_config.priority = task_config.query_priority default_query_job_config.allow_field_addition = task_config.allow_field_addition client = bigquery.Client(project=task_config.execution_project, credentials=credentials, default_query_job_config=default_query_job_config) - return BigqueryService(client, labels, writer, on_job_finish=on_job_finish, on_job_register=on_job_register) + return BigqueryService(client, labels, writer, retry_timeout=task_config.retry_timeout, on_job_finish=on_job_finish, on_job_register=on_job_register) def _get_bigquery_credentials(): diff --git a/task/bq2bq/executor/bumblebee/config.py b/task/bq2bq/executor/bumblebee/config.py index 7762399..63d63ab 100644 --- a/task/bq2bq/executor/bumblebee/config.py +++ b/task/bq2bq/executor/bumblebee/config.py @@ -126,6 +126,7 @@ def __init__(self): self._use_spillover = _bool_from_str(get_env_config("USE_SPILLOVER", default="true")) self._concurrency = _validate_greater_than_zero(int(get_env_config("CONCURRENCY", default=1))) self._allow_field_addition = _bool_from_str(get_env_config("ALLOW_FIELD_ADDITION", default="false")) + self._retry_timeout = get_env_config("RETRY_TIMEOUT", default=None) @property def destination_project(self) -> str: @@ -178,6 +179,12 @@ def timezone(self): def concurrency(self) -> int: return self._concurrency + @property + def retry_timeout(self) -> Optional[float]: + if self._retry_timeout: + return float(self._retry_timeout) + return None + def print(self): logger.info("task config:\n{}".format( "\n".join([ @@ -348,6 +355,7 @@ def __init__(self, raw_properties): self._use_spillover = _bool_from_str(self._get_property_or_default("USE_SPILLOVER", "true")) self._concurrency = _validate_greater_than_zero(int(self._get_property_or_default("CONCURRENCY", 1))) self._allow_field_addition = _bool_from_str(self._get_property_or_default("ALLOW_FIELD_ADDITION", "false")) + self._retry_timeout = self._get_property_or_default("RETRY_TIMEOUT", None) @property def sql_type(self) -> str: @@ -412,6 +420,12 @@ def filter_expression(self) -> str: def allow_field_addition(self) -> bool: return self._allow_field_addition + @property + def retry_timeout(self) -> Optional[float]: + if self._retry_timeout: + return float(self._retry_timeout) + return None + def print(self): logger.info("task config:\n{}".format( "\n".join([ diff --git a/task/bq2bq/executor/tests/test_config.py b/task/bq2bq/executor/tests/test_config.py index ec87c06..5af12a9 100644 --- a/task/bq2bq/executor/tests/test_config.py +++ b/task/bq2bq/executor/tests/test_config.py @@ -156,6 +156,16 @@ def test_concurrency(self): self.assertEqual(config.concurrency, 2) + def test_retry_timeout(self): + self.set_vars_with_default() + config = TaskConfigFromEnv() + self.assertEqual(config.retry_timeout, None) + + self.set_vars_with_default() + os.environ['RETRY_TIMEOUT'] = "120.0" + config = TaskConfigFromEnv() + self.assertEqual(config.retry_timeout, 120.0) + def test_concurrency_should_not_zero_exception(self): self.set_vars_with_default() os.environ['CONCURRENCY'] = "0"