Skip to content

Commit

Permalink
use default process-based Dask distributed cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii-i committed Aug 7, 2024
1 parent b02f850 commit f4d8f8a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 31 deletions.
33 changes: 17 additions & 16 deletions jupyter_scheduler/job_files_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import random
import tarfile
from typing import Awaitable, Dict, List, Optional, Type
from multiprocessing import Process
from typing import Dict, List, Optional, Type

import fsspec
from dask.distributed import Client as DaskClient
Expand All @@ -14,10 +15,7 @@
class JobFilesManager:
scheduler = None

def __init__(
self,
scheduler: Type[BaseScheduler],
):
def __init__(self, scheduler: Type[BaseScheduler]):
self.scheduler = scheduler

async def copy_from_staging(self, job_id: str, redownload: Optional[bool] = False):
Expand All @@ -26,17 +24,20 @@ async def copy_from_staging(self, job_id: str, redownload: Optional[bool] = Fals
output_filenames = self.scheduler.get_job_filenames(job)
output_dir = self.scheduler.get_local_output_path(model=job, root_dir_relative=True)

dask_client: DaskClient = await self.scheduler.dask_client_future
dask_client.submit(
Downloader(
output_formats=job.output_formats,
output_filenames=output_filenames,
staging_paths=staging_paths,
output_dir=output_dir,
redownload=redownload,
include_staging_files=job.package_input_folder,
).download
)
download = Downloader(
output_formats=job.output_formats,
output_filenames=output_filenames,
staging_paths=staging_paths,
output_dir=output_dir,
redownload=redownload,
include_staging_files=job.package_input_folder,
).download
if self.scheduler.dask_client:
dask_client: DaskClient = self.scheduler.dask_client
dask_client.submit(download)
else:
p = Process(target=download)
p.start()


class Downloader:
Expand Down
25 changes: 10 additions & 15 deletions jupyter_scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import random
import shutil
from typing import Awaitable, Dict, List, Optional, Type, Union
from typing import Dict, List, Optional, Type, Union

import fsspec
import psutil
Expand Down Expand Up @@ -421,12 +421,11 @@ def __init__(
if self.task_runner_class:
self.task_runner = self.task_runner_class(scheduler=self, config=config)

loop = asyncio.get_event_loop()
self.dask_client_future: Awaitable[DaskClient] = loop.create_task(self._get_dask_client())
self.dask_client: DaskClient = self._get_dask_client()

async def _get_dask_client(self):
def _get_dask_client(self):
"""Creates and configures a Dask client."""
return DaskClient(processes=False, asynchronous=True)
return DaskClient()

@property
def db_session(self):
Expand All @@ -451,7 +450,7 @@ def copy_input_folder(self, input_uri: str, nb_copy_to_path: str) -> List[str]:
destination_dir=staging_dir,
)

async def create_job(self, model: CreateJob) -> str:
def create_job(self, model: CreateJob) -> str:
if not model.job_definition_id and not self.file_exists(model.input_uri):
raise InputUriError(model.input_uri)

Expand Down Expand Up @@ -492,8 +491,7 @@ async def create_job(self, model: CreateJob) -> str:
else:
self.copy_input_file(model.input_uri, staging_paths["input"])

dask_client: DaskClient = await self.dask_client_future
future = dask_client.submit(
future = self.dask_client.submit(
self.execution_manager_class(
job_id=job.job_id,
staging_paths=staging_paths,
Expand Down Expand Up @@ -755,16 +753,14 @@ def list_job_definitions(self, query: ListJobDefinitionsQuery) -> ListJobDefinit

return list_response

async def create_job_from_definition(
self, job_definition_id: str, model: CreateJobFromDefinition
):
def create_job_from_definition(self, job_definition_id: str, model: CreateJobFromDefinition):
job_id = None
definition = self.get_job_definition(job_definition_id)
if definition:
input_uri = self.get_staging_paths(definition)["input"]
attributes = definition.dict(exclude={"schedule", "timezone"}, exclude_none=True)
attributes = {**attributes, **model.dict(exclude_none=True), "input_uri": input_uri}
job_id = await self.create_job(CreateJob(**attributes))
job_id = self.create_job(CreateJob(**attributes))

return job_id

Expand All @@ -789,9 +785,8 @@ async def stop_extension(self):
"""
Cleanup code to run when the server is stopping.
"""
if self.dask_client_future:
dask_client: DaskClient = await self.dask_client_future
await dask_client.close()
if self.dask_client:
self.dask_client.close()


class ArchivingScheduler(Scheduler):
Expand Down

0 comments on commit f4d8f8a

Please sign in to comment.