diff --git a/doc/conf.py b/doc/conf.py index a1fec6a06e..5ffbb59065 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -423,6 +423,8 @@ (f'https://docs.esmvaltool.org/projects/ESMValCore/en/{rtd_version}/', None), 'esmvaltool': (f'https://docs.esmvaltool.org/en/{rtd_version}/', None), + 'dask': ('https://docs.dask.org/en/stable/', None), + 'distributed': ('https://distributed.dask.org/en/stable/', None), 'iris': ('https://scitools-iris.readthedocs.io/en/latest/', None), 'iris-esmf-regrid': ('https://iris-esmf-regrid.readthedocs.io/en/latest', None), diff --git a/doc/quickstart/configure.rst b/doc/quickstart/configure.rst index 3ca78b9f26..68fe667533 100644 --- a/doc/quickstart/configure.rst +++ b/doc/quickstart/configure.rst @@ -199,6 +199,161 @@ the user. debugging, etc. You can even provide any config user value as a run flag ``--argument_name argument_value`` +.. _config-dask: + +Dask distributed configuration +============================== + +The :ref:`preprocessor functions ` and many of the +:ref:`Python diagnostics in ESMValTool ` make use of the +:ref:`Iris ` library to work with the data. +In Iris, data can be either :ref:`real or lazy `. +Lazy data is represented by `dask arrays `_. +Dask arrays consist of many small +`numpy arrays `_ +(called chunks) and if possible, computations are run on those small arrays in +parallel. +In order to figure out what needs to be computed when, Dask makes use of a +'`scheduler `_'. +The default scheduler in Dask is rather basic, so it can only run on a single +computer and it may not always find the optimal task scheduling solution, +resulting in excessive memory use when using e.g. the +:func:`esmvalcore.preprocessor.multi_model_statistics` preprocessor function. +Therefore it is recommended that you take a moment to configure the +`Dask distributed `_ scheduler. +A Dask scheduler and the 'workers' running the actual computations, are +collectively called a 'Dask cluster'. + +In ESMValCore, the Dask cluster can configured by creating a file called +``~/.esmvaltool/dask.yml``, where ``~`` is short for your home directory. +In this file, under the ``client`` keyword, the arguments to +:obj:`distributed.Client` can be provided. +Under the ``cluster`` keyword, the type of cluster (e.g. +:obj:`distributed.LocalCluster`), as well as any arguments required to start +the cluster can be provided. +Extensive documentation on setting up Dask Clusters is available +`here `__. + +.. warning:: + + The format of the ``~/.esmvaltool/dask.yml`` configuration file is not yet + fixed and may change in the next release of ESMValCore. + +.. note:: + + If not all preprocessor functions support lazy data, computational + performance may be best with the default scheduler. + See `issue #674 `_ for + progress on making all preprocessor functions lazy. + +**Example configurations** + +*Personal computer* + +Create a Dask distributed cluster on the computer running ESMValCore using +all available resources: + +.. code:: yaml + + cluster: + type: distributed.LocalCluster + +this should work well for most personal computers. + +.. note:: + + Note that, if running this configuration on a shared node of an HPC cluster, + Dask will try and use as many resources it can find available, and this may + lead to overcrowding the node by a single user (you)! + +*Shared computer* + +Create a Dask distributed cluster on the computer running ESMValCore, with +2 workers with 4 threads/4 GiB of memory each (8 GiB in total): + +.. code:: yaml + + cluster: + type: distributed.LocalCluster + n_workers: 2 + threads_per_worker: 4 + memory_limit: 4 GiB + +this should work well for shared computers. + +*Computer cluster* + +Create a Dask distributed cluster on the +`Levante `_ +supercomputer using the `Dask-Jobqueue `_ +package: + +.. code:: yaml + + cluster: + type: dask_jobqueue.SLURMCluster + queue: shared + account: bk1088 + cores: 8 + memory: 7680MiB + processes: 2 + interface: ib0 + local_directory: "/scratch/b/b381141/dask-tmp" + n_workers: 24 + +This will start 24 workers with ``cores / processes = 4`` threads each, +resulting in ``n_workers / processes = 12`` Slurm jobs, where each Slurm job +will request 8 CPU cores and 7680 MiB of memory and start ``processes = 2`` +workers. +This example will use the fast infiniband network connection (called ``ib0`` +on Levante) for communication between workers running on different nodes. +It is +`important to set the right location for temporary storage `__, +in this case the ``/scratch`` space is used. +It is also possible to use environmental variables to configure the temporary +storage location, if you cluster provides these. + +A configuration like this should work well for larger computations where it is +advantageous to use multiple nodes in a compute cluster. +See +`Deploying Dask Clusters on High Performance Computers `_ +for more information. + +*Externally managed Dask cluster* + +Use an externally managed cluster, e.g. a cluster that you started using the +`Dask Jupyterlab extension `_: + +.. code:: yaml + + client: + address: '127.0.0.1:8786' + +See `here `_ +for an example of how to configure this on a remote system. + +For debugging purposes, it can be useful to start the cluster outside of +ESMValCore because then +`Dask dashboard `_ remains +available after ESMValCore has finished running. + +**Advice on choosing performant configurations** + +The threads within a single worker can access the same memory locations, so +they may freely pass around chunks, while communicating a chunk between workers +is done by copying it, so this is (a bit) slower. +Therefore it is beneficial for performance to have multiple threads per worker. +However, due to limitations in the CPython implementation (known as the Global +Interpreter Lock or GIL), only a single thread in a worker can execute Python +code (this limitation does not apply to compiled code called by Python code, +e.g. numpy), therefore the best performing configurations will typically not +use much more than 10 threads per worker. + +Due to limitations of the NetCDF library (it is not thread-safe), only one +of the threads in a worker can read or write to a NetCDF file at a time. +Therefore, it may be beneficial to use fewer threads per worker if the +computation is very simple and the runtime is determined by the +speed with which the data can be read from and/or written to disk. .. _config-esgf: diff --git a/environment.yml b/environment.yml index 66c4782c88..9b9da11b50 100644 --- a/environment.yml +++ b/environment.yml @@ -10,6 +10,8 @@ dependencies: - cftime - compilers - dask + - dask-jobqueue + - distributed - esgf-pyclient>=0.3.1 - esmpy!=8.1.0 - filelock @@ -18,7 +20,7 @@ dependencies: - geopy - humanfriendly - importlib_resources - - iris>=3.4.0 + - iris>=3.6.0 - iris-esmf-regrid >=0.6.0 # to work with latest esmpy - isodate - jinja2 diff --git a/esmvalcore/_main.py b/esmvalcore/_main.py index 0a0f563ff5..0aee610edf 100755 --- a/esmvalcore/_main.py +++ b/esmvalcore/_main.py @@ -74,6 +74,7 @@ def process_recipe(recipe_file: Path, session): import shutil from esmvalcore._recipe.recipe import read_recipe_file + from esmvalcore.config._dask import check_distributed_config if not recipe_file.is_file(): import errno raise OSError(errno.ENOENT, "Specified recipe file does not exist", @@ -103,6 +104,8 @@ def process_recipe(recipe_file: Path, session): logger.info("If you experience memory problems, try reducing " "'max_parallel_tasks' in your user configuration file.") + check_distributed_config() + if session['compress_netcdf']: logger.warning( "You have enabled NetCDF compression. Accessing .nc files can be " diff --git a/esmvalcore/_task.py b/esmvalcore/_task.py index c05491eda0..01d7861f28 100644 --- a/esmvalcore/_task.py +++ b/esmvalcore/_task.py @@ -19,9 +19,11 @@ import psutil import yaml +from distributed import Client from ._citation import _write_citation_files from ._provenance import TrackedFile, get_task_provenance +from .config._dask import get_distributed_client from .config._diagnostics import DIAGNOSTICS, TAGS @@ -718,10 +720,22 @@ def run(self, max_parallel_tasks: Optional[int] = None) -> None: max_parallel_tasks : int Number of processes to run. If `1`, run the tasks sequentially. """ - if max_parallel_tasks == 1: - self._run_sequential() - else: - self._run_parallel(max_parallel_tasks) + with get_distributed_client() as client: + if client is None: + address = None + else: + address = client.scheduler.address + for task in self.flatten(): + if (isinstance(task, DiagnosticTask) + and Path(task.script).suffix.lower() == '.py'): + # Only insert the scheduler address if running a + # Python script. + task.settings['scheduler_address'] = address + + if max_parallel_tasks == 1: + self._run_sequential() + else: + self._run_parallel(address, max_parallel_tasks) def _run_sequential(self) -> None: """Run tasks sequentially.""" @@ -732,7 +746,7 @@ def _run_sequential(self) -> None: for task in sorted(tasks, key=lambda t: t.priority): task.run() - def _run_parallel(self, max_parallel_tasks=None): + def _run_parallel(self, scheduler_address, max_parallel_tasks): """Run tasks in parallel.""" scheduled = self.flatten() running = {} @@ -757,7 +771,8 @@ def done(task): if len(running) >= max_parallel_tasks: break if all(done(t) for t in task.ancestors): - future = pool.apply_async(_run_task, [task]) + future = pool.apply_async(_run_task, + [task, scheduler_address]) running[task] = future scheduled.remove(task) @@ -790,7 +805,14 @@ def _copy_results(task, future): task.output_files, task.products = future.get() -def _run_task(task): +def _run_task(task, scheduler_address): """Run task and return the result.""" - output_files = task.run() + if scheduler_address is None: + client = contextlib.nullcontext() + else: + client = Client(scheduler_address) + + with client: + output_files = task.run() + return output_files, task.products diff --git a/esmvalcore/config/_dask.py b/esmvalcore/config/_dask.py new file mode 100644 index 0000000000..7030ea816a --- /dev/null +++ b/esmvalcore/config/_dask.py @@ -0,0 +1,79 @@ +"""Configuration for Dask distributed.""" +import contextlib +import importlib +import logging +from pathlib import Path + +import yaml +from distributed import Client + +logger = logging.getLogger(__name__) + +CONFIG_FILE = Path.home() / '.esmvaltool' / 'dask.yml' + + +def check_distributed_config(): + """Check the Dask distributed configuration.""" + if not CONFIG_FILE.exists(): + logger.warning( + "Using the Dask basic scheduler. This may lead to slow " + "computations and out-of-memory errors. " + "Note that the basic scheduler may still be the best choice for " + "preprocessor functions that are not lazy. " + "In that case, you can safely ignore this warning. " + "See https://docs.esmvaltool.org/projects/ESMValCore/en/latest/" + "quickstart/configure.html#dask-distributed-configuration for " + "more information. ") + + +@contextlib.contextmanager +def get_distributed_client(): + """Get a Dask distributed client.""" + dask_args = {} + if CONFIG_FILE.exists(): + config = yaml.safe_load(CONFIG_FILE.read_text(encoding='utf-8')) + if config is not None: + dask_args = config + + client_args = dask_args.get('client') or {} + cluster_args = dask_args.get('cluster') or {} + + # Start a cluster, if requested + if 'address' in client_args: + # Use an externally managed cluster. + cluster = None + if cluster_args: + logger.warning( + "Not using Dask 'cluster' settings from %s because a cluster " + "'address' is already provided in 'client'.", CONFIG_FILE) + elif cluster_args: + # Start cluster. + cluster_type = cluster_args.pop( + 'type', + 'distributed.LocalCluster', + ) + cluster_module_name, cluster_cls_name = cluster_type.rsplit('.', 1) + cluster_module = importlib.import_module(cluster_module_name) + cluster_cls = getattr(cluster_module, cluster_cls_name) + cluster = cluster_cls(**cluster_args) + client_args['address'] = cluster.scheduler_address + else: + # No cluster configured, use Dask basic scheduler, or a LocalCluster + # managed through Client. + cluster = None + + # Start a client, if requested + if dask_args: + client = Client(**client_args) + logger.info("Dask dashboard: %s", client.dashboard_link) + else: + logger.info("Using the Dask basic scheduler.") + client = None + + try: + yield client + finally: + if client is not None: + client.close() + if cluster is not None: + cluster.close() diff --git a/esmvalcore/experimental/recipe.py b/esmvalcore/experimental/recipe.py index f0ffe85fcc..5839d72df8 100644 --- a/esmvalcore/experimental/recipe.py +++ b/esmvalcore/experimental/recipe.py @@ -10,7 +10,7 @@ import yaml from esmvalcore._recipe.recipe import Recipe as RecipeEngine -from esmvalcore.config import CFG, Session +from esmvalcore.config import CFG, Session, _dask from ._logging import log_to_dir from .recipe_info import RecipeInfo @@ -132,6 +132,7 @@ def run( session['diagnostics'] = task with log_to_dir(session.run_dir): + _dask.check_distributed_config() self._engine = self._load(session=session) self._engine.run() diff --git a/setup.py b/setup.py index 5743b3f603..2073c1817d 100755 --- a/setup.py +++ b/setup.py @@ -28,9 +28,9 @@ # Use with pip install . to install from source 'install': [ 'cartopy', - # see https://github.com/SciTools/cf-units/issues/218 'cf-units', - 'dask[array]', + 'dask[array,distributed]', + 'dask-jobqueue', 'esgf-pyclient>=0.3.1', 'esmf-regrid', 'esmpy!=8.1.0', @@ -56,8 +56,8 @@ 'pyyaml', 'requests', 'scipy>=1.6', - 'scitools-iris>=3.4.0', - 'shapely[vectorized]', + 'scitools-iris>=3.6.0', + 'shapely', 'stratify>=0.3', 'yamale', ], diff --git a/tests/integration/test_diagnostic_run.py b/tests/integration/test_diagnostic_run.py index 4d95dd850f..243d4c6a28 100644 --- a/tests/integration/test_diagnostic_run.py +++ b/tests/integration/test_diagnostic_run.py @@ -8,10 +8,26 @@ import pytest import yaml +import esmvalcore._task from esmvalcore._main import run from esmvalcore.config._diagnostics import TAGS +@pytest.fixture(autouse=True) +def get_mock_distributed_client(monkeypatch): + """Mock `get_distributed_client` to avoid starting a Dask cluster.""" + + @contextlib.contextmanager + def get_distributed_client(): + yield None + + monkeypatch.setattr( + esmvalcore._task, + 'get_distributed_client', + get_distributed_client, + ) + + def write_config_user_file(dirname): config_file = dirname / 'config-user.yml' cfg = { @@ -51,7 +67,9 @@ def check(result_file): } missing = required_keys - set(result) assert not missing - unwanted_keys = ['profile_diagnostic', ] + unwanted_keys = [ + 'profile_diagnostic', + ] for unwanted_key in unwanted_keys: assert unwanted_key not in result diff --git a/tests/integration/test_task.py b/tests/integration/test_task.py index 42b724e1c9..3c9801f189 100644 --- a/tests/integration/test_task.py +++ b/tests/integration/test_task.py @@ -1,6 +1,7 @@ import multiprocessing import os import shutil +from contextlib import contextmanager from functools import partial from multiprocessing.pool import ThreadPool @@ -13,6 +14,7 @@ DiagnosticTask, TaskSet, _py2ncl, + _run_task, ) from esmvalcore.config._diagnostics import DIAGNOSTICS @@ -61,6 +63,16 @@ def example_tasks(tmp_path): return tasks +def get_distributed_client_mock(client): + """Mock `get_distributed_client` to avoid starting a Dask cluster.""" + + @contextmanager + def get_distributed_client(): + yield client + + return get_distributed_client + + @pytest.mark.parametrize(['mpmethod', 'max_parallel_tasks'], [ ('fork', 1), ('fork', 2), @@ -68,9 +80,13 @@ def example_tasks(tmp_path): ('fork', None), ('spawn', 2), ]) -def test_run_tasks(monkeypatch, tmp_path, max_parallel_tasks, example_tasks, - mpmethod): +def test_run_tasks(monkeypatch, max_parallel_tasks, example_tasks, mpmethod): """Check that tasks are run correctly.""" + monkeypatch.setattr( + esmvalcore._task, + 'get_distributed_client', + get_distributed_client_mock(None), + ) monkeypatch.setattr(esmvalcore._task, 'Pool', multiprocessing.get_context(mpmethod).Pool) example_tasks.run(max_parallel_tasks=max_parallel_tasks) @@ -80,9 +96,43 @@ def test_run_tasks(monkeypatch, tmp_path, max_parallel_tasks, example_tasks, assert task.output_files +def test_diag_task_updated_with_address(monkeypatch, mocker, tmp_path): + """Test that the scheduler address is passed to the diagnostic tasks.""" + # Set up mock Dask distributed client + client = mocker.Mock() + monkeypatch.setattr( + esmvalcore._task, + 'get_distributed_client', + get_distributed_client_mock(client), + ) + + # Create a task + mocker.patch.object(DiagnosticTask, '_initialize_cmd') + task = DiagnosticTask( + script='test.py', + settings={'run_dir': tmp_path / 'run'}, + output_dir=tmp_path / 'work', + ) + + # Create a taskset + mocker.patch.object(TaskSet, '_run_sequential') + tasks = TaskSet() + tasks.add(task) + tasks.run(max_parallel_tasks=1) + + # Check that the scheduler address was added to the + # diagnostic task settings. + assert 'scheduler_address' in task.settings + assert task.settings['scheduler_address'] is client.scheduler.address + + @pytest.mark.parametrize('runner', [ TaskSet._run_sequential, - partial(TaskSet._run_parallel, max_parallel_tasks=1), + partial( + TaskSet._run_parallel, + scheduler_address=None, + max_parallel_tasks=1, + ), ]) def test_runner_uses_priority(monkeypatch, runner, example_tasks): """Check that the runner tries to respect task priority.""" @@ -102,6 +152,22 @@ def _run(self, input_files): assert order == sorted(order) +@pytest.mark.parametrize('address', [None, 'localhost:1234']) +def test_run_task(mocker, address): + # Set up mock Dask distributed client + mocker.patch.object(esmvalcore._task, 'Client') + + task = mocker.create_autospec(DiagnosticTask, instance=True) + task.products = mocker.Mock() + output_files, products = _run_task(task, scheduler_address=address) + assert output_files == task.run.return_value + assert products == task.products + if address is None: + esmvalcore._task.Client.assert_not_called() + else: + esmvalcore._task.Client.assert_called_once_with(address) + + def test_py2ncl(): """Test for _py2ncl func.""" ncl_text = _py2ncl(None, 'tas') diff --git a/tests/sample_data/experimental/test_run_recipe.py b/tests/sample_data/experimental/test_run_recipe.py index 86f31a438a..8dadb73805 100644 --- a/tests/sample_data/experimental/test_run_recipe.py +++ b/tests/sample_data/experimental/test_run_recipe.py @@ -3,11 +3,13 @@ Runs recipes using :meth:`esmvalcore.experimental.Recipe.run`. """ +from contextlib import contextmanager from pathlib import Path import iris import pytest +import esmvalcore._task from esmvalcore.config._config_object import CFG_DEFAULT from esmvalcore.config._diagnostics import TAGS from esmvalcore.exceptions import RecipeError @@ -20,7 +22,6 @@ esmvaltool_sample_data = pytest.importorskip("esmvaltool_sample_data") - AUTHOR_TAGS = { 'authors': { 'doe_john': { @@ -32,6 +33,21 @@ } +@pytest.fixture(autouse=True) +def get_mock_distributed_client(monkeypatch): + """Mock `get_distributed_client` to avoid starting a Dask cluster.""" + + @contextmanager + def get_distributed_client(): + yield None + + monkeypatch.setattr( + esmvalcore._task, + 'get_distributed_client', + get_distributed_client, + ) + + @pytest.fixture def recipe(): recipe = get_recipe(Path(__file__).with_name('recipe_api_test.yml')) diff --git a/tests/unit/config/test_dask.py b/tests/unit/config/test_dask.py new file mode 100644 index 0000000000..22e7735628 --- /dev/null +++ b/tests/unit/config/test_dask.py @@ -0,0 +1,76 @@ +import pytest +import yaml + +from esmvalcore.config import _dask + + +def test_get_no_distributed_client(mocker, tmp_path): + mocker.patch.object(_dask, 'CONFIG_FILE', tmp_path / 'nonexistent.yml') + with _dask.get_distributed_client() as client: + assert client is None + + +@pytest.mark.parametrize('warn_unused_args', [False, True]) +def test_get_distributed_client_external(mocker, tmp_path, warn_unused_args): + # Create mock client configuration. + cfg = { + 'client': { + 'address': 'tcp://127.0.0.1:42021', + }, + } + if warn_unused_args: + cfg['cluster'] = {'n_workers': 2} + cfg_file = tmp_path / 'dask.yml' + with cfg_file.open('w', encoding='utf-8') as file: + yaml.safe_dump(cfg, file) + mocker.patch.object(_dask, 'CONFIG_FILE', cfg_file) + + # Create mock distributed.Client + mock_client = mocker.Mock() + mocker.patch.object(_dask, + 'Client', + create_autospec=True, + return_value=mock_client) + + with _dask.get_distributed_client() as client: + assert client is mock_client + _dask.Client.assert_called_with(**cfg['client']) + mock_client.close.assert_called() + + +def test_get_distributed_client_slurm(mocker, tmp_path): + cfg = { + 'cluster': { + 'type': 'dask_jobqueue.SLURMCluster', + 'queue': 'interactive', + 'cores': '8', + 'memory': '16GiB', + }, + } + cfg_file = tmp_path / 'dask.yml' + with cfg_file.open('w', encoding='utf-8') as file: + yaml.safe_dump(cfg, file) + mocker.patch.object(_dask, 'CONFIG_FILE', cfg_file) + + # Create mock distributed.Client + mock_client = mocker.Mock() + mocker.patch.object(_dask, + 'Client', + create_autospec=True, + return_value=mock_client) + + mock_module = mocker.Mock() + mock_cluster_cls = mocker.Mock() + mock_module.SLURMCluster = mock_cluster_cls + mocker.patch.object(_dask.importlib, + 'import_module', + create_autospec=True, + return_value=mock_module) + with _dask.get_distributed_client() as client: + assert client is mock_client + mock_client.close.assert_called() + mock_cluster = mock_cluster_cls.return_value + _dask.Client.assert_called_with(address=mock_cluster.scheduler_address) + args = {k: v for k, v in cfg['cluster'].items() if k != 'type'} + mock_cluster_cls.assert_called_with(**args) + mock_cluster.close.assert_called()