From f4d8f8af5c70c8d89d7b8e43d0946a9f8e0e0fee Mon Sep 17 00:00:00 2001 From: Andrii Ieroshenko Date: Fri, 28 Jun 2024 13:24:41 -0700 Subject: [PATCH] use default process-based Dask distributed cluster --- jupyter_scheduler/job_files_manager.py | 33 +++++++++++++------------- jupyter_scheduler/scheduler.py | 25 ++++++++----------- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/jupyter_scheduler/job_files_manager.py b/jupyter_scheduler/job_files_manager.py index fec5caee..384bcbbd 100644 --- a/jupyter_scheduler/job_files_manager.py +++ b/jupyter_scheduler/job_files_manager.py @@ -1,7 +1,8 @@ import os import random import tarfile -from typing import Awaitable, Dict, List, Optional, Type +from multiprocessing import Process +from typing import Dict, List, Optional, Type import fsspec from dask.distributed import Client as DaskClient @@ -14,10 +15,7 @@ class JobFilesManager: scheduler = None - def __init__( - self, - scheduler: Type[BaseScheduler], - ): + def __init__(self, scheduler: Type[BaseScheduler]): self.scheduler = scheduler async def copy_from_staging(self, job_id: str, redownload: Optional[bool] = False): @@ -26,17 +24,20 @@ async def copy_from_staging(self, job_id: str, redownload: Optional[bool] = Fals output_filenames = self.scheduler.get_job_filenames(job) output_dir = self.scheduler.get_local_output_path(model=job, root_dir_relative=True) - dask_client: DaskClient = await self.scheduler.dask_client_future - dask_client.submit( - Downloader( - output_formats=job.output_formats, - output_filenames=output_filenames, - staging_paths=staging_paths, - output_dir=output_dir, - redownload=redownload, - include_staging_files=job.package_input_folder, - ).download - ) + download = Downloader( + output_formats=job.output_formats, + output_filenames=output_filenames, + staging_paths=staging_paths, + output_dir=output_dir, + redownload=redownload, + include_staging_files=job.package_input_folder, + ).download + if self.scheduler.dask_client: + dask_client: DaskClient = self.scheduler.dask_client + dask_client.submit(download) + else: + p = Process(target=download) + p.start() class Downloader: diff --git a/jupyter_scheduler/scheduler.py b/jupyter_scheduler/scheduler.py index 2ae53a13..1360c70c 100644 --- a/jupyter_scheduler/scheduler.py +++ b/jupyter_scheduler/scheduler.py @@ -2,7 +2,7 @@ import os import random import shutil -from typing import Awaitable, Dict, List, Optional, Type, Union +from typing import Dict, List, Optional, Type, Union import fsspec import psutil @@ -421,12 +421,11 @@ def __init__( if self.task_runner_class: self.task_runner = self.task_runner_class(scheduler=self, config=config) - loop = asyncio.get_event_loop() - self.dask_client_future: Awaitable[DaskClient] = loop.create_task(self._get_dask_client()) + self.dask_client: DaskClient = self._get_dask_client() - async def _get_dask_client(self): + def _get_dask_client(self): """Creates and configures a Dask client.""" - return DaskClient(processes=False, asynchronous=True) + return DaskClient() @property def db_session(self): @@ -451,7 +450,7 @@ def copy_input_folder(self, input_uri: str, nb_copy_to_path: str) -> List[str]: destination_dir=staging_dir, ) - async def create_job(self, model: CreateJob) -> str: + def create_job(self, model: CreateJob) -> str: if not model.job_definition_id and not self.file_exists(model.input_uri): raise InputUriError(model.input_uri) @@ -492,8 +491,7 @@ async def create_job(self, model: CreateJob) -> str: else: self.copy_input_file(model.input_uri, staging_paths["input"]) - dask_client: DaskClient = await self.dask_client_future - future = dask_client.submit( + future = self.dask_client.submit( self.execution_manager_class( job_id=job.job_id, staging_paths=staging_paths, @@ -755,16 +753,14 @@ def list_job_definitions(self, query: ListJobDefinitionsQuery) -> ListJobDefinit return list_response - async def create_job_from_definition( - self, job_definition_id: str, model: CreateJobFromDefinition - ): + def create_job_from_definition(self, job_definition_id: str, model: CreateJobFromDefinition): job_id = None definition = self.get_job_definition(job_definition_id) if definition: input_uri = self.get_staging_paths(definition)["input"] attributes = definition.dict(exclude={"schedule", "timezone"}, exclude_none=True) attributes = {**attributes, **model.dict(exclude_none=True), "input_uri": input_uri} - job_id = await self.create_job(CreateJob(**attributes)) + job_id = self.create_job(CreateJob(**attributes)) return job_id @@ -789,9 +785,8 @@ async def stop_extension(self): """ Cleanup code to run when the server is stopping. """ - if self.dask_client_future: - dask_client: DaskClient = await self.dask_client_future - await dask_client.close() + if self.dask_client: + self.dask_client.close() class ArchivingScheduler(Scheduler):