Skip to content

Commit

Permalink
fix: listen state to receive sigterm cancellation (#17)
Browse files Browse the repository at this point in the history
* fix: listen state to receive sigterm cancellation

* fix: no need to use thread for terminating jobs

* refactor: handle_job_cancel to register_job

* refactor: on_job_cancel to o_job_regsiter
  • Loading branch information
deryrahman authored Aug 4, 2023
1 parent 4e200bb commit 88b2f01
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 20 deletions.
16 changes: 8 additions & 8 deletions task/bq2bq/executor/bumblebee/bigquery_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_table(self, full_table_name):

class BigqueryService(BaseBigqueryService):

def __init__(self, client, labels, writer, on_job_finish = None, on_job_cancelled = None):
def __init__(self, client, labels, writer, on_job_finish = None, on_job_register = None):
"""
:rtype:
Expand All @@ -62,7 +62,7 @@ def __init__(self, client, labels, writer, on_job_finish = None, on_job_cancelle
self.labels = labels
self.writer = writer
self.on_job_finish = on_job_finish
self.on_job_cancelled = on_job_cancelled
self.on_job_register = on_job_register

def execute_query(self, query):
query_job_config = QueryJobConfig()
Expand All @@ -78,8 +78,8 @@ def execute_query(self, query):
logger.info("Job {} is initially in state {} of {} project".format(query_job.job_id, query_job.state,
query_job.project))

if self.on_job_cancelled:
self.on_job_cancelled(self.client, query_job)
if self.on_job_register:
self.on_job_register(self.client, query_job)

try:
result = query_job.result()
Expand Down Expand Up @@ -129,8 +129,8 @@ def transform_load(self,
logger.info("Job {} is initially in state {} of {} project".format(query_job.job_id, query_job.state,
query_job.project))

if self.on_job_cancelled:
self.on_job_cancelled(self.client, query_job)
if self.on_job_register:
self.on_job_register(self.client, query_job)

try:
result = query_job.result()
Expand Down Expand Up @@ -174,7 +174,7 @@ def get_table(self, full_table_name):
return self.client.get_table(table_ref)


def create_bigquery_service(task_config: TaskConfigFromEnv, labels, writer, on_job_finish = None, on_job_cancelled = None):
def create_bigquery_service(task_config: TaskConfigFromEnv, labels, writer, on_job_finish = None, on_job_register = None):
if writer is None:
writer = writer.StdWriter()

Expand All @@ -183,7 +183,7 @@ def create_bigquery_service(task_config: TaskConfigFromEnv, labels, writer, on_j
default_query_job_config.priority = task_config.query_priority
default_query_job_config.allow_field_addition = task_config.allow_field_addition
client = bigquery.Client(project=task_config.execution_project, credentials=credentials, default_query_job_config=default_query_job_config)
return BigqueryService(client, labels, writer, on_job_finish=on_job_finish, on_job_cancelled=on_job_cancelled)
return BigqueryService(client, labels, writer, on_job_finish=on_job_finish, on_job_register=on_job_register)


def _get_bigquery_credentials():
Expand Down
4 changes: 2 additions & 2 deletions task/bq2bq/executor/bumblebee/bq2bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def bq2bq(properties_file: str,
labels: dict = {},
output_on: str = './return.json',
on_job_finish = None,
on_job_cancelled = None,
on_job_register = None,
):

logger.info("Using bumblebee version: {}".format(VERSION))
Expand All @@ -39,7 +39,7 @@ def bq2bq(properties_file: str,

bigquery_service = DummyService()
if not dry_run:
bigquery_service = create_bigquery_service(task_config, job_labels, writer, on_job_finish=on_job_finish, on_job_cancelled=on_job_cancelled)
bigquery_service = create_bigquery_service(task_config, job_labels, writer, on_job_finish=on_job_finish, on_job_register=on_job_register)

transformation = Transformation(bigquery_service,
task_config,
Expand Down
28 changes: 19 additions & 9 deletions task/bq2bq/executor/bumblebee/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,30 @@ class BigqueryJobHandler:
def __init__(self) -> None:
self._sum_slot_millis = 0
self._sum_total_bytes_processed = 0
self.client = None
self.jobs = []
self._init_signal_handling()

def _init_signal_handling(self):
def handle_sigterm(signum, frame):
self._terminate_jobs()
sys.exit(1)
signal.signal(signal.SIGTERM, handle_sigterm)

def _terminate_jobs(self):
if self.client and self.jobs:
for job in self.jobs:
job_id = job.job_id
self.client.cancel_job(job_id)
logger.info(f"{job_id} successfully cancelled")

def handle_job_finish(self, job) -> None:
self._sum_slot_millis += job.slot_millis
self._sum_total_bytes_processed += job.total_bytes_processed

def handle_job_cancelled(self, client, job):
c = client
job_id = job.job_id
def handler(signum, frame):
c.cancel_job(job_id)
logger.info(f"{job_id} successfully cancelled")
sys.exit(1)

signal.signal(signal.SIGTERM, handler)
def register_job(self, client, job):
self.client = client
self.jobs.append(job)

def get_sum_slot_millis(self) -> int:
return self._sum_slot_millis
Expand Down
2 changes: 1 addition & 1 deletion task/bq2bq/executor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
app_config.job_labels,
app_config.xcom_path,
on_job_finish = job_handler.handle_job_finish,
on_job_cancelled = job_handler.handle_job_cancelled,
on_job_register = job_handler.register_job,
)

xcom_data['monitoring'] = {
Expand Down

0 comments on commit 88b2f01

Please sign in to comment.