diff --git a/jupyter_scheduler/scheduler.py b/jupyter_scheduler/scheduler.py index 1360c70c..3ec33aab 100644 --- a/jupyter_scheduler/scheduler.py +++ b/jupyter_scheduler/scheduler.py @@ -7,6 +7,7 @@ import fsspec import psutil from dask.distributed import Client as DaskClient +from distributed import LocalCluster from jupyter_core.paths import jupyter_data_dir from jupyter_server.transutils import _i18n from jupyter_server.utils import to_os_path @@ -402,6 +403,12 @@ class Scheduler(BaseScheduler): ), ) + dask_cluster_url = Unicode( + allow_none=True, + config=True, + help="URL of the Dask cluster to connect to.", + ) + db_url = Unicode(help=_i18n("Scheduler database url")) task_runner = Instance(allow_none=True, klass="jupyter_scheduler.task_runner.BaseTaskRunner") @@ -425,7 +432,10 @@ def __init__( def _get_dask_client(self): """Creates and configures a Dask client.""" - return DaskClient() + if self.dask_cluster_url: + return DaskClient(self.dask_cluster_url) + cluster = LocalCluster(processes=True) + return DaskClient(cluster) @property def db_session(self): @@ -786,7 +796,7 @@ async def stop_extension(self): Cleanup code to run when the server is stopping. """ if self.dask_client: - self.dask_client.close() + await self.dask_client.close() class ArchivingScheduler(Scheduler):