Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dask.delayed to map_over_subtree #253

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Illviljan
Copy link

@Illviljan Illviljan commented Aug 3, 2023

  • Closes Parallelize map_over_subtree #252
  • Tests added
  • Passes pre-commit run --all-files
  • New functions/methods are listed in api.rst
  • Changes are summarized in docs/source/whats-new.rst

Main:

%timeit dt.interp(time=new_time)
49.9 s ± 1.3 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

This PR:

%timeit dt.interp(time=new_time)
16.7 s ± 297 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
from functools import partial
import time

import numpy as np
import xarray as xr
from datatree import DataTree
import datatree
import dask
import matplotlib.pyplot as plt

number_of_files = 25
number_of_groups = 20
number_of_variables = 2000


def create_datatree(number_of_files, number_of_groups, number_of_variables):
    datasets = {}
    for f in range(number_of_files):
        for g in range(number_of_groups):
            # Create random data:
            time = np.linspace(0, 50 + f, 100 + g)
            y = f * time + g

            # Create dataset:
            ds = xr.Dataset(
                data_vars={
                    f"temperature_{g}{i}": ("time", y)
                    for i in range(number_of_variables // number_of_groups)
                },
                coords={"time": ("time", time)},
            ).chunk()

            # Prepare for Datatree:
            name = f"file_{f}/group_{g}"
            datasets[name] = ds
    dt = DataTree.from_dict(datasets)

    return dt


# %% Interpolate to same time coordinate
dt = create_datatree(number_of_files, number_of_groups, number_of_variables)

new_time = np.linspace(0, 150, 50)
datatree.mapping._map_over_subtree_kwargs.update(parallel=True)
dt_interp = dt.interp(time=new_time)

# %timeit dt.interp(time=new_time)

# %% Usages of mab_over_subtree
# Built in:
datatree.mapping._map_over_subtree_kwargs.update(parallel=True)
dt_interp = dt.interp(time=new_time)


# Decorator;
@partial(datatree.map_over_subtree, parallel=True)
def mean(ds):
    return ds.mean()


mean(dt)

# Function:
datatree.map_over_subtree(np.mean, parallel=True)(dt)

# %% Time a file sweep
new_time = np.linspace(0, 150, 50)


def time_many_files(n=50, step=5):
    times = {}
    for f in range(1, n, step):
        dt = create_datatree(f, number_of_groups, number_of_variables)

        start = time.time()
        dt_interp = dt.interp(time=new_time)
        end = time.time()

        diff = end - start
        print(f"{f} files took {diff:0.5} seconds.")
        times[f] = end - start

    return times


print("Sequential:")
datatree.mapping._map_over_subtree_kwargs.update(parallel=False)
times_seq = time_many_files()

print("Parallel:")
datatree.mapping._map_over_subtree_kwargs.update(parallel=True)
times_par = time_many_files()


plt.figure()
fig, ax = plt.subplots(1, 1)
ax.plot(list(times_seq.keys()), list(times_seq.values()), label="Sequential")
ax.plot(list(times_par.keys()), list(times_par.values()), label="Parallel")
ax.set_title(
    (
        "Time to interpolate datatree\n"
        f"Each file has {number_of_variables} variables and {number_of_groups} groups"
    )
)
ax.set_ylabel("Time [s]")
ax.set_xlabel("Number of files")
ax.legend()

datatree/mapping.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Parallelize map_over_subtree Parallelize map_over_subtree
1 participant