Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Improve the performance of temporal group averaging #689

Merged
merged 8 commits into from
Sep 4, 2024

Conversation

tomvothecoder
Copy link
Collaborator

@tomvothecoder tomvothecoder commented Aug 29, 2024

Description

TODO:

  • Identify performance bottlenecks
    1. Generating labeled time coordinates (aka assign groups) and adding it to the existing time dimension with existing coordinates, then performing the Xarray groupby yields extremely slow results (not sure why, it's an Xarray issue). (Refer to comment) -- replace time coords with labeled time coords directly for grouping, rather than adding labeled time coords as auxiliary coords on the time dimension (which slows things down in Xarray for some reason, need to ask Xarray forum)
    2. In _get_weights(), loading time lengths into memory is slow (lines) -- replace with casting to "timedelta64[ns]" then float64
    3. In _get_weights(), performing validation to check the sums of weights for each group adds up to 1 is slow (lines) -- remove this unnecessary assertion
  • Identify performance optimizations -- I don't think this is necessary right now
    1. Xarray groupby with vs. without flox package
    2. Try with Dask chunking
  • Make sure unit tests still pass
  • Measure performance difference between this branch and main
  • Perform regression testing between branch code on same dataset

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • My changes generate no new warnings
  • Any dependent changes have been merged and published in downstream modules

If applicable:

  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass with my changes (locally and CI/CD build)
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have noted that this is a breaking change for a major release (fix or feature that would cause existing functionality to not work as expected)

@tomvothecoder tomvothecoder changed the title Add initial temporal performance refactor code [Refactor] Improve the performance of temporal group averaging Aug 30, 2024
Copy link

codecov bot commented Aug 30, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 100.00%. Comparing base (584fcce) to head (6459c1b).
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff            @@
##              main      #689   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files           15        15           
  Lines         1544      1546    +2     
=========================================
+ Hits          1544      1546    +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Replace `.load()` with `.astype("timedelta64[ns"])` for clarity
Copy link
Collaborator Author

@tomvothecoder tomvothecoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial self-review. The GH Actions build is passing.

Comment on lines 782 to 784
# 5. Calculate the departures for the data variable.
# ----------------------------------------------------------------------
# This step allows us to perform xarray's grouped arithmetic to
# calculate departures.
dv_obs = ds_obs[data_var].copy()
self._labeled_time = self._label_time_coords(dv_obs[self.dim])
dv_obs_grouped = self._group_data(dv_obs)

# 5. Align time dimension names using the labeled time dimension name.
# ----------------------------------------------------------------------
# The climatology's time dimension is renamed to the labeled time
# dimension in step #4 above (e.g., "time" -> "season"). xarray requires
# dimension names to be aligned to perform grouped arithmetic, which we
# use for calculating departures in step #5. Otherwise, this error is
# raised: "`ValueError: incompatible dimensions for a grouped binary
# operation: the group variable '<FREQ ARG>' is not a dimension on the
# other argument`".
dv_climo = ds_climo[data_var]
dv_climo = dv_climo.rename({self.dim: self._labeled_time.name})

# 6. Calculate the departures for the data variable.
# ----------------------------------------------------------------------
# departures = observation - climatology
with xr.set_options(keep_attrs=True):
dv_departs = dv_obs_grouped - dv_climo
dv_departs = self._add_operation_attrs(dv_departs)
ds_obs[data_var] = dv_departs
ds_departs = self._calculate_departures(ds_obs, ds_climo, data_var)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored this block of code into self._calculate_departures() for readability.

Comment on lines 1168 to +1169
self._labeled_time = self._label_time_coords(dv[self.dim])
dv = dv.assign_coords({self.dim: self._labeled_time})
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Address bottleneck #1 from PR description.

replace time coords with labeled time coords directly for grouping, rather than adding labeled time coords as auxiliary coords on the time dimension (which slows things down in Xarray for some reason, need to ask Xarray forum)

xcdat/temporal.py Outdated Show resolved Hide resolved
@@ -1285,19 +1248,14 @@ def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray:
# or time unit (with rare exceptions see release notes). To avoid this
# warning please use the scalar types `np.float64`, or string notation.`
if isinstance(time_lengths.data, Array):
time_lengths.load()
time_lengths = time_lengths.astype("timedelta64[ns]")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Address bottleneck #2 from PR description

Comment on lines +1282 to +1283
dv = dv.assign_coords({self.dim: self._labeled_time})
dv_gb = dv.groupby(self.dim)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Address bottleneck #1 from PR description

replace time coords with labeled time coords directly for grouping, rather than adding labeled time coords as auxiliary coords on the time dimension (which slows things down in Xarray for some reason, need to ask Xarray forum)

Comment on lines 1334 to 1340
time_grouped = xr.DataArray(
name="_".join(df_dt_components.columns),
name=self.dim,
data=dt_objects,
coords={self.dim: time_coords[self.dim]},
coords={self.dim: dt_objects},
dims=[self.dim],
attrs=time_coords[self.dim].attrs,
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Address bottleneck #1 from PR description

Comment on lines 1644 to 1648
if self._mode in ["group_average", "climatology"]:
self._weights = self._weights.rename({self.dim: f"{self.dim}_original"})
# Only keep the original time coordinates, not the ones labeled
# by group.
self._weights = self._weights.drop_vars(self._labeled_time.name)
weights = self._weights.assign_coords({self.dim: self._dataset[self.dim]})
weights = weights.rename({self.dim: f"{self.dim}_original"})

ds[self._weights.name] = self._weights
ds[weights.name] = weights
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reassign the original, unlabeled time coordinates back to the weights xr.DataArray and then rename it to "time_original" to avoid conflicting the the labeled time coordinates (now called "time").

Comment on lines +1742 to +1743
dv_departs = dv_departs.assign_coords({self.dim: ds_obs[self.dim]})
ds_departs[data_var] = dv_departs
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reassign the grouped, unlabeled time coordinates back to the final departures time coordinates (since the labeled, grouped time coordinates sometimes removes the year of the time coordinates).

@tomvothecoder tomvothecoder self-assigned this Sep 3, 2024
@tomvothecoder tomvothecoder added the type: enhancement New enhancement request label Sep 3, 2024
@tomvothecoder tomvothecoder marked this pull request as ready for review September 3, 2024 21:31
@tomvothecoder
Copy link
Collaborator Author

tomvothecoder commented Sep 3, 2024

Hi @chengzhuzhang, this PR is ready for review.

After refactoring, I managed to cut down the runtime as following:

  1. Annual climatology: 33s -> 5.85s
  2. Annual departures: 1min9s -> 11.6s
  3. monthly group averages: 33.5s -> 5.59s.

I also performed a regression test using the same e3sm_diags dataset between main and this branch and produced identical results. The GH Actions build also passes.

Benchmarking Script

# %%
import xarray as xr
import xcdat as xc

### 1. Using temporal.climatology from xcdat
file_path = "/global/cfs/cdirs/e3sm/e3sm_diags/postprocessed_e3sm_v2_data_for_e3sm_diags/20221103.v2.LR.amip.NGD_v3atm.chrysalis/arm-diags-data/PRECT_sgpc1_198501_201412.nc"
ds = xc.open_dataset(file_path)

branch = "dev"
# %%
# 1. Calculate annual climatology
# -------------------------------
ds_annual_cycle = ds.temporal.climatology("PRECT", "month", keep_weights=True)
ds_annual_cycle.to_netcdf(f"temporal_climatology_{branch}.nc")
"""
main
--------------------------
CPU times: user 33 s, sys: 2.41 s, total: 35.4 s
Wall time: 35.4 s

refactor/688-temp-api-perf
--------------------------
CPU times: user 5.85 s, sys: 2.88 s, total: 8.72 s
Wall time: 8.78 s
"""

# %%
# 2. Calculate annual departures
# ------------------------------
ds_annual_cycle_anom = ds.temporal.departures("PRECT", "month", keep_weights=True)
ds_annual_cycle_anom.to_netcdf(f"temporal_departures_{branch}.nc")
"""
main
--------------------------
CPU times: user 1min 9s, sys: 4.8 s, total: 1min 14s
Wall time: 1min 14s

refactor/688-temp-api-perf
--------------------------
CPU times: user 11.6 s, sys: 4.32 s, total: 15.9 s
Wall time: 15.9 s
"""

# %%
# 3. Calculate monthly group averages
# -----------------------------------
ds_annual_avg = ds.temporal.group_average("PRECT", "month", keep_weights=True)
ds_annual_avg.to_netcdf(f"temporal_group_average_{branch}.nc")

"""
main
--------------------------
CPU times: user 33.5 s, sys: 2.27 s, total: 35.8 s
Wall time: 35.9 s

refactor/688-temp-api-perf
--------------------------
CPU times: user 5.59 s, sys: 2.06 s, total: 7.65 s
Wall time: 7.65 s
"""

Regression testing script

import glob

import xarray as xr

# Get the filepaths for the dev and main branches
dev_filepaths = sorted(glob.glob("qa/issue-688/dev/*.nc"))
main_filepaths = sorted(glob.glob("qa/issue-688/main/*.nc"))

for fp, mp in zip(dev_filepaths, main_filepaths):
    print(f"Comparing {fp} and {mp}")
    # Load the datasets
    dev_ds = xr.open_dataset(fp)
    main_ds = xr.open_dataset(mp)

    # Compare the datasets
    try:
        xr.testing.assert_identical(dev_ds, main_ds)
    except AssertionError as e:
        print(f"Datasets are not identical: {e}")
    else:
        print("Datasets are identical")

Next step

  1. I will investigate the differences you pointed out here between xCDAT and the e3sm_diags climatology functions separately from this PR (related e3sm_diags discussion post)
  2. Open a GH issue on the Xarray repo about grouping with auxiliary time coordinates resulting in a large performance hit


time_lengths = time_lengths.astype(np.float64)

grouped_time_lengths = self._group_data(time_lengths)
weights: xr.DataArray = grouped_time_lengths / grouped_time_lengths.sum()
weights.name = f"{self.dim}_wts"

# Validate the sum of weights for each group is 1.0.
Copy link
Collaborator

@chengzhuzhang chengzhuzhang Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be a good feature to have to check if the sum matches. But if it de-gradates the performance a lot, we can exclude it. maybe this check can be just implemented in testing (if it is not included yet). Also the _get_weights description needs to be updated to reflect that sum is no longer validated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should expect the logic of _get_weights() to be correct, so this assertion should not be necessary at runtime (especially with the performance hit).

I like your suggestion of making it a unit test instead. I will push a commit with this change soon.

if weighted and keep_weights:
self._weights = ds_climo.time_wts
ds_obs = self._keep_weights(ds_obs)
if keep_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice this if statement changed from if weighted and keep_weights, should it be kept the same?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for catching this. I reverted the conditional.

Copy link
Collaborator

@chengzhuzhang chengzhuzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Tom, Thank you for the PR! I think it looks great, just have minor comments for you to consider.

xcdat/temporal.py Outdated Show resolved Hide resolved
- Check if sum of each weight group equals 1.0
- Update `_get_weights()` docs to remove validation portion
xcdat/temporal.py Outdated Show resolved Hide resolved
@tomvothecoder tomvothecoder merged commit 94c8932 into main Sep 4, 2024
9 checks passed
@tomvothecoder tomvothecoder deleted the refactor/688-temp-api-perf branch September 4, 2024 19:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type: enhancement New enhancement request
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[Enhancement]: Temporal averaging performance
2 participants