diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..e85d945 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,25 @@ +# How to Contribute + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution, +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows [Google's Open Source Community +Guidelines](https://opensource.google/conduct/). diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..bc81e9f --- /dev/null +++ b/README.md @@ -0,0 +1,127 @@ +# GraphCast: Learning skillful medium-range global weather forecasting + +This package contains example code to run and train [GraphCast](https://arxiv.org/abs/2212.12794). +It also provides three pretrained models: + +1. `GraphCast`, the high-resolution model used in the GraphCast paper (0.25 degree +resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017, + +2. `GraphCast_small`, a smaller, low-resolution version of GraphCast (1 degree +resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from +1979 to 2015, useful to run a model with lower memory and compute constraints, + +3. `GraphCast_operational`, a high-resolution model (0.25 degree resolution, 13 +pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on +HRES data from 2016 to 2021. This model can be initialized from HRES data (does +not require precipitation inputs). + +The model weights, normalization statistics, and example inputs are available on [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast). + +Full model training requires downloading the +[ERA5](https://www.ecmwf.int/en/forecasts/datasets/reanalysis-datasets/era5) +dataset, available from [ECMWF](https://www.ecmwf.int/). + +## Overview of files + +The best starting point is to open `graphcast_demo.ipynb` in [Colaboratory](https://colab.research.google.com/github/deepmind/graphcast/blob/master/graphcast_demo.ipynb), which gives an +example of loading data, generating random weights or load a pre-trained +snapshot, generating predictions, computing the loss and computing gradients. +The one-step implementation of GraphCast architecture, is provided in +`graphcast.py`. + +### Brief description of library files: + +* `autoregressive.py`: Wrapper used to run (and train) the one-step GraphCast + to produce a sequence of predictions by auto-regressively feeding the + outputs back as inputs at each step, in JAX a differentiable way. +* `casting.py`: Wrapper used around GraphCast to make it work using + BFloat16 precision. +* `checkpoint.py`: Utils to serialize and deserialize trees. +* `data_utils.py`: Utils for data preprocessing. +* `deep_typed_graph_net.py`: General purpose deep graph neural network (GNN) + that operates on `TypedGraph`'s where both inputs and outputs are flat + vectors of features for each of the nodes and edges. `graphcast.py` uses + three of these for the Grid2Mesh GNN, the Multi-mesh GNN and the Mesh2Grid + GNN, respectively. +* `graphcast.py`: The main GraphCast model architecture for one-step of + predictions. +* `grid_mesh_connectivity.py`: Tools for converting between regular grids on a + sphere and triangular meshes. +* `icosahedral_mesh.py`: Definition of an icosahedral multi-mesh. +* `losses.py`: Loss computations, including latitude-weighting. +* `model_utils.py`: Utilities to produce flat node and edge vector features + from input grid data, and to manipulate the node output vectors back + into a multilevel grid data. +* `normalization.py`: Wrapper for the one-step GraphCast used to normalize + inputs according to historical values, and targets according to historical + time differences. +* `predictor_base.py`: Defines the interface of the predictor, which GraphCast + and all of the wrappers implement. +* `rollout.py`: Similar to `autoregressive.py` but used only at inference time + using a python loop to produce longer, but non-differentiable trajectories. +* `typed_graph.py`: Definition of `TypedGraph`'s. +* `typed_graph_net.py`: Implementation of simple graph neural network + building blocks defined over `TypedGraph`'s that can be combined to build + deeper models. +* `xarray_jax.py`: A wrapper to let JAX work with `xarray`s. +* `xarray_tree.py`: An implementation of tree.map_structure that works with + `xarray`s. + + +### Dependencies. + +[Chex](https://github.com/deepmind/chex), +[Dask](https://github.com/dask/dask), +[Haiku](https://github.com/deepmind/dm-haiku), +[JAX](https://github.com/google/jax), +[JAXline](https://github.com/deepmind/jaxline), +[Jraph](https://github.com/deepmind/jraph), +[Numpy](https://numpy.org/), +[Pandas](https://pandas.pydata.org/), +[Python](https://www.python.org/), +[SciPy](https://scipy.org/), +[Tree](https://github.com/deepmind/tree), +[Trimesh](https://github.com/mikedh/trimesh) and +[XArray](https://github.com/pydata/xarray). + + +### License and attribution + +The Colab notebook and the associated code are licensed under the Apache +License, Version 2.0. You may obtain a copy of the License at: +https://www.apache.org/licenses/LICENSE-2.0. + +The model weights are made available for use under the terms of the Creative +Commons Attribution-NonCommercial-ShareAlike 4.0 International +(CC BY-NC-SA 4.0). You may obtain a copy of the License at: +https://creativecommons.org/licenses/by-nc-sa/4.0/. + +The weights were trained on ECMWF's ERA5 and HRES data. The colab includes a few +examples of ERA5 and HRES data that can be used as inputs to the models. +ECMWF data product are subject to the following terms: + +1. Copyright statement: Copyright "© 2023 European Centre for Medium-Range Weather Forecasts (ECMWF)". +2. Source www.ecmwf.int +3. Licence Statement: ECMWF data is published under a Creative Commons Attribution 4.0 International (CC BY 4.0). https://creativecommons.org/licenses/by/4.0/ +4. Disclaimer: ECMWF does not accept any liability whatsoever for any error or omission in the data, their availability, or for any loss or damage arising from their use. + +### Disclaimer + +This is not an officially supported Google product. + +Copyright 2023 DeepMind Technologies Limited. + +### Citation + +If you use this work, consider citing our [paper](https://arxiv.org/abs/2212.12794): + +```latex +@article{lam2022graphcast, + title={GraphCast: Learning skillful medium-range global weather forecasting}, + author={Remi Lam and Alvaro Sanchez-Gonzalez and Matthew Willson and Peter Wirnsberger and Meire Fortunato and Alexander Pritzel and Suman Ravuri and Timo Ewalds and Ferran Alet and Zach Eaton-Rosen and Weihua Hu and Alexander Merose and Stephan Hoyer and George Holland and Jacklynn Stott and Oriol Vinyals and Shakir Mohamed and Peter Battaglia}, + year={2022}, + eprint={2212.12794}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` diff --git a/graphcast/autoregressive.py b/graphcast/autoregressive.py new file mode 100644 index 0000000..1cf1324 --- /dev/null +++ b/graphcast/autoregressive.py @@ -0,0 +1,312 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Predictor wrapping a one-step Predictor to make autoregressive predictions. +""" + +from typing import Optional, cast + +from absl import logging +from graphcast import predictor_base +from graphcast import xarray_jax +from graphcast import xarray_tree +import haiku as hk +import jax +import xarray + + +def _unflatten_and_expand_time(flat_variables, tree_def, time_coords): + variables = jax.tree_util.tree_unflatten(tree_def, flat_variables) + return variables.expand_dims(time=time_coords, axis=0) + + +def _get_flat_arrays_and_single_timestep_treedef(variables): + flat_arrays = jax.tree_util.tree_leaves(variables.transpose('time', ...)) + _, treedef = jax.tree_util.tree_flatten(variables.isel(time=0, drop=True)) + return flat_arrays, treedef + + +class Predictor(predictor_base.Predictor): + """Wraps a one-step Predictor to make multi-step predictions autoregressively. + + The wrapped Predictor will be used to predict a single timestep conditional + on the inputs passed to the outer Predictor. Its predictions are then + passed back in as inputs at the next timestep, for as many timesteps as are + requested in the targets_template. (When multiple timesteps of input are + used, a rolling window of inputs is maintained with new predictions + concatenated onto the end). + + You may ask for additional variables to be predicted as targets which aren't + used as inputs. These will be predicted as output variables only and not fed + back in autoregressively. All target variables must be time-dependent however. + + You may also specify static (non-time-dependent) inputs which will be passed + in at each timestep but are not predicted. + + At present, any time-dependent inputs must also be present as targets so they + can be passed in autoregressively. + + The loss of the wrapped one-step Predictor is averaged over all timesteps to + give a loss for the autoregressive Predictor. + """ + + def __init__( + self, + predictor: predictor_base.Predictor, + noise_level: Optional[float] = None, + gradient_checkpointing: bool = False, + ): + """Initializes an autoregressive predictor wrapper. + + Args: + predictor: A predictor to wrap in an auto-regressive way. + noise_level: Optional value that multiplies the standard normal noise + added to the time-dependent variables of the predictor inputs. In + particular, no noise is added to the predictions that are fed back + auto-regressively. Defaults to not adding noise. + gradient_checkpointing: If True, gradient checkpointing will be + used at each step of the computation to save on memory. Roughtly this + should make the backwards pass two times more expensive, and the time + per step counting the forward pass, should only increase by about 50%. + Note this parameter will be ignored with a warning if the scan sequence + length is 1. + """ + self._predictor = predictor + self._noise_level = noise_level + self._gradient_checkpointing = gradient_checkpointing + + def _get_and_validate_constant_inputs(self, inputs, targets, forcings): + constant_inputs = inputs.drop_vars(targets.keys(), errors='ignore') + constant_inputs = constant_inputs.drop_vars( + forcings.keys(), errors='ignore') + for name, var in constant_inputs.items(): + if 'time' in var.dims: + raise ValueError( + f'Time-dependent input variable {name} must either be a forcing ' + 'variable, or a target variable to allow for auto-regressive ' + 'feedback.') + return constant_inputs + + def _validate_targets_and_forcings(self, targets, forcings): + for name, var in targets.items(): + if 'time' not in var.dims: + raise ValueError(f'Target variable {name} must be time-dependent.') + + for name, var in forcings.items(): + if 'time' not in var.dims: + raise ValueError(f'Forcing variable {name} must be time-dependent.') + + overlap = forcings.keys() & targets.keys() + if overlap: + raise ValueError('The following were specified as both targets and ' + f'forcings, which isn\'t allowed: {overlap}') + + def _update_inputs(self, inputs, next_frame): + num_inputs = inputs.dims['time'] + + predicted_or_forced_inputs = next_frame[list(inputs.keys())] + + # Combining datasets with inputs and target time stamps aligns them. + # Only keep the num_inputs trailing frames for use as next inputs. + return (xarray.concat([inputs, predicted_or_forced_inputs], dim='time') + .tail(time=num_inputs) + # Update the time coordinate to reset the lead times for + # next AR iteration. + .assign_coords(time=inputs.coords['time'])) + + def __call__(self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: xarray.Dataset, + **kwargs) -> xarray.Dataset: + """Calls the Predictor. + + Args: + inputs: input variable used to make predictions. Inputs can include both + time-dependent and time independent variables. Any time-dependent + input variables must also be present in the targets_template or the + forcings. + targets_template: A target template containing informations about which + variables should be predicted and the time alignment of the predictions. + All target variables must be time-dependent. + The number of time frames is used to set the number of unroll of the AR + predictor (e.g. multiple unroll of the inner predictor for one time step + in the targets is not supported yet). + forcings: Variables that will be fed to the model. The variables + should not overlap with the target ones. The time coordinates of the + forcing variables should match the target ones. + Forcing variables which are also present in the inputs, will be used to + supply ground-truth values for those inputs when they are passed to the + underlying predictor at timesteps beyond the first timestep. + **kwargs: Additional arguments passed along to the inner Predictor. + + Returns: + predictions: the model predictions matching the target template. + + Raise: + ValueError: if the time coordinates of the inputs and targets are not + different by a constant time step. + """ + + constant_inputs = self._get_and_validate_constant_inputs( + inputs, targets_template, forcings) + self._validate_targets_and_forcings(targets_template, forcings) + + # After the above checks, the remaining inputs must be time-dependent: + inputs = inputs.drop_vars(constant_inputs.keys()) + + # A predictions template only including the next time to predict. + target_template = targets_template.isel(time=[0]) + + flat_forcings, forcings_treedef = ( + _get_flat_arrays_and_single_timestep_treedef(forcings)) + scan_variables = flat_forcings + + def one_step_prediction(inputs, scan_variables): + + flat_forcings = scan_variables + forcings = _unflatten_and_expand_time(flat_forcings, forcings_treedef, + target_template.coords['time']) + + # Add constant inputs: + all_inputs = xarray.merge([constant_inputs, inputs]) + predictions: xarray.Dataset = self._predictor( + all_inputs, target_template, + forcings=forcings, + **kwargs) + + next_frame = xarray.merge([predictions, forcings]) + next_inputs = self._update_inputs(inputs, next_frame) + + # Drop the length-1 time dimension, since scan will concat all the outputs + # for different times along a new leading time dimension: + predictions = predictions.squeeze('time', drop=True) + # We return the prediction flattened into plain jax arrays, because the + # extra leading dimension added by scan prevents the tree_util + # registrations in xarray_jax from unflattening them back into an + # xarray.Dataset automatically: + flat_pred = jax.tree_util.tree_leaves(predictions) + return next_inputs, flat_pred + + if self._gradient_checkpointing: + scan_length = targets_template.dims['time'] + if scan_length <= 1: + logging.warning( + 'Skipping gradient checkpointing for sequence length of 1') + else: + # Just in case we take gradients (e.g. for control), although + # in most cases this will just be for a forward pass. + one_step_prediction = hk.remat(one_step_prediction) + + # Loop (without unroll) with hk states in cell (jax.lax.scan won't do). + _, flat_preds = hk.scan(one_step_prediction, inputs, scan_variables) + + # The result of scan will have an extra leading axis on all arrays, + # corresponding to the target times in this case. We need to be prepared for + # it when unflattening the arrays back into a Dataset: + scan_result_template = ( + target_template.squeeze('time', drop=True) + .expand_dims(time=targets_template.coords['time'], axis=0)) + _, scan_result_treedef = jax.tree_util.tree_flatten(scan_result_template) + predictions = jax.tree_util.tree_unflatten(scan_result_treedef, flat_preds) + return predictions + + def loss(self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: xarray.Dataset, + **kwargs + ) -> predictor_base.LossAndDiagnostics: + """The mean of the per-timestep losses of the underlying predictor.""" + if targets.sizes['time'] == 1: + # If there is only a single target timestep then we don't need any + # autoregressive feedback and can delegate the loss directly to the + # underlying single-step predictor. This means the underlying predictor + # doesn't need to implement .loss_and_predictions. + return self._predictor.loss(inputs, targets, forcings, **kwargs) + + constant_inputs = self._get_and_validate_constant_inputs( + inputs, targets, forcings) + self._validate_targets_and_forcings(targets, forcings) + # After the above checks, the remaining inputs must be time-dependent: + inputs = inputs.drop_vars(constant_inputs.keys()) + + if self._noise_level: + def add_noise(x): + return x + self._noise_level * jax.random.normal( + hk.next_rng_key(), shape=x.shape) + # Add noise to time-dependent variables of the inputs. + inputs = jax.tree_map(add_noise, inputs) + + # The per-timestep targets passed by scan to one_step_loss below will have + # no leading time axis. We need a treedef without the time axis to use + # inside one_step_loss to unflatten it back into a dataset: + flat_targets, target_treedef = _get_flat_arrays_and_single_timestep_treedef( + targets) + scan_variables = flat_targets + + flat_forcings, forcings_treedef = ( + _get_flat_arrays_and_single_timestep_treedef(forcings)) + scan_variables = (flat_targets, flat_forcings) + + def one_step_loss(inputs, scan_variables): + flat_target, flat_forcings = scan_variables + forcings = _unflatten_and_expand_time(flat_forcings, forcings_treedef, + targets.coords['time'][:1]) + + target = _unflatten_and_expand_time(flat_target, target_treedef, + targets.coords['time'][:1]) + + # Add constant inputs: + all_inputs = xarray.merge([constant_inputs, inputs]) + + (loss, diagnostics), predictions = self._predictor.loss_and_predictions( + all_inputs, + target, + forcings=forcings, + **kwargs) + + # Unwrap to jax arrays shape (batch,): + loss, diagnostics = xarray_tree.map_structure( + xarray_jax.unwrap_data, (loss, diagnostics)) + + predictions = cast(xarray.Dataset, predictions) # Keeps pytype happy. + next_frame = xarray.merge([predictions, forcings]) + next_inputs = self._update_inputs(inputs, next_frame) + + return next_inputs, (loss, diagnostics) + + if self._gradient_checkpointing: + scan_length = targets.dims['time'] + if scan_length <= 1: + logging.warning( + 'Skipping gradient checkpointing for sequence length of 1') + else: + one_step_loss = hk.remat(one_step_loss) + + # We can pass inputs (the initial state of the loop) in directly as a + # Dataset because the shape we pass in to scan is the same as the shape scan + # passes to the inner function. But, for scan_variables, we must flatten the + # targets (and unflatten them inside the inner function) because they are + # passed to the inner function per-timestep without the original time axis. + # The same apply to the optional forcing. + _, (per_timestep_losses, per_timestep_diagnostics) = hk.scan( + one_step_loss, inputs, scan_variables) + + # Re-wrap loss and diagnostics as DataArray and average them over time: + (loss, diagnostics) = jax.tree_util.tree_map( + lambda x: xarray_jax.DataArray(x, dims=('time', 'batch')).mean( # pylint: disable=g-long-lambda + 'time', skipna=False), + (per_timestep_losses, per_timestep_diagnostics)) + + return loss, diagnostics diff --git a/graphcast/casting.py b/graphcast/casting.py new file mode 100644 index 0000000..3ea9e7a --- /dev/null +++ b/graphcast/casting.py @@ -0,0 +1,205 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrappers that take care of casting.""" + +import contextlib +from typing import Any, Mapping, Tuple + +import chex +from graphcast import predictor_base +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import xarray + + +PyTree = Any + + +class Bfloat16Cast(predictor_base.Predictor): + """Wrapper that casts all inputs to bfloat16 and outputs to targets dtype.""" + + def __init__(self, predictor: predictor_base.Predictor, enabled: bool = True): + """Inits the wrapper. + + Args: + predictor: predictor being wrapped. + enabled: disables the wrapper if False, for simpler hyperparameter scans. + + """ + self._enabled = enabled + self._predictor = predictor + + def __call__(self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: xarray.Dataset, + **kwargs + ) -> xarray.Dataset: + if not self._enabled: + return self._predictor(inputs, targets_template, forcings, **kwargs) + + with bfloat16_variable_view(): + predictions = self._predictor( + *_all_inputs_to_bfloat16(inputs, targets_template, forcings), + **kwargs,) + + predictions_dtype = infer_floating_dtype(predictions) + if predictions_dtype != jnp.bfloat16: + raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}') + + targets_dtype = infer_floating_dtype(targets_template) + return tree_map_cast( + predictions, input_dtype=jnp.bfloat16, output_dtype=targets_dtype) + + def loss(self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: xarray.Dataset, + **kwargs, + ) -> predictor_base.LossAndDiagnostics: + if not self._enabled: + return self._predictor.loss(inputs, targets, forcings, **kwargs) + + with bfloat16_variable_view(): + loss, scalars = self._predictor.loss( + *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs) + + if loss.dtype != jnp.bfloat16: + raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}') + + targets_dtype = infer_floating_dtype(targets) + + # Note that casting back the loss to e.g. float32 should not affect data + # types of the backwards pass, because the first thing the backwards pass + # should do is to go backwards the casting op and cast back to bfloat16 + # (and xprofs seem to confirm this). + return tree_map_cast((loss, scalars), + input_dtype=jnp.bfloat16, output_dtype=targets_dtype) + + def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray + self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: xarray.Dataset, + **kwargs, + ) -> Tuple[predictor_base.LossAndDiagnostics, + xarray.Dataset]: + if not self._enabled: + return self._predictor.loss_and_predictions(inputs, targets, forcings, # pytype: disable=bad-return-type # jax-ndarray + **kwargs) + + with bfloat16_variable_view(): + (loss, scalars), predictions = self._predictor.loss_and_predictions( + *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs) + + if loss.dtype != jnp.bfloat16: + raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}') + + predictions_dtype = infer_floating_dtype(predictions) + if predictions_dtype != jnp.bfloat16: + raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}') + + targets_dtype = infer_floating_dtype(targets) + return tree_map_cast(((loss, scalars), predictions), + input_dtype=jnp.bfloat16, output_dtype=targets_dtype) + + +def infer_floating_dtype(data_vars: Mapping[str, chex.Array]) -> np.dtype: + """Infers a floating dtype from an input mapping of data.""" + dtypes = { + v.dtype + for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)} + if len(dtypes) != 1: + dtypes_and_shapes = { + k: (v.dtype, v.shape) + for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)} + raise ValueError( + f'Did not found exactly one floating dtype {dtypes} in input variables:' + f'{dtypes_and_shapes}') + return list(dtypes)[0] + + +def _all_inputs_to_bfloat16( + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: xarray.Dataset, + ) -> Tuple[xarray.Dataset, + xarray.Dataset, + xarray.Dataset]: + return (inputs.astype(jnp.bfloat16), + jax.tree_map(lambda x: x.astype(jnp.bfloat16), targets), + forcings.astype(jnp.bfloat16)) + + +def tree_map_cast(inputs: PyTree, input_dtype: np.dtype, output_dtype: np.dtype, + ) -> PyTree: + def cast_fn(x): + if x.dtype == input_dtype: + return x.astype(output_dtype) + return jax.tree_map(cast_fn, inputs) + + +@contextlib.contextmanager +def bfloat16_variable_view(enabled: bool = True): + """Context for Haiku modules with float32 params, but bfloat16 activations. + + It works as follows: + * Every time a variable is requested to be created/set as np.bfloat16, + it will create an underlying float32 variable, instead. + * Every time a variable a variable is requested as bfloat16, it will check the + variable is of float32 type, and cast the variable to bfloat16. + + Note the gradients are still computed and accumulated as float32, because + the params returned by init are float32, so the gradient function with + respect to the params will already include an implicit casting to float32. + + Args: + enabled: Only enables bfloat16 behavior if True. + + Yields: + None + """ + + if enabled: + with hk.custom_creator( + _bfloat16_creator, state=True), hk.custom_getter( + _bfloat16_getter, state=True), hk.custom_setter( + _bfloat16_setter): + yield + else: + yield + + +def _bfloat16_creator(next_creator, shape, dtype, init, context): + """Creates float32 variables when bfloat16 is requested.""" + if context.original_dtype == jnp.bfloat16: + dtype = jnp.float32 + return next_creator(shape, dtype, init) + + +def _bfloat16_getter(next_getter, value, context): + """Casts float32 to bfloat16 when bfloat16 was originally requested.""" + if context.original_dtype == jnp.bfloat16: + assert value.dtype == jnp.float32 + value = value.astype(jnp.bfloat16) + return next_getter(value) + + +def _bfloat16_setter(next_setter, value, context): + """Casts bfloat16 to float32 when bfloat16 was originally set.""" + if context.original_dtype == jnp.bfloat16: + value = value.astype(jnp.float32) + return next_setter(value) diff --git a/graphcast/checkpoint.py b/graphcast/checkpoint.py new file mode 100644 index 0000000..b4c8433 --- /dev/null +++ b/graphcast/checkpoint.py @@ -0,0 +1,170 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Serialize and deserialize trees.""" + +import dataclasses +import io +import types +from typing import Any, BinaryIO, Optional, TypeVar + +import numpy as np + +_T = TypeVar("_T") + + +def dump(dest: BinaryIO, value: Any) -> None: + """Dump a tree of dicts/dataclasses to a file object. + + Args: + dest: a file object to write to. + value: A tree of dicts, lists, tuples and dataclasses of numpy arrays and + other basic types. Unions are not supported, other than Optional/None + which is only supported in dataclasses, not in dicts, lists or tuples. + All leaves must be coercible to a numpy array, and recoverable as a single + arg to a type. + """ + buffer = io.BytesIO() # In case the destination doesn't support seeking. + np.savez(buffer, **_flatten(value)) + dest.write(buffer.getvalue()) + + +def load(source: BinaryIO, typ: type[_T]) -> _T: + """Load from a file object and convert it to the specified type. + + Args: + source: a file object to read from. + typ: a type object that acts as a schema for deserialization. It must match + what was serialized. If a type is Any, it will be returned however numpy + serialized it, which is what you want for a tree of numpy arrays. + + Returns: + the deserialized value as the specified type. + """ + return _convert_types(typ, _unflatten(np.load(source))) + + +_SEP = ":" + + +def _flatten(tree: Any) -> dict[str, Any]: + """Flatten a tree of dicts/dataclasses/lists/tuples to a single dict.""" + if dataclasses.is_dataclass(tree): + # Don't use dataclasses.asdict as it is recursive so skips dropping None. + tree = {f.name: v for f in dataclasses.fields(tree) + if (v := getattr(tree, f.name)) is not None} + elif isinstance(tree, (list, tuple)): + tree = dict(enumerate(tree)) + + assert isinstance(tree, dict) + + flat = {} + for k, v in tree.items(): + k = str(k) + assert _SEP not in k + if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)): + for a, b in _flatten(v).items(): + flat[f"{k}{_SEP}{a}"] = b + else: + assert v is not None + flat[k] = v + return flat + + +def _unflatten(flat: dict[str, Any]) -> dict[str, Any]: + """Unflatten a dict to a tree of dicts.""" + tree = {} + for flat_key, v in flat.items(): + node = tree + keys = flat_key.split(_SEP) + for k in keys[:-1]: + if k not in node: + node[k] = {} + node = node[k] + node[keys[-1]] = v + return tree + + +def _convert_types(typ: type[_T], value: Any) -> _T: + """Convert some structure into the given type. The structures must match.""" + if typ in (Any, ...): + return value + + if typ in (int, float, str, bool): + return typ(value) + + if typ is np.ndarray: + assert isinstance(value, np.ndarray) + return value + + if dataclasses.is_dataclass(typ): + kwargs = {} + for f in dataclasses.fields(typ): + # Only support Optional for dataclasses, as numpy can't serialize it + # directly (without pickle), and dataclasses are the only case where we + # can know the full set of values and types and therefore know the + # non-existence must mean None. + if isinstance(f.type, (types.UnionType, type(Optional[int]))): + constructors = [t for t in f.type.__args__ if t is not types.NoneType] + if len(constructors) != 1: + raise TypeError( + "Optional works, Union with anything except None doesn't") + if f.name not in value: + kwargs[f.name] = None + continue + constructor = constructors[0] + else: + constructor = f.type + + if f.name in value: + kwargs[f.name] = _convert_types(constructor, value[f.name]) + else: + raise ValueError(f"Missing value: {f.name}") + return typ(**kwargs) + + base_type = getattr(typ, "__origin__", None) + + if base_type is dict: + assert len(typ.__args__) == 2 + key_type, value_type = typ.__args__ + return {_convert_types(key_type, k): _convert_types(value_type, v) + for k, v in value.items()} + + if base_type is list: + assert len(typ.__args__) == 1 + value_type = typ.__args__[0] + return [_convert_types(value_type, v) + for _, v in sorted(value.items(), key=lambda x: int(x[0]))] + + if base_type is tuple: + if len(typ.__args__) == 2 and typ.__args__[1] == ...: + # An arbitrary length tuple of a single type, eg: tuple[int, ...] + value_type = typ.__args__[0] + return tuple(_convert_types(value_type, v) + for _, v in sorted(value.items(), key=lambda x: int(x[0]))) + else: + # A fixed length tuple of arbitrary types, eg: tuple[int, str, float] + assert len(typ.__args__) == len(value) + return tuple( + _convert_types(t, v) + for t, (_, v) in zip( + typ.__args__, sorted(value.items(), key=lambda x: int(x[0])))) + + # This is probably unreachable with reasonable serializable inputs. + try: + return typ(value) + except TypeError as e: + raise TypeError( + "_convert_types expects the type argument to be a dataclass defined " + "with types that are valid constructors (eg tuple is fine, Tuple " + "isn't), and accept a numpy array as the sole argument.") from e diff --git a/graphcast/checkpoint_test.py b/graphcast/checkpoint_test.py new file mode 100644 index 0000000..6ec0ce6 --- /dev/null +++ b/graphcast/checkpoint_test.py @@ -0,0 +1,124 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Check that the checkpoint serialization is reversable.""" + +import dataclasses +import io +from typing import Any, Optional, Union + +from absl.testing import absltest +from graphcast import checkpoint +import numpy as np + + +@dataclasses.dataclass +class SubConfig: + a: int + b: str + + +@dataclasses.dataclass +class Config: + bt: bool + bf: bool + i: int + f: float + o1: Optional[int] + o2: Optional[int] + o3: Union[int, None] + o4: Union[int, None] + o5: int | None + o6: int | None + li: list[int] + ls: list[str] + ldc: list[SubConfig] + tf: tuple[float, ...] + ts: tuple[str, ...] + t: tuple[str, int, SubConfig] + tdc: tuple[SubConfig, ...] + dsi: dict[str, int] + dss: dict[str, str] + dis: dict[int, str] + dsdis: dict[str, dict[int, str]] + dc: SubConfig + dco: Optional[SubConfig] + ddc: dict[str, SubConfig] + + +@dataclasses.dataclass +class Checkpoint: + params: dict[str, Any] + config: Config + + +class DataclassTest(absltest.TestCase): + + def test_serialize_dataclass(self): + ckpt = Checkpoint( + params={ + "layer1": { + "w": np.arange(10).reshape(2, 5), + "b": np.array([2, 6]), + }, + "layer2": { + "w": np.arange(8).reshape(2, 4), + "b": np.array([2, 6]), + }, + "blah": np.array([3, 9]), + }, + config=Config( + bt=True, + bf=False, + i=42, + f=3.14, + o1=1, + o2=None, + o3=2, + o4=None, + o5=3, + o6=None, + li=[12, 9, 7, 15, 16, 14, 1, 6, 11, 4, 10, 5, 13, 3, 8, 2], + ls=list("qhjfdxtpzgemryoikwvblcaus"), + ldc=[SubConfig(1, "hello"), SubConfig(2, "world")], + tf=(1, 4, 2, 10, 5, 9, 13, 16, 15, 8, 12, 7, 11, 14, 3, 6), + ts=("hello", "world"), + t=("foo", 42, SubConfig(1, "bar")), + tdc=(SubConfig(1, "hello"), SubConfig(2, "world")), + dsi={"a": 1, "b": 2, "c": 3}, + dss={"d": "e", "f": "g"}, + dis={1: "a", 2: "b", 3: "c"}, + dsdis={"a": {1: "hello", 2: "world"}, "b": {1: "world"}}, + dc=SubConfig(1, "hello"), + dco=None, + ddc={"a": SubConfig(1, "hello"), "b": SubConfig(2, "world")}, + )) + + buffer = io.BytesIO() + checkpoint.dump(buffer, ckpt) + buffer.seek(0) + ckpt2 = checkpoint.load(buffer, Checkpoint) + np.testing.assert_array_equal(ckpt.params["layer1"]["w"], + ckpt2.params["layer1"]["w"]) + np.testing.assert_array_equal(ckpt.params["layer1"]["b"], + ckpt2.params["layer1"]["b"]) + np.testing.assert_array_equal(ckpt.params["layer2"]["w"], + ckpt2.params["layer2"]["w"]) + np.testing.assert_array_equal(ckpt.params["layer2"]["b"], + ckpt2.params["layer2"]["b"]) + np.testing.assert_array_equal(ckpt.params["blah"], ckpt2.params["blah"]) + self.assertEqual(ckpt.config, ckpt2.config) + + +if __name__ == "__main__": + absltest.main() diff --git a/graphcast/data_utils.py b/graphcast/data_utils.py new file mode 100644 index 0000000..6888166 --- /dev/null +++ b/graphcast/data_utils.py @@ -0,0 +1,314 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dataset utilities.""" + +from typing import Any, Mapping, Sequence, Tuple, Union + +import numpy as np +import pandas as pd +import xarray + +TimedeltaLike = Any # Something convertible to pd.Timedelta. +TimedeltaStr = str # A string convertible to pd.Timedelta. + +TargetLeadTimes = Union[ + TimedeltaLike, + Sequence[TimedeltaLike], + slice # with TimedeltaLike as its start and stop. +] + +_SEC_PER_HOUR = 3600 +_HOUR_PER_DAY = 24 +SEC_PER_DAY = _SEC_PER_HOUR * _HOUR_PER_DAY +_AVG_DAY_PER_YEAR = 365.24219 +AVG_SEC_PER_YEAR = SEC_PER_DAY * _AVG_DAY_PER_YEAR + +DAY_PROGRESS = "day_progress" +YEAR_PROGRESS = "year_progress" + + +def get_year_progress(seconds_since_epoch: np.ndarray) -> np.ndarray: + """Computes year progress for times in seconds. + + Args: + seconds_since_epoch: Times in seconds since the "epoch" (the point at which + UNIX time starts). + + Returns: + Year progress normalized to be in the [0, 1) interval for each time point. + """ + + # Start with the pure integer division, and then float at the very end. + # We will try to keep as much precision as possible. + years_since_epoch = ( + seconds_since_epoch / SEC_PER_DAY / np.float64(_AVG_DAY_PER_YEAR) + ) + # Note depending on how these ops are down, we may end up with a "weak_type" + # which can cause issues in subtle ways, and hard to track here. + # In any case, casting to float32 should get rid of the weak type. + # [0, 1.) Interval. + return np.mod(years_since_epoch, 1.0).astype(np.float32) + + +def get_day_progress( + seconds_since_epoch: np.ndarray, + longitude: np.ndarray, +) -> np.ndarray: + """Computes day progress for times in seconds at each longitude. + + Args: + seconds_since_epoch: 1D array of times in seconds since the 'epoch' (the + point at which UNIX time starts). + longitude: 1D array of longitudes at which day progress is computed. + + Returns: + 2D array of day progress values normalized to be in the [0, 1) inverval + for each time point at each longitude. + """ + + # [0.0, 1.0) Interval. + day_progress_greenwich = ( + np.mod(seconds_since_epoch, SEC_PER_DAY) / SEC_PER_DAY + ) + + # Offset the day progress to the longitude of each point on Earth. + longitude_offsets = np.deg2rad(longitude) / (2 * np.pi) + day_progress = np.mod( + day_progress_greenwich[..., np.newaxis] + longitude_offsets, 1.0 + ) + return day_progress.astype(np.float32) + + +def featurize_progress( + name: str, dims: Sequence[str], progress: np.ndarray +) -> Mapping[str, xarray.Variable]: + """Derives features used by ML models from the `progress` variable. + + Args: + name: Base variable name from which features are derived. + dims: List of the output feature dimensions, e.g. ("day", "lon"). + progress: Progress variable values. + + Returns: + Dictionary of xarray variables derived from the `progress` values. It + includes the original `progress` variable along with its sin and cos + transformations. + + Raises: + ValueError if the number of feature dimensions is not equal to the number + of data dimensions. + """ + if len(dims) != progress.ndim: + raise ValueError( + f"Number of feature dimensions ({len(dims)}) must be equal to the" + f" number of data dimensions: {progress.ndim}." + ) + progress_phase = progress * (2 * np.pi) + return { + name: xarray.Variable(dims, progress), + name + "_sin": xarray.Variable(dims, np.sin(progress_phase)), + name + "_cos": xarray.Variable(dims, np.cos(progress_phase)), + } + + +def add_derived_vars(data: xarray.Dataset) -> None: + """Adds year and day progress features to `data` in place. + + NOTE: `toa_incident_solar_radiation` needs to be computed in this function + as well. + + Args: + data: Xarray dataset to which derived features will be added. + + Raises: + ValueError if `datetime` or `lon` are not in `data` coordinates. + """ + + for coord in ("datetime", "lon"): + if coord not in data.coords: + raise ValueError(f"'{coord}' must be in `data` coordinates.") + + # Compute seconds since epoch. + # Note `data.coords["datetime"].astype("datetime64[s]").astype(np.int64)` + # does not work as xarrays always cast dates into nanoseconds! + seconds_since_epoch = ( + data.coords["datetime"].data.astype("datetime64[s]").astype(np.int64) + ) + batch_dim = ("batch",) if "batch" in data.dims else () + + # Add year progress features. + year_progress = get_year_progress(seconds_since_epoch) + data.update( + featurize_progress( + name=YEAR_PROGRESS, dims=batch_dim + ("time",), progress=year_progress + ) + ) + + # Add day progress features. + longitude_coord = data.coords["lon"] + day_progress = get_day_progress(seconds_since_epoch, longitude_coord.data) + data.update( + featurize_progress( + name=DAY_PROGRESS, + dims=batch_dim + ("time",) + longitude_coord.dims, + progress=day_progress, + ) + ) + + +def extract_input_target_times( + dataset: xarray.Dataset, + input_duration: TimedeltaLike, + target_lead_times: TargetLeadTimes, + ) -> Tuple[xarray.Dataset, xarray.Dataset]: + """Extracts inputs and targets for prediction, from a Dataset with a time dim. + + The input period is assumed to be contiguous (specified by a duration), but + the targets can be a list of arbitrary lead times. + + Examples: + + # Use 18 hours of data as inputs, and two specific lead times as targets: + # 3 days and 5 days after the final input. + extract_inputs_targets( + dataset, + input_duration='18h', + target_lead_times=('3d', '5d') + ) + + # Use 1 day of data as input, and all lead times between 6 hours and + # 24 hours inclusive as targets. Demonstrates a friendlier supported string + # syntax. + extract_inputs_targets( + dataset, + input_duration='1 day', + target_lead_times=slice('6 hours', '24 hours') + ) + + # Just use a single target lead time of 3 days: + extract_inputs_targets( + dataset, + input_duration='24h', + target_lead_times='3d' + ) + + Args: + dataset: An xarray.Dataset with a 'time' dimension whose coordinates are + timedeltas. It's assumed that the time coordinates have a fixed offset / + time resolution, and that the input_duration and target_lead_times are + multiples of this. + input_duration: pandas.Timedelta or something convertible to it (e.g. a + shorthand string like '6h' or '5d12h'). + target_lead_times: Either a single lead time, a slice with start and stop + (inclusive) lead times, or a sequence of lead times. Lead times should be + Timedeltas (or something convertible to). They are given relative to the + final input timestep, and should be positive. + + Returns: + inputs: + targets: + Two datasets with the same shape as the input dataset except that a + selection has been made from the time axis, and the origin of the + time coordinate will be shifted to refer to lead times relative to the + final input timestep. So for inputs the times will end at lead time 0, + for targets the time coordinates will refer to the lead times requested. + """ + + (target_lead_times, target_duration + ) = _process_target_lead_times_and_get_duration(target_lead_times) + + # Shift the coordinates for the time axis so that a timedelta of zero + # corresponds to the forecast reference time. That is, the final timestep + # that's available as input to the forecast, with all following timesteps + # forming the target period which needs to be predicted. + # This means the time coordinates are now forecast lead times. + time = dataset.coords["time"] + dataset = dataset.assign_coords(time=time + target_duration - time[-1]) + + # Slice out targets: + targets = dataset.sel({"time": target_lead_times}) + + input_duration = pd.Timedelta(input_duration) + # Both endpoints are inclusive with label-based slicing, so we offset by a + # small epsilon to make one of the endpoints non-inclusive: + zero = pd.Timedelta(0) + epsilon = pd.Timedelta(1, "ns") + inputs = dataset.sel({"time": slice(-input_duration + epsilon, zero)}) + return inputs, targets + + +def _process_target_lead_times_and_get_duration( + target_lead_times: TargetLeadTimes) -> TimedeltaLike: + """Returns the minimum duration for the target lead times.""" + if isinstance(target_lead_times, slice): + # A slice of lead times. xarray already accepts timedelta-like values for + # the begin/end/step of the slice. + if target_lead_times.start is None: + # If the start isn't specified, we assume it starts at the next timestep + # after lead time 0 (lead time 0 is the final input timestep): + target_lead_times = slice( + pd.Timedelta(1, "ns"), target_lead_times.stop, target_lead_times.step + ) + target_duration = pd.Timedelta(target_lead_times.stop) + else: + if not isinstance(target_lead_times, (list, tuple, set)): + # A single lead time, which we wrap as a length-1 array to ensure there + # still remains a time dimension (here of length 1) for consistency. + target_lead_times = [target_lead_times] + + # A list of multiple (not necessarily contiguous) lead times: + target_lead_times = [pd.Timedelta(x) for x in target_lead_times] + target_lead_times.sort() + target_duration = target_lead_times[-1] + return target_lead_times, target_duration + + +def extract_inputs_targets_forcings( + dataset: xarray.Dataset, + *, + input_variables: Tuple[str, ...], + target_variables: Tuple[str, ...], + forcing_variables: Tuple[str, ...], + pressure_levels: Tuple[int, ...], + input_duration: TimedeltaLike, + target_lead_times: TargetLeadTimes, + ) -> Tuple[xarray.Dataset, xarray.Dataset, xarray.Dataset]: + """Extracts inputs, targets and forcings according to requirements.""" + dataset = dataset.sel(level=list(pressure_levels)) + + # "Forcings" are derived variables and do not exist in the original ERA5 or + # HRES datasets. Compute them if they are not in `dataset`. + if not set(forcing_variables).issubset(set(dataset.data_vars)): + add_derived_vars(dataset) + + # `datetime` is needed by add_derived_vars but breaks autoregressive rollouts. + dataset = dataset.drop_vars("datetime") + + inputs, targets = extract_input_target_times( + dataset, + input_duration=input_duration, + target_lead_times=target_lead_times) + + if set(forcing_variables) & set(target_variables): + raise ValueError( + f"Forcing variables {forcing_variables} should not " + f"overlap with target variables {target_variables}." + ) + + inputs = inputs[list(input_variables)] + # The forcing uses the same time coordinates as the target. + forcings = targets[list(forcing_variables)] + targets = targets[list(target_variables)] + + return inputs, targets, forcings diff --git a/graphcast/data_utils_test.py b/graphcast/data_utils_test.py new file mode 100644 index 0000000..d199785 --- /dev/null +++ b/graphcast/data_utils_test.py @@ -0,0 +1,201 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for `data_utils.py`.""" + +import datetime +from absl.testing import absltest +from absl.testing import parameterized +from graphcast import data_utils +import numpy as np +import xarray + + +class DataUtilsTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + # Fix the seed for reproducibility. + np.random.seed(0) + + def test_year_progress_is_zero_at_year_start_or_end(self): + year_progress = data_utils.get_year_progress( + np.array([ + 0, + data_utils.AVG_SEC_PER_YEAR, + data_utils.AVG_SEC_PER_YEAR * 42, # 42 years. + ]) + ) + np.testing.assert_array_equal(year_progress, np.zeros(year_progress.shape)) + + def test_year_progress_is_almost_one_before_year_ends(self): + year_progress = data_utils.get_year_progress( + np.array([ + data_utils.AVG_SEC_PER_YEAR - 1, + (data_utils.AVG_SEC_PER_YEAR - 1) * 42, # ~42 years + ]) + ) + with self.subTest("Year progress values are close to 1"): + self.assertTrue(np.all(year_progress > 0.999)) + with self.subTest("Year progress values != 1"): + self.assertTrue(np.all(year_progress < 1.0)) + + def test_day_progress_computes_for_all_times_and_longitudes(self): + times = np.random.randint(low=0, high=1e10, size=10) + longitudes = np.arange(0, 360.0, 1.0) + day_progress = data_utils.get_day_progress(times, longitudes) + with self.subTest("Day progress is computed for all times and longinutes"): + self.assertSequenceEqual( + day_progress.shape, (len(times), len(longitudes)) + ) + + @parameterized.named_parameters( + dict( + testcase_name="random_date_1", + year=1988, + month=11, + day=7, + hour=2, + minute=45, + second=34, + ), + dict( + testcase_name="random_date_2", + year=2022, + month=3, + day=12, + hour=7, + minute=1, + second=0, + ), + ) + def test_day_progress_is_in_between_zero_and_one( + self, year, month, day, hour, minute, second + ): + # Datetime from a timestamp. + dt = datetime.datetime(year, month, day, hour, minute, second) + # Epoch time. + epoch_time = datetime.datetime(1970, 1, 1) + # Seconds since epoch. + seconds_since_epoch = np.array([(dt - epoch_time).total_seconds()]) + + # Longitudes with 1 degree resolution. + longitudes = np.arange(0, 360.0, 1.0) + + day_progress = data_utils.get_day_progress(seconds_since_epoch, longitudes) + with self.subTest("Day progress >= 0"): + self.assertTrue(np.all(day_progress >= 0.0)) + with self.subTest("Day progress < 1"): + self.assertTrue(np.all(day_progress < 1.0)) + + def test_day_progress_is_zero_at_day_start_or_end(self): + day_progress = data_utils.get_day_progress( + seconds_since_epoch=np.array([ + 0, + data_utils.SEC_PER_DAY, + data_utils.SEC_PER_DAY * 42, # 42 days. + ]), + longitude=np.array([0.0]), + ) + np.testing.assert_array_equal(day_progress, np.zeros(day_progress.shape)) + + def test_day_progress_specific_value(self): + day_progress = data_utils.get_day_progress( + seconds_since_epoch=np.array([123]), + longitude=np.array([0.0]), + ) + np.testing.assert_array_almost_equal( + day_progress, np.array([[0.00142361]]), decimal=6 + ) + + def test_featurize_progress_valid_values_and_dimensions(self): + day_progress = np.array([0.0, 0.45, 0.213]) + feature_dimensions = ("time",) + progress_features = data_utils.featurize_progress( + name="day_progress", dims=feature_dimensions, progress=day_progress + ) + for feature in progress_features.values(): + with self.subTest(f"Valid dimensions for {feature}"): + self.assertSequenceEqual(feature.dims, feature_dimensions) + + with self.subTest("Valid values for day_progress"): + np.testing.assert_array_equal( + day_progress, progress_features["day_progress"].values + ) + + with self.subTest("Valid values for day_progress_sin"): + np.testing.assert_array_almost_equal( + np.array([0.0, 0.30901699, 0.97309851]), + progress_features["day_progress_sin"].values, + decimal=6, + ) + + with self.subTest("Valid values for day_progress_cos"): + np.testing.assert_array_almost_equal( + np.array([1.0, -0.95105652, 0.23038943]), + progress_features["day_progress_cos"].values, + decimal=6, + ) + + def test_featurize_progress_invalid_dimensions(self): + year_progress = np.array([0.0, 0.45, 0.213]) + feature_dimensions = ("time", "longitude") + with self.assertRaises(ValueError): + data_utils.featurize_progress( + name="year_progress", dims=feature_dimensions, progress=year_progress + ) + + def test_add_derived_vars_variables_added(self): + data = xarray.Dataset( + data_vars={ + "var1": (["x", "lon", "datetime"], 8 * np.random.randn(2, 2, 3)) + }, + coords={ + "lon": np.array([0.0, 0.5]), + "datetime": np.array([ + datetime.datetime(2021, 1, 1), + datetime.datetime(2023, 1, 1), + datetime.datetime(2023, 1, 3), + ]), + }, + ) + data_utils.add_derived_vars(data) + all_variables = set(data.variables) + + with self.subTest("Original value was not removed"): + self.assertIn("var1", all_variables) + with self.subTest("Year progress feature was added"): + self.assertIn(data_utils.YEAR_PROGRESS, all_variables) + with self.subTest("Day progress feature was added"): + self.assertIn(data_utils.DAY_PROGRESS, all_variables) + + @parameterized.named_parameters( + dict(testcase_name="missing_datetime", coord_name="lon"), + dict(testcase_name="missing_lon", coord_name="datetime"), + ) + def test_add_derived_vars_missing_coordinate_raises_value_error( + self, coord_name + ): + with self.subTest(f"Missing {coord_name} coordinate"): + data = xarray.Dataset( + data_vars={"var1": (["x", coord_name], 8 * np.random.randn(2, 2))}, + coords={ + coord_name: np.array([0.0, 0.5]), + }, + ) + with self.assertRaises(ValueError): + data_utils.add_derived_vars(data) + + +if __name__ == "__main__": + absltest.main() diff --git a/graphcast/deep_typed_graph_net.py b/graphcast/deep_typed_graph_net.py new file mode 100644 index 0000000..93a1bd3 --- /dev/null +++ b/graphcast/deep_typed_graph_net.py @@ -0,0 +1,391 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX implementation of Graph Networks Simulator. + +Generalization to TypedGraphs of the deep Graph Neural Network from: + +@inproceedings{pfaff2021learning, + title={Learning Mesh-Based Simulation with Graph Networks}, + author={Pfaff, Tobias and Fortunato, Meire and Sanchez-Gonzalez, Alvaro and + Battaglia, Peter}, + booktitle={International Conference on Learning Representations}, + year={2021} +} + +@inproceedings{sanchez2020learning, + title={Learning to simulate complex physics with graph networks}, + author={Sanchez-Gonzalez, Alvaro and Godwin, Jonathan and Pfaff, Tobias and + Ying, Rex and Leskovec, Jure and Battaglia, Peter}, + booktitle={International conference on machine learning}, + pages={8459--8468}, + year={2020}, + organization={PMLR} +} +""" + +from typing import Mapping, Optional + +from graphcast import typed_graph +from graphcast import typed_graph_net +import haiku as hk +import jax +import jax.numpy as jnp +import jraph + + +class DeepTypedGraphNet(hk.Module): + """Deep Graph Neural Network. + + It works with TypedGraphs with typed nodes and edges. It runs message + passing on all of the node sets and all of the edge sets in the graph. For + each message passing step a `typed_graph_net.InteractionNetwork` is used to + update the full TypedGraph by using different MLPs for each of the node sets + and each of the edge sets. + + If embed_{nodes,edges} is specified the node/edge features will be embedded + into a fixed dimensionality before running the first step of message passing. + + If {node,edge}_output_size the final node/edge features will be embedded into + the specified output size. + + This class may be used for shared or unshared message passing: + * num_message_passing_steps = N, num_processor_repetitions = 1, gives + N layers of message passing with fully unshared weights: + [W_1, W_2, ... , W_M] (default) + * num_message_passing_steps = 1, num_processor_repetitions = M, gives + N layers of message passing with fully shared weights: + [W_1] * M + * num_message_passing_steps = N, num_processor_repetitions = M, gives + M*N layers of message passing with both shared and unshared message passing + such that the weights used at each iteration are: + [W_1, W_2, ... , W_N] * M + + """ + + def __init__(self, + *, + node_latent_size: Mapping[str, int], + edge_latent_size: Mapping[str, int], + mlp_hidden_size: int, + mlp_num_hidden_layers: int, + num_message_passing_steps: int, + num_processor_repetitions: int = 1, + embed_nodes: bool = True, + embed_edges: bool = True, + node_output_size: Optional[Mapping[str, int]] = None, + edge_output_size: Optional[Mapping[str, int]] = None, + include_sent_messages_in_node_update: bool = False, + use_layer_norm: bool = True, + activation: str = "relu", + f32_aggregation: bool = False, + aggregate_edges_for_nodes_fn: str = "segment_sum", + aggregate_normalization: Optional[float] = None, + name: str = "DeepTypedGraphNet"): + """Inits the model. + + Args: + node_latent_size: Size of the node latent representations. + edge_latent_size: Size of the edge latent representations. + mlp_hidden_size: Hidden layer size for all MLPs. + mlp_num_hidden_layers: Number of hidden layers in all MLPs. + num_message_passing_steps: Number of unshared message passing steps + in the processor steps. + num_processor_repetitions: Number of times that the same processor is + applied sequencially. + embed_nodes: If False, the node embedder will be omitted. + embed_edges: If False, the edge embedder will be omitted. + node_output_size: Size of the output node representations for + each node type. For node types not specified here, the latent node + representation from the output of the processor will be returned. + edge_output_size: Size of the output edge representations for + each edge type. For edge types not specified here, the latent edge + representation from the output of the processor will be returned. + include_sent_messages_in_node_update: Whether to include pooled sent + messages from each node in the node update. + use_layer_norm: Whether it uses layer norm or not. + activation: name of activation function. + f32_aggregation: Use float32 in the edge aggregation. + aggregate_edges_for_nodes_fn: function used to aggregate messages to each + node. + aggregate_normalization: An optional constant that normalizes the output + of aggregate_edges_for_nodes_fn. For context, this can be used to + reduce the shock the model undergoes when switching resolution, which + increase the number of edges connected to a node. In particular, this is + useful when using segment_sum, but should not be combined with + segment_mean. + name: Name of the model. + """ + + super().__init__(name=name) + + self._node_latent_size = node_latent_size + self._edge_latent_size = edge_latent_size + self._mlp_hidden_size = mlp_hidden_size + self._mlp_num_hidden_layers = mlp_num_hidden_layers + self._num_message_passing_steps = num_message_passing_steps + self._num_processor_repetitions = num_processor_repetitions + self._embed_nodes = embed_nodes + self._embed_edges = embed_edges + self._node_output_size = node_output_size + self._edge_output_size = edge_output_size + self._include_sent_messages_in_node_update = ( + include_sent_messages_in_node_update) + self._use_layer_norm = use_layer_norm + self._activation = _get_activation_fn(activation) + self._initialized = False + self._f32_aggregation = f32_aggregation + self._aggregate_edges_for_nodes_fn = _get_aggregate_edges_for_nodes_fn( + aggregate_edges_for_nodes_fn) + self._aggregate_normalization = aggregate_normalization + + if aggregate_normalization: + # using aggregate_normalization only makes sense with segment_sum. + assert aggregate_edges_for_nodes_fn == "segment_sum" + + def __call__(self, + input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + """Forward pass of the learnable dynamics model.""" + self._networks_builder(input_graph) + + # Embed input features (if applicable). + latent_graph_0 = self._embed(input_graph) + + # Do `m` message passing steps in the latent graphs. + latent_graph_m = self._process(latent_graph_0) + + # Compute outputs from the last latent graph (if applicable). + return self._output(latent_graph_m) + + def _networks_builder(self, graph_template): + if self._initialized: + return + self._initialized = True + + def build_mlp(name, output_size): + mlp = hk.nets.MLP( + output_sizes=[self._mlp_hidden_size] * self._mlp_num_hidden_layers + [ + output_size], name=name + "_mlp", activation=self._activation) + return jraph.concatenated_args(mlp) + + def build_mlp_with_maybe_layer_norm(name, output_size): + network = build_mlp(name, output_size) + if self._use_layer_norm: + layer_norm = hk.LayerNorm( + axis=-1, create_scale=True, create_offset=True, + name=name + "_layer_norm") + network = hk.Sequential([network, layer_norm]) + return jraph.concatenated_args(network) + + # The embedder graph network independently embeds edge and node features. + if self._embed_edges: + embed_edge_fn = _build_update_fns_for_edge_types( + build_mlp_with_maybe_layer_norm, + graph_template, + "encoder_edges_", + output_sizes=self._edge_latent_size) + else: + embed_edge_fn = None + if self._embed_nodes: + embed_node_fn = _build_update_fns_for_node_types( + build_mlp_with_maybe_layer_norm, + graph_template, + "encoder_nodes_", + output_sizes=self._node_latent_size) + else: + embed_node_fn = None + embedder_kwargs = dict( + embed_edge_fn=embed_edge_fn, + embed_node_fn=embed_node_fn, + ) + self._embedder_network = typed_graph_net.GraphMapFeatures( + **embedder_kwargs) + + if self._f32_aggregation: + def aggregate_fn(data, *args, **kwargs): + dtype = data.dtype + data = data.astype(jnp.float32) + output = self._aggregate_edges_for_nodes_fn(data, *args, **kwargs) + if self._aggregate_normalization: + output = output / self._aggregate_normalization + output = output.astype(dtype) + return output + + else: + def aggregate_fn(data, *args, **kwargs): + output = self._aggregate_edges_for_nodes_fn(data, *args, **kwargs) + if self._aggregate_normalization: + output = output / self._aggregate_normalization + return output + + # Create `num_message_passing_steps` graph networks with unshared parameters + # that update the node and edge latent features. + # Note that we can use `modules.InteractionNetwork` because + # it also outputs the messages as updated edge latent features. + self._processor_networks = [] + for step_i in range(self._num_message_passing_steps): + self._processor_networks.append( + typed_graph_net.InteractionNetwork( + update_edge_fn=_build_update_fns_for_edge_types( + build_mlp_with_maybe_layer_norm, + graph_template, + f"processor_edges_{step_i}_", + output_sizes=self._edge_latent_size), + update_node_fn=_build_update_fns_for_node_types( + build_mlp_with_maybe_layer_norm, + graph_template, + f"processor_nodes_{step_i}_", + output_sizes=self._node_latent_size), + aggregate_edges_for_nodes_fn=aggregate_fn, + include_sent_messages_in_node_update=( + self._include_sent_messages_in_node_update), + )) + + # The output MLPs converts edge/node latent features into the output sizes. + output_kwargs = dict( + embed_edge_fn=_build_update_fns_for_edge_types( + build_mlp, graph_template, "decoder_edges_", self._edge_output_size) + if self._edge_output_size else None, + embed_node_fn=_build_update_fns_for_node_types( + build_mlp, graph_template, "decoder_nodes_", self._node_output_size) + if self._node_output_size else None,) + self._output_network = typed_graph_net.GraphMapFeatures( + **output_kwargs) + + def _embed( + self, input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + """Embeds the input graph features into a latent graph.""" + + # Copy the context to all of the node types, if applicable. + context_features = input_graph.context.features + if jax.tree_util.tree_leaves(context_features): + # This code assumes a single input feature array for the context and for + # each node type. + assert len(jax.tree_util.tree_leaves(context_features)) == 1 + new_nodes = {} + for node_set_name, node_set in input_graph.nodes.items(): + node_features = node_set.features + broadcasted_context = jnp.repeat( + context_features, node_set.n_node, axis=0, + total_repeat_length=node_features.shape[0]) + new_nodes[node_set_name] = node_set._replace( + features=jnp.concatenate( + [node_features, broadcasted_context], axis=-1)) + input_graph = input_graph._replace( + nodes=new_nodes, + context=input_graph.context._replace(features=())) + + # Embeds the node and edge features. + latent_graph_0 = self._embedder_network(input_graph) + return latent_graph_0 + + def _process( + self, latent_graph_0: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + """Processes the latent graph with several steps of message passing.""" + + # Do `num_message_passing_steps` with each of the `self._processor_networks` + # with unshared weights, and repeat that `self._num_processor_repetitions` + # times. + latent_graph = latent_graph_0 + for unused_repetition_i in range(self._num_processor_repetitions): + for processor_network in self._processor_networks: + latent_graph = self._process_step(processor_network, latent_graph) + + return latent_graph + + def _process_step( + self, processor_network_k, + latent_graph_prev_k: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + """Single step of message passing with node/edge residual connections.""" + + # One step of message passing. + latent_graph_k = processor_network_k(latent_graph_prev_k) + + # Add residuals. + nodes_with_residuals = {} + for k, prev_set in latent_graph_prev_k.nodes.items(): + nodes_with_residuals[k] = prev_set._replace( + features=prev_set.features + latent_graph_k.nodes[k].features) + + edges_with_residuals = {} + for k, prev_set in latent_graph_prev_k.edges.items(): + edges_with_residuals[k] = prev_set._replace( + features=prev_set.features + latent_graph_k.edges[k].features) + + latent_graph_k = latent_graph_k._replace( + nodes=nodes_with_residuals, edges=edges_with_residuals) + return latent_graph_k + + def _output(self, + latent_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + """Produces the output from the latent graph.""" + return self._output_network(latent_graph) + + +def _build_update_fns_for_node_types( + builder_fn, graph_template, prefix, output_sizes=None): + """Builds an update function for all node types or a subset of them.""" + + output_fns = {} + for node_set_name in graph_template.nodes.keys(): + if output_sizes is None: + # Use the default output size for all types. + output_size = None + else: + # Otherwise, ignore any type that does not have an explicit output size. + if node_set_name in output_sizes: + output_size = output_sizes[node_set_name] + else: + continue + output_fns[node_set_name] = builder_fn( + f"{prefix}{node_set_name}", output_size) + return output_fns + + +def _build_update_fns_for_edge_types( + builder_fn, graph_template, prefix, output_sizes=None): + """Builds an edge function for all node types or a subset of them.""" + output_fns = {} + for edge_set_key in graph_template.edges.keys(): + edge_set_name = edge_set_key.name + if output_sizes is None: + # Use the default output size for all types. + output_size = None + else: + # Otherwise, ignore any type that does not have an explicit output size. + if edge_set_name in output_sizes: + output_size = output_sizes[edge_set_name] + else: + continue + output_fns[edge_set_name] = builder_fn( + f"{prefix}{edge_set_name}", output_size) + return output_fns + + +def _get_activation_fn(name): + """Return activation function corresponding to function_name.""" + if name == "identity": + return lambda x: x + if hasattr(jax.nn, name): + return getattr(jax.nn, name) + if hasattr(jnp, name): + return getattr(jnp, name) + raise ValueError(f"Unknown activation function {name} specified.") + + +def _get_aggregate_edges_for_nodes_fn(name): + """Return aggregate_edges_for_nodes_fn corresponding to function_name.""" + if hasattr(jraph, name): + return getattr(jraph, name) + raise ValueError( + f"Unknown aggregate_edges_for_nodes_fn function {name} specified.") diff --git a/graphcast/graphcast.py b/graphcast/graphcast.py new file mode 100644 index 0000000..244128f --- /dev/null +++ b/graphcast/graphcast.py @@ -0,0 +1,796 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A predictor that runs multiple graph neural networks on mesh data. + +It learns to interpolate between the grid and the mesh nodes, with the loss +and the rollouts ultimately computed at the grid level. + +It uses ideas similar to those in Keisler (2022): + +Reference: + https://arxiv.org/pdf/2202.07575.pdf + +It assumes data across time and level is stacked, and operates only operates in +a 2D mesh over latitudes and longitudes. +""" + +from typing import Any, Callable, Mapping, Optional + +import chex +from graphcast import deep_typed_graph_net +from graphcast import grid_mesh_connectivity +from graphcast import icosahedral_mesh +from graphcast import losses +from graphcast import model_utils +from graphcast import predictor_base +from graphcast import typed_graph +from graphcast import xarray_jax +import jax.numpy as jnp +import jraph +import numpy as np +import xarray + +Kwargs = Mapping[str, Any] + +GNN = Callable[[jraph.GraphsTuple], jraph.GraphsTuple] + + +# https://www.ecmwf.int/en/forecasts/dataset/ecmwf-reanalysis-v5 +PRESSURE_LEVELS_ERA5_37 = ( + 1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 125, 150, 175, 200, 225, 250, 300, + 350, 400, 450, 500, 550, 600, 650, 700, 750, 775, 800, 825, 850, 875, 900, + 925, 950, 975, 1000) + +# https://www.ecmwf.int/en/forecasts/datasets/set-i +PRESSURE_LEVELS_HRES_25 = ( + 1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 400, 500, 600, + 700, 800, 850, 900, 925, 950, 1000) + +# https://agupubs.onlinelibrary.wiley.com/doi/full/10.1029/2020MS002203 +PRESSURE_LEVELS_WEATHERBENCH_13 = ( + 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000) + +PRESSURE_LEVELS = { + 13: PRESSURE_LEVELS_WEATHERBENCH_13, + 25: PRESSURE_LEVELS_HRES_25, + 37: PRESSURE_LEVELS_ERA5_37, +} + +# The list of all possible atmospheric variables. Taken from: +# https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation#ERA5:datadocumentation-Table9 +ALL_ATMOSPHERIC_VARS = ( + "potential_vorticity", + "specific_rain_water_content", + "specific_snow_water_content", + "geopotential", + "temperature", + "u_component_of_wind", + "v_component_of_wind", + "specific_humidity", + "vertical_velocity", + "vorticity", + "divergence", + "relative_humidity", + "ozone_mass_mixing_ratio", + "specific_cloud_liquid_water_content", + "specific_cloud_ice_water_content", + "fraction_of_cloud_cover", +) + +TARGET_SURFACE_VARS = ( + "2m_temperature", + "mean_sea_level_pressure", + "10m_v_component_of_wind", + "10m_u_component_of_wind", + "total_precipitation_6hr", +) +TARGET_SURFACE_NO_PRECIP_VARS = ( + "2m_temperature", + "mean_sea_level_pressure", + "10m_v_component_of_wind", + "10m_u_component_of_wind", +) +TARGET_ATMOSPHERIC_VARS = ( + "temperature", + "geopotential", + "u_component_of_wind", + "v_component_of_wind", + "vertical_velocity", + "specific_humidity", +) +TARGET_ATMOSPHERIC_NO_W_VARS = ( + "temperature", + "geopotential", + "u_component_of_wind", + "v_component_of_wind", + "specific_humidity", +) +EXTERNAL_FORCING_VARS = ( + "toa_incident_solar_radiation", +) +GENERATED_FORCING_VARS = ( + "year_progress_sin", + "year_progress_cos", + "day_progress_sin", + "day_progress_cos", +) +FORCING_VARS = EXTERNAL_FORCING_VARS + GENERATED_FORCING_VARS +STATIC_VARS = ( + "geopotential_at_surface", + "land_sea_mask", +) + + +@chex.dataclass(frozen=True, eq=True) +class TaskConfig: + """Defines inputs and targets on which a model is trained and/or evaluated.""" + input_variables: tuple[str, ...] + # Target variables which the model is expected to predict. + target_variables: tuple[str, ...] + forcing_variables: tuple[str, ...] + pressure_levels: tuple[int, ...] + input_duration: str + +TASK = TaskConfig( + input_variables=( + TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS + + STATIC_VARS), + target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS, + forcing_variables=FORCING_VARS, + pressure_levels=PRESSURE_LEVELS_ERA5_37, + input_duration="12h", +) +TASK_13 = TaskConfig( + input_variables=( + TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS + + STATIC_VARS), + target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS, + forcing_variables=FORCING_VARS, + pressure_levels=PRESSURE_LEVELS_WEATHERBENCH_13, + input_duration="12h", +) +TASK_13_PRECIP_OUT = TaskConfig( + input_variables=( + TARGET_SURFACE_NO_PRECIP_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS + + STATIC_VARS), + target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS, + forcing_variables=FORCING_VARS, + pressure_levels=PRESSURE_LEVELS_WEATHERBENCH_13, + input_duration="12h", +) + + +@chex.dataclass(frozen=True, eq=True) +class ModelConfig: + """Defines the architecture of the GraphCast neural network architecture. + + Properties: + resolution: The resolution of the data, in degrees (e.g. 0.25 or 1.0). + mesh_size: How many refinements to do on the multi-mesh. + gnn_msg_steps: How many Graph Network message passing steps to do. + latent_size: How many latent features to include in the various MLPs. + hidden_layers: How many hidden layers for each MLP. + radius_query_fraction_edge_length: Scalar that will be multiplied by the + length of the longest edge of the finest mesh to define the radius of + connectivity to use in the Grid2Mesh graph. Reasonable values are + between 0.6 and 1. 0.6 reduces the number of grid points feeding into + multiple mesh nodes and therefore reduces edge count and memory use, but + 1 gives better predictions. + mesh2grid_edge_normalization_factor: Allows explicitly controlling edge + normalization for mesh2grid edges. If None, defaults to max edge length. + This supports using pre-trained model weights with a different graph + structure to what it was trained on. + """ + resolution: float + mesh_size: int + latent_size: int + gnn_msg_steps: int + hidden_layers: int + radius_query_fraction_edge_length: float + mesh2grid_edge_normalization_factor: Optional[float] = None + + +@chex.dataclass(frozen=True, eq=True) +class CheckPoint: + params: dict[str, Any] + model_config: ModelConfig + task_config: TaskConfig + description: str + license: str + + +class GraphCast(predictor_base.Predictor): + """GraphCast Predictor. + + The model works on graphs that take into account: + * Mesh nodes: nodes for the vertices of the mesh. + * Grid nodes: nodes for the points of the grid. + * Nodes: When referring to just "nodes", this means the joint set of + both mesh nodes, concatenated with grid nodes. + + The model works with 3 graphs: + * Grid2Mesh graph: Graph that contains all nodes. This graph is strictly + bipartite with edges going from grid nodes to mesh nodes using a + fixed radius query. The grid2mesh_gnn will operate in this graph. The output + of this stage will be a latent representation for the mesh nodes, and a + latent representation for the grid nodes. + * Mesh graph: Graph that contains mesh nodes only. The mesh_gnn will + operate in this graph. It will update the latent state of the mesh nodes + only. + * Mesh2Grid graph: Graph that contains all nodes. This graph is strictly + bipartite with edges going from mesh nodes to grid nodes such that each grid + nodes is connected to 3 nodes of the mesh triangular face that contains + the grid points. The mesh2grid_gnn will operate in this graph. It will + process the updated latent state of the mesh nodes, and the latent state + of the grid nodes, to produce the final output for the grid nodes. + + The model is built on top of `TypedGraph`s so the different types of nodes and + edges can be stored and treated separately. + + """ + + def __init__(self, model_config: ModelConfig, task_config: TaskConfig): + """Initializes the predictor.""" + self._spatial_features_kwargs = dict( + add_node_positions=False, + add_node_latitude=True, + add_node_longitude=True, + add_relative_positions=True, + relative_longitude_local_coordinates=True, + relative_latitude_local_coordinates=True, + ) + + # Specification of the multimesh. + self._meshes = ( + icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( + splits=model_config.mesh_size)) + + # Encoder, which moves data from the grid to the mesh with a single message + # passing step. + self._grid2mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet( + embed_nodes=True, # Embed raw features of the grid and mesh nodes. + embed_edges=True, # Embed raw features of the grid2mesh edges. + edge_latent_size=dict(grid2mesh=model_config.latent_size), + node_latent_size=dict( + mesh_nodes=model_config.latent_size, + grid_nodes=model_config.latent_size), + mlp_hidden_size=model_config.latent_size, + mlp_num_hidden_layers=model_config.hidden_layers, + num_message_passing_steps=1, + use_layer_norm=True, + include_sent_messages_in_node_update=False, + activation="swish", + f32_aggregation=True, + aggregate_normalization=None, + name="grid2mesh_gnn", + ) + + # Processor, which performs message passing on the multi-mesh. + self._mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet( + embed_nodes=False, # Node features already embdded by previous layers. + embed_edges=True, # Embed raw features of the multi-mesh edges. + node_latent_size=dict(mesh_nodes=model_config.latent_size), + edge_latent_size=dict(mesh=model_config.latent_size), + mlp_hidden_size=model_config.latent_size, + mlp_num_hidden_layers=model_config.hidden_layers, + num_message_passing_steps=model_config.gnn_msg_steps, + use_layer_norm=True, + include_sent_messages_in_node_update=False, + activation="swish", + f32_aggregation=False, + name="mesh_gnn", + ) + + num_surface_vars = len( + set(task_config.target_variables) - set(ALL_ATMOSPHERIC_VARS)) + num_atmospheric_vars = len( + set(task_config.target_variables) & set(ALL_ATMOSPHERIC_VARS)) + num_outputs = (num_surface_vars + + len(task_config.pressure_levels) * num_atmospheric_vars) + + # Decoder, which moves data from the mesh back into the grid with a single + # message passing step. + self._mesh2grid_gnn = deep_typed_graph_net.DeepTypedGraphNet( + # Require a specific node dimensionaly for the grid node outputs. + node_output_size=dict(grid_nodes=num_outputs), + embed_nodes=False, # Node features already embdded by previous layers. + embed_edges=True, # Embed raw features of the mesh2grid edges. + edge_latent_size=dict(mesh2grid=model_config.latent_size), + node_latent_size=dict( + mesh_nodes=model_config.latent_size, + grid_nodes=model_config.latent_size), + mlp_hidden_size=model_config.latent_size, + mlp_num_hidden_layers=model_config.hidden_layers, + num_message_passing_steps=1, + use_layer_norm=True, + include_sent_messages_in_node_update=False, + activation="swish", + f32_aggregation=False, + name="mesh2grid_gnn", + ) + + # Obtain the query radius in absolute units for the unit-sphere for the + # grid2mesh model, by rescaling the `radius_query_fraction_edge_length`. + self._query_radius = (_get_max_edge_distance(self._finest_mesh) + * model_config.radius_query_fraction_edge_length) + self._mesh2grid_edge_normalization_factor = ( + model_config.mesh2grid_edge_normalization_factor + ) + + # Other initialization is delayed until the first call (`_maybe_init`) + # when we get some sample data so we know the lat/lon values. + self._initialized = False + + # A "_init_mesh_properties": + # This one could be initialized at init but we delay it for consistency too. + self._num_mesh_nodes = None # num_mesh_nodes + self._mesh_nodes_lat = None # [num_mesh_nodes] + self._mesh_nodes_lon = None # [num_mesh_nodes] + + # A "_init_grid_properties": + self._grid_lat = None # [num_lat_points] + self._grid_lon = None # [num_lon_points] + self._num_grid_nodes = None # num_lat_points * num_lon_points + self._grid_nodes_lat = None # [num_grid_nodes] + self._grid_nodes_lon = None # [num_grid_nodes] + + # A "_init_{grid2mesh,processor,mesh2grid}_graph" + self._grid2mesh_graph_structure = None + self._mesh_graph_structure = None + self._mesh2grid_graph_structure = None + + @property + def _finest_mesh(self): + return self._meshes[-1] + + def __call__(self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: xarray.Dataset, + is_training: bool = False, + ) -> xarray.Dataset: + self._maybe_init(inputs) + + # Convert all input data into flat vectors for each of the grid nodes. + # xarray (batch, time, lat, lon, level, multiple vars, forcings) + # -> [num_grid_nodes, batch, num_channels] + grid_node_features = self._inputs_to_grid_node_features(inputs, forcings) + + # Transfer data for the grid to the mesh, + # [num_mesh_nodes, batch, latent_size], [num_grid_nodes, batch, latent_size] + (latent_mesh_nodes, latent_grid_nodes + ) = self._run_grid2mesh_gnn(grid_node_features) + + # Run message passing in the multimesh. + # [num_mesh_nodes, batch, latent_size] + updated_latent_mesh_nodes = self._run_mesh_gnn(latent_mesh_nodes) + + # Transfer data frome the mesh to the grid. + # [num_grid_nodes, batch, output_size] + output_grid_nodes = self._run_mesh2grid_gnn( + updated_latent_mesh_nodes, latent_grid_nodes) + + # Conver output flat vectors for the grid nodes to the format of the output. + # [num_grid_nodes, batch, output_size] -> + # xarray (batch, one time step, lat, lon, level, multiple vars) + return self._grid_node_outputs_to_prediction( + output_grid_nodes, targets_template) + + def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray + self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: xarray.Dataset, + ) -> tuple[predictor_base.LossAndDiagnostics, xarray.Dataset]: + # Forward pass. + predictions = self( + inputs, targets_template=targets, forcings=forcings, is_training=True) + # Compute loss. + loss = losses.weighted_mse_per_level( + predictions, targets, + per_variable_weights={ + # Any variables not specified here are weighted as 1.0. + # A single-level variable, but an important headline variable + # and also one which we have struggled to get good performance + # on at short lead times, so leaving it weighted at 1.0, equal + # to the multi-level variables: + "2m_temperature": 1.0, + # New single-level variables, which we don't weight too highly + # to avoid hurting performance on other variables. + "10m_u_component_of_wind": 0.1, + "10m_v_component_of_wind": 0.1, + "mean_sea_level_pressure": 0.1, + "total_precipitation_6hr": 0.1, + }) + return loss, predictions # pytype: disable=bad-return-type # jax-ndarray + + def loss( # pytype: disable=signature-mismatch # jax-ndarray + self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: xarray.Dataset, + ) -> predictor_base.LossAndDiagnostics: + loss, _ = self.loss_and_predictions(inputs, targets, forcings) + return loss # pytype: disable=bad-return-type # jax-ndarray + + def _maybe_init(self, sample_inputs: xarray.Dataset): + """Inits everything that has a dependency on the input coordinates.""" + if not self._initialized: + self._init_mesh_properties() + self._init_grid_properties( + grid_lat=sample_inputs.lat, grid_lon=sample_inputs.lon) + self._grid2mesh_graph_structure = self._init_grid2mesh_graph() + self._mesh_graph_structure = self._init_mesh_graph() + self._mesh2grid_graph_structure = self._init_mesh2grid_graph() + + self._initialized = True + + def _init_mesh_properties(self): + """Inits static properties that have to do with mesh nodes.""" + self._num_mesh_nodes = self._finest_mesh.vertices.shape[0] + mesh_phi, mesh_theta = model_utils.cartesian_to_spherical( + self._finest_mesh.vertices[:, 0], + self._finest_mesh.vertices[:, 1], + self._finest_mesh.vertices[:, 2]) + ( + mesh_nodes_lat, + mesh_nodes_lon, + ) = model_utils.spherical_to_lat_lon( + phi=mesh_phi, theta=mesh_theta) + # Convert to f32 to ensure the lat/lon features aren't in f64. + self._mesh_nodes_lat = mesh_nodes_lat.astype(np.float32) + self._mesh_nodes_lon = mesh_nodes_lon.astype(np.float32) + + def _init_grid_properties(self, grid_lat: np.ndarray, grid_lon: np.ndarray): + """Inits static properties that have to do with grid nodes.""" + self._grid_lat = grid_lat.astype(np.float32) + self._grid_lon = grid_lon.astype(np.float32) + # Initialized the counters. + self._num_grid_nodes = grid_lat.shape[0] * grid_lon.shape[0] + + # Initialize lat and lon for the grid. + grid_nodes_lon, grid_nodes_lat = np.meshgrid(grid_lon, grid_lat) + self._grid_nodes_lon = grid_nodes_lon.reshape([-1]).astype(np.float32) + self._grid_nodes_lat = grid_nodes_lat.reshape([-1]).astype(np.float32) + + def _init_grid2mesh_graph(self) -> typed_graph.TypedGraph: + """Build Grid2Mesh graph.""" + + # Create some edges according to distance between mesh and grid nodes. + assert self._grid_lat is not None and self._grid_lon is not None + (grid_indices, mesh_indices) = grid_mesh_connectivity.radius_query_indices( + grid_latitude=self._grid_lat, + grid_longitude=self._grid_lon, + mesh=self._finest_mesh, + radius=self._query_radius) + + # Edges sending info from grid to mesh. + senders = grid_indices + receivers = mesh_indices + + # Precompute structural node and edge features according to config options. + # Structural features are those that depend on the fixed values of the + # latitude and longitudes of the nodes. + (senders_node_features, receivers_node_features, + edge_features) = model_utils.get_bipartite_graph_spatial_features( + senders_node_lat=self._grid_nodes_lat, + senders_node_lon=self._grid_nodes_lon, + receivers_node_lat=self._mesh_nodes_lat, + receivers_node_lon=self._mesh_nodes_lon, + senders=senders, + receivers=receivers, + edge_normalization_factor=None, + **self._spatial_features_kwargs, + ) + + n_grid_node = np.array([self._num_grid_nodes]) + n_mesh_node = np.array([self._num_mesh_nodes]) + n_edge = np.array([mesh_indices.shape[0]]) + grid_node_set = typed_graph.NodeSet( + n_node=n_grid_node, features=senders_node_features) + mesh_node_set = typed_graph.NodeSet( + n_node=n_mesh_node, features=receivers_node_features) + edge_set = typed_graph.EdgeSet( + n_edge=n_edge, + indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), + features=edge_features) + nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set} + edges = { + typed_graph.EdgeSetKey("grid2mesh", ("grid_nodes", "mesh_nodes")): + edge_set + } + grid2mesh_graph = typed_graph.TypedGraph( + context=typed_graph.Context(n_graph=np.array([1]), features=()), + nodes=nodes, + edges=edges) + return grid2mesh_graph + + def _init_mesh_graph(self) -> typed_graph.TypedGraph: + """Build Mesh graph.""" + merged_mesh = icosahedral_mesh.merge_meshes(self._meshes) + + # Work simply on the mesh edges. + senders, receivers = icosahedral_mesh.faces_to_edges(merged_mesh.faces) + + # Precompute structural node and edge features according to config options. + # Structural features are those that depend on the fixed values of the + # latitude and longitudes of the nodes. + assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None + node_features, edge_features = model_utils.get_graph_spatial_features( + node_lat=self._mesh_nodes_lat, + node_lon=self._mesh_nodes_lon, + senders=senders, + receivers=receivers, + **self._spatial_features_kwargs, + ) + + n_mesh_node = np.array([self._num_mesh_nodes]) + n_edge = np.array([senders.shape[0]]) + assert n_mesh_node == len(node_features) + mesh_node_set = typed_graph.NodeSet( + n_node=n_mesh_node, features=node_features) + edge_set = typed_graph.EdgeSet( + n_edge=n_edge, + indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), + features=edge_features) + nodes = {"mesh_nodes": mesh_node_set} + edges = { + typed_graph.EdgeSetKey("mesh", ("mesh_nodes", "mesh_nodes")): edge_set + } + mesh_graph = typed_graph.TypedGraph( + context=typed_graph.Context(n_graph=np.array([1]), features=()), + nodes=nodes, + edges=edges) + + return mesh_graph + + def _init_mesh2grid_graph(self) -> typed_graph.TypedGraph: + """Build Mesh2Grid graph.""" + + # Create some edges according to how the grid nodes are contained by + # mesh triangles. + (grid_indices, + mesh_indices) = grid_mesh_connectivity.in_mesh_triangle_indices( + grid_latitude=self._grid_lat, + grid_longitude=self._grid_lon, + mesh=self._finest_mesh) + + # Edges sending info from mesh to grid. + senders = mesh_indices + receivers = grid_indices + + # Precompute structural node and edge features according to config options. + assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None + (senders_node_features, receivers_node_features, + edge_features) = model_utils.get_bipartite_graph_spatial_features( + senders_node_lat=self._mesh_nodes_lat, + senders_node_lon=self._mesh_nodes_lon, + receivers_node_lat=self._grid_nodes_lat, + receivers_node_lon=self._grid_nodes_lon, + senders=senders, + receivers=receivers, + edge_normalization_factor=self._mesh2grid_edge_normalization_factor, + **self._spatial_features_kwargs, + ) + + n_grid_node = np.array([self._num_grid_nodes]) + n_mesh_node = np.array([self._num_mesh_nodes]) + n_edge = np.array([senders.shape[0]]) + grid_node_set = typed_graph.NodeSet( + n_node=n_grid_node, features=receivers_node_features) + mesh_node_set = typed_graph.NodeSet( + n_node=n_mesh_node, features=senders_node_features) + edge_set = typed_graph.EdgeSet( + n_edge=n_edge, + indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), + features=edge_features) + nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set} + edges = { + typed_graph.EdgeSetKey("mesh2grid", ("mesh_nodes", "grid_nodes")): + edge_set + } + mesh2grid_graph = typed_graph.TypedGraph( + context=typed_graph.Context(n_graph=np.array([1]), features=()), + nodes=nodes, + edges=edges) + return mesh2grid_graph + + def _run_grid2mesh_gnn(self, grid_node_features: chex.Array, + ) -> tuple[chex.Array, chex.Array]: + """Runs the grid2mesh_gnn, extracting latent mesh and grid nodes.""" + + # Concatenate node structural features with input features. + batch_size = grid_node_features.shape[1] + + grid2mesh_graph = self._grid2mesh_graph_structure + assert grid2mesh_graph is not None + grid_nodes = grid2mesh_graph.nodes["grid_nodes"] + mesh_nodes = grid2mesh_graph.nodes["mesh_nodes"] + new_grid_nodes = grid_nodes._replace( + features=jnp.concatenate([ + grid_node_features, + _add_batch_second_axis( + grid_nodes.features.astype(grid_node_features.dtype), + batch_size) + ], + axis=-1)) + + # To make sure capacity of the embedded is identical for the grid nodes and + # the mesh nodes, we also append some dummy zero input features for the + # mesh nodes. + dummy_mesh_node_features = jnp.zeros( + (self._num_mesh_nodes,) + grid_node_features.shape[1:], + dtype=grid_node_features.dtype) + new_mesh_nodes = mesh_nodes._replace( + features=jnp.concatenate([ + dummy_mesh_node_features, + _add_batch_second_axis( + mesh_nodes.features.astype(dummy_mesh_node_features.dtype), + batch_size) + ], + axis=-1)) + + # Broadcast edge structural features to the required batch size. + grid2mesh_edges_key = grid2mesh_graph.edge_key_by_name("grid2mesh") + edges = grid2mesh_graph.edges[grid2mesh_edges_key] + + new_edges = edges._replace( + features=_add_batch_second_axis( + edges.features.astype(dummy_mesh_node_features.dtype), batch_size)) + + input_graph = self._grid2mesh_graph_structure._replace( + edges={grid2mesh_edges_key: new_edges}, + nodes={ + "grid_nodes": new_grid_nodes, + "mesh_nodes": new_mesh_nodes + }) + + # Run the GNN. + grid2mesh_out = self._grid2mesh_gnn(input_graph) + latent_mesh_nodes = grid2mesh_out.nodes["mesh_nodes"].features + latent_grid_nodes = grid2mesh_out.nodes["grid_nodes"].features + return latent_mesh_nodes, latent_grid_nodes + + def _run_mesh_gnn(self, latent_mesh_nodes: chex.Array) -> chex.Array: + """Runs the mesh_gnn, extracting updated latent mesh nodes.""" + + # Add the structural edge features of this graph. Note we don't need + # to add the structural node features, because these are already part of + # the latent state, via the original Grid2Mesh gnn, however, we need + # the edge ones, because it is the first time we are seeing this particular + # set of edges. + batch_size = latent_mesh_nodes.shape[1] + + mesh_graph = self._mesh_graph_structure + assert mesh_graph is not None + mesh_edges_key = mesh_graph.edge_key_by_name("mesh") + edges = mesh_graph.edges[mesh_edges_key] + + # We are assuming here that the mesh gnn uses a single set of edge keys + # named "mesh" for the edges and that it uses a single set of nodes named + # "mesh_nodes" + msg = ("The setup currently requires to only have one kind of edge in the" + " mesh GNN.") + assert len(mesh_graph.edges) == 1, msg + + new_edges = edges._replace( + features=_add_batch_second_axis( + edges.features.astype(latent_mesh_nodes.dtype), batch_size)) + + nodes = mesh_graph.nodes["mesh_nodes"] + nodes = nodes._replace(features=latent_mesh_nodes) + + input_graph = mesh_graph._replace( + edges={mesh_edges_key: new_edges}, nodes={"mesh_nodes": nodes}) + + # Run the GNN. + return self._mesh_gnn(input_graph).nodes["mesh_nodes"].features + + def _run_mesh2grid_gnn(self, + updated_latent_mesh_nodes: chex.Array, + latent_grid_nodes: chex.Array, + ) -> chex.Array: + """Runs the mesh2grid_gnn, extracting the output grid nodes.""" + + # Add the structural edge features of this graph. Note we don't need + # to add the structural node features, because these are already part of + # the latent state, via the original Grid2Mesh gnn, however, we need + # the edge ones, because it is the first time we are seeing this particular + # set of edges. + batch_size = updated_latent_mesh_nodes.shape[1] + + mesh2grid_graph = self._mesh2grid_graph_structure + assert mesh2grid_graph is not None + mesh_nodes = mesh2grid_graph.nodes["mesh_nodes"] + grid_nodes = mesh2grid_graph.nodes["grid_nodes"] + new_mesh_nodes = mesh_nodes._replace(features=updated_latent_mesh_nodes) + new_grid_nodes = grid_nodes._replace(features=latent_grid_nodes) + mesh2grid_key = mesh2grid_graph.edge_key_by_name("mesh2grid") + edges = mesh2grid_graph.edges[mesh2grid_key] + + new_edges = edges._replace( + features=_add_batch_second_axis( + edges.features.astype(latent_grid_nodes.dtype), batch_size)) + + input_graph = mesh2grid_graph._replace( + edges={mesh2grid_key: new_edges}, + nodes={ + "mesh_nodes": new_mesh_nodes, + "grid_nodes": new_grid_nodes + }) + + # Run the GNN. + output_graph = self._mesh2grid_gnn(input_graph) + output_grid_nodes = output_graph.nodes["grid_nodes"].features + + return output_grid_nodes + + def _inputs_to_grid_node_features( + self, + inputs: xarray.Dataset, + forcings: xarray.Dataset, + ) -> chex.Array: + """xarrays -> [num_grid_nodes, batch, num_channels].""" + + # xarray `Dataset` (batch, time, lat, lon, level, multiple vars) + # to xarray `DataArray` (batch, lat, lon, channels) + stacked_inputs = model_utils.dataset_to_stacked(inputs) + stacked_forcings = model_utils.dataset_to_stacked(forcings) + stacked_inputs = xarray.concat( + [stacked_inputs, stacked_forcings], dim="channels") + + # xarray `DataArray` (batch, lat, lon, channels) + # to single numpy array with shape [lat_lon_node, batch, channels] + grid_xarray_lat_lon_leading = model_utils.lat_lon_to_leading_axes( + stacked_inputs) + return xarray_jax.unwrap(grid_xarray_lat_lon_leading.data).reshape( + (-1,) + grid_xarray_lat_lon_leading.data.shape[2:]) + + def _grid_node_outputs_to_prediction( + self, + grid_node_outputs: chex.Array, + targets_template: xarray.Dataset, + ) -> xarray.Dataset: + """[num_grid_nodes, batch, num_outputs] -> xarray.""" + + # numpy array with shape [lat_lon_node, batch, channels] + # to xarray `DataArray` (batch, lat, lon, channels) + assert self._grid_lat is not None and self._grid_lon is not None + grid_shape = (self._grid_lat.shape[0], self._grid_lon.shape[0]) + grid_outputs_lat_lon_leading = grid_node_outputs.reshape( + grid_shape + grid_node_outputs.shape[1:]) + dims = ("lat", "lon", "batch", "channels") + grid_xarray_lat_lon_leading = xarray_jax.DataArray( + data=grid_outputs_lat_lon_leading, + dims=dims) + grid_xarray = model_utils.restore_leading_axes(grid_xarray_lat_lon_leading) + + # xarray `DataArray` (batch, lat, lon, channels) + # to xarray `Dataset` (batch, one time step, lat, lon, level, multiple vars) + return model_utils.stacked_to_dataset( + grid_xarray.variable, targets_template) + + +def _add_batch_second_axis(data, batch_size): + # data [leading_dim, trailing_dim] + assert data.ndim == 2 + ones = jnp.ones([batch_size, 1], dtype=data.dtype) + return data[:, None] * ones # [leading_dim, batch, trailing_dim] + + +def _get_max_edge_distance(mesh): + senders, receivers = icosahedral_mesh.faces_to_edges(mesh.faces) + edge_distances = np.linalg.norm( + mesh.vertices[senders] - mesh.vertices[receivers], axis=-1) + return edge_distances.max() diff --git a/graphcast/grid_mesh_connectivity.py b/graphcast/grid_mesh_connectivity.py new file mode 100644 index 0000000..5cfc1b6 --- /dev/null +++ b/graphcast/grid_mesh_connectivity.py @@ -0,0 +1,133 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tools for converting from regular grids on a sphere, to triangular meshes.""" + +from graphcast import icosahedral_mesh +import numpy as np +import scipy +import trimesh + + +def _grid_lat_lon_to_coordinates( + grid_latitude: np.ndarray, grid_longitude: np.ndarray) -> np.ndarray: + """Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3].""" + # Convert to spherical coordinates phi and theta defined in the grid. + # Each [num_latitude_points, num_longitude_points] + phi_grid, theta_grid = np.meshgrid( + np.deg2rad(grid_longitude), + np.deg2rad(90 - grid_latitude)) + + # [num_latitude_points, num_longitude_points, 3] + # Note this assumes unit radius, since for now we model the earth as a + # sphere of unit radius, and keep any vertical dimension as a regular grid. + return np.stack( + [np.cos(phi_grid)*np.sin(theta_grid), + np.sin(phi_grid)*np.sin(theta_grid), + np.cos(theta_grid)], axis=-1) + + +def radius_query_indices( + *, + grid_latitude: np.ndarray, + grid_longitude: np.ndarray, + mesh: icosahedral_mesh.TriangularMesh, + radius: float) -> tuple[np.ndarray, np.ndarray]: + """Returns mesh-grid edge indices for radius query. + + Args: + grid_latitude: Latitude values for the grid [num_lat_points] + grid_longitude: Longitude values for the grid [num_lon_points] + mesh: Mesh object. + radius: Radius of connectivity in R3. for a sphere of unit radius. + + Returns: + tuple with `grid_indices` and `mesh_indices` indicating edges between the + grid and the mesh such that the distances in a straight line (not geodesic) + are smaller than or equal to `radius`. + * grid_indices: Indices of shape [num_edges], that index into a + [num_lat_points, num_lon_points] grid, after flattening the leading axes. + * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices. + """ + + # [num_grid_points=num_lat_points * num_lon_points, 3] + grid_positions = _grid_lat_lon_to_coordinates( + grid_latitude, grid_longitude).reshape([-1, 3]) + + # [num_mesh_points, 3] + mesh_positions = mesh.vertices + kd_tree = scipy.spatial.cKDTree(mesh_positions) + + # [num_grid_points, num_mesh_points_per_grid_point] + # Note `num_mesh_points_per_grid_point` is not constant, so this is a list + # of arrays, rather than a 2d array. + query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius) + + grid_edge_indices = [] + mesh_edge_indices = [] + for grid_index, mesh_neighbors in enumerate(query_indices): + grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors))) + mesh_edge_indices.append(mesh_neighbors) + + # [num_edges] + grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int) + mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int) + + return grid_edge_indices, mesh_edge_indices + + +def in_mesh_triangle_indices( + *, + grid_latitude: np.ndarray, + grid_longitude: np.ndarray, + mesh: icosahedral_mesh.TriangularMesh) -> tuple[np.ndarray, np.ndarray]: + """Returns mesh-grid edge indices for grid points contained in mesh triangles. + + Args: + grid_latitude: Latitude values for the grid [num_lat_points] + grid_longitude: Longitude values for the grid [num_lon_points] + mesh: Mesh object. + + Returns: + tuple with `grid_indices` and `mesh_indices` indicating edges between the + grid and the mesh vertices of the triangle that contain each grid point. + The number of edges is always num_lat_points * num_lon_points * 3 + * grid_indices: Indices of shape [num_edges], that index into a + [num_lat_points, num_lon_points] grid, after flattening the leading axes. + * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices. + """ + + # [num_grid_points=num_lat_points * num_lon_points, 3] + grid_positions = _grid_lat_lon_to_coordinates( + grid_latitude, grid_longitude).reshape([-1, 3]) + + mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces) + + # [num_grid_points] with mesh face indices for each grid point. + _, _, query_face_indices = trimesh.proximity.closest_point( + mesh_trimesh, grid_positions) + + # [num_grid_points, 3] with mesh node indices for each grid point. + mesh_edge_indices = mesh.faces[query_face_indices] + + # [num_grid_points, 3] with grid node indices, where every row simply contains + # the row (grid_point) index. + grid_indices = np.arange(grid_positions.shape[0]) + grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3]) + + # Flatten to get a regular list. + # [num_edges=num_grid_points*3] + mesh_edge_indices = mesh_edge_indices.reshape([-1]) + grid_edge_indices = grid_edge_indices.reshape([-1]) + + return grid_edge_indices, mesh_edge_indices diff --git a/graphcast/grid_mesh_connectivity_test.py b/graphcast/grid_mesh_connectivity_test.py new file mode 100644 index 0000000..d52188b --- /dev/null +++ b/graphcast/grid_mesh_connectivity_test.py @@ -0,0 +1,74 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for graphcast.grid_mesh_connectivity.""" + +from absl.testing import absltest +from graphcast import grid_mesh_connectivity +from graphcast import icosahedral_mesh +import numpy as np + + +class GridMeshConnectivityTest(absltest.TestCase): + + def test_grid_lat_lon_to_coordinates(self): + + # Intervals of 30 degrees. + grid_latitude = np.array([-45., 0., 45]) + grid_longitude = np.array([0., 90., 180., 270.]) + + inv_sqrt2 = 1 / np.sqrt(2) + expected_coordinates = np.array([ + [[inv_sqrt2, 0., -inv_sqrt2], + [0., inv_sqrt2, -inv_sqrt2], + [-inv_sqrt2, 0., -inv_sqrt2], + [0., -inv_sqrt2, -inv_sqrt2]], + [[1., 0., 0.], + [0., 1., 0.], + [-1., 0., 0.], + [0., -1., 0.]], + [[inv_sqrt2, 0., inv_sqrt2], + [0., inv_sqrt2, inv_sqrt2], + [-inv_sqrt2, 0., inv_sqrt2], + [0., -inv_sqrt2, inv_sqrt2]], + ]) + + coordinates = grid_mesh_connectivity._grid_lat_lon_to_coordinates( + grid_latitude, grid_longitude) + np.testing.assert_allclose(expected_coordinates, coordinates, atol=1e-15) + + def test_radius_query_indices_smoke(self): + # TODO(alvarosg): Add non-smoke test? + grid_latitude = np.linspace(-75, 75, 6) + grid_longitude = np.arange(12) * 30. + mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( + splits=3)[-1] + grid_mesh_connectivity.radius_query_indices( + grid_latitude=grid_latitude, + grid_longitude=grid_longitude, + mesh=mesh, radius=0.2) + + def test_in_mesh_triangle_indices_smoke(self): + # TODO(alvarosg): Add non-smoke test? + grid_latitude = np.linspace(-75, 75, 6) + grid_longitude = np.arange(12) * 30. + mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( + splits=3)[-1] + grid_mesh_connectivity.in_mesh_triangle_indices( + grid_latitude=grid_latitude, + grid_longitude=grid_longitude, + mesh=mesh) + + +if __name__ == "__main__": + absltest.main() diff --git a/graphcast/icosahedral_mesh.py b/graphcast/icosahedral_mesh.py new file mode 100644 index 0000000..4c43642 --- /dev/null +++ b/graphcast/icosahedral_mesh.py @@ -0,0 +1,281 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for creating icosahedral meshes.""" + +import itertools +from typing import List, NamedTuple, Sequence, Tuple + +import numpy as np +from scipy.spatial import transform + + +class TriangularMesh(NamedTuple): + """Data structure for triangular meshes. + + Attributes: + vertices: spatial positions of the vertices of the mesh of shape + [num_vertices, num_dims]. + faces: triangular faces of the mesh of shape [num_faces, 3]. Contains + integer indices into `vertices`. + + """ + vertices: np.ndarray + faces: np.ndarray + + +def merge_meshes( + mesh_list: Sequence[TriangularMesh]) -> TriangularMesh: + """Merges all meshes into one. Assumes the last mesh is the finest. + + Args: + mesh_list: Sequence of meshes, from coarse to fine refinement levels. The + vertices and faces may contain those from preceding, coarser levels. + + Returns: + `TriangularMesh` for which the vertices correspond to the highest + resolution mesh in the hierarchy, and the faces are the join set of the + faces at all levels of the hierarchy. + """ + for mesh_i, mesh_ip1 in itertools.pairwise(mesh_list): + num_nodes_mesh_i = mesh_i.vertices.shape[0] + assert np.allclose(mesh_i.vertices, mesh_ip1.vertices[:num_nodes_mesh_i]) + + return TriangularMesh( + vertices=mesh_list[-1].vertices, + faces=np.concatenate([mesh.faces for mesh in mesh_list], axis=0)) + + +def get_hierarchy_of_triangular_meshes_for_sphere( + splits: int) -> List[TriangularMesh]: + """Returns a sequence of meshes, each with triangularization sphere. + + Starting with a regular icosahedron (12 vertices, 20 faces, 30 edges) with + circumscribed unit sphere. Then, each triangular face is iteratively + subdivided into 4 triangular faces `splits` times. The new vertices are then + projected back onto the unit sphere. All resulting meshes are returned in a + list, from lowest to highest resolution. + + The vertices in each face are specified in counter-clockwise order as + observed from the outside the icosahedron. + + Args: + splits: How many times to split each triangle. + Returns: + Sequence of `TriangularMesh`s of length `splits + 1` each with: + + vertices: [num_vertices, 3] vertex positions in 3D, all with unit norm. + faces: [num_faces, 3] with triangular faces joining sets of 3 vertices. + Each row contains three indices into the vertices array, indicating + the vertices adjacent to the face. Always with positive orientation + (counterclock-wise when looking from the outside). + """ + current_mesh = get_icosahedron() + output_meshes = [current_mesh] + for _ in range(splits): + current_mesh = _two_split_unit_sphere_triangle_faces(current_mesh) + output_meshes.append(current_mesh) + return output_meshes + + +def get_icosahedron() -> TriangularMesh: + """Returns a regular icosahedral mesh with circumscribed unit sphere. + + See https://en.wikipedia.org/wiki/Regular_icosahedron#Cartesian_coordinates + for details on the construction of the regular icosahedron. + + The vertices in each face are specified in counter-clockwise order as observed + from the outside of the icosahedron. + + Returns: + TriangularMesh with: + + vertices: [num_vertices=12, 3] vertex positions in 3D, all with unit norm. + faces: [num_faces=20, 3] with triangular faces joining sets of 3 vertices. + Each row contains three indices into the vertices array, indicating + the vertices adjacent to the face. Always with positive orientation ( + counterclock-wise when looking from the outside). + + """ + phi = (1 + np.sqrt(5)) / 2 + vertices = [] + for c1 in [1., -1.]: + for c2 in [phi, -phi]: + vertices.append((c1, c2, 0.)) + vertices.append((0., c1, c2)) + vertices.append((c2, 0., c1)) + + vertices = np.array(vertices, dtype=np.float32) + vertices /= np.linalg.norm([1., phi]) + + # I did this manually, checking the orientation one by one. + faces = [(0, 1, 2), + (0, 6, 1), + (8, 0, 2), + (8, 4, 0), + (3, 8, 2), + (3, 2, 7), + (7, 2, 1), + (0, 4, 6), + (4, 11, 6), + (6, 11, 5), + (1, 5, 7), + (4, 10, 11), + (4, 8, 10), + (10, 8, 3), + (10, 3, 9), + (11, 10, 9), + (11, 9, 5), + (5, 9, 7), + (9, 3, 7), + (1, 6, 5), + ] + + # By default the top is an aris parallel to the Y axis. + # Need to rotate around the y axis by half the supplementary to the + # angle between faces divided by two to get the desired orientation. + # /O\ (top arist) + # / \ Z + # (adjacent face)/ \ (adjacent face) ^ + # / angle_between_faces \ | + # / \ | + # / \ YO-----> X + # This results in: + # (adjacent faceis now top plane) + # ----------------------O\ (top arist) + # \ + # \ + # \ (adjacent face) + # \ + # \ + # \ + + angle_between_faces = 2 * np.arcsin(phi / np.sqrt(3)) + rotation_angle = (np.pi - angle_between_faces) / 2 + rotation = transform.Rotation.from_euler(seq="y", angles=rotation_angle) + rotation_matrix = rotation.as_matrix() + vertices = np.dot(vertices, rotation_matrix) + + return TriangularMesh(vertices=vertices.astype(np.float32), + faces=np.array(faces, dtype=np.int32)) + + +def _two_split_unit_sphere_triangle_faces( + triangular_mesh: TriangularMesh) -> TriangularMesh: + """Splits each triangular face into 4 triangles keeping the orientation.""" + + # Every time we split a triangle into 4 we will be adding 3 extra vertices, + # located at the edge centres. + # This class handles the positioning of the new vertices, and avoids creating + # duplicates. + new_vertices_builder = _ChildVerticesBuilder(triangular_mesh.vertices) + + new_faces = [] + for ind1, ind2, ind3 in triangular_mesh.faces: + # Transform each triangular face into 4 triangles, + # preserving the orientation. + # ind3 + # / \ + # / \ + # / #3 \ + # / \ + # ind31 -------------- ind23 + # / \ / \ + # / \ #4 / \ + # / #1 \ / #2 \ + # / \ / \ + # ind1 ------------ ind12 ------------ ind2 + ind12 = new_vertices_builder.get_new_child_vertex_index((ind1, ind2)) + ind23 = new_vertices_builder.get_new_child_vertex_index((ind2, ind3)) + ind31 = new_vertices_builder.get_new_child_vertex_index((ind3, ind1)) + # Note how each of the 4 triangular new faces specifies the order of the + # vertices to preserve the orientation of the original face. As the input + # face should always be counter-clockwise as specified in the diagram, + # this means child faces should also be counter-clockwise. + new_faces.extend([[ind1, ind12, ind31], # 1 + [ind12, ind2, ind23], # 2 + [ind31, ind23, ind3], # 3 + [ind12, ind23, ind31], # 4 + ]) + return TriangularMesh(vertices=new_vertices_builder.get_all_vertices(), + faces=np.array(new_faces, dtype=np.int32)) + + +class _ChildVerticesBuilder(object): + """Bookkeeping of new child vertices added to an existing set of vertices.""" + + def __init__(self, parent_vertices): + + # Because the same new vertex will be required when splitting adjacent + # triangles (which share an edge) we keep them in a hash table indexed by + # sorted indices of the vertices adjacent to the edge, to avoid creating + # duplicated child vertices. + self._child_vertices_index_mapping = {} + self._parent_vertices = parent_vertices + # We start with all previous vertices. + self._all_vertices_list = list(parent_vertices) + + def _get_child_vertex_key(self, parent_vertex_indices): + return tuple(sorted(parent_vertex_indices)) + + def _create_child_vertex(self, parent_vertex_indices): + """Creates a new vertex.""" + # Position for new vertex is the middle point, between the parent points, + # projected to unit sphere. + child_vertex_position = self._parent_vertices[ + list(parent_vertex_indices)].mean(0) + child_vertex_position /= np.linalg.norm(child_vertex_position) + + # Add the vertex to the output list. The index for this new vertex will + # match the length of the list before adding it. + child_vertex_key = self._get_child_vertex_key(parent_vertex_indices) + self._child_vertices_index_mapping[child_vertex_key] = len( + self._all_vertices_list) + self._all_vertices_list.append(child_vertex_position) + + def get_new_child_vertex_index(self, parent_vertex_indices): + """Returns index for a child vertex, creating it if necessary.""" + # Get the key to see if we already have a new vertex in the middle. + child_vertex_key = self._get_child_vertex_key(parent_vertex_indices) + if child_vertex_key not in self._child_vertices_index_mapping: + self._create_child_vertex(parent_vertex_indices) + return self._child_vertices_index_mapping[child_vertex_key] + + def get_all_vertices(self): + """Returns an array with old vertices.""" + return np.array(self._all_vertices_list) + + +def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Transforms polygonal faces to sender and receiver indices. + + It does so by transforming every face into N_i edges. Such if the triangular + face has indices [0, 1, 2], three edges are added 0->1, 1->2, and 2->0. + + If all faces have consistent orientation, and the surface represented by the + faces is closed, then every edge in a polygon with a certain orientation + is also part of another polygon with the opposite orientation. In this + situation, the edges returned by the method are always bidirectional. + + Args: + faces: Integer array of shape [num_faces, 3]. Contains node indices + adjacent to each face. + Returns: + Tuple with sender/receiver indices, each of shape [num_edges=num_faces*3]. + + """ + assert faces.ndim == 2 + assert faces.shape[-1] == 3 + senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]]) + receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]]) + return senders, receivers diff --git a/graphcast/icosahedral_mesh_test.py b/graphcast/icosahedral_mesh_test.py new file mode 100644 index 0000000..c76848e --- /dev/null +++ b/graphcast/icosahedral_mesh_test.py @@ -0,0 +1,131 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for icosahedral_mesh.""" + +from absl.testing import absltest +from absl.testing import parameterized +import chex +from graphcast import icosahedral_mesh +import numpy as np + + +def _get_mesh_spec(splits: int): + """Returns size of the final icosahedral mesh resulting from the splitting.""" + num_vertices = 12 + num_faces = 20 + for _ in range(splits): + # Each previous face adds three new vertices, but each vertex is shared + # by two faces. + num_vertices += num_faces * 3 // 2 + num_faces *= 4 + return num_vertices, num_faces + + +class IcosahedralMeshTest(parameterized.TestCase): + + def test_icosahedron(self): + mesh = icosahedral_mesh.get_icosahedron() + _assert_valid_mesh( + mesh, num_expected_vertices=12, num_expected_faces=20) + + @parameterized.parameters(list(range(5))) + def test_get_hierarchy_of_triangular_meshes_for_sphere(self, splits): + meshes = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( + splits=splits) + prev_vertices = None + for mesh_i, mesh in enumerate(meshes): + # Check that `mesh` is valid. + num_expected_vertices, num_expected_faces = _get_mesh_spec(mesh_i) + _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces) + + # Check that the first N vertices from this mesh match all of the + # vertices from the previous mesh. + if prev_vertices is not None: + leading_mesh_vertices = mesh.vertices[:prev_vertices.shape[0]] + np.testing.assert_array_equal(leading_mesh_vertices, prev_vertices) + + # Increase the expected/previous values for the next iteration. + if mesh_i < len(meshes) - 1: + prev_vertices = mesh.vertices + + @parameterized.parameters(list(range(4))) + def test_merge_meshes(self, splits): + mesh_hierarchy = ( + icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( + splits=splits)) + mesh = icosahedral_mesh.merge_meshes(mesh_hierarchy) + + expected_faces = np.concatenate([m.faces for m in mesh_hierarchy], axis=0) + np.testing.assert_array_equal(mesh.vertices, mesh_hierarchy[-1].vertices) + np.testing.assert_array_equal(mesh.faces, expected_faces) + + def test_faces_to_edges(self): + + faces = np.array([[0, 1, 2], + [3, 4, 5]]) + + # This also documents the order of the edges returned by the method. + expected_edges = np.array( + [[0, 1], + [3, 4], + [1, 2], + [4, 5], + [2, 0], + [5, 3]]) + expected_senders = expected_edges[:, 0] + expected_receivers = expected_edges[:, 1] + + senders, receivers = icosahedral_mesh.faces_to_edges(faces) + + np.testing.assert_array_equal(senders, expected_senders) + np.testing.assert_array_equal(receivers, expected_receivers) + + +def _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces): + vertices = mesh.vertices + faces = mesh.faces + chex.assert_shape(vertices, [num_expected_vertices, 3]) + chex.assert_shape(faces, [num_expected_faces, 3]) + + # Vertices norm should be 1. + vertices_norm = np.linalg.norm(vertices, axis=-1) + np.testing.assert_allclose(vertices_norm, 1., rtol=1e-6) + + _assert_positive_face_orientation(vertices, faces) + + +def _assert_positive_face_orientation(vertices, faces): + + # Obtain a unit vector that points, in the direction of the face. + face_orientation = np.cross(vertices[faces[:, 1]] - vertices[faces[:, 0]], + vertices[faces[:, 2]] - vertices[faces[:, 1]]) + face_orientation /= np.linalg.norm(face_orientation, axis=-1, keepdims=True) + + # And a unit vector pointing from the origin to the center of the face. + face_centers = vertices[faces].mean(1) + face_centers /= np.linalg.norm(face_centers, axis=-1, keepdims=True) + + # Positive orientation means those two vectors should be parallel + # (dot product, 1), and not anti-parallel (dot product, -1). + dot_center_orientation = np.einsum("ik,ik->i", face_orientation, face_centers) + + # Check that the face normal is parallel to the vector that joins the center + # of the face to the center of the sphere. Note we need a small tolerance + # because some discretizations are not exactly uniform, so it will not be + # exactly parallel. + np.testing.assert_allclose(dot_center_orientation, 1., atol=6e-4) + + +if __name__ == "__main__": + absltest.main() diff --git a/graphcast/losses.py b/graphcast/losses.py new file mode 100644 index 0000000..9ceeb94 --- /dev/null +++ b/graphcast/losses.py @@ -0,0 +1,179 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Loss functions (and terms for use in loss functions) used for weather.""" + +from typing import Mapping + +from graphcast import xarray_tree +import numpy as np +from typing_extensions import Protocol +import xarray + + +LossAndDiagnostics = tuple[xarray.DataArray, xarray.Dataset] + + +class LossFunction(Protocol): + """A loss function. + + This is a protocol so it's fine to use a plain function which 'quacks like' + this. This is just to document the interface. + """ + + def __call__(self, + predictions: xarray.Dataset, + targets: xarray.Dataset, + **optional_kwargs) -> LossAndDiagnostics: + """Computes a loss function. + + Args: + predictions: Dataset of predictions. + targets: Dataset of targets. + **optional_kwargs: Implementations may support extra optional kwargs. + + Returns: + loss: A DataArray with dimensions ('batch',) containing losses for each + element of the batch. These will be averaged to give the final + loss, locally and across replicas. + diagnostics: Mapping of additional quantities to log by name alongside the + loss. These will will typically correspond to terms in the loss. They + should also have dimensions ('batch',) and will be averaged over the + batch before logging. + """ + + +def weighted_mse_per_level( + predictions: xarray.Dataset, + targets: xarray.Dataset, + per_variable_weights: Mapping[str, float], +) -> LossAndDiagnostics: + """Latitude- and pressure-level-weighted MSE loss.""" + def loss(prediction, target): + loss = (prediction - target)**2 + loss *= normalized_latitude_weights(target).astype(loss.dtype) + if 'level' in target.dims: + loss *= normalized_level_weights(target).astype(loss.dtype) + return _mean_preserving_batch(loss) + + losses = xarray_tree.map_structure(loss, predictions, targets) + return sum_per_variable_losses(losses, per_variable_weights) + + +def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray: + return x.mean([d for d in x.dims if d != 'batch'], skipna=False) + + +def sum_per_variable_losses( + per_variable_losses: Mapping[str, xarray.DataArray], + weights: Mapping[str, float], +) -> LossAndDiagnostics: + """Weighted sum of per-variable losses.""" + if not set(weights.keys()).issubset(set(per_variable_losses.keys())): + raise ValueError( + 'Passing a weight that does not correspond to any variable ' + f'{set(weights.keys())-set(per_variable_losses.keys())}') + + weighted_per_variable_losses = { + name: loss * weights.get(name, 1) + for name, loss in per_variable_losses.items() + } + total = xarray.concat( + weighted_per_variable_losses.values(), dim='variable', join='exact').sum( + 'variable', skipna=False) + return total, per_variable_losses + + +def normalized_level_weights(data: xarray.DataArray) -> xarray.DataArray: + """Weights proportional to pressure at each level.""" + level = data.coords['level'] + return level / level.mean(skipna=False) + + +def normalized_latitude_weights(data: xarray.DataArray) -> xarray.DataArray: + """Weights based on latitude, roughly proportional to grid cell area. + + This method supports two use cases only (both for equispaced values): + * Latitude values such that the closest value to the pole is at latitude + (90 - d_lat/2), where d_lat is the difference between contiguous latitudes. + For example: [-89, -87, -85, ..., 85, 87, 89]) (d_lat = 2) + In this case each point with `lat` value represents a sphere slice between + `lat - d_lat/2` and `lat + d_lat/2`, and the area of this slice would be + proportional to: + `sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)`, and + we can simply omit the term `2 * sin(d_lat/2)` which is just a constant + that cancels during normalization. + * Latitude values that fall exactly at the poles. + For example: [-90, -88, -86, ..., 86, 88, 90]) (d_lat = 2) + In this case each point with `lat` value also represents + a sphere slice between `lat - d_lat/2` and `lat + d_lat/2`, + except for the points at the poles, that represent a slice between + `90 - d_lat/2` and `90` or, `-90` and `-90 + d_lat/2`. + The areas of the first type of point are still proportional to: + * sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat) + but for the points at the poles now is: + * sin(90) - sin(90 - d_lat/2) = 2 * sin(d_lat/4) ^ 2 + and we will be using these weights, depending on whether we are looking at + pole cells, or non-pole cells (omitting the common factor of 2 which will be + absorbed by the normalization). + + It can be shown via a limit, or simple geometry, that in the small angles + regime, the proportion of area per pole-point is equal to 1/8th + the proportion of area covered by each of the nearest non-pole point, and we + test for this in the test. + + Args: + data: `DataArray` with latitude coordinates. + Returns: + Unit mean latitude weights. + """ + latitude = data.coords['lat'] + + if np.any(np.isclose(np.abs(latitude), 90.)): + weights = _weight_for_latitude_vector_with_poles(latitude) + else: + weights = _weight_for_latitude_vector_without_poles(latitude) + + return weights / weights.mean(skipna=False) + + +def _weight_for_latitude_vector_without_poles(latitude): + """Weights for uniform latitudes of the form [+-90-+d/2, ..., -+90+-d/2].""" + delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude)) + if (not np.isclose(np.max(latitude), 90 - delta_latitude/2) or + not np.isclose(np.min(latitude), -90 + delta_latitude/2)): + raise ValueError( + f'Latitude vector {latitude} does not start/end at ' + '+- (90 - delta_latitude/2) degrees.') + return np.cos(np.deg2rad(latitude)) + + +def _weight_for_latitude_vector_with_poles(latitude): + """Weights for uniform latitudes of the form [+- 90, ..., -+90].""" + delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude)) + if (not np.isclose(np.max(latitude), 90.) or + not np.isclose(np.min(latitude), -90.)): + raise ValueError( + f'Latitude vector {latitude} does not start/end at +- 90 degrees.') + weights = np.cos(np.deg2rad(latitude)) * np.sin(np.deg2rad(delta_latitude/2)) + # The two checks above enough to guarantee that latitudes are sorted, so + # the extremes are the poles + weights[[0, -1]] = np.sin(np.deg2rad(delta_latitude/4)) ** 2 + return weights + + +def _check_uniform_spacing_and_get_delta(vector): + diff = np.diff(vector) + if not np.all(np.isclose(diff[0], diff)): + raise ValueError(f'Vector {diff} is not uniformly spaced.') + return diff[0] diff --git a/graphcast/model_utils.py b/graphcast/model_utils.py new file mode 100644 index 0000000..949088c --- /dev/null +++ b/graphcast/model_utils.py @@ -0,0 +1,724 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for building models.""" + +from typing import Mapping, Optional, Tuple + +import numpy as np +from scipy.spatial import transform +import xarray + + +def get_graph_spatial_features( + *, node_lat: np.ndarray, node_lon: np.ndarray, + senders: np.ndarray, receivers: np.ndarray, + add_node_positions: bool, + add_node_latitude: bool, + add_node_longitude: bool, + add_relative_positions: bool, + relative_longitude_local_coordinates: bool, + relative_latitude_local_coordinates: bool, + sine_cosine_encoding: bool = False, + encoding_num_freqs: int = 10, + encoding_multiplicative_factor: float = 1.2, + ) -> Tuple[np.ndarray, np.ndarray]: + """Computes spatial features for the nodes. + + Args: + node_lat: Latitudes in the [-90, 90] interval of shape [num_nodes] + node_lon: Longitudes in the [0, 360] interval of shape [num_nodes] + senders: Sender indices of shape [num_edges] + receivers: Receiver indices of shape [num_edges] + add_node_positions: Add unit norm absolute positions. + add_node_latitude: Add a feature for latitude (cos(90 - lat)) + Note even if this is set to False, the model may be able to infer the + longitude from relative features, unless + `relative_latitude_local_coordinates` is also True, or if there is any + bias on the relative edge sizes for different longitudes. + add_node_longitude: Add features for longitude (cos(lon), sin(lon)). + Note even if this is set to False, the model may be able to infer the + longitude from relative features, unless + `relative_longitude_local_coordinates` is also True, or if there is any + bias on the relative edge sizes for different longitudes. + add_relative_positions: Whether to relative positions in R3 to the edges. + relative_longitude_local_coordinates: If True, relative positions are + computed in a local space where the receiver is at 0 longitude. + relative_latitude_local_coordinates: If True, relative positions are + computed in a local space where the receiver is at 0 latitude. + sine_cosine_encoding: If True, we will transform the node/edge features + with sine and cosine functions, similar to NERF. + encoding_num_freqs: frequency parameter + encoding_multiplicative_factor: used for calculating the frequency. + + Returns: + Arrays of shape: [num_nodes, num_features] and [num_edges, num_features]. + with node and edge features. + + """ + + num_nodes = node_lat.shape[0] + num_edges = senders.shape[0] + dtype = node_lat.dtype + node_phi, node_theta = lat_lon_deg_to_spherical(node_lat, node_lon) + + # Computing some node features. + node_features = [] + if add_node_positions: + # Already in [-1, 1.] range. + node_features.extend(spherical_to_cartesian(node_phi, node_theta)) + + if add_node_latitude: + # Using the cos of theta. + # From 1. (north pole) to -1 (south pole). + node_features.append(np.cos(node_theta)) + + if add_node_longitude: + # Using the cos and sin, which is already normalized. + node_features.append(np.cos(node_phi)) + node_features.append(np.sin(node_phi)) + + if not node_features: + node_features = np.zeros([num_nodes, 0], dtype=dtype) + else: + node_features = np.stack(node_features, axis=-1) + + # Computing some edge features. + edge_features = [] + + if add_relative_positions: + + relative_position = get_relative_position_in_receiver_local_coordinates( + node_phi=node_phi, + node_theta=node_theta, + senders=senders, + receivers=receivers, + latitude_local_coordinates=relative_latitude_local_coordinates, + longitude_local_coordinates=relative_longitude_local_coordinates + ) + + # Note this is L2 distance in 3d space, rather than geodesic distance. + relative_edge_distances = np.linalg.norm( + relative_position, axis=-1, keepdims=True) + + # Normalize to the maximum edge distance. Note that we expect to always + # have an edge that goes in the opposite direction of any given edge + # so the distribution of relative positions should be symmetric around + # zero. So by scaling by the maximum length, we expect all relative + # positions to fall in the [-1., 1.] interval, and all relative distances + # to fall in the [0., 1.] interval. + max_edge_distance = relative_edge_distances.max() + edge_features.append(relative_edge_distances / max_edge_distance) + edge_features.append(relative_position / max_edge_distance) + + if not edge_features: + edge_features = np.zeros([num_edges, 0], dtype=dtype) + else: + edge_features = np.concatenate(edge_features, axis=-1) + + if sine_cosine_encoding: + def sine_cosine_transform(x: np.ndarray) -> np.ndarray: + freqs = encoding_multiplicative_factor**np.arange(encoding_num_freqs) + phases = freqs * x[..., None] + x_sin = np.sin(phases) + x_cos = np.cos(phases) + x_cat = np.concatenate([x_sin, x_cos], axis=-1) + return x_cat.reshape([x.shape[0], -1]) + + node_features = sine_cosine_transform(node_features) + edge_features = sine_cosine_transform(edge_features) + + return node_features, edge_features + + +def lat_lon_to_leading_axes( + grid_xarray: xarray.DataArray) -> xarray.DataArray: + """Reorders xarray so lat/lon axes come first.""" + # leading + ["lat", "lon"] + trailing + # to + # ["lat", "lon"] + leading + trailing + return grid_xarray.transpose("lat", "lon", ...) + + +def restore_leading_axes(grid_xarray: xarray.DataArray) -> xarray.DataArray: + """Reorders xarray so batch/time/level axes come first (if present).""" + + # ["lat", "lon"] + [(batch,) (time,) (level,)] + trailing + # to + # [(batch,) (time,) (level,)] + ["lat", "lon"] + trailing + + input_dims = list(grid_xarray.dims) + output_dims = list(input_dims) + for leading_key in ["level", "time", "batch"]: # reverse order for insert + if leading_key in input_dims: + output_dims.remove(leading_key) + output_dims.insert(0, leading_key) + return grid_xarray.transpose(*output_dims) + + +def lat_lon_deg_to_spherical(node_lat: np.ndarray, + node_lon: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + phi = np.deg2rad(node_lon) + theta = np.deg2rad(90 - node_lat) + return phi, theta + + +def spherical_to_lat_lon(phi: np.ndarray, + theta: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + lon = np.mod(np.rad2deg(phi), 360) + lat = 90 - np.rad2deg(theta) + return lat, lon + + +def cartesian_to_spherical(x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + phi = np.arctan2(y, x) + with np.errstate(invalid="ignore"): # circumventing b/253179568 + theta = np.arccos(z) # Assuming unit radius. + return phi, theta + + +def spherical_to_cartesian( + phi: np.ndarray, theta: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + # Assuming unit radius. + return (np.cos(phi)*np.sin(theta), + np.sin(phi)*np.sin(theta), + np.cos(theta)) + + +def get_relative_position_in_receiver_local_coordinates( + node_phi: np.ndarray, + node_theta: np.ndarray, + senders: np.ndarray, + receivers: np.ndarray, + latitude_local_coordinates: bool, + longitude_local_coordinates: bool + ) -> np.ndarray: + """Returns relative position features for the edges. + + The relative positions will be computed in a rotated space for a local + coordinate system as defined by the receiver. The relative positions are + simply obtained by subtracting sender position minues receiver position in + that local coordinate system after the rotation in R^3. + + Args: + node_phi: [num_nodes] with polar angles. + node_theta: [num_nodes] with azimuthal angles. + senders: [num_edges] with indices. + receivers: [num_edges] with indices. + latitude_local_coordinates: Whether to rotate edges such that in the + positions are computed such that the receiver is always at latitude 0. + longitude_local_coordinates: Whether to rotate edges such that in the + positions are computed such that the receiver is always at longitude 0. + + Returns: + Array of relative positions in R3 [num_edges, 3] + """ + + node_pos = np.stack(spherical_to_cartesian(node_phi, node_theta), axis=-1) + + # No rotation in this case. + if not (latitude_local_coordinates or longitude_local_coordinates): + return node_pos[senders] - node_pos[receivers] + + # Get rotation matrices for the local space space for every node. + rotation_matrices = get_rotation_matrices_to_local_coordinates( + reference_phi=node_phi, + reference_theta=node_theta, + rotate_latitude=latitude_local_coordinates, + rotate_longitude=longitude_local_coordinates) + + # Each edge will be rotated according to the rotation matrix of its receiver + # node. + edge_rotation_matrices = rotation_matrices[receivers] + + # Rotate all nodes to the rotated space of the corresponding edge. + # Note for receivers we can also do the matmul first and the gather second: + # ``` + # receiver_pos_in_rotated_space = rotate_with_matrices( + # rotation_matrices, node_pos)[receivers] + # ``` + # which is more efficient, however, we do gather first to keep it more + # symmetric with the sender computation. + receiver_pos_in_rotated_space = rotate_with_matrices( + edge_rotation_matrices, node_pos[receivers]) + sender_pos_in_in_rotated_space = rotate_with_matrices( + edge_rotation_matrices, node_pos[senders]) + # Note, here, that because the rotated space is chosen according to the + # receiver, if: + # * latitude_local_coordinates = True: latitude for the receivers will be + # 0, that is the z coordinate will always be 0. + # * longitude_local_coordinates = True: longitude for the receivers will be + # 0, that is the y coordinate will be 0. + + # Now we can just subtract. + # Note we are rotating to a local coordinate system, where the y-z axes are + # parallel to a tangent plane to the sphere, but still remain in a 3d space. + # Note that if both `latitude_local_coordinates` and + # `longitude_local_coordinates` are True, and edges are short, + # then the difference in x coordinate between sender and receiver + # should be small, so we could consider dropping the new x coordinate if + # we wanted to the tangent plane, however in doing so + # we would lose information about the curvature of the mesh, which may be + # important for very coarse meshes. + return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space + + +def get_rotation_matrices_to_local_coordinates( + reference_phi: np.ndarray, + reference_theta: np.ndarray, + rotate_latitude: bool, + rotate_longitude: bool) -> np.ndarray: + + """Returns a rotation matrix to rotate to a point based on a reference vector. + + The rotation matrix is build such that, a vector in the + same coordinate system at the reference point that points towards the pole + before the rotation, continues to point towards the pole after the rotation. + + Args: + reference_phi: [leading_axis] Polar angles of the reference. + reference_theta: [leading_axis] Azimuthal angles of the reference. + rotate_latitude: Whether to produce a rotation matrix that would rotate + R^3 vectors to zero latitude. + rotate_longitude: Whether to produce a rotation matrix that would rotate + R^3 vectors to zero longitude. + + Returns: + Matrices of shape [leading_axis] such that when applied to the reference + position with `rotate_with_matrices(rotation_matrices, reference_pos)` + + * phi goes to 0. if "rotate_longitude" is True. + + * theta goes to np.pi / 2 if "rotate_latitude" is True. + + The rotation consists of: + * rotate_latitude = False, rotate_longitude = True: + Latitude preserving rotation. + * rotate_latitude = True, rotate_longitude = True: + Latitude preserving rotation, followed by longitude preserving + rotation. + * rotate_latitude = True, rotate_longitude = False: + Latitude preserving rotation, followed by longitude preserving + rotation, and the inverse of the latitude preserving rotation. Note + this is computationally different from rotating the longitude only + and is. We do it like this, so the polar geodesic curve, continues + to be aligned with one of the axis after the rotation. + + """ + + if rotate_longitude and rotate_latitude: + + # We first rotate around the z axis "minus the azimuthal angle", to get the + # point with zero longitude + azimuthal_rotation = - reference_phi + + # One then we will do a polar rotation (which can be done along the y + # axis now that we are at longitude 0.), "minus the polar angle plus 2pi" + # to get the point with zero latitude. + polar_rotation = - reference_theta + np.pi/2 + + return transform.Rotation.from_euler( + "zy", np.stack([azimuthal_rotation, polar_rotation], + axis=1)).as_matrix() + elif rotate_longitude: + # Just like the previous case, but applying only the azimuthal rotation. + azimuthal_rotation = - reference_phi + return transform.Rotation.from_euler("z", -reference_phi).as_matrix() + elif rotate_latitude: + # Just like the first case, but after doing the polar rotation, undoing + # the azimuthal rotation. + azimuthal_rotation = - reference_phi + polar_rotation = - reference_theta + np.pi/2 + + return transform.Rotation.from_euler( + "zyz", np.stack( + [azimuthal_rotation, polar_rotation, -azimuthal_rotation] + , axis=1)).as_matrix() + else: + raise ValueError( + "At least one of longitude and latitude should be rotated.") + + +def rotate_with_matrices(rotation_matrices: np.ndarray, positions: np.ndarray + ) -> np.ndarray: + return np.einsum("bji,bi->bj", rotation_matrices, positions) + + +def get_bipartite_graph_spatial_features( + *, + senders_node_lat: np.ndarray, + senders_node_lon: np.ndarray, + senders: np.ndarray, + receivers_node_lat: np.ndarray, + receivers_node_lon: np.ndarray, + receivers: np.ndarray, + add_node_positions: bool, + add_node_latitude: bool, + add_node_longitude: bool, + add_relative_positions: bool, + edge_normalization_factor: Optional[float] = None, + relative_longitude_local_coordinates: bool, + relative_latitude_local_coordinates: bool, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Computes spatial features for the nodes. + + This function is almost identical to `get_graph_spatial_features`. The only + difference is that sender nodes and receiver nodes can be in different arrays. + This is necessary to enable combination with typed Graph. + + Args: + senders_node_lat: Latitudes in the [-90, 90] interval of shape + [num_sender_nodes] + senders_node_lon: Longitudes in the [0, 360] interval of shape + [num_sender_nodes] + senders: Sender indices of shape [num_edges], indices in [0, + num_sender_nodes) + receivers_node_lat: Latitudes in the [-90, 90] interval of shape + [num_receiver_nodes] + receivers_node_lon: Longitudes in the [0, 360] interval of shape + [num_receiver_nodes] + receivers: Receiver indices of shape [num_edges], indices in [0, + num_receiver_nodes) + add_node_positions: Add unit norm absolute positions. + add_node_latitude: Add a feature for latitude (cos(90 - lat)) Note even if + this is set to False, the model may be able to infer the longitude from + relative features, unless `relative_latitude_local_coordinates` is also + True, or if there is any bias on the relative edge sizes for different + longitudes. + add_node_longitude: Add features for longitude (cos(lon), sin(lon)). Note + even if this is set to False, the model may be able to infer the longitude + from relative features, unless `relative_longitude_local_coordinates` is + also True, or if there is any bias on the relative edge sizes for + different longitudes. + add_relative_positions: Whether to relative positions in R3 to the edges. + edge_normalization_factor: Allows explicitly controlling edge normalization. + If None, defaults to max edge length. This supports using pre-trained + model weights with a different graph structure to what it was trained on. + relative_longitude_local_coordinates: If True, relative positions are + computed in a local space where the receiver is at 0 longitude. + relative_latitude_local_coordinates: If True, relative positions are + computed in a local space where the receiver is at 0 latitude. + + Returns: + Arrays of shape: [num_nodes, num_features] and [num_edges, num_features]. + with node and edge features. + + """ + + num_senders = senders_node_lat.shape[0] + num_receivers = receivers_node_lat.shape[0] + num_edges = senders.shape[0] + dtype = senders_node_lat.dtype + assert receivers_node_lat.dtype == dtype + senders_node_phi, senders_node_theta = lat_lon_deg_to_spherical( + senders_node_lat, senders_node_lon) + receivers_node_phi, receivers_node_theta = lat_lon_deg_to_spherical( + receivers_node_lat, receivers_node_lon) + + # Computing some node features. + senders_node_features = [] + receivers_node_features = [] + if add_node_positions: + # Already in [-1, 1.] range. + senders_node_features.extend( + spherical_to_cartesian(senders_node_phi, senders_node_theta)) + receivers_node_features.extend( + spherical_to_cartesian(receivers_node_phi, receivers_node_theta)) + + if add_node_latitude: + # Using the cos of theta. + # From 1. (north pole) to -1 (south pole). + senders_node_features.append(np.cos(senders_node_theta)) + receivers_node_features.append(np.cos(receivers_node_theta)) + + if add_node_longitude: + # Using the cos and sin, which is already normalized. + senders_node_features.append(np.cos(senders_node_phi)) + senders_node_features.append(np.sin(senders_node_phi)) + + receivers_node_features.append(np.cos(receivers_node_phi)) + receivers_node_features.append(np.sin(receivers_node_phi)) + + if not senders_node_features: + senders_node_features = np.zeros([num_senders, 0], dtype=dtype) + receivers_node_features = np.zeros([num_receivers, 0], dtype=dtype) + else: + senders_node_features = np.stack(senders_node_features, axis=-1) + receivers_node_features = np.stack(receivers_node_features, axis=-1) + + # Computing some edge features. + edge_features = [] + + if add_relative_positions: + + relative_position = get_bipartite_relative_position_in_receiver_local_coordinates( # pylint: disable=line-too-long + senders_node_phi=senders_node_phi, + senders_node_theta=senders_node_theta, + receivers_node_phi=receivers_node_phi, + receivers_node_theta=receivers_node_theta, + senders=senders, + receivers=receivers, + latitude_local_coordinates=relative_latitude_local_coordinates, + longitude_local_coordinates=relative_longitude_local_coordinates) + + # Note this is L2 distance in 3d space, rather than geodesic distance. + relative_edge_distances = np.linalg.norm( + relative_position, axis=-1, keepdims=True) + + if edge_normalization_factor is None: + # Normalize to the maximum edge distance. Note that we expect to always + # have an edge that goes in the opposite direction of any given edge + # so the distribution of relative positions should be symmetric around + # zero. So by scaling by the maximum length, we expect all relative + # positions to fall in the [-1., 1.] interval, and all relative distances + # to fall in the [0., 1.] interval. + edge_normalization_factor = relative_edge_distances.max() + + edge_features.append(relative_edge_distances / edge_normalization_factor) + edge_features.append(relative_position / edge_normalization_factor) + + if not edge_features: + edge_features = np.zeros([num_edges, 0], dtype=dtype) + else: + edge_features = np.concatenate(edge_features, axis=-1) + + return senders_node_features, receivers_node_features, edge_features + + +def get_bipartite_relative_position_in_receiver_local_coordinates( + senders_node_phi: np.ndarray, + senders_node_theta: np.ndarray, + senders: np.ndarray, + receivers_node_phi: np.ndarray, + receivers_node_theta: np.ndarray, + receivers: np.ndarray, + latitude_local_coordinates: bool, + longitude_local_coordinates: bool) -> np.ndarray: + """Returns relative position features for the edges. + + This function is equivalent to + `get_relative_position_in_receiver_local_coordinates`, but adapted to work + with bipartite typed graphs. + + The relative positions will be computed in a rotated space for a local + coordinate system as defined by the receiver. The relative positions are + simply obtained by subtracting sender position minues receiver position in + that local coordinate system after the rotation in R^3. + + Args: + senders_node_phi: [num_sender_nodes] with polar angles. + senders_node_theta: [num_sender_nodes] with azimuthal angles. + senders: [num_edges] with indices into sender nodes. + receivers_node_phi: [num_sender_nodes] with polar angles. + receivers_node_theta: [num_sender_nodes] with azimuthal angles. + receivers: [num_edges] with indices into receiver nodes. + latitude_local_coordinates: Whether to rotate edges such that in the + positions are computed such that the receiver is always at latitude 0. + longitude_local_coordinates: Whether to rotate edges such that in the + positions are computed such that the receiver is always at longitude 0. + + Returns: + Array of relative positions in R3 [num_edges, 3] + """ + + senders_node_pos = np.stack( + spherical_to_cartesian(senders_node_phi, senders_node_theta), axis=-1) + + receivers_node_pos = np.stack( + spherical_to_cartesian(receivers_node_phi, receivers_node_theta), axis=-1) + + # No rotation in this case. + if not (latitude_local_coordinates or longitude_local_coordinates): + return senders_node_pos[senders] - receivers_node_pos[receivers] + + # Get rotation matrices for the local space space for every receiver node. + receiver_rotation_matrices = get_rotation_matrices_to_local_coordinates( + reference_phi=receivers_node_phi, + reference_theta=receivers_node_theta, + rotate_latitude=latitude_local_coordinates, + rotate_longitude=longitude_local_coordinates) + + # Each edge will be rotated according to the rotation matrix of its receiver + # node. + edge_rotation_matrices = receiver_rotation_matrices[receivers] + + # Rotate all nodes to the rotated space of the corresponding edge. + # Note for receivers we can also do the matmul first and the gather second: + # ``` + # receiver_pos_in_rotated_space = rotate_with_matrices( + # rotation_matrices, node_pos)[receivers] + # ``` + # which is more efficient, however, we do gather first to keep it more + # symmetric with the sender computation. + receiver_pos_in_rotated_space = rotate_with_matrices( + edge_rotation_matrices, receivers_node_pos[receivers]) + sender_pos_in_in_rotated_space = rotate_with_matrices( + edge_rotation_matrices, senders_node_pos[senders]) + # Note, here, that because the rotated space is chosen according to the + # receiver, if: + # * latitude_local_coordinates = True: latitude for the receivers will be + # 0, that is the z coordinate will always be 0. + # * longitude_local_coordinates = True: longitude for the receivers will be + # 0, that is the y coordinate will be 0. + + # Now we can just subtract. + # Note we are rotating to a local coordinate system, where the y-z axes are + # parallel to a tangent plane to the sphere, but still remain in a 3d space. + # Note that if both `latitude_local_coordinates` and + # `longitude_local_coordinates` are True, and edges are short, + # then the difference in x coordinate between sender and receiver + # should be small, so we could consider dropping the new x coordinate if + # we wanted to the tangent plane, however in doing so + # we would lose information about the curvature of the mesh, which may be + # important for very coarse meshes. + return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space + + +def variable_to_stacked( + variable: xarray.Variable, + sizes: Mapping[str, int], + preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"), + ) -> xarray.Variable: + """Converts an xarray.Variable to preserved_dims + ("channels",). + + Any dimensions other than those included in preserved_dims get stacked into a + final "channels" dimension. If any of the preserved_dims are missing then they + are added, with the data broadcast/tiled to match the sizes specified in + `sizes`. + + Args: + variable: An xarray.Variable. + sizes: Mapping including sizes for any dimensions which are not present in + `variable` but are needed for the output. This may be needed for example + for a static variable with only ("lat", "lon") dims, or if you want to + encode just the latitude coordinates (a variable with dims ("lat",)). + preserved_dims: dimensions of variable to not be folded in channels. + + Returns: + An xarray.Variable with dimensions preserved_dims + ("channels",). + """ + stack_to_channels_dims = [ + d for d in variable.dims if d not in preserved_dims] + if stack_to_channels_dims: + variable = variable.stack(channels=stack_to_channels_dims) + dims = {dim: variable.sizes.get(dim) or sizes[dim] for dim in preserved_dims} + dims["channels"] = variable.sizes.get("channels", 1) + return variable.set_dims(dims) + + +def dataset_to_stacked( + dataset: xarray.Dataset, + sizes: Optional[Mapping[str, int]] = None, + preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"), +) -> xarray.DataArray: + """Converts an xarray.Dataset to a single stacked array. + + This takes each consistuent data_var, converts it into BHWC layout + using `variable_to_stacked`, then concats them all along the channels axis. + + Args: + dataset: An xarray.Dataset. + sizes: Mapping including sizes for any dimensions which are not present in + the `dataset` but are needed for the output. See variable_to_stacked. + preserved_dims: dimensions from the dataset that should not be folded in + the predictions channels. + + Returns: + An xarray.DataArray with dimensions preserved_dims + ("channels",). + Existing coordinates for preserved_dims axes will be preserved, however + there will be no coordinates for "channels". + """ + data_vars = [ + variable_to_stacked(dataset.variables[name], sizes or dataset.sizes, + preserved_dims) + for name in sorted(dataset.data_vars.keys()) + ] + coords = { + dim: coord + for dim, coord in dataset.coords.items() + if dim in preserved_dims + } + return xarray.DataArray( + data=xarray.Variable.concat(data_vars, dim="channels"), coords=coords) + + +def stacked_to_dataset( + stacked_array: xarray.Variable, + template_dataset: xarray.Dataset, + preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"), + ) -> xarray.Dataset: + """The inverse of dataset_to_stacked. + + Requires a template dataset to demonstrate the variables/shapes/coordinates + required. + All variables must have preserved_dims dimensions. + + Args: + stacked_array: Data in BHWC layout, encoded the same as dataset_to_stacked + would if it was asked to encode `template_dataset`. + template_dataset: A template Dataset (or other mapping of DataArrays) + demonstrating the shape of output required (variables, shapes, + coordinates etc). + preserved_dims: dimensions from the target_template that were not folded in + the predictions channels. The preserved_dims need to be a subset of the + dims of all the variables of template_dataset. + + Returns: + An xarray.Dataset (or other mapping of DataArrays) with the same shape and + type as template_dataset. + """ + unstack_from_channels_sizes = {} + var_names = sorted(template_dataset.keys()) + for name in var_names: + template_var = template_dataset[name] + if not all(dim in template_var.dims for dim in preserved_dims): + raise ValueError( + f"stacked_to_dataset requires all Variables to have {preserved_dims} " + f"dimensions, but found only {template_var.dims}.") + unstack_from_channels_sizes[name] = { + dim: size for dim, size in template_var.sizes.items() + if dim not in preserved_dims} + + channels = {name: np.prod(list(unstack_sizes.values()), dtype=np.int64) + for name, unstack_sizes in unstack_from_channels_sizes.items()} + total_expected_channels = sum(channels.values()) + found_channels = stacked_array.sizes["channels"] + if total_expected_channels != found_channels: + raise ValueError( + f"Expected {total_expected_channels} channels but found " + f"{found_channels}, when trying to convert a stacked array of shape " + f"{stacked_array.sizes} to a dataset of shape {template_dataset}.") + + data_vars = {} + index = 0 + for name in var_names: + template_var = template_dataset[name] + var = stacked_array.isel({"channels": slice(index, index + channels[name])}) + index += channels[name] + var = var.unstack({"channels": unstack_from_channels_sizes[name]}) + var = var.transpose(*template_var.dims) + data_vars[name] = xarray.DataArray( + data=var, + coords=template_var.coords, + # This might not always be the same as the name it's keyed under; it + # will refer to the original variable name, whereas the key might be + # some alias e.g. temperature_850 under which it should be logged: + name=template_var.name, + ) + return type(template_dataset)(data_vars) # pytype:disable=not-callable,wrong-arg-count diff --git a/graphcast/normalization.py b/graphcast/normalization.py new file mode 100644 index 0000000..9bd63e4 --- /dev/null +++ b/graphcast/normalization.py @@ -0,0 +1,196 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrappers for Predictors which allow them to work with normalized data. + +The Predictor which is wrapped sees normalized inputs and targets, and makes +normalized predictions. The wrapper handles translating the predictions back +to the original domain. +""" + +import logging +from typing import Optional, Tuple + +from graphcast import predictor_base +from graphcast import xarray_tree +import xarray + + +def normalize(values: xarray.Dataset, + scales: xarray.Dataset, + locations: Optional[xarray.Dataset], + ) -> xarray.Dataset: + """Normalize variables using the given scales and (optionally) locations.""" + def normalize_array(array): + if array.name is None: + raise ValueError( + "Can't look up normalization constants because array has no name.") + if locations is not None: + if array.name in locations: + array = array - locations[array.name].astype(array.dtype) + else: + logging.warning('No normalization location found for %s', array.name) + if array.name in scales: + array = array / scales[array.name].astype(array.dtype) + else: + logging.warning('No normalization scale found for %s', array.name) + return array + return xarray_tree.map_structure(normalize_array, values) + + +def unnormalize(values: xarray.Dataset, + scales: xarray.Dataset, + locations: Optional[xarray.Dataset], + ) -> xarray.Dataset: + """Unnormalize variables using the given scales and (optionally) locations.""" + def unnormalize_array(array): + if array.name is None: + raise ValueError( + "Can't look up normalization constants because array has no name.") + if array.name in scales: + array = array * scales[array.name].astype(array.dtype) + else: + logging.warning('No normalization scale found for %s', array.name) + if locations is not None: + if array.name in locations: + array = array + locations[array.name].astype(array.dtype) + else: + logging.warning('No normalization location found for %s', array.name) + return array + return xarray_tree.map_structure(unnormalize_array, values) + + +class InputsAndResiduals(predictor_base.Predictor): + """Wraps with a residual connection, normalizing inputs and target residuals. + + The inner predictor is given inputs that are normalized using `locations` + and `scales` to roughly zero-mean unit variance. + + For target variables that are present in the inputs, the inner predictor is + trained to predict residuals (target - last_frame_of_input) that have been + normalized using `residual_scales` (and optionally `residual_locations`) to + roughly unit variance / zero mean. + + This replaces `residual.Predictor` in the case where you want normalization + that's based on the scales of the residuals. + + Since we return the underlying predictor's loss on the normalized residuals, + if the underlying predictor is a sum of per-variable losses, the normalization + will affect the relative weighting of the per-variable loss terms (hopefully + in a good way). + + For target variables *not* present in the inputs, the inner predictor is + trained to predict targets directly, that have been normalized in the same + way as the inputs. + + The transforms applied to the targets (the residual connection and the + normalization) are applied in reverse to the predictions before returning + them. + """ + + def __init__( + self, + predictor: predictor_base.Predictor, + stddev_by_level: xarray.Dataset, + mean_by_level: xarray.Dataset, + diffs_stddev_by_level: xarray.Dataset): + self._predictor = predictor + self._scales = stddev_by_level + self._locations = mean_by_level + self._residual_scales = diffs_stddev_by_level + self._residual_locations = None + + def _unnormalize_prediction_and_add_input(self, inputs, norm_prediction): + if norm_prediction.sizes.get('time') != 1: + raise ValueError( + 'normalization.InputsAndResiduals only supports predicting a ' + 'single timestep.') + if norm_prediction.name in inputs: + # Residuals are assumed to be predicted as normalized (unit variance), + # but the scale and location they need mapping to is that of the residuals + # not of the values themselves. + prediction = unnormalize( + norm_prediction, self._residual_scales, self._residual_locations) + # A prediction for which we have a corresponding input -- we are + # predicting the residual: + last_input = inputs[norm_prediction.name].isel(time=-1) + prediction += last_input + return prediction + else: + # A predicted variable which is not an input variable. We are predicting + # it directly, so unnormalize it directly to the target scale/location: + return unnormalize(norm_prediction, self._scales, self._locations) + + def _subtract_input_and_normalize_target(self, inputs, target): + if target.sizes.get('time') != 1: + raise ValueError( + 'normalization.InputsAndResiduals only supports wrapping predictors' + 'that predict a single timestep.') + if target.name in inputs: + target_residual = target + last_input = inputs[target.name].isel(time=-1) + target_residual -= last_input + return normalize( + target_residual, self._residual_scales, self._residual_locations) + else: + return normalize(target, self._scales, self._locations) + + def __call__(self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: xarray.Dataset, + **kwargs + ) -> xarray.Dataset: + norm_inputs = normalize(inputs, self._scales, self._locations) + norm_forcings = normalize(forcings, self._scales, self._locations) + norm_predictions = self._predictor( + norm_inputs, targets_template, forcings=norm_forcings, **kwargs) + return xarray_tree.map_structure( + lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred), + norm_predictions) + + def loss(self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: xarray.Dataset, + **kwargs, + ) -> predictor_base.LossAndDiagnostics: + """Returns the loss computed on normalized inputs and targets.""" + norm_inputs = normalize(inputs, self._scales, self._locations) + norm_forcings = normalize(forcings, self._scales, self._locations) + norm_target_residuals = xarray_tree.map_structure( + lambda t: self._subtract_input_and_normalize_target(inputs, t), + targets) + return self._predictor.loss( + norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs) + + def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray + self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: xarray.Dataset, + **kwargs, + ) -> Tuple[predictor_base.LossAndDiagnostics, + xarray.Dataset]: + """The loss computed on normalized data, with unnormalized predictions.""" + norm_inputs = normalize(inputs, self._scales, self._locations) + norm_forcings = normalize(forcings, self._scales, self._locations) + norm_target_residuals = xarray_tree.map_structure( + lambda t: self._subtract_input_and_normalize_target(inputs, t), + targets) + (loss, scalars), norm_predictions = self._predictor.loss_and_predictions( + norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs) + predictions = xarray_tree.map_structure( + lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred), + norm_predictions) + return (loss, scalars), predictions diff --git a/graphcast/predictor_base.py b/graphcast/predictor_base.py new file mode 100644 index 0000000..6ecb644 --- /dev/null +++ b/graphcast/predictor_base.py @@ -0,0 +1,170 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Abstract base classes for an xarray-based Predictor API.""" + +import abc + +from typing import Tuple + +from graphcast import losses +from graphcast import xarray_jax +import jax.numpy as jnp +import xarray + +LossAndDiagnostics = losses.LossAndDiagnostics + + +class Predictor(abc.ABC): + """A possibly-trainable predictor of weather, exposing an xarray-based API. + + Typically wraps an underlying JAX model and handles translating the xarray + Dataset values to and from plain JAX arrays that are convenient for input to + (and output from) the underlying model. + + Different subclasses may exist to wrap different kinds of underlying model, + e.g. models taking stacked inputs/outputs, models taking separate 2D and 3D + inputs/outputs, autoregressive models. + + You can also implement a specific model directly as a Predictor if you want, + for example if it has quite specific/unique requirements for its input/output + or loss function, or if it's convenient to implement directly using xarray. + """ + + @abc.abstractmethod + def __call__(self, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: xarray.Dataset, + **optional_kwargs + ) -> xarray.Dataset: + """Makes predictions. + + This is only used by the Experiment for inference / evaluation, with + training going via the .loss method. So it should default to making + predictions for evaluation, although you can also support making predictions + for use in the loss via an is_training argument -- see + LossFunctionPredictor which helps with that. + + Args: + inputs: An xarray.Dataset of inputs. + targets_template: An xarray.Dataset or other mapping of xarray.DataArrays, + with the same shape as the targets, to demonstrate what kind of + predictions are required. You can use this to determine which variables, + levels and lead times must be predicted. + You are free to raise an error if you don't support predicting what is + requested. + forcings: An xarray.Dataset of forcings terms. Forcings are variables + that can be fed to the model, but do not need to be predicted. This is + often because this variable can be computed analytically (e.g. the toa + radiation of the sun is mostly a function of geometry) or are considered + to be controlled for the experiment (e.g., impose a scenario of C02 + emission into the atmosphere). Unlike `inputs`, the `forcings` can + include information "from the future", that is, information at target + times specified in the `targets_template`. + **optional_kwargs: Implementations may support extra optional kwargs, + provided they set appropriate defaults for them. + + Returns: + Predictions, as an xarray.Dataset or other mapping of DataArrays which + is capable of being evaluated against targets with shape given by + targets_template. + For probabilistic predictors which can return multiple samples from a + predictive distribution, these should (by convention) be returned along + an additional 'sample' dimension. + """ + + def loss(self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: xarray.Dataset, + **optional_kwargs, + ) -> LossAndDiagnostics: + """Computes a training loss, for predictors that are trainable. + + Why make this the Predictor's responsibility, rather than letting callers + compute their own loss function using predictions obtained from + Predictor.__call__? + + Doing it this way gives Predictors more control over their training setup. + For example, some predictors may wish to train using different targets to + the ones they predict at evaluation time -- perhaps different lead times and + variables, perhaps training to predict transformed versions of targets + where the transform needs to be inverted at evaluation time, etc. + + It's also necessary for generative models (VAEs, GANs, ...) where the + training loss is more complex and isn't expressible as a parameter-free + function of predictions and targets. + + Args: + inputs: An xarray.Dataset. + targets: An xarray.Dataset or other mapping of xarray.DataArrays. See + docs on __call__ for an explanation about the targets. + forcings: xarray.Dataset of forcing terms. + **optional_kwargs: Implementations may support extra optional kwargs, + provided they set appropriate defaults for them. + + Returns: + loss: A DataArray with dimensions ('batch',) containing losses for each + element of the batch. These will be averaged to give the final + loss, locally and across replicas. + diagnostics: Mapping of additional quantities to log by name alongside the + loss. These will will typically correspond to terms in the loss. They + should also have dimensions ('batch',) and will be averaged over the + batch before logging. + You need not include the loss itself in this dict; it will be added for + you. + """ + del targets, forcings, optional_kwargs + batch_size = inputs.sizes['batch'] + dummy_loss = xarray_jax.DataArray(jnp.zeros(batch_size), dims=('batch',)) + return dummy_loss, {} + + def loss_and_predictions( + self, + inputs: xarray.Dataset, + targets: xarray.Dataset, + forcings: xarray.Dataset, + **optional_kwargs, + ) -> Tuple[LossAndDiagnostics, xarray.Dataset]: + """Like .loss but also returns corresponding predictions. + + Implementing this is optional as it's not used directly by the Experiment, + but it is required by autoregressive.Predictor when applying an inner + Predictor autoregressively at training time; we need a loss at each step but + also predictions to feed back in for the next step. + + Note the loss itself may not be directly regressing the predictions towards + targets, the loss may be computed in terms of transformed predictions and + targets (or in some other way). For this reason we can't always cleanly + separate this into step 1: get predictions, step 2: compute loss from them, + hence the need for this combined method. + + Args: + inputs: + targets: + forcings: + **optional_kwargs: + As for self.loss. + + Returns: + (loss, diagnostics) + As for self.loss + predictions: + The predictions which the loss relates to. These should be of the same + shape as what you would get from + `self.__call__(inputs, targets_template=targets)`, and should be in the + same 'domain' as the inputs (i.e. they shouldn't be transformed + differently to how the predictor expects its inputs). + """ + raise NotImplementedError diff --git a/graphcast/rollout.py b/graphcast/rollout.py new file mode 100644 index 0000000..d0dc0e7 --- /dev/null +++ b/graphcast/rollout.py @@ -0,0 +1,267 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for rolling out models.""" + +from typing import Iterator + +from absl import logging +import chex +import dask +from graphcast import xarray_tree +import jax +import numpy as np +import typing_extensions +import xarray + + +class PredictorFn(typing_extensions.Protocol): + """Functional version of base.Predictor.__call__ with explicit rng.""" + + def __call__( + self, rng: chex.PRNGKey, inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: xarray.Dataset, + **optional_kwargs, + ) -> xarray.Dataset: + ... + + +def chunked_prediction( + predictor_fn: PredictorFn, + rng: chex.PRNGKey, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: xarray.Dataset, + num_steps_per_chunk: int = 1, + verbose: bool = False, +) -> xarray.Dataset: + """Outputs a long trajectory by iteratively concatenating chunked predictions. + + Args: + predictor_fn: Function to use to make predictions for each chunk. + rng: Random key. + inputs: Inputs for the model. + targets_template: Template for the target prediction, requires targets + equispaced in time. + forcings: Optional forcing for the model. + num_steps_per_chunk: How many of the steps in `targets_template` to predict + at each call of `predictor_fn`. It must evenly divide the number of + steps in `targets_template`. + verbose: Whether to log the current chunk being predicted. + + Returns: + Predictions for the targets template. + + """ + chunks_list = [] + for prediction_chunk in chunked_prediction_generator( + predictor_fn=predictor_fn, + rng=rng, + inputs=inputs, + targets_template=targets_template, + forcings=forcings, + num_steps_per_chunk=num_steps_per_chunk, + verbose=verbose): + chunks_list.append(jax.device_get(prediction_chunk)) + return xarray.concat(chunks_list, dim="time") + + +def chunked_prediction_generator( + predictor_fn: PredictorFn, + rng: chex.PRNGKey, + inputs: xarray.Dataset, + targets_template: xarray.Dataset, + forcings: xarray.Dataset, + num_steps_per_chunk: int = 1, + verbose: bool = False, +) -> Iterator[xarray.Dataset]: + """Outputs a long trajectory by yielding chunked predictions. + + Args: + predictor_fn: Function to use to make predictions for each chunk. + rng: Random key. + inputs: Inputs for the model. + targets_template: Template for the target prediction, requires targets + equispaced in time. + forcings: Optional forcing for the model. + num_steps_per_chunk: How many of the steps in `targets_template` to predict + at each call of `predictor_fn`. It must evenly divide the number of + steps in `targets_template`. + verbose: Whether to log the current chunk being predicted. + + Yields: + The predictions for each chunked step of the chunked rollout, such as + if all predictions are concatenated in time this would match the targets + template in structure. + + """ + + # Create copies to avoid mutating inputs. + inputs = xarray.Dataset(inputs) + targets_template = xarray.Dataset(targets_template) + forcings = xarray.Dataset(forcings) + + if "datetime" in inputs.coords: + del inputs.coords["datetime"] + + if "datetime" in targets_template.coords: + output_datetime = targets_template.coords["datetime"] + del targets_template.coords["datetime"] + else: + output_datetime = None + + if "datetime" in forcings.coords: + del forcings.coords["datetime"] + + num_target_steps = targets_template.dims["time"] + num_chunks, remainder = divmod(num_target_steps, num_steps_per_chunk) + if remainder != 0: + raise ValueError( + f"The number of steps per chunk {num_steps_per_chunk} must " + f"evenly divide the number of target steps {num_target_steps} ") + + if len(np.unique(np.diff(targets_template.coords["time"].data))) > 1: + raise ValueError("The targets time coordinates must be evenly spaced") + + # Our template targets will always have a time axis corresponding for the + # timedeltas for the first chunk. + targets_chunk_time = targets_template.time.isel( + time=slice(0, num_steps_per_chunk)) + + current_inputs = inputs + for chunk_index in range(num_chunks): + if verbose: + logging.info("Chunk %d/%d", chunk_index, num_chunks) + logging.flush() + + # Select targets for the time period that we are predicting for this chunk. + target_offset = num_steps_per_chunk * chunk_index + target_slice = slice(target_offset, target_offset + num_steps_per_chunk) + current_targets_template = targets_template.isel(time=target_slice) + + # Replace the timedelta, by the one corresponding to the first chunk, so we + # don't recompile at every iteration, keeping the + actual_target_time = current_targets_template.coords["time"] + current_targets_template = current_targets_template.assign_coords( + time=targets_chunk_time).compute() + + current_forcings = forcings.isel(time=target_slice) + current_forcings = current_forcings.assign_coords(time=targets_chunk_time) + current_forcings = current_forcings.compute() + # Make predictions for the chunk. + rng, this_rng = jax.random.split(rng) + predictions = predictor_fn( + rng=this_rng, + inputs=current_inputs, + targets_template=current_targets_template, + forcings=current_forcings) + + next_frame = xarray.merge([predictions, current_forcings]) + + current_inputs = _get_next_inputs(current_inputs, next_frame) + + # At this point we can assign the actual targets time coordinates. + predictions = predictions.assign_coords(time=actual_target_time) + if output_datetime is not None: + predictions.coords["datetime"] = output_datetime.isel( + time=target_slice) + yield predictions + del predictions + + +def _get_next_inputs( + prev_inputs: xarray.Dataset, next_frame: xarray.Dataset, + ) -> xarray.Dataset: + """Computes next inputs, from previous inputs and predictions.""" + + # Make sure are are predicting all inputs with a time axis. + non_predicted_or_forced_inputs = list( + set(prev_inputs.keys()) - set(next_frame.keys())) + if "time" in prev_inputs[non_predicted_or_forced_inputs].dims: + raise ValueError( + "Found an input with a time index that is not predicted or forced.") + + # Keys we need to copy from predictions to inputs. + next_inputs_keys = list( + set(next_frame.keys()).intersection(set(prev_inputs.keys()))) + next_inputs = next_frame[next_inputs_keys] + + # Apply concatenate next frame with inputs, crop what we don't need and + # shift timedelta coordinates, so we don't recompile at every iteration. + num_inputs = prev_inputs.dims["time"] + return ( + xarray.concat( + [prev_inputs, next_inputs], dim="time", data_vars="different") + .tail(time=num_inputs) + .assign_coords(time=prev_inputs.coords["time"])) + + +def extend_targets_template( + targets_template: xarray.Dataset, + required_num_steps: int) -> xarray.Dataset: + """Extends `targets_template` to `required_num_steps` with lazy arrays. + + It uses lazy dask arrays of zeros, so it does not require instantiating the + array in memory. + + Args: + targets_template: Input template to extend. + required_num_steps: Number of steps required in the returned template. + + Returns: + `xarray.Dataset` identical in variables and timestep to `targets_template` + full of `dask.array.zeros` such that the time axis has `required_num_steps`. + + """ + + # Extend the "time" and "datetime" coordinates + time = targets_template.coords["time"] + + # Assert the first target time corresponds to the timestep. + timestep = time[0].data + if time.shape[0] > 1: + assert np.all(timestep == time[1:] - time[:-1]) + + extended_time = (np.arange(required_num_steps) + 1) * timestep + + if "datetime" in targets_template.coords: + datetime = targets_template.coords["datetime"] + extended_datetime = (datetime[0].data - timestep) + extended_time + else: + extended_datetime = None + + # Replace the values with empty dask arrays extending the time coordinates. + datetime = targets_template.coords["time"] + + def extend_time(data_array: xarray.DataArray) -> xarray.DataArray: + dims = data_array.dims + shape = list(data_array.shape) + shape[dims.index("time")] = required_num_steps + dask_data = dask.array.zeros( + shape=tuple(shape), + chunks=-1, # Will give chunk info directly to `ChunksToZarr``. + dtype=data_array.dtype) + + coords = dict(data_array.coords) + coords["time"] = extended_time + + if extended_datetime is not None: + coords["datetime"] = ("time", extended_datetime) + + return xarray.DataArray( + dims=dims, + data=dask_data, + coords=coords) + + return xarray_tree.map_structure(extend_time, targets_template) diff --git a/graphcast/typed_graph.py b/graphcast/typed_graph.py new file mode 100644 index 0000000..12dd981 --- /dev/null +++ b/graphcast/typed_graph.py @@ -0,0 +1,97 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data-structure for storing graphs with typed edges and nodes.""" + +from typing import NamedTuple, Any, Union, Tuple, Mapping, TypeVar + +ArrayLike = Union[Any] # np.ndarray, jnp.ndarray, tf.tensor +ArrayLikeTree = Union[Any, ArrayLike] # Nest of ArrayLike + +_T = TypeVar('_T') + + +# All tensors have a "flat_batch_axis", which is similar to the leading +# axes of graph_tuples: +# * In the case of nodes this is simply a shared node and flat batch axis, with +# size corresponding to the total number of nodes in the flattened batch. +# * In the case of edges this is simply a shared edge and flat batch axis, with +# size corresponding to the total number of edges in the flattened batch. +# * In the case of globals this is simply the number of graphs in the flattened +# batch. + +# All shapes may also have any additional leading shape "batch_shape". +# Options for building batches are: +# * Use a provided "flatten" method that takes a leading `batch_shape` and +# it into the flat_batch_axis (this will be useful when using `tf.Dataset` +# which supports batching into RaggedTensors, with leading batch shape even +# if graphs have different numbers of nodes and edges), so the RaggedBatches +# can then be converted into something without ragged dimensions that jax can +# use. +# * Directly build a "flat batch" using a provided function for batching a list +# of graphs (how it is done in `jraph`). + + +class NodeSet(NamedTuple): + """Represents a set of nodes.""" + n_node: ArrayLike # [num_flat_graphs] + features: ArrayLikeTree # Prev. `nodes`: [num_flat_nodes] + feature_shape + + +class EdgesIndices(NamedTuple): + """Represents indices to nodes adjacent to the edges.""" + senders: ArrayLike # [num_flat_edges] + receivers: ArrayLike # [num_flat_edges] + + +class EdgeSet(NamedTuple): + """Represents a set of edges.""" + n_edge: ArrayLike # [num_flat_graphs] + indices: EdgesIndices + features: ArrayLikeTree # Prev. `edges`: [num_flat_edges] + feature_shape + + +class Context(NamedTuple): + # `n_graph` always contains ones but it is useful to query the leading shape + # in case of graphs without any nodes or edges sets. + n_graph: ArrayLike # [num_flat_graphs] + features: ArrayLikeTree # Prev. `globals`: [num_flat_graphs] + feature_shape + + +class EdgeSetKey(NamedTuple): + name: str # Name of the EdgeSet. + + # Sender node set name and receiver node set name connected by the edge set. + node_sets: Tuple[str, str] + + +class TypedGraph(NamedTuple): + """A graph with typed nodes and edges. + + A typed graph is made of a context, multiple sets of nodes and multiple + sets of edges connecting those nodes (as indicated by the EdgeSetKey). + """ + + context: Context + nodes: Mapping[str, NodeSet] + edges: Mapping[EdgeSetKey, EdgeSet] + + def edge_key_by_name(self, name: str) -> EdgeSetKey: + found_key = [k for k in self.edges.keys() if k.name == name] + if len(found_key) != 1: + raise KeyError("invalid edge key '{}'. Available edges: [{}]".format( + name, ', '.join(x.name for x in self.edges.keys()))) + return found_key[0] + + def edge_by_name(self, name: str) -> EdgeSet: + return self.edges[self.edge_key_by_name(name)] diff --git a/graphcast/typed_graph_net.py b/graphcast/typed_graph_net.py new file mode 100644 index 0000000..aa62ac3 --- /dev/null +++ b/graphcast/typed_graph_net.py @@ -0,0 +1,317 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A library of typed Graph Neural Networks.""" + +from typing import Callable, Mapping, Optional, Union + +from graphcast import typed_graph +import jax.numpy as jnp +import jax.tree_util as tree +import jraph + + +# All features will be an ArrayTree. +NodeFeatures = EdgeFeatures = SenderFeatures = ReceiverFeatures = Globals = ( + jraph.ArrayTree) + +# Signature: +# (node features, outgoing edge features, incoming edge features, +# globals) -> updated node features +GNUpdateNodeFn = Callable[ + [NodeFeatures, Mapping[str, SenderFeatures], Mapping[str, ReceiverFeatures], + Globals], + NodeFeatures] + +GNUpdateGlobalFn = Callable[ + [Mapping[str, NodeFeatures], Mapping[str, EdgeFeatures], Globals], + Globals] + + +def GraphNetwork( # pylint: disable=invalid-name + update_edge_fn: Mapping[str, jraph.GNUpdateEdgeFn], + update_node_fn: Mapping[str, GNUpdateNodeFn], + update_global_fn: Optional[GNUpdateGlobalFn] = None, + aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph + .segment_sum, + aggregate_nodes_for_globals_fn: jraph.AggregateNodesToGlobalsFn = jraph + .segment_sum, + aggregate_edges_for_globals_fn: jraph.AggregateEdgesToGlobalsFn = jraph + .segment_sum, + ): + """Returns a method that applies a configured GraphNetwork. + + This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 + extended to Typed Graphs with multiple edge sets and node sets and extended to + allow aggregating not only edges received by the nodes, but also edges sent by + the nodes. + + Example usage:: + + gn = GraphNetwork(update_edge_function, + update_node_function, **kwargs) + # Conduct multiple rounds of message passing with the same parameters: + for _ in range(num_message_passing_steps): + graph = gn(graph) + + Args: + update_edge_fn: mapping of functions used to update a subset of the edge + types, indexed by edge type name. + update_node_fn: mapping of functions used to update a subset of the node + types, indexed by node type name. + update_global_fn: function used to update the globals or None to deactivate + globals updates. + aggregate_edges_for_nodes_fn: function used to aggregate messages to each + node. + aggregate_nodes_for_globals_fn: function used to aggregate the nodes for the + globals. + aggregate_edges_for_globals_fn: function used to aggregate the edges for the + globals. + + Returns: + A method that applies the configured GraphNetwork. + """ + + def _apply_graph_net(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + """Applies a configured GraphNetwork to a graph. + + This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 + extended to Typed Graphs with multiple edge sets and node sets and extended + to allow aggregating not only edges received by the nodes, but also edges + sent by the nodes. + + Args: + graph: a `TypedGraph` containing the graph. + + Returns: + Updated `TypedGraph`. + """ + + updated_graph = graph + + # Edge update. + updated_edges = dict(updated_graph.edges) + for edge_set_name, edge_fn in update_edge_fn.items(): + edge_set_key = graph.edge_key_by_name(edge_set_name) + updated_edges[edge_set_key] = _edge_update( + updated_graph, edge_fn, edge_set_key) + updated_graph = updated_graph._replace(edges=updated_edges) + + # Node update. + updated_nodes = dict(updated_graph.nodes) + for node_set_key, node_fn in update_node_fn.items(): + updated_nodes[node_set_key] = _node_update( + updated_graph, node_fn, node_set_key, aggregate_edges_for_nodes_fn) + updated_graph = updated_graph._replace(nodes=updated_nodes) + + # Global update. + if update_global_fn: + updated_context = _global_update( + updated_graph, update_global_fn, + aggregate_edges_for_globals_fn, + aggregate_nodes_for_globals_fn) + updated_graph = updated_graph._replace(context=updated_context) + + return updated_graph + + return _apply_graph_net + + +def _edge_update(graph, edge_fn, edge_set_key): # pylint: disable=invalid-name + """Updates an edge set of a given key.""" + + sender_nodes = graph.nodes[edge_set_key.node_sets[0]] + receiver_nodes = graph.nodes[edge_set_key.node_sets[1]] + edge_set = graph.edges[edge_set_key] + senders = edge_set.indices.senders # pytype: disable=attribute-error + receivers = edge_set.indices.receivers # pytype: disable=attribute-error + + sent_attributes = tree.tree_map( + lambda n: n[senders], sender_nodes.features) + received_attributes = tree.tree_map( + lambda n: n[receivers], receiver_nodes.features) + + n_edge = edge_set.n_edge + sum_n_edge = senders.shape[0] + global_features = tree.tree_map( + lambda g: jnp.repeat(g, n_edge, axis=0, total_repeat_length=sum_n_edge), + graph.context.features) + new_features = edge_fn( + edge_set.features, sent_attributes, received_attributes, + global_features) + return edge_set._replace(features=new_features) + + +def _node_update(graph, node_fn, node_set_key, aggregation_fn): # pylint: disable=invalid-name + """Updates an edge set of a given key.""" + node_set = graph.nodes[node_set_key] + sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0] + + sent_features = {} + for edge_set_key, edge_set in graph.edges.items(): + sender_node_set_key = edge_set_key.node_sets[0] + if sender_node_set_key == node_set_key: + assert isinstance(edge_set.indices, typed_graph.EdgesIndices) + senders = edge_set.indices.senders + sent_features[edge_set_key.name] = tree.tree_map( + lambda e: aggregation_fn(e, senders, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop + + received_features = {} + for edge_set_key, edge_set in graph.edges.items(): + receiver_node_set_key = edge_set_key.node_sets[1] + if receiver_node_set_key == node_set_key: + assert isinstance(edge_set.indices, typed_graph.EdgesIndices) + receivers = edge_set.indices.receivers + received_features[edge_set_key.name] = tree.tree_map( + lambda e: aggregation_fn(e, receivers, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop + + n_node = node_set.n_node + global_features = tree.tree_map( + lambda g: jnp.repeat(g, n_node, axis=0, total_repeat_length=sum_n_node), + graph.context.features) + new_features = node_fn( + node_set.features, sent_features, received_features, global_features) + return node_set._replace(features=new_features) + + +def _global_update(graph, global_fn, edge_aggregation_fn, node_aggregation_fn): # pylint: disable=invalid-name + """Updates an edge set of a given key.""" + n_graph = graph.context.n_graph.shape[0] + graph_idx = jnp.arange(n_graph) + + edge_features = {} + for edge_set_key, edge_set in graph.edges.items(): + assert isinstance(edge_set.indices, typed_graph.EdgesIndices) + sum_n_edge = edge_set.indices.senders.shape[0] + edge_gr_idx = jnp.repeat( + graph_idx, edge_set.n_edge, axis=0, total_repeat_length=sum_n_edge) + edge_features[edge_set_key.name] = tree.tree_map( + lambda e: edge_aggregation_fn(e, edge_gr_idx, n_graph), # pylint: disable=cell-var-from-loop + edge_set.features) + + node_features = {} + for node_set_key, node_set in graph.nodes.items(): + sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0] + node_gr_idx = jnp.repeat( + graph_idx, node_set.n_node, axis=0, total_repeat_length=sum_n_node) + node_features[node_set_key] = tree.tree_map( + lambda n: node_aggregation_fn(n, node_gr_idx, n_graph), # pylint: disable=cell-var-from-loop + node_set.features) + + new_features = global_fn(node_features, edge_features, graph.context.features) + return graph.context._replace(features=new_features) + + +InteractionUpdateNodeFn = Callable[ + [jraph.NodeFeatures, + Mapping[str, SenderFeatures], + Mapping[str, ReceiverFeatures]], + jraph.NodeFeatures] + + +InteractionUpdateNodeFnNoSentEdges = Callable[ + [jraph.NodeFeatures, + Mapping[str, ReceiverFeatures]], + jraph.NodeFeatures] + + +def InteractionNetwork( # pylint: disable=invalid-name + update_edge_fn: Mapping[str, jraph.InteractionUpdateEdgeFn], + update_node_fn: Mapping[str, Union[InteractionUpdateNodeFn, + InteractionUpdateNodeFnNoSentEdges]], + aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph + .segment_sum, + include_sent_messages_in_node_update: bool = False): + """Returns a method that applies a configured InteractionNetwork. + + An interaction network computes interactions on the edges based on the + previous edges features, and on the features of the nodes sending into those + edges. It then updates the nodes based on the incoming updated edges. + See https://arxiv.org/abs/1612.00222 for more details. + + This implementation extends the behavior to `TypedGraphs` adding an option + to include edge features for which a node is a sender in the arguments to + the node update function. + + Args: + update_edge_fn: mapping of functions used to update a subset of the edge + types, indexed by edge type name. + update_node_fn: mapping of functions used to update a subset of the node + types, indexed by node type name. + aggregate_edges_for_nodes_fn: function used to aggregate messages to each + node. + include_sent_messages_in_node_update: pass edge features for which a node is + a sender to the node update function. + """ + # An InteractionNetwork is a GraphNetwork without globals features, + # so we implement the InteractionNetwork as a configured GraphNetwork. + + # An InteractionNetwork edge function does not have global feature inputs, + # so we filter the passed global argument in the GraphNetwork. + wrapped_update_edge_fn = tree.tree_map( + lambda fn: lambda e, s, r, g: fn(e, s, r), update_edge_fn) + + # Similarly, we wrap the update_node_fn to ensure only the expected + # arguments are passed to the Interaction net. + if include_sent_messages_in_node_update: + wrapped_update_node_fn = tree.tree_map( + lambda fn: lambda n, s, r, g: fn(n, s, r), update_node_fn) + else: + wrapped_update_node_fn = tree.tree_map( + lambda fn: lambda n, s, r, g: fn(n, r), update_node_fn) + return GraphNetwork( + update_edge_fn=wrapped_update_edge_fn, + update_node_fn=wrapped_update_node_fn, + aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn) + + +def GraphMapFeatures( # pylint: disable=invalid-name + embed_edge_fn: Optional[Mapping[str, jraph.EmbedEdgeFn]] = None, + embed_node_fn: Optional[Mapping[str, jraph.EmbedNodeFn]] = None, + embed_global_fn: Optional[jraph.EmbedGlobalFn] = None): + """Returns function which embeds the components of a graph independently. + + Args: + embed_edge_fn: mapping of functions used to embed each edge type, + indexed by edge type name. + embed_node_fn: mapping of functions used to embed each node type, + indexed by node type name. + embed_global_fn: function used to embed the globals. + """ + + def _embed(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph: + + updated_edges = dict(graph.edges) + if embed_edge_fn: + for edge_set_name, embed_fn in embed_edge_fn.items(): + edge_set_key = graph.edge_key_by_name(edge_set_name) + edge_set = graph.edges[edge_set_key] + updated_edges[edge_set_key] = edge_set._replace( + features=embed_fn(edge_set.features)) + + updated_nodes = dict(graph.nodes) + if embed_node_fn: + for node_set_key, embed_fn in embed_node_fn.items(): + node_set = graph.nodes[node_set_key] + updated_nodes[node_set_key] = node_set._replace( + features=embed_fn(node_set.features)) + + updated_context = graph.context + if embed_global_fn: + updated_context = updated_context._replace( + features=embed_global_fn(updated_context.features)) + + return graph._replace(edges=updated_edges, nodes=updated_nodes, + context=updated_context) + + return _embed diff --git a/graphcast/xarray_jax.py b/graphcast/xarray_jax.py new file mode 100644 index 0000000..8d60743 --- /dev/null +++ b/graphcast/xarray_jax.py @@ -0,0 +1,795 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helpers to use xarray.{Variable,DataArray,Dataset} with JAX. + +Allows them to be based on JAX arrays without converting to numpy arrays under +the hood, so you can start with a JAX array, do some computation with it in +xarray-land, get a JAX array out the other end and (for example) jax.jit +through the whole thing. You can even jax.jit a function which accepts and +returns xarray.Dataset, DataArray and Variable. + +## Creating xarray datatypes from jax arrays, and vice-versa. + +You can use the xarray_jax.{Variable,DataArray,Dataset} constructors, which have +the same API as the standard xarray constructors but will accept JAX arrays +without converting them to numpy. + +It does this by wrapping the JAX array in a wrapper before passing it to +xarray; you can also do this manually by calling xarray_jax.wrap on your JAX +arrays before passing them to the standard xarray constructors. + +To get non-wrapped JAX arrays out the other end, you can use e.g.: + + xarray_jax.jax_vars(dataset) + xarray_jax.jax_data(dataset.some_var) + +which will complain if the data isn't actually a JAX array. Use this if you need +to make sure the computation has gone via JAX, e.g. if it's the output of code +that you want to JIT or compute gradients through. If this is not the case and +you want to support passing plain numpy arrays through as well as potentially +JAX arrays, you can use: + + xarray_jax.unwrap_vars(dataset) + xarray_jax.unwrap_data(dataset.some_var) + +which will unwrap the data if it is a wrapped JAX array, but otherwise pass +it through to you without complaint. + +The wrapped JAX arrays aim to support all the core operations from the numpy +array API that xarray expects, however there may still be some gaps; if you run +into any problems around this, you may need to add a few more proxy methods onto +the wrapper class below. + +In future once JAX and xarray support the new Python array API standard +(https://data-apis.org/array-api/latest/index.html), we hope to avoid the need +for wrapping the JAX arrays like this. + +## jax.jit and pmap of functions taking and returning xarray datatypes + +We register xarray datatypes with jax.tree_util, which allows them to be treated +as generic containers of jax arrays by various parts of jax including jax.jit. + +This allows for, e.g.: + + @jax.jit + def foo(input: xarray.Dataset) -> xarray.Dataset: + ... + +It will not work out-of-the-box with shape-modifying transformations like +jax.pmap, or e.g. a jax.tree_util.tree_map with some transform that alters array +shapes or dimension order. That's because we won't know what dimension names +and/or coordinates to use when unflattening, if the results have a different +shape to the data that was originally flattened. + +You can work around this using xarray_jax.dims_change_on_unflatten, however, +and in the case of jax.pmap we provide a wrapper xarray_jax.pmap which allows +it to be used with functions taking and returning xarrays. + +## Treatment of coordinates + +We don't support passing jax arrays as coordinates when constructing a +DataArray/Dataset. This is because xarray's advanced indexing and slicing is +unlikely to work with jax arrays (at least when a Tracer is used during +jax.jit), and also because some important datatypes used for coordinates, like +timedelta64 and datetime64, are not supported by jax. + +For the purposes of tree_util and jax.jit, coordinates are not treated as leaves +of the tree (array data 'contained' by a Dataset/DataArray), they are just a +static part of the structure. That means that if a jit'ed function is called +twice with Dataset inputs that use different coordinates, it will compile a +separate version of the function for each. The coordinates are treated like +static_argnums by jax.jit. + +If you want to use dynamic data for coordinates, we recommend making it a +data_var instead of a coord. You won't be able to do indexing and slicing using +the coordinate, but that wasn't going to work with a jax array anyway. +""" + +import collections +import contextlib +import contextvars +from typing import Any, Callable, Hashable, Iterator, Mapping, Optional, Union, Tuple, TypeVar, cast + +import jax +import jax.numpy as jnp +import numpy as np +import tree +import xarray + + +def Variable(dims, data, **kwargs) -> xarray.Variable: # pylint:disable=invalid-name + """Like xarray.Variable, but can wrap JAX arrays.""" + return xarray.Variable(dims, wrap(data), **kwargs) + + +_JAX_COORD_ATTR_NAME = '_jax_coord' + + +def DataArray( # pylint:disable=invalid-name + data, + coords=None, + dims=None, + name=None, + attrs=None, + jax_coords=None, + ) -> xarray.DataArray: + """Like xarray.DataArray, but supports using JAX arrays. + + Args: + data: As for xarray.DataArray, except jax arrays are also supported. + coords: Coordinates for the array, see xarray.DataArray. These coordinates + must be based on plain numpy arrays or something convertible to plain + numpy arrays. Their values will form a static part of the data structure + from the point of view of jax.tree_util. In particular this means these + coordinates will be passed as plain numpy arrays even inside a JIT'd + function, and the JIT'd function will be recompiled under the hood if the + coordinates of DataArrays passed into it change. + If this is not convenient for you, see also jax_coords below. + dims: See xarray.DataArray. + name: See xarray.DataArray. + attrs: See xarray.DataArray. + jax_coords: Additional coordinates, which *can* use JAX arrays. These + coordinates will be treated as JAX data from the point of view of + jax.tree_util, that means when JIT'ing they will be passed as tracers and + computation involving them will be JIT'd. + Unfortunately a side-effect of this is that they can't be used as index + coordinates (because xarray's indexing logic is not JIT-able). If you + specify a coordinate with the same name as a dimension here, it will not + be set as an index coordinate; this behaviour is different to the default + for `coords`, and it means that things like `.sel` based on the jax + coordinate will not work. + Note we require `jax_coords` to be explicitly specified via a different + constructor argument to `coords`, rather than just looking for jax arrays + within the `coords` and treating them differently. This is because it + affects the way jax.tree_util treats them, which is somewhat orthogonal to + whether the value is passed in as numpy or not, and generally needs to be + handled consistently so is something we encourage explicit control over. + + Returns: + An instance of xarray.DataArray. Where JAX arrays are used as data or + coords, they will be wrapped with JaxArrayWrapper and can be unwrapped via + `unwrap` and `unwrap_data`. + """ + result = xarray.DataArray( + wrap(data), dims=dims, name=name, attrs=attrs or {}) + return assign_coords(result, coords=coords, jax_coords=jax_coords) + + +def Dataset( # pylint:disable=invalid-name + data_vars, + coords=None, + attrs=None, + jax_coords=None, + ) -> xarray.Dataset: + """Like xarray.Dataset, but can wrap JAX arrays. + + Args: + data_vars: As for xarray.Dataset, except jax arrays are also supported. + coords: Coordinates for the dataset, see xarray.Dataset. These coordinates + must be based on plain numpy arrays or something convertible to plain + numpy arrays. Their values will form a static part of the data structure + from the point of view of jax.tree_util. In particular this means these + coordinates will be passed as plain numpy arrays even inside a JIT'd + function, and the JIT'd function will be recompiled under the hood if the + coordinates of DataArrays passed into it change. + If this is not convenient for you, see also jax_coords below. + attrs: See xarray.Dataset. + jax_coords: Additional coordinates, which *can* use JAX arrays. These + coordinates will be treated as JAX data from the point of view of + jax.tree_util, that means when JIT'ing they will be passed as tracers and + computation involving them will be JIT'd. + Unfortunately a side-effect of this is that they can't be used as index + coordinates (because xarray's indexing logic is not JIT-able). If you + specify a coordinate with the same name as a dimension here, it will not + be set as an index coordinate; this behaviour is different to the default + for `coords`, and it means that things like `.sel` based on the jax + coordinate will not work. + Note we require `jax_coords` to be explicitly specified via a different + constructor argument to `coords`, rather than just looking for jax arrays + within the `coords` and treating them differently. This is because it + affects the way jax.tree_util treats them, which is somewhat orthogonal to + whether the value is passed in as numpy or not, and generally needs to be + handled consistently so is something we encourage explicit control over. + + Returns: + An instance of xarray.Dataset. Where JAX arrays are used as data, they + will be wrapped with JaxArrayWrapper. + """ + wrapped_data_vars = {} + for name, var_like in data_vars.items(): + # xarray.Dataset accepts a few different formats for data_vars: + if isinstance(var_like, jax.Array): + wrapped_data_vars[name] = wrap(var_like) + elif isinstance(var_like, tuple): + # Layout is (dims, data, ...). We wrap data. + wrapped_data_vars[name] = (var_like[0], wrap(var_like[1])) + var_like[2:] + else: + # Could be a plain numpy array or scalar (we don't wrap), or an + # xarray.Variable, DataArray etc, which we must assume is already wrapped + # if necessary (e.g. if creating using xarray_jax.{Variable,DataArray}). + wrapped_data_vars[name] = var_like + + result = xarray.Dataset( + data_vars=wrapped_data_vars, + attrs=attrs) + + return assign_coords(result, coords=coords, jax_coords=jax_coords) + + +DatasetOrDataArray = TypeVar( + 'DatasetOrDataArray', xarray.Dataset, xarray.DataArray) + + +def assign_coords( + x: DatasetOrDataArray, + *, + coords: Optional[Mapping[Hashable, Any]] = None, + jax_coords: Optional[Mapping[Hashable, Any]] = None, + ) -> DatasetOrDataArray: + """Replacement for assign_coords which works in presence of jax_coords. + + `jax_coords` allow certain specified coordinates to have their data passed as + JAX arrays (including through jax.jit boundaries). The compromise in return is + that they are not created as index coordinates and cannot be used for .sel + and other coordinate-based indexing operations. See docs for `jax_coords` on + xarray_jax.Dataset and xarray_jax.DataArray for more information. + + This function can be used to set jax_coords on an existing DataArray or + Dataset, and also to set a mix of jax and non-jax coordinates. It implements + some workarounds to prevent xarray trying and failing to create IndexVariables + from jax arrays under the hood. + + If you have any jax_coords with the same name as a dimension, you'll need to + use this function instead of data_array.assign_coords or dataset.assign_coords + in general, to avoid an xarray bug where it tries (and in our case fails) to + create indexes for existing jax coords. See + https://github.com/pydata/xarray/issues/7885. + + Args: + x: An xarray Dataset or DataArray. + coords: Dict of (non-JAX) coords, or None if not assigning any. + jax_coords: Dict of JAX coords, or None if not assigning any. See docs for + xarray_jax.Dataset / DataArray for more information on jax_coords. + + Returns: + The Dataset or DataArray with coordinates assigned, similarly to + Dataset.assign_coords / DataArray.assign_coords. + """ + coords = {} if coords is None else dict(coords) # Copy before mutating. + jax_coords = {} if jax_coords is None else dict(jax_coords) + + # Any existing JAX coords must be dropped and re-added via the workaround + # below, since otherwise .assign_coords will trigger an xarray bug where + # it tries to recreate the indexes again for the existing coordinates. + # Can remove if/when https://github.com/pydata/xarray/issues/7885 fixed. + existing_jax_coords = { + name: coord_var for name, coord_var in x.coords.variables.items() + if coord_var.attrs.get(_JAX_COORD_ATTR_NAME, False) + } + jax_coords = existing_jax_coords | jax_coords + x = x.drop_vars(existing_jax_coords.keys()) + + # We need to ensure that xarray doesn't try to create an index for + # coordinates with the same name as a dimension, since this will fail if + # given a wrapped JAX tracer. + # It appears the only way to avoid this is to name them differently to any + # dimension name, then rename them back afterwards. + renamed_jax_coords = {} + for name, coord in jax_coords.items(): + if isinstance(coord, xarray.DataArray): + coord = coord.variable + if isinstance(coord, xarray.Variable): + coord = coord.copy(deep=False) # Copy before mutating attrs. + else: + # Must wrap as Variable with the correct dims first if this has not + # already been done, otherwise xarray.Dataset will assume the dimension + # name is also __NONINDEX_{n}. + coord = Variable((name,), coord) + + # We set an attr on each jax_coord identifying it as such. These attrs on + # the coord Variable gets reflected on the coord DataArray exposed too, and + # when set on coordinates they generally get preserved under the default + # keep_attrs setting. + # These attrs are used by jax.tree_util registered flatten/unflatten to + # determine which coords need to be treated as leaves of the flattened + # structure vs static data. + coord.attrs[_JAX_COORD_ATTR_NAME] = True + renamed_jax_coords[f'__NONINDEX_{name}'] = coord + + x = x.assign_coords(coords=coords | renamed_jax_coords) + + rename_back_mapping = {f'__NONINDEX_{name}': name for name in jax_coords} + if isinstance(x, xarray.Dataset): + # Using 'rename' doesn't work if renaming to the same name as a dimension. + return x.rename_vars(rename_back_mapping) + else: # DataArray + return x.rename(rename_back_mapping) + + +def assign_jax_coords( + x: DatasetOrDataArray, + jax_coords: Optional[Mapping[Hashable, Any]] = None, + **jax_coords_kwargs + ) -> DatasetOrDataArray: + """Assigns only jax_coords, with same API as xarray's assign_coords.""" + return assign_coords(x, jax_coords=jax_coords or jax_coords_kwargs) + + +def wrap(value): + """Wraps JAX arrays for use in xarray, passing through other values.""" + if isinstance(value, jax.Array): + return JaxArrayWrapper(value) + else: + return value + + +def unwrap(value, require_jax=False): + """Unwraps wrapped JAX arrays used in xarray, passing through other values.""" + if isinstance(value, JaxArrayWrapper): + return value.jax_array + elif isinstance(value, jax.Array): + return value + elif require_jax: + raise TypeError(f'Expected JAX array, found {type(value)}.') + else: + return value + + +def _wrapped(func): + """Surrounds a function with JAX array unwrapping/wrapping.""" + def wrapped_func(*args, **kwargs): + args, kwargs = tree.map_structure(unwrap, (args, kwargs)) + result = func(*args, **kwargs) + return tree.map_structure(wrap, result) + return wrapped_func + + +def unwrap_data( + value: Union[xarray.Variable, xarray.DataArray], + require_jax: bool = False + ) -> Union[jax.Array, np.ndarray]: + """The unwrapped (see unwrap) data of a an xarray.Variable or DataArray.""" + return unwrap(value.data, require_jax=require_jax) + + +def unwrap_vars( + dataset: Mapping[Hashable, xarray.DataArray], + require_jax: bool = False + ) -> Mapping[str, Union[jax.Array, np.ndarray]]: + """The unwrapped data (see unwrap) of the variables in a dataset.""" + # xarray types variable names as Hashable, but in practice they're invariably + # strings and we convert to str to allow for a more useful return type. + return {str(name): unwrap_data(var, require_jax=require_jax) + for name, var in dataset.items()} + + +def unwrap_coords( + dataset: Union[xarray.Dataset, xarray.DataArray], + require_jax: bool = False + ) -> Mapping[str, Union[jax.Array, np.ndarray]]: + """The unwrapped data (see unwrap) of the coords in a Dataset or DataArray.""" + return {str(name): unwrap_data(var, require_jax=require_jax) + for name, var in dataset.coords.items()} + + +def jax_data(value: Union[xarray.Variable, xarray.DataArray]) -> jax.Array: + """Like unwrap_data, but will complain if not a jax array.""" + # Implementing this separately so we can give a more specific return type + # for it. + return cast(jax.Array, unwrap_data(value, require_jax=True)) + + +def jax_vars( + dataset: Mapping[Hashable, xarray.DataArray]) -> Mapping[str, jax.Array]: + """Like unwrap_vars, but will complain if vars are not all jax arrays.""" + return cast(Mapping[str, jax.Array], unwrap_vars(dataset, require_jax=True)) + + +class JaxArrayWrapper(np.lib.mixins.NDArrayOperatorsMixin): + """Wraps a JAX array into a duck-typed array suitable for use with xarray. + + This uses an older duck-typed array protocol based on __array_ufunc__ and + __array_function__ which works with numpy and xarray. This is in the process + of being superseded by the Python array API standard + (https://data-apis.org/array-api/latest/index.html), but JAX and xarray + haven't implemented it yet. Once they have, we should be able to get rid of + this wrapper and use JAX arrays directly with xarray. + """ + + def __init__(self, jax_array): + self.jax_array = jax_array + + def __array_ufunc__(self, ufunc, method, *args, **kwargs): + for x in args: + if not isinstance(x, (jax.typing.ArrayLike, type(self))): + return NotImplemented + if method != '__call__': + return NotImplemented + try: + # Get the corresponding jax.numpy function to the NumPy ufunc: + func = getattr(jnp, ufunc.__name__) + except AttributeError: + return NotImplemented + # There may be an 'out' kwarg requesting an in-place operation, e.g. when + # this is called via __iadd__ (+=), __imul__ (*=) etc. JAX doesn't support + # in-place operations so we just remove this argument and have the ufunc + # return a fresh JAX array instead. + kwargs.pop('out', None) + return _wrapped(func)(*args, **kwargs) + + def __array_function__(self, func, types, args, kwargs): + try: + # Get the corresponding jax.np function to the NumPy function: + func = getattr(jnp, func.__name__) + except AttributeError: + return NotImplemented + return _wrapped(func)(*args, **kwargs) + + def __repr__(self): + return f'xarray_jax.JaxArrayWrapper({repr(self.jax_array)})' + + # NDArrayOperatorsMixin already proxies most __dunder__ operator methods. + # We need to proxy through a few more methods in a similar way: + + # Essential array properties: + + @property + def shape(self): + return self.jax_array.shape + + @property + def dtype(self): + return self.jax_array.dtype + + @property + def ndim(self): + return self.jax_array.ndim + + @property + def size(self): + return self.jax_array.size + + # Array methods not covered by NDArrayOperatorsMixin: + + # Allows conversion to numpy array using np.asarray etc. Warning: doing this + # will fail in a jax.jit-ed function. + def __array__(self, dtype=None, context=None): + return np.asarray(self.jax_array, dtype=dtype) + + __getitem__ = _wrapped(lambda array, *args: array.__getitem__(*args)) + # We drop the kwargs on this as they are not supported by JAX, but xarray + # uses at least one of them (the copy arg). + astype = _wrapped(lambda array, *args, **kwargs: array.astype(*args)) + + # There are many more methods which are more canonically available via (j)np + # functions, e.g. .sum() available via jnp.sum, and also mean, max, min, + # argmax, argmin etc. We don't attempt to proxy through all of these as + # methods, since this doesn't appear to be expected from a duck-typed array + # implementation. But there are a few which xarray calls as methods, so we + # proxy those: + transpose = _wrapped(jnp.transpose) + reshape = _wrapped(jnp.reshape) + all = _wrapped(jnp.all) + + +def apply_ufunc(func, *args, require_jax=False, **apply_ufunc_kwargs): + """Like xarray.apply_ufunc but for jax-specific ufuncs. + + Many numpy ufuncs will work fine out of the box with xarray_jax and + JaxArrayWrapper, since JaxArrayWrapper quacks (mostly) like a numpy array and + will convert many numpy operations to jax ops under the hood. For these + situations, xarray.apply_ufunc should work fine. + + But sometimes you need a jax-specific ufunc which needs to be given a + jax array as input or return a jax array as output. In that case you should + use this helper as it will remove any JaxArrayWrapper before calling the func, + and wrap the result afterwards before handing it back to xarray. + + Args: + func: A function that works with jax arrays (e.g. using functions from + jax.numpy) but otherwise meets the spec for the func argument to + xarray.apply_ufunc. + *args: xarray arguments to be mapped to arguments for func + (see xarray.apply_ufunc). + require_jax: Whether to require that inputs are based on jax arrays or allow + those based on plain numpy arrays too. + **apply_ufunc_kwargs: See xarray.apply_ufunc. + + Returns: + Corresponding xarray results (see xarray.apply_ufunc). + """ + def wrapped_func(*maybe_wrapped_args): + unwrapped_args = [unwrap(a, require_jax) for a in maybe_wrapped_args] + result = func(*unwrapped_args) + # Result can be an array or a tuple of arrays, this handles both: + return jax.tree_util.tree_map(wrap, result) + return xarray.apply_ufunc(wrapped_func, *args, **apply_ufunc_kwargs) + + +def pmap(fn: Callable[..., Any], + dim: str, + axis_name: Optional[str] = None, + devices: ... = None, + backend: ... = None) -> Callable[..., Any]: + """Wraps a subset of jax.pmap functionality to handle xarray input/output. + + Constraints: + * Any Dataset or DataArray passed to the function must have `dim` as the + first dimension. This will be checked. You can ensure this if necessary + by calling `.transpose(dim, ...)` beforehand. + * All args and return values will be mapped over the first dimension, + it will use in_axes=0, out_axes=0. + * No support for static_broadcasted_argnums, donate_argnums etc. + + Args: + fn: Function to be pmap'd which takes and returns trees which may contain + xarray Dataset/DataArray. Any Dataset/DataArrays passed as input must use + `dim` as the first dimension on all arrays. + dim: The xarray dimension name corresponding to the first dimension that is + pmapped over (pmap is called with in_axes=0, out_axes=0). + axis_name: Used by jax to identify the mapped axis so that parallel + collectives can be applied. Defaults to same as `dim`. + devices: + backend: + See jax.pmap. + + Returns: + A pmap'd version of `fn`, which takes and returns Dataset/DataArray with an + extra leading dimension `dim` relative to what the original `fn` sees. + """ + input_treedef = None + output_treedef = None + + def fn_passed_to_pmap(*flat_args): + assert input_treedef is not None + # Inside the pmap the original first dimension will no longer be present: + def check_and_remove_leading_dim(dims): + try: + index = dims.index(dim) + except ValueError: + index = None + if index != 0: + raise ValueError(f'Expected dim {dim} at index 0, found at {index}.') + return dims[1:] + with dims_change_on_unflatten(check_and_remove_leading_dim): + args = jax.tree_util.tree_unflatten(input_treedef, flat_args) + result = fn(*args) + nonlocal output_treedef + flat_result, output_treedef = jax.tree_util.tree_flatten(result) + return flat_result + + pmapped_fn = jax.pmap( + fn_passed_to_pmap, + axis_name=axis_name or dim, + in_axes=0, + out_axes=0, + devices=devices, + backend=backend) + + def result_fn(*args): + nonlocal input_treedef + flat_args, input_treedef = jax.tree_util.tree_flatten(args) + flat_result = pmapped_fn(*flat_args) + assert output_treedef is not None + # After the pmap an extra leading axis will be present, we need to add an + # xarray dimension for this when unflattening the result: + with dims_change_on_unflatten(lambda dims: (dim,) + dims): + return jax.tree_util.tree_unflatten(output_treedef, flat_result) + + return result_fn + + +# Register xarray datatypes with jax.tree_util. + + +DimsChangeFn = Callable[[Tuple[Hashable, ...]], Tuple[Hashable, ...]] +_DIMS_CHANGE_ON_UNFLATTEN_FN: contextvars.ContextVar[DimsChangeFn] = ( + contextvars.ContextVar('dims_change_on_unflatten_fn')) + + +@contextlib.contextmanager +def dims_change_on_unflatten(dims_change_fn: DimsChangeFn): + """Can be used to change the dims used when unflattening arrays into xarrays. + + This is useful when some axes were added to / removed from the underlying jax + arrays after they were flattened using jax.tree_util.tree_flatten, and you + want to unflatten them again afterwards using the original treedef but + adjusted for the added/removed dimensions. + + It can also be used with jax.tree_util.tree_map, when it's called with a + function that adds/removes axes or otherwise changes the axis order. + + When dimensions are removed, any coordinates using those removed dimensions + will also be removed on unflatten. + + This is implemented as a context manager that sets some thread-local state + affecting the behaviour of our unflatten functions, because it's not possible + to directly modify the treedef to change the dims/coords in it (and with + tree_map, the treedef isn't exposed to you anyway). + + Args: + dims_change_fn: Maps a tuple of dimension names for the original + Variable/DataArray/Dataset that was flattened, to an updated tuple of + dimensions which should be used when unflattening. + + Yields: + To a context manager in whose scope jax.tree_util.tree_unflatten and + jax.tree_util.tree_map will apply the dims_change_fn before reconstructing + xarrays from jax arrays. + """ + token = _DIMS_CHANGE_ON_UNFLATTEN_FN.set(dims_change_fn) + try: + yield + finally: + _DIMS_CHANGE_ON_UNFLATTEN_FN.reset(token) + + +def _flatten_variable(v: xarray.Variable) -> Tuple[ + Tuple[jax.typing.ArrayLike], Tuple[Hashable, ...]]: + """Flattens a Variable for jax.tree_util.""" + children = (unwrap_data(v),) + aux = v.dims + return children, aux + + +def _unflatten_variable( + aux: Tuple[Hashable, ...], + children: Tuple[jax.typing.ArrayLike]) -> xarray.Variable: + """Unflattens a Variable for jax.tree_util.""" + dims = aux + dims_change_fn = _DIMS_CHANGE_ON_UNFLATTEN_FN.get(None) + if dims_change_fn: dims = dims_change_fn(dims) + return Variable(dims=dims, data=children[0]) + + +def _split_static_and_jax_coords( + coords: xarray.core.coordinates.Coordinates) -> Tuple[ + Mapping[Hashable, xarray.Variable], Mapping[Hashable, xarray.Variable]]: + static_coord_vars = {} + jax_coord_vars = {} + for name, coord in coords.items(): + if coord.attrs.get(_JAX_COORD_ATTR_NAME, False): + jax_coord_vars[name] = coord.variable + else: + assert not isinstance(coord, (jax.Array, JaxArrayWrapper)) + static_coord_vars[name] = coord.variable + return static_coord_vars, jax_coord_vars + + +def _drop_with_none_of_dims( + coord_vars: Mapping[Hashable, xarray.Variable], + dims: Tuple[Hashable]) -> Mapping[Hashable, xarray.Variable]: + return {name: var for name, var in coord_vars.items() + if set(var.dims) <= set(dims)} + + +class _HashableCoords(collections.abc.Mapping): + """Wraps a dict of xarray Variables as hashable, used for static coordinates. + + This needs to be hashable so that when an xarray.Dataset is passed to a + jax.jit'ed function, jax can check whether it's seen an array with the + same static coordinates(*) before or whether it needs to recompile the + function for the new values of the static coordinates. + + (*) note jax_coords are not included in this; their value can be different + on different calls without triggering a recompile. + """ + + def __init__(self, coord_vars: Mapping[Hashable, xarray.Variable]): + self._variables = coord_vars + + def __repr__(self) -> str: + return f'_HashableCoords({repr(self._variables)})' + + def __getitem__(self, key: Hashable) -> xarray.Variable: + return self._variables[key] + + def __len__(self) -> int: + return len(self._variables) + + def __iter__(self) -> Iterator[Hashable]: + return iter(self._variables) + + def __hash__(self): + if not hasattr(self, '_hash'): + self._hash = hash(frozenset((name, var.data.tobytes()) + for name, var in self._variables.items())) + return self._hash + + def __eq__(self, other): + if self is other: + return True + elif not isinstance(other, type(self)): + return NotImplemented + elif self._variables is other._variables: + return True + else: + return self._variables.keys() == other._variables.keys() and all( + variable.equals(other._variables[name]) + for name, variable in self._variables.items()) + + +def _flatten_data_array(v: xarray.DataArray) -> Tuple[ + # Children (data variable, jax_coord_vars): + Tuple[xarray.Variable, Mapping[Hashable, xarray.Variable]], + # Static auxiliary data (name, static_coord_vars): + Tuple[Optional[Hashable], _HashableCoords]]: + """Flattens a DataArray for jax.tree_util.""" + static_coord_vars, jax_coord_vars = _split_static_and_jax_coords(v.coords) + children = (v.variable, jax_coord_vars) + aux = (v.name, _HashableCoords(static_coord_vars)) + return children, aux + + +def _unflatten_data_array( + aux: Tuple[Optional[Hashable], _HashableCoords], + children: Tuple[xarray.Variable, Mapping[Hashable, xarray.Variable]], +) -> xarray.DataArray: + """Unflattens a DataArray for jax.tree_util.""" + variable, jax_coord_vars = children + name, static_coord_vars = aux + # Drop static coords which have dims not present in any of the data_vars. + # These would generally be dims that were dropped by a dims_change_fn, but + # because static coordinates don't go through dims_change_fn on unflatten, we + # just drop them where this causes a problem. + # Since jax_coords go through the dims_change_fn on unflatten we don't need + # to do this for jax_coords. + static_coord_vars = _drop_with_none_of_dims(static_coord_vars, variable.dims) + return DataArray( + variable, name=name, coords=static_coord_vars, jax_coords=jax_coord_vars) + + +def _flatten_dataset(dataset: xarray.Dataset) -> Tuple[ + # Children (data variables, jax_coord_vars): + Tuple[Mapping[Hashable, xarray.Variable], + Mapping[Hashable, xarray.Variable]], + # Static auxiliary data (static_coord_vars): + _HashableCoords]: + """Flattens a Dataset for jax.tree_util.""" + variables = {name: data_array.variable + for name, data_array in dataset.data_vars.items()} + static_coord_vars, jax_coord_vars = _split_static_and_jax_coords( + dataset.coords) + children = (variables, jax_coord_vars) + aux = _HashableCoords(static_coord_vars) + return children, aux + + +def _unflatten_dataset( + aux: _HashableCoords, + children: Tuple[Mapping[Hashable, xarray.Variable], + Mapping[Hashable, xarray.Variable]], + ) -> xarray.Dataset: + """Unflattens a Dataset for jax.tree_util.""" + data_vars, jax_coord_vars = children + static_coord_vars = aux + dataset = xarray.Dataset(data_vars) + # Drop static coords which have dims not present in any of the data_vars. + # See corresponding comment in _unflatten_data_array. + static_coord_vars = _drop_with_none_of_dims(static_coord_vars, dataset.dims) + return assign_coords( + dataset, coords=static_coord_vars, jax_coords=jax_coord_vars) + + +jax.tree_util.register_pytree_node( + xarray.Variable, _flatten_variable, _unflatten_variable) +# This is a subclass of Variable but still needs registering separately. +# Flatten/unflatten for IndexVariable is a bit of a corner case but we do +# need to support it. +jax.tree_util.register_pytree_node( + xarray.IndexVariable, _flatten_variable, _unflatten_variable) +jax.tree_util.register_pytree_node( + xarray.DataArray, _flatten_data_array, _unflatten_data_array) +jax.tree_util.register_pytree_node( + xarray.Dataset, _flatten_dataset, _unflatten_dataset) diff --git a/graphcast/xarray_jax_test.py b/graphcast/xarray_jax_test.py new file mode 100644 index 0000000..6189ec9 --- /dev/null +++ b/graphcast/xarray_jax_test.py @@ -0,0 +1,526 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for xarray_jax.""" + +from absl.testing import absltest +import chex +from graphcast import xarray_jax +import jax +import jax.numpy as jnp +import numpy as np +import xarray + + +class XarrayJaxTest(absltest.TestCase): + + def test_jax_array_wrapper_with_numpy_api(self): + # This is just a side benefit of making things work with xarray, but the + # JaxArrayWrapper does allow you to manipulate JAX arrays using the + # standard numpy API, without converting them to numpy in the process: + ones = jnp.ones((3, 4), dtype=np.float32) + x = xarray_jax.JaxArrayWrapper(ones) + x = np.abs((x + 2) * (x - 3)) + x = x[:-1, 1:3] + x = np.concatenate([x, x + 1], axis=0) + x = np.transpose(x, (1, 0)) + x = np.reshape(x, (-1,)) + x = x.astype(np.int32) + self.assertIsInstance(x, xarray_jax.JaxArrayWrapper) + # An explicit conversion gets us out of JAX-land however: + self.assertIsInstance(np.asarray(x), np.ndarray) + + def test_jax_xarray_variable(self): + def ops_via_xarray(inputs): + x = xarray_jax.Variable(('lat', 'lon'), inputs) + # We'll apply a sequence of operations just to test that the end result is + # still a JAX array, i.e. we haven't converted to numpy at any point. + x = np.abs((x + 2) * (x - 3)) + x = x.isel({'lat': slice(0, -1), 'lon': slice(1, 3)}) + x = xarray.Variable.concat([x, x + 1], dim='lat') + x = x.transpose('lon', 'lat') + x = x.stack(channels=('lon', 'lat')) + x = x.sum() + return xarray_jax.jax_data(x) + + # Check it doesn't leave jax-land when passed concrete values: + ones = jnp.ones((3, 4), dtype=np.float32) + result = ops_via_xarray(ones) + self.assertIsInstance(result, jax.Array) + + # And that you can JIT it and compute gradients through it. These will + # involve passing jax tracers through the xarray computation: + jax.jit(ops_via_xarray)(ones) + jax.grad(ops_via_xarray)(ones) + + def test_jax_xarray_data_array(self): + def ops_via_xarray(inputs): + x = xarray_jax.DataArray(dims=('lat', 'lon'), + data=inputs, + coords={'lat': np.arange(3) * 10, + 'lon': np.arange(4) * 10}) + x = np.abs((x + 2) * (x - 3)) + x = x.sel({'lat': slice(0, 20)}) + y = xarray_jax.DataArray(dims=('lat', 'lon'), + data=ones, + coords={'lat': np.arange(3, 6) * 10, + 'lon': np.arange(4) * 10}) + x = xarray.concat([x, y], dim='lat') + x = x.transpose('lon', 'lat') + x = x.stack(channels=('lon', 'lat')) + x = x.unstack() + x = x.sum() + return xarray_jax.jax_data(x) + + ones = jnp.ones((3, 4), dtype=np.float32) + result = ops_via_xarray(ones) + self.assertIsInstance(result, jax.Array) + + jax.jit(ops_via_xarray)(ones) + jax.grad(ops_via_xarray)(ones) + + def test_jax_xarray_dataset(self): + def ops_via_xarray(foo, bar): + x = xarray_jax.Dataset( + data_vars={'foo': (('lat', 'lon'), foo), + 'bar': (('time', 'lat', 'lon'), bar)}, + coords={ + 'time': np.arange(2), + 'lat': np.arange(3) * 10, + 'lon': np.arange(4) * 10}) + x = np.abs((x + 2) * (x - 3)) + x = x.sel({'lat': slice(0, 20)}) + y = xarray_jax.Dataset( + data_vars={'foo': (('lat', 'lon'), foo), + 'bar': (('time', 'lat', 'lon'), bar)}, + coords={ + 'time': np.arange(2), + 'lat': np.arange(3, 6) * 10, + 'lon': np.arange(4) * 10}) + x = xarray.concat([x, y], dim='lat') + x = x.transpose('lon', 'lat', 'time') + x = x.stack(channels=('lon', 'lat')) + x = (x.foo + x.bar).sum() + return xarray_jax.jax_data(x) + + foo = jnp.ones((3, 4), dtype=np.float32) + bar = jnp.ones((2, 3, 4), dtype=np.float32) + result = ops_via_xarray(foo, bar) + self.assertIsInstance(result, jax.Array) + + jax.jit(ops_via_xarray)(foo, bar) + jax.grad(ops_via_xarray)(foo, bar) + + def test_jit_function_with_xarray_variable_arguments_and_return(self): + function = jax.jit(lambda v: v + 1) + with self.subTest('jax input'): + inputs = xarray_jax.Variable( + ('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32)) + _ = function(inputs) + # We test running the jitted function a second time, to exercise logic in + # jax which checks if the structure of the inputs (including dimension + # names and coordinates) is the same as it was for the previous call and + # so whether it needs to re-trace-and-compile a new version of the + # function or not. This can run into problems if the 'aux' structure + # returned by the registered flatten function is not hashable/comparable. + outputs = function(inputs) + self.assertEqual(outputs.dims, inputs.dims) + with self.subTest('numpy input'): + inputs = xarray.Variable( + ('lat', 'lon'), np.ones((3, 4), dtype=np.float32)) + _ = function(inputs) + outputs = function(inputs) + self.assertEqual(outputs.dims, inputs.dims) + + def test_jit_problem_if_convert_to_plain_numpy_array(self): + inputs = xarray_jax.DataArray( + data=jnp.ones((2,), dtype=np.float32), dims=('foo',)) + with self.assertRaises(jax.errors.TracerArrayConversionError): + # Calling .values on a DataArray converts its values to numpy: + jax.jit(lambda data_array: data_array.values)(inputs) + + def test_grad_function_with_xarray_variable_arguments(self): + x = xarray_jax.Variable(('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32)) + # For grad we still need a JAX scalar as the output: + jax.grad(lambda v: xarray_jax.jax_data(v.sum()))(x) + + def test_jit_function_with_xarray_data_array_arguments_and_return(self): + inputs = xarray_jax.DataArray( + data=jnp.ones((3, 4), dtype=np.float32), + dims=('lat', 'lon'), + coords={'lat': np.arange(3), + 'lon': np.arange(4) * 10}) + fn = jax.jit(lambda v: v + 1) + _ = fn(inputs) + outputs = fn(inputs) + self.assertEqual(outputs.dims, inputs.dims) + chex.assert_trees_all_equal(outputs.coords, inputs.coords) + + def test_jit_function_with_data_array_and_jax_coords(self): + inputs = xarray_jax.DataArray( + data=jnp.ones((3, 4), dtype=np.float32), + dims=('lat', 'lon'), + coords={'lat': np.arange(3)}, + jax_coords={'lon': jnp.arange(4) * 10}) + # Verify the jax_coord 'lon' retains jax data, and has not been created + # as an index coordinate: + self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper) + self.assertNotIn('lon', inputs.indexes) + + @jax.jit + def fn(v): + # The non-JAX coord is passed with numpy array data and an index: + self.assertIsInstance(v.coords['lat'].data, np.ndarray) + self.assertIn('lat', v.indexes) + + # The jax_coord is passed with JAX array data: + self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper) + self.assertNotIn('lon', v.indexes) + + # Use the jax coord in the computation: + v = v + v.coords['lon'] + + # Return with an updated jax coord: + return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1) + + _ = fn(inputs) + outputs = fn(inputs) + + # Verify the jax_coord 'lon' has jax data in the output too: + self.assertIsInstance( + outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper) + self.assertNotIn('lon', outputs.indexes) + + self.assertEqual(outputs.dims, inputs.dims) + chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat']) + # Check our computations with the coordinate values worked: + chex.assert_trees_all_equal( + outputs.coords['lon'].data, (inputs.coords['lon']+1).data) + chex.assert_trees_all_equal( + outputs.data, (inputs + inputs.coords['lon']).data) + + def test_jit_function_with_xarray_dataset_arguments_and_return(self): + foo = jnp.ones((3, 4), dtype=np.float32) + bar = jnp.ones((2, 3, 4), dtype=np.float32) + inputs = xarray_jax.Dataset( + data_vars={'foo': (('lat', 'lon'), foo), + 'bar': (('time', 'lat', 'lon'), bar)}, + coords={ + 'time': np.arange(2), + 'lat': np.arange(3) * 10, + 'lon': np.arange(4) * 10}) + fn = jax.jit(lambda v: v + 1) + _ = fn(inputs) + outputs = fn(inputs) + self.assertEqual({'foo', 'bar'}, outputs.data_vars.keys()) + self.assertEqual(inputs.foo.dims, outputs.foo.dims) + self.assertEqual(inputs.bar.dims, outputs.bar.dims) + chex.assert_trees_all_equal(outputs.coords, inputs.coords) + + def test_jit_function_with_dataset_and_jax_coords(self): + foo = jnp.ones((3, 4), dtype=np.float32) + bar = jnp.ones((2, 3, 4), dtype=np.float32) + inputs = xarray_jax.Dataset( + data_vars={'foo': (('lat', 'lon'), foo), + 'bar': (('time', 'lat', 'lon'), bar)}, + coords={ + 'time': np.arange(2), + 'lat': np.arange(3) * 10, + }, + jax_coords={'lon': jnp.arange(4) * 10} + ) + # Verify the jax_coord 'lon' retains jax data, and has not been created + # as an index coordinate: + self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper) + self.assertNotIn('lon', inputs.indexes) + + @jax.jit + def fn(v): + # The non-JAX coords are passed with numpy array data and an index: + self.assertIsInstance(v.coords['lat'].data, np.ndarray) + self.assertIn('lat', v.indexes) + + # The jax_coord is passed with JAX array data: + self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper) + self.assertNotIn('lon', v.indexes) + + # Use the jax coord in the computation: + v = v + v.coords['lon'] + + # Return with an updated jax coord: + return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1) + + _ = fn(inputs) + outputs = fn(inputs) + + # Verify the jax_coord 'lon' has jax data in the output too: + self.assertIsInstance( + outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper) + self.assertNotIn('lon', outputs.indexes) + + self.assertEqual(outputs.dims, inputs.dims) + chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat']) + # Check our computations with the coordinate values worked: + chex.assert_trees_all_equal( + (outputs.coords['lon']).data, + (inputs.coords['lon']+1).data, + ) + outputs_dict = {key: outputs[key].data for key in outputs} + inputs_and_inputs_coords_dict = { + key: (inputs + inputs.coords['lon'])[key].data + for key in inputs + inputs.coords['lon'] + } + chex.assert_trees_all_equal(outputs_dict, inputs_and_inputs_coords_dict) + + def test_flatten_unflatten_variable(self): + variable = xarray_jax.Variable( + ('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32)) + children, aux = xarray_jax._flatten_variable(variable) + # Check auxiliary info is hashable/comparable (important for jax.jit): + hash(aux) + self.assertEqual(aux, aux) + roundtrip = xarray_jax._unflatten_variable(aux, children) + self.assertTrue(variable.equals(roundtrip)) + + def test_flatten_unflatten_data_array(self): + data_array = xarray_jax.DataArray( + data=jnp.ones((3, 4), dtype=np.float32), + dims=('lat', 'lon'), + coords={'lat': np.arange(3)}, + jax_coords={'lon': np.arange(4) * 10}, + ) + children, aux = xarray_jax._flatten_data_array(data_array) + # Check auxiliary info is hashable/comparable (important for jax.jit): + hash(aux) + self.assertEqual(aux, aux) + roundtrip = xarray_jax._unflatten_data_array(aux, children) + self.assertTrue(data_array.equals(roundtrip)) + + def test_flatten_unflatten_dataset(self): + foo = jnp.ones((3, 4), dtype=np.float32) + bar = jnp.ones((2, 3, 4), dtype=np.float32) + dataset = xarray_jax.Dataset( + data_vars={'foo': (('lat', 'lon'), foo), + 'bar': (('time', 'lat', 'lon'), bar)}, + coords={ + 'time': np.arange(2), + 'lat': np.arange(3) * 10}, + jax_coords={ + 'lon': np.arange(4) * 10}) + children, aux = xarray_jax._flatten_dataset(dataset) + # Check auxiliary info is hashable/comparable (important for jax.jit): + hash(aux) + self.assertEqual(aux, aux) + roundtrip = xarray_jax._unflatten_dataset(aux, children) + self.assertTrue(dataset.equals(roundtrip)) + + def test_flatten_unflatten_added_dim(self): + data_array = xarray_jax.DataArray( + data=jnp.ones((3, 4), dtype=np.float32), + dims=('lat', 'lon'), + coords={'lat': np.arange(3), + 'lon': np.arange(4) * 10}) + leaves, treedef = jax.tree_util.tree_flatten(data_array) + leaves = [jnp.expand_dims(x, 0) for x in leaves] + with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims): + with_new_dim = jax.tree_util.tree_unflatten(treedef, leaves) + self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims) + xarray.testing.assert_identical( + jax.device_get(data_array), + jax.device_get(with_new_dim.isel(new=0))) + + def test_map_added_dim(self): + data_array = xarray_jax.DataArray( + data=jnp.ones((3, 4), dtype=np.float32), + dims=('lat', 'lon'), + coords={'lat': np.arange(3), + 'lon': np.arange(4) * 10}) + with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims): + with_new_dim = jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, 0), + data_array) + self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims) + xarray.testing.assert_identical( + jax.device_get(data_array), + jax.device_get(with_new_dim.isel(new=0))) + + def test_map_remove_dim(self): + foo = jnp.ones((1, 3, 4), dtype=np.float32) + bar = jnp.ones((1, 2, 3, 4), dtype=np.float32) + dataset = xarray_jax.Dataset( + data_vars={'foo': (('batch', 'lat', 'lon'), foo), + 'bar': (('batch', 'time', 'lat', 'lon'), bar)}, + coords={ + 'batch': np.array([123]), + 'time': np.arange(2), + 'lat': np.arange(3) * 10, + 'lon': np.arange(4) * 10}) + with xarray_jax.dims_change_on_unflatten(lambda dims: dims[1:]): + with_removed_dim = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, 0), + dataset) + self.assertEqual(('lat', 'lon'), with_removed_dim['foo'].dims) + self.assertEqual(('time', 'lat', 'lon'), with_removed_dim['bar'].dims) + self.assertNotIn('batch', with_removed_dim.dims) + self.assertNotIn('batch', with_removed_dim.coords) + xarray.testing.assert_identical( + jax.device_get(dataset.isel(batch=0, drop=True)), + jax.device_get(with_removed_dim)) + + def test_pmap(self): + devices = jax.local_device_count() + foo = jnp.zeros((devices, 3, 4), dtype=np.float32) + bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32) + dataset = xarray_jax.Dataset({ + 'foo': (('device', 'lat', 'lon'), foo), + 'bar': (('device', 'time', 'lat', 'lon'), bar)}) + + def func(d): + self.assertNotIn('device', d.dims) + return d + 1 + func = xarray_jax.pmap(func, dim='device') + + result = func(dataset) + xarray.testing.assert_identical( + jax.device_get(dataset + 1), + jax.device_get(result)) + + # Can call it again with a different argument structure (it will recompile + # under the hood but should work): + dataset = dataset.drop_vars('foo') + result = func(dataset) + xarray.testing.assert_identical( + jax.device_get(dataset + 1), + jax.device_get(result)) + + def test_pmap_with_jax_coords(self): + devices = jax.local_device_count() + foo = jnp.zeros((devices, 3, 4), dtype=np.float32) + bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32) + time = jnp.zeros((devices, 2), dtype=np.float32) + dataset = xarray_jax.Dataset( + {'foo': (('device', 'lat', 'lon'), foo), + 'bar': (('device', 'time', 'lat', 'lon'), bar)}, + coords={ + 'lat': np.arange(3), + 'lon': np.arange(4), + }, + jax_coords={ + # Currently any jax_coords need a leading device dimension to use + # with pmap, same as for data_vars. + # TODO(matthjw): have pmap automatically broadcast to all devices + # where the device dimension not present. + 'time': xarray_jax.Variable(('device', 'time'), time), + } + ) + + def func(d): + self.assertNotIn('device', d.dims) + self.assertNotIn('device', d.coords['time'].dims) + + # The jax_coord 'time' should be passed in backed by a JAX array, but + # not as an index coordinate. + self.assertIsInstance(d.coords['time'].data, xarray_jax.JaxArrayWrapper) + self.assertNotIn('time', d.indexes) + + return d + 1 + func = xarray_jax.pmap(func, dim='device') + + result = func(dataset) + xarray.testing.assert_identical( + jax.device_get(dataset + 1), + jax.device_get(result)) + + # Can call it again with a different argument structure (it will recompile + # under the hood but should work): + dataset = dataset.drop_vars('foo') + result = func(dataset) + xarray.testing.assert_identical( + jax.device_get(dataset + 1), + jax.device_get(result)) + + def test_pmap_with_tree_mix_of_xarray_and_jax_array(self): + devices = jax.local_device_count() + data_array = xarray_jax.DataArray( + data=jnp.ones((devices, 3, 4), dtype=np.float32), + dims=('device', 'lat', 'lon')) + plain_array = jnp.ones((devices, 2), dtype=np.float32) + inputs = {'foo': data_array, + 'bar': plain_array} + + def func(x): + return x['foo'] + 1, x['bar'] + 1 + + func = xarray_jax.pmap(func, dim='device') + result_foo, result_bar = func(inputs) + xarray.testing.assert_identical( + jax.device_get(inputs['foo'] + 1), + jax.device_get(result_foo)) + np.testing.assert_array_equal( + jax.device_get(inputs['bar'] + 1), + jax.device_get(result_bar)) + + def test_pmap_complains_when_dim_not_first(self): + devices = jax.local_device_count() + data_array = xarray_jax.DataArray( + data=jnp.ones((3, devices, 4), dtype=np.float32), + dims=('lat', 'device', 'lon')) + + func = xarray_jax.pmap(lambda x: x+1, dim='device') + with self.assertRaisesRegex( + ValueError, 'Expected dim device at index 0, found at 1'): + func(data_array) + + def test_apply_ufunc(self): + inputs = xarray_jax.DataArray( + data=jnp.asarray([[1, 2], [3, 4]]), + dims=('x', 'y'), + coords={'x': [0, 1], + 'y': [2, 3]}) + result = xarray_jax.apply_ufunc( + lambda x: jnp.sum(x, axis=-1), + inputs, + input_core_dims=[['x']]) + expected_result = xarray_jax.DataArray( + data=[4, 6], + dims=('y',), + coords={'y': [2, 3]}) + xarray.testing.assert_identical(expected_result, jax.device_get(result)) + + def test_apply_ufunc_multiple_return_values(self): + def ufunc(array): + return jnp.min(array, axis=-1), jnp.max(array, axis=-1) + inputs = xarray_jax.DataArray( + data=jnp.asarray([[1, 4], [3, 2]]), + dims=('x', 'y'), + coords={'x': [0, 1], + 'y': [2, 3]}) + result = xarray_jax.apply_ufunc( + ufunc, inputs, input_core_dims=[['x']], output_core_dims=[[], []]) + expected = ( + # Mins: + xarray_jax.DataArray( + data=[1, 2], + dims=('y',), + coords={'y': [2, 3]} + ), + # Maxes: + xarray_jax.DataArray( + data=[3, 4], + dims=('y',), + coords={'y': [2, 3]} + ) + ) + xarray.testing.assert_identical(expected[0], jax.device_get(result[0])) + xarray.testing.assert_identical(expected[1], jax.device_get(result[1])) + +if __name__ == '__main__': + absltest.main() diff --git a/graphcast/xarray_tree.py b/graphcast/xarray_tree.py new file mode 100644 index 0000000..e8854ef --- /dev/null +++ b/graphcast/xarray_tree.py @@ -0,0 +1,70 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for working with trees of xarray.DataArray (including Datasets). + +Note that xarray.Dataset doesn't work out-of-the-box with the `tree` library; +it won't work as a leaf node since it implements Mapping, but also won't work +as an internal node since tree doesn't know how to re-create it properly. + +To fix this, we reimplement a subset of `map_structure`, exposing its +constituent DataArrays as leaf nodes. This means it can be mapped over as a +generic container of DataArrays, while still preserving the result as a Dataset +where possible. + +This is useful because in a few places we need to handle a general +Mapping[str, DataArray] (where the coordinates might not be compatible across +the constituent DataArrays) but also the special case of a Dataset nicely. + +For the result e.g. of a tree.map_structure(fn, dataset), if fn returns None for +some of the child DataArrays, they will be omitted from the returned dataset. If +any values other than DataArrays or None are returned, then we don't attempt to +return a Dataset and just return a plain dict of the results. Similarly if +DataArrays are returned but with non-matching coordinates, it will just return a +plain dict of DataArrays. + +Note xarray datatypes are registered with `jax.tree_util` by xarray_jax.py, +but `jax.tree_util.tree_map` is distinct from the `xarray_tree.map_structure`. +as the former exposes the underlying JAX/numpy arrays as leaf nodes, while the +latter exposes DataArrays as leaf nodes. +""" + +from typing import Any, Callable + +import xarray + + +def map_structure(func: Callable[..., Any], *structures: Any) -> Any: + """Maps func through given structures with xarrays. See tree.map_structure.""" + if not callable(func): + raise TypeError(f'func must be callable, got: {func}') + if not structures: + raise ValueError('Must provide at least one structure') + + first = structures[0] + if isinstance(first, xarray.Dataset): + data = {k: func(*[s[k] for s in structures]) for k in first.keys()} + if all(isinstance(a, (type(None), xarray.DataArray)) + for a in data.values()): + data_arrays = [v.rename(k) for k, v in data.items() if v is not None] + try: + return xarray.merge(data_arrays, join='exact') + except ValueError: # Exact join not possible. + pass + return data + if isinstance(first, dict): + return {k: map_structure(func, *[s[k] for s in structures]) + for k in first.keys()} + if isinstance(first, (list, tuple, set)): + return type(first)(map_structure(func, *s) for s in zip(*structures)) + return func(*structures) diff --git a/graphcast/xarray_tree_test.py b/graphcast/xarray_tree_test.py new file mode 100644 index 0000000..b94b11d --- /dev/null +++ b/graphcast/xarray_tree_test.py @@ -0,0 +1,95 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for xarray_tree.""" + +from absl.testing import absltest +from graphcast import xarray_tree +import numpy as np +import xarray + + +TEST_DATASET = xarray.Dataset( + data_vars={ + "foo": (("x", "y"), np.zeros((2, 3))), + "bar": (("x",), np.zeros((2,))), + }, + coords={ + "x": [1, 2], + "y": [10, 20, 30], + } +) + + +class XarrayTreeTest(absltest.TestCase): + + def test_map_structure_maps_over_leaves_but_preserves_dataset_type(self): + def fn(leaf): + self.assertIsInstance(leaf, xarray.DataArray) + result = leaf + 1 + # Removing the name from the returned DataArray to test that we don't rely + # on it being present to restore the correct names in the result: + result = result.rename(None) + return result + + result = xarray_tree.map_structure(fn, TEST_DATASET) + self.assertIsInstance(result, xarray.Dataset) + self.assertSameElements({"foo", "bar"}, result.keys()) + + def test_map_structure_on_data_arrays(self): + data_arrays = dict(TEST_DATASET) + result = xarray_tree.map_structure(lambda x: x+1, data_arrays) + self.assertIsInstance(result, dict) + self.assertSameElements({"foo", "bar"}, result.keys()) + + def test_map_structure_on_dataset_plain_dict_when_coords_incompatible(self): + def fn(leaf): + # Returns DataArrays that can't be exactly merged back into a Dataset + # due to the coordinates not matching: + if leaf.name == "foo": + return xarray.DataArray( + data=np.zeros(2), dims=("x",), coords={"x": [1, 2]}) + else: + return xarray.DataArray( + data=np.zeros(2), dims=("x",), coords={"x": [3, 4]}) + + result = xarray_tree.map_structure(fn, TEST_DATASET) + self.assertIsInstance(result, dict) + self.assertSameElements({"foo", "bar"}, result.keys()) + + def test_map_structure_on_dataset_drops_vars_with_none_return_values(self): + def fn(leaf): + return leaf if leaf.name == "foo" else None + + result = xarray_tree.map_structure(fn, TEST_DATASET) + self.assertIsInstance(result, xarray.Dataset) + self.assertSameElements({"foo"}, result.keys()) + + def test_map_structure_on_dataset_returns_plain_dict_other_return_types(self): + def fn(leaf): + self.assertIsInstance(leaf, xarray.DataArray) + return "not a DataArray" + + result = xarray_tree.map_structure(fn, TEST_DATASET) + self.assertEqual({"foo": "not a DataArray", + "bar": "not a DataArray"}, result) + + def test_map_structure_two_args_different_variable_orders(self): + dataset_different_order = TEST_DATASET[["bar", "foo"]] + def fn(arg1, arg2): + self.assertEqual(arg1.name, arg2.name) + xarray_tree.map_structure(fn, TEST_DATASET, dataset_different_order) + + +if __name__ == "__main__": + absltest.main() diff --git a/graphcast_demo.ipynb b/graphcast_demo.ipynb new file mode 100644 index 0000000..723b6ac --- /dev/null +++ b/graphcast_demo.ipynb @@ -0,0 +1,857 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "-jAYlxeKxvAJ" + }, + "source": [ + "# GraphCast\n", + "\n", + "This colab lets you run several versions of GraphCast.\n", + "\n", + "The model weights, normalization statistics, and example inputs are available on [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast).\n", + "\n", + "A Colab runtime with TPU/GPU acceleration will substantially speed up generating predictions and computing the loss/gradients. If you're using a CPU-only runtime, you can switch using the menu \"Runtime \u003e Change runtime type\"." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IIWlNRupdI2i" + }, + "source": [ + "\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eCopyright 2023 DeepMind Technologies Limited.\u003c/small\u003e\u003c/p\u003e\n", + "\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at \u003ca href=\"http://www.apache.org/licenses/LICENSE-2.0\"\u003ehttp://www.apache.org/licenses/LICENSE-2.0\u003c/a\u003e.\u003c/small\u003e\u003c/small\u003e\u003c/p\u003e\n", + "\u003e \u003cp\u003e\u003csmall\u003e\u003csmall\u003eUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\u003c/small\u003e\u003c/small\u003e\u003c/p\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yMbbXFl4msJw" + }, + "source": [ + "# Installation and Initialization\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "-W4K9skv9vh-" + }, + "outputs": [], + "source": [ + "# @title Pip install graphcast and dependencies\n", + "\n", + "%pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "MA5087Vb29z2" + }, + "outputs": [], + "source": [ + "# @title Workaround for cartopy crashes\n", + "\n", + "# Workaround for cartopy crashes due to the shapely installed by default in\n", + "# google colab kernel (https://github.com/anitagraser/movingpandas/issues/81):\n", + "!pip uninstall -y shapely\n", + "!pip install shapely --no-binary shapely" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Z_j8ej4Pyg1L" + }, + "outputs": [], + "source": [ + "# @title Imports\n", + "\n", + "import dataclasses\n", + "import datetime\n", + "import functools\n", + "import math\n", + "import re\n", + "from typing import Optional\n", + "\n", + "import cartopy.crs as ccrs\n", + "from google.cloud import storage\n", + "from graphcast import autoregressive\n", + "from graphcast import casting\n", + "from graphcast import checkpoint\n", + "from graphcast import data_utils\n", + "from graphcast import graphcast\n", + "from graphcast import normalization\n", + "from graphcast import rollout\n", + "from graphcast import xarray_jax\n", + "from graphcast import xarray_tree\n", + "from IPython.display import HTML\n", + "import ipywidgets as widgets\n", + "import haiku as hk\n", + "import jax\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import animation\n", + "import numpy as np\n", + "import xarray\n", + "\n", + "\n", + "def parse_file_parts(file_name):\n", + " return dict(part.split(\"-\", 1) for part in file_name.split(\"_\"))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "4wagX1TL_f15" + }, + "outputs": [], + "source": [ + "# @title Authenticate with Google Cloud Storage\n", + "\n", + "# TODO: Figure out how to access a public cloud bucket without authentication.\n", + "from google.colab import auth\n", + "auth.authenticate_user()\n", + "\n", + "gcs_client = storage.Client()\n", + "gcs_bucket = gcs_client.get_bucket(\"dm_graphcast\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "5JUymx84dI2m" + }, + "outputs": [], + "source": [ + "# @title Plotting functions\n", + "\n", + "def select(\n", + " data: xarray.Dataset,\n", + " variable: str,\n", + " level: Optional[int] = None,\n", + " max_steps: Optional[int] = None\n", + " ) -\u003e xarray.Dataset:\n", + " data = data[variable]\n", + " if \"batch\" in data.dims:\n", + " data = data.isel(batch=0)\n", + " if max_steps is not None and \"time\" in data.sizes and max_steps \u003c data.sizes[\"time\"]:\n", + " data = data.isel(time=range(0, max_steps))\n", + " if level is not None and \"level\" in data.coords:\n", + " data = data.sel(level=level)\n", + " return data\n", + "\n", + "def scale(\n", + " data: xarray.Dataset,\n", + " center: Optional[float] = None,\n", + " robust: bool = False,\n", + " ) -\u003e tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:\n", + " vmin = np.nanpercentile(data, (2 if robust else 0))\n", + " vmax = np.nanpercentile(data, (98 if robust else 100))\n", + " if center is not None:\n", + " diff = max(vmax - center, center - vmin)\n", + " vmin = center - diff\n", + " vmax = center + diff\n", + " return (data, matplotlib.colors.Normalize(vmin, vmax),\n", + " (\"RdBu_r\" if center is not None else \"viridis\"))\n", + "\n", + "def plot_data(\n", + " data: dict[str, xarray.Dataset],\n", + " fig_title: str,\n", + " plot_size: float = 5,\n", + " robust: bool = False,\n", + " cols: int = 4\n", + " ) -\u003e tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:\n", + "\n", + " first_data = next(iter(data.values()))[0]\n", + " max_steps = first_data.sizes.get(\"time\", 1)\n", + " assert all(max_steps == d.sizes.get(\"time\", 1) for d, _, _ in data.values())\n", + "\n", + " cols = min(cols, len(data))\n", + " rows = math.ceil(len(data) / cols)\n", + " figure = plt.figure(figsize=(plot_size * 2 * cols,\n", + " plot_size * rows))\n", + " figure.suptitle(fig_title, fontsize=16)\n", + " figure.subplots_adjust(wspace=0, hspace=0)\n", + " figure.tight_layout()\n", + "\n", + " images = []\n", + " for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):\n", + " ax = figure.add_subplot(rows, cols, i+1)\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])\n", + " ax.set_title(title)\n", + " im = ax.imshow(\n", + " plot_data.isel(time=0, missing_dims=\"ignore\"), norm=norm,\n", + " origin=\"lower\", cmap=cmap)\n", + " plt.colorbar(\n", + " mappable=im,\n", + " ax=ax,\n", + " orientation=\"vertical\",\n", + " pad=0.02,\n", + " aspect=16,\n", + " shrink=0.75,\n", + " cmap=cmap,\n", + " extend=(\"both\" if robust else \"neither\"))\n", + " images.append(im)\n", + "\n", + " def update(frame):\n", + " if \"time\" in first_data.dims:\n", + " td = datetime.timedelta(microseconds=first_data[\"time\"][frame].item() / 1000)\n", + " figure.suptitle(f\"{fig_title}, {td}\", fontsize=16)\n", + " else:\n", + " figure.suptitle(fig_title, fontsize=16)\n", + " for im, (plot_data, norm, cmap) in zip(images, data.values()):\n", + " im.set_data(plot_data.isel(time=frame, missing_dims=\"ignore\"))\n", + "\n", + " ani = animation.FuncAnimation(\n", + " fig=figure, func=update, frames=max_steps, interval=250)\n", + " plt.close(figure.number)\n", + " return HTML(ani.to_jshtml())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WEtSV8HEkHtf" + }, + "source": [ + "# Load the Data and initialize the model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G50ORsY_dI2n" + }, + "source": [ + "## Load the model params\n", + "\n", + "Choose one of the two ways of getting model params:\n", + "- **random**: You'll get random predictions, but you can change the model architecture, which may run faster or fit on your device.\n", + "- **checkpoint**: You'll get sensible predictions, but are limited to the model architecture that it was trained with, which may not fit on your device. In particular generating gradients uses a lot of memory, so you'll need at least 25GB of ram (TPUv4 or A100).\n", + "\n", + "Checkpoints vary across a few axes:\n", + "- The mesh size specifies the internal graph representation of the earth. Smaller meshes will run faster but will have worse outputs. The mesh size does not affect the number of parameters of the model.\n", + "- The resolution and number of pressure levels must match the data. Lower resolution and fewer levels will run a bit faster. Data resolution only affects the encoder/decoder.\n", + "- All our models predict precipitation. However, ERA5 includes precipitation, while HRES does not. Our models marked as \"ERA5\" take precipitation as input and expect ERA5 data as input, while model marked \"ERA5-HRES\" do not take precipitation as input and are specifically trained to take HRES-fc0 as input (see the data section below).\n", + "\n", + "We provide three pre-trained models.\n", + "1. `GraphCast`, the high-resolution model used in the GraphCast paper (0.25 degree resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017,\n", + "\n", + "2. `GraphCast_small`, a smaller, low-resolution version of GraphCast (1 degree resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from 1979 to 2015, useful to run a model with lower memory and compute constraints,\n", + "\n", + "3. `GraphCast_operational`, a high-resolution model (0.25 degree resolution, 13 pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on HRES data from 2016 to 2021. This model can be initialized from HRES data (does not require precipitation inputs).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "KGaJ6V9MdI2n" + }, + "outputs": [], + "source": [ + "# @title Choose the model\n", + "\n", + "params_file_options = [\n", + " name for blob in gcs_bucket.list_blobs(prefix=\"params/\")\n", + " if (name := blob.name.removeprefix(\"params/\"))] # Drop empty string.\n", + "\n", + "random_mesh_size = widgets.IntSlider(\n", + " value=4, min=4, max=6, description=\"Mesh size:\")\n", + "random_gnn_msg_steps = widgets.IntSlider(\n", + " value=4, min=1, max=32, description=\"GNN message steps:\")\n", + "random_latent_size = widgets.Dropdown(\n", + " options=[int(2**i) for i in range(4, 10)], value=32,description=\"Latent size:\")\n", + "random_levels = widgets.Dropdown(\n", + " options=[13, 37], value=13, description=\"Pressure levels:\")\n", + "\n", + "\n", + "params_file = widgets.Dropdown(\n", + " options=params_file_options,\n", + " description=\"Params file:\",\n", + " layout={\"width\": \"max-content\"})\n", + "\n", + "source_tab = widgets.Tab([\n", + " widgets.VBox([\n", + " random_mesh_size,\n", + " random_gnn_msg_steps,\n", + " random_latent_size,\n", + " random_levels,\n", + " ]),\n", + " params_file,\n", + "])\n", + "source_tab.set_title(0, \"Random\")\n", + "source_tab.set_title(1, \"Checkpoint\")\n", + "widgets.VBox([\n", + " source_tab,\n", + " widgets.Label(value=\"Run the next cell to load the model. Rerunning this cell clears your selection.\")\n", + "])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "lYQgrPgPdI2n" + }, + "outputs": [], + "source": [ + "# @title Load the model\n", + "\n", + "source = source_tab.get_title(source_tab.selected_index)\n", + "\n", + "if source == \"Random\":\n", + " params = None # Filled in below\n", + " state = {}\n", + " model_config = graphcast.ModelConfig(\n", + " resolution=0,\n", + " mesh_size=random_mesh_size.value,\n", + " latent_size=random_latent_size.value,\n", + " gnn_msg_steps=random_gnn_msg_steps.value,\n", + " hidden_layers=1,\n", + " radius_query_fraction_edge_length=0.6)\n", + " task_config = graphcast.TaskConfig(\n", + " input_variables=graphcast.TASK.input_variables,\n", + " target_variables=graphcast.TASK.target_variables,\n", + " forcing_variables=graphcast.TASK.forcing_variables,\n", + " pressure_levels=graphcast.PRESSURE_LEVELS[random_levels.value],\n", + " input_duration=graphcast.TASK.input_duration,\n", + " )\n", + "else:\n", + " assert source == \"Checkpoint\"\n", + " with gcs_bucket.blob(f\"params/{params_file.value}\").open(\"rb\") as f:\n", + " ckpt = checkpoint.load(f, graphcast.CheckPoint)\n", + " params = ckpt.params\n", + " state = {}\n", + "\n", + " model_config = ckpt.model_config\n", + " task_config = ckpt.task_config\n", + " print(\"Model description:\\n\", ckpt.description, \"\\n\")\n", + " print(\"Model license:\\n\", ckpt.license, \"\\n\")\n", + "\n", + "model_config" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rQWk0RRuCjDN" + }, + "source": [ + "## Load the example data\n", + "\n", + "Several example datasets are available, varying across a few axes:\n", + "- **Source**: fake, era5, hres\n", + "- **Resolution**: 0.25deg, 1deg, 6deg\n", + "- **Levels**: 13, 37\n", + "- **Steps**: How many timesteps are included\n", + "\n", + "Not all combinations are available.\n", + "- Higher resolution is only available for fewer steps due to the memory requirements of loading them.\n", + "- HRES is only available in 0.25 deg, with 13 pressure levels.\n", + "\n", + "The data resolution must match the model that is loaded.\n", + "\n", + "Some transformations were done from the base datasets:\n", + "- We accumulated precipitation over 6 hours instead of the default 1 hour.\n", + "- For HRES data, each time step corresponds to the HRES forecast at leadtime 0, essentially providing an \"initialisation\" from HRES. See HRES-fc0 in the GraphCast paper for further description. Note that a 6h accumulation of precipitation is not available from HRES, so our model taking HRES inputs does not depend on precipitation. However, because our models predict precipitation, we include the ERA5 precipitation in the example data so it can serve as an illustrative example of ground truth.\n", + "- We include ERA5 `toa_incident_solar_radiation` in the data. Our model use the radiation a -6h, 0h and +6h as a forcing term for each 1-step prediction. In an operational setting, one can compute the radiation using a package such as `pysolar`, if the +6h radiation is not readily available." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "-DJzie5me2-H" + }, + "outputs": [], + "source": [ + "# @title Get and filter the list of available example datasets\n", + "\n", + "dataset_file_options = [\n", + " name for blob in gcs_bucket.list_blobs(prefix=\"dataset/\")\n", + " if (name := blob.name.removeprefix(\"dataset/\"))] # Drop empty string.\n", + "\n", + "def data_valid_for_model(\n", + " file_name: str, model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig):\n", + " file_parts = parse_file_parts(file_name.removesuffix(\".nc\"))\n", + " return (\n", + " model_config.resolution in (0, float(file_parts[\"res\"])) and\n", + " len(task_config.pressure_levels) == int(file_parts[\"levels\"]) and\n", + " (\n", + " (\"total_precipitation_6hr\" in task_config.input_variables and\n", + " file_parts[\"source\"] in (\"era5\", \"fake\")) or\n", + " (\"total_precipitation_6hr\" not in task_config.input_variables and\n", + " file_parts[\"source\"] in (\"hres\", \"fake\"))\n", + " )\n", + " )\n", + "\n", + "\n", + "dataset_file = widgets.Dropdown(\n", + " options=[\n", + " (\", \".join([f\"{k}: {v}\" for k, v in parse_file_parts(option.removesuffix(\".nc\")).items()]), option)\n", + " for option in dataset_file_options\n", + " if data_valid_for_model(option, model_config, task_config)\n", + " ],\n", + " description=\"Dataset file:\",\n", + " layout={\"width\": \"max-content\"})\n", + "widgets.VBox([\n", + " dataset_file,\n", + " widgets.Label(value=\"Run the next cell to load the dataset. Rerunning this cell clears your selection and refilters the datasets that match your model.\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Yz-ekISoJxeZ" + }, + "outputs": [], + "source": [ + "# @title Load weather data\n", + "\n", + "if not data_valid_for_model(dataset_file.value, model_config, task_config):\n", + " raise ValueError(\n", + " \"Invalid dataset file, rerun the cell above and choose a valid dataset file.\")\n", + "\n", + "with gcs_bucket.blob(f\"dataset/{dataset_file.value}\").open(\"rb\") as f:\n", + " example_batch = xarray.load_dataset(f).compute()\n", + "\n", + "assert example_batch.dims[\"time\"] \u003e= 3 # 2 for input, \u003e=1 for targets\n", + "\n", + "print(\", \".join([f\"{k}: {v}\" for k, v in parse_file_parts(dataset_file.value.removesuffix(\".nc\")).items()]))\n", + "\n", + "example_batch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "lXjFvdE6qStr" + }, + "outputs": [], + "source": [ + "# @title Choose data to plot\n", + "\n", + "plot_example_variable = widgets.Dropdown(\n", + " options=example_batch.data_vars.keys(),\n", + " value=\"2m_temperature\",\n", + " description=\"Variable\")\n", + "plot_example_level = widgets.Dropdown(\n", + " options=example_batch.coords[\"level\"].values,\n", + " value=500,\n", + " description=\"Level\")\n", + "plot_example_robust = widgets.Checkbox(value=True, description=\"Robust\")\n", + "plot_example_max_steps = widgets.IntSlider(\n", + " min=1, max=example_batch.dims[\"time\"], value=example_batch.dims[\"time\"],\n", + " description=\"Max steps\")\n", + "\n", + "widgets.VBox([\n", + " plot_example_variable,\n", + " plot_example_level,\n", + " plot_example_robust,\n", + " plot_example_max_steps,\n", + " widgets.Label(value=\"Run the next cell to plot the data. Rerunning this cell clears your selection.\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "kIK-EgMdkHtk" + }, + "outputs": [], + "source": [ + "# @title Plot example data\n", + "\n", + "plot_size = 7\n", + "\n", + "data = {\n", + " \" \": scale(select(example_batch, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),\n", + " robust=plot_example_robust.value),\n", + "}\n", + "fig_title = plot_example_variable.value\n", + "if \"level\" in example_batch[plot_example_variable.value].coords:\n", + " fig_title += f\" at {plot_example_level.value} hPa\"\n", + "\n", + "plot_data(data, fig_title, plot_size, plot_example_robust.value)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "tPVy1GHokHtk" + }, + "outputs": [], + "source": [ + "# @title Choose training and eval data to extract\n", + "train_steps = widgets.IntSlider(\n", + " value=1, min=1, max=example_batch.sizes[\"time\"]-2, description=\"Train steps\")\n", + "eval_steps = widgets.IntSlider(\n", + " value=example_batch.sizes[\"time\"]-2, min=1, max=example_batch.sizes[\"time\"]-2, description=\"Eval steps\")\n", + "\n", + "widgets.VBox([\n", + " train_steps,\n", + " eval_steps,\n", + " widgets.Label(value=\"Run the next cell to extract the data. Rerunning this cell clears your selection.\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Ogp4vTBvsgSt" + }, + "outputs": [], + "source": [ + "# @title Extract training and eval data\n", + "\n", + "train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(\n", + " example_batch, target_lead_times=slice(\"6h\", f\"{train_steps.value*6}h\"),\n", + " **dataclasses.asdict(task_config))\n", + "\n", + "eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(\n", + " example_batch, target_lead_times=slice(\"6h\", f\"{eval_steps.value*6}h\"),\n", + " **dataclasses.asdict(task_config))\n", + "\n", + "print(\"All Examples: \", example_batch.dims.mapping)\n", + "print(\"Train Inputs: \", train_inputs.dims.mapping)\n", + "print(\"Train Targets: \", train_targets.dims.mapping)\n", + "print(\"Train Forcings:\", train_forcings.dims.mapping)\n", + "print(\"Eval Inputs: \", eval_inputs.dims.mapping)\n", + "print(\"Eval Targets: \", eval_targets.dims.mapping)\n", + "print(\"Eval Forcings: \", eval_forcings.dims.mapping)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Q--ZRhpTdI2o" + }, + "outputs": [], + "source": [ + "# @title Load normalization data\n", + "\n", + "with gcs_bucket.blob(\"stats/diffs_stddev_by_level.nc\").open(\"rb\") as f:\n", + " diffs_stddev_by_level = xarray.load_dataset(f).compute()\n", + "with gcs_bucket.blob(\"stats/mean_by_level.nc\").open(\"rb\") as f:\n", + " mean_by_level = xarray.load_dataset(f).compute()\n", + "with gcs_bucket.blob(\"stats/stddev_by_level.nc\").open(\"rb\") as f:\n", + " stddev_by_level = xarray.load_dataset(f).compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ke2zQyuT_sMA" + }, + "outputs": [], + "source": [ + "# @title Build jitted functions, and possibly initialize random weights\n", + "\n", + "def construct_wrapped_graphcast(\n", + " model_config: graphcast.ModelConfig,\n", + " task_config: graphcast.TaskConfig):\n", + " \"\"\"Constructs and wraps the GraphCast Predictor.\"\"\"\n", + " # Deeper one-step predictor.\n", + " predictor = graphcast.GraphCast(model_config, task_config)\n", + "\n", + " # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to\n", + " # from/to float32 to/from BFloat16.\n", + " predictor = casting.Bfloat16Cast(predictor)\n", + "\n", + " # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from\n", + " # BFloat16 happens after applying normalization to the inputs/targets.\n", + " predictor = normalization.InputsAndResiduals(\n", + " predictor,\n", + " diffs_stddev_by_level=diffs_stddev_by_level,\n", + " mean_by_level=mean_by_level,\n", + " stddev_by_level=stddev_by_level)\n", + "\n", + " # Wraps everything so the one-step model can produce trajectories.\n", + " predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)\n", + " return predictor\n", + "\n", + "\n", + "@hk.transform_with_state\n", + "def run_forward(model_config, task_config, inputs, targets_template, forcings):\n", + " predictor = construct_wrapped_graphcast(model_config, task_config)\n", + " return predictor(inputs, targets_template=targets_template, forcings=forcings)\n", + "\n", + "\n", + "@hk.transform_with_state\n", + "def loss_fn(model_config, task_config, inputs, targets, forcings):\n", + " predictor = construct_wrapped_graphcast(model_config, task_config)\n", + " loss, diagnostics = predictor.loss(inputs, targets, forcings)\n", + " return xarray_tree.map_structure(\n", + " lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),\n", + " (loss, diagnostics))\n", + "\n", + "def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):\n", + " def _aux(params, state, i, t, f):\n", + " (loss, diagnostics), next_state = loss_fn.apply(\n", + " params, state, jax.random.PRNGKey(0), model_config, task_config,\n", + " i, t, f)\n", + " return loss, (diagnostics, next_state)\n", + " (loss, (diagnostics, next_state)), grads = jax.value_and_grad(\n", + " _aux, has_aux=True)(params, state, inputs, targets, forcings)\n", + " return loss, diagnostics, next_state, grads\n", + "\n", + "# Jax doesn't seem to like passing configs as args through the jit. Passing it\n", + "# in via partial (instead of capture by closure) forces jax to invalidate the\n", + "# jit cache if you change configs.\n", + "def with_configs(fn):\n", + " return functools.partial(\n", + " fn, model_config=model_config, task_config=task_config)\n", + "\n", + "# Always pass params and state, so the usage below are simpler\n", + "def with_params(fn):\n", + " return functools.partial(fn, params=params, state=state)\n", + "\n", + "# Our models aren't stateful, so the state is always empty, so just return the\n", + "# predictions. This is requiredy by our rollout code, and generally simpler.\n", + "def drop_state(fn):\n", + " return lambda **kw: fn(**kw)[0]\n", + "\n", + "init_jitted = jax.jit(with_configs(run_forward.init))\n", + "\n", + "if params is None:\n", + " params, state = init_jitted(\n", + " rng=jax.random.PRNGKey(0),\n", + " inputs=train_inputs,\n", + " targets_template=train_targets,\n", + " forcings=train_forcings)\n", + "\n", + "loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))\n", + "grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))\n", + "run_forward_jitted = drop_state(with_params(jax.jit(with_configs(\n", + " run_forward.apply))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VBNutliiCyqA" + }, + "source": [ + "# Run the model\n", + "\n", + "Note that the cell below may take a while (possibly minutes) to run the first time you execute them, because this will include the time it takes for the code to compile. The second time running will be significantly faster.\n", + "\n", + "This use the python loop to iterate over prediction steps, where the 1-step prediction is jitted. This has lower memory requirements than the training steps below, and should enable making prediction with the small GraphCast model on 1 deg resolution data for 4 steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "7obeY9i9oTtD" + }, + "outputs": [], + "source": [ + "# @title Autoregressive rollout (loop in python)\n", + "\n", + "assert model_config.resolution in (0, 360. / eval_inputs.sizes[\"lon\"]), (\n", + " \"Model resolution doesn't match the data resolution. You likely want to \"\n", + " \"re-filter the dataset list, and download the correct data.\")\n", + "\n", + "print(\"Inputs: \", eval_inputs.dims.mapping)\n", + "print(\"Targets: \", eval_targets.dims.mapping)\n", + "print(\"Forcings:\", eval_forcings.dims.mapping)\n", + "\n", + "predictions = rollout.chunked_prediction(\n", + " run_forward_jitted,\n", + " rng=jax.random.PRNGKey(0),\n", + " inputs=eval_inputs,\n", + " targets_template=eval_targets * np.nan,\n", + " forcings=eval_forcings)\n", + "predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ft298eZskHtn" + }, + "outputs": [], + "source": [ + "# @title Choose predictions to plot\n", + "\n", + "plot_pred_variable = widgets.Dropdown(\n", + " options=predictions.data_vars.keys(),\n", + " value=\"2m_temperature\",\n", + " description=\"Variable\")\n", + "plot_pred_level = widgets.Dropdown(\n", + " options=predictions.coords[\"level\"].values,\n", + " value=500,\n", + " description=\"Level\")\n", + "plot_pred_robust = widgets.Checkbox(value=True, description=\"Robust\")\n", + "plot_pred_max_steps = widgets.IntSlider(\n", + " min=1,\n", + " max=predictions.dims[\"time\"],\n", + " value=predictions.dims[\"time\"],\n", + " description=\"Max steps\")\n", + "\n", + "widgets.VBox([\n", + " plot_pred_variable,\n", + " plot_pred_level,\n", + " plot_pred_robust,\n", + " plot_pred_max_steps,\n", + " widgets.Label(value=\"Run the next cell to plot the predictions. Rerunning this cell clears your selection.\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "_tTdx6fmmj1I" + }, + "outputs": [], + "source": [ + "# @title Plot predictions\n", + "\n", + "plot_size = 5\n", + "plot_max_steps = min(predictions.dims[\"time\"], plot_pred_max_steps.value)\n", + "\n", + "data = {\n", + " \"Targets\": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),\n", + " \"Predictions\": scale(select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),\n", + " \"Diff\": scale((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -\n", + " select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),\n", + " robust=plot_pred_robust.value, center=0),\n", + "}\n", + "fig_title = plot_pred_variable.value\n", + "if \"level\" in predictions[plot_pred_variable.value].coords:\n", + " fig_title += f\" at {plot_pred_level.value} hPa\"\n", + "\n", + "plot_data(data, fig_title, plot_size, plot_pred_robust.value)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pa78b64bLYe1" + }, + "source": [ + "# Train the model\n", + "\n", + "The following operations require a large amount of memory and, depending on the accelerator being used, will only fit the very small \"random\" model on low resolution data. It uses the number of training steps selected above.\n", + "\n", + "The first time executing the cell takes more time, as it include the time to jit the function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Nv-u3dAP7IRZ" + }, + "outputs": [], + "source": [ + "# @title Loss computation (autoregressive loss over multiple steps)\n", + "loss, diagnostics = loss_fn_jitted(\n", + " rng=jax.random.PRNGKey(0),\n", + " inputs=train_inputs,\n", + " targets=train_targets,\n", + " forcings=train_forcings)\n", + "print(\"Loss:\", float(loss))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "mBNFq1IGZNLz" + }, + "outputs": [], + "source": [ + "# @title Gradient computation (backprop through time)\n", + "loss, diagnostics, next_state, grads = grads_fn_jitted(\n", + " inputs=train_inputs,\n", + " targets=train_targets,\n", + " forcings=train_forcings)\n", + "mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])\n", + "print(f\"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "J4FJFKWD8Loz" + }, + "outputs": [], + "source": [ + "# @title Autoregressive rollout (keep the loop in JAX)\n", + "print(\"Inputs: \", train_inputs.dims.mapping)\n", + "print(\"Targets: \", train_targets.dims.mapping)\n", + "print(\"Forcings:\", train_forcings.dims.mapping)\n", + "\n", + "predictions = run_forward_jitted(\n", + " rng=jax.random.PRNGKey(0),\n", + " inputs=train_inputs,\n", + " targets_template=train_targets * np.nan,\n", + " forcings=train_forcings)\n", + "predictions" + ] + } + ], + "metadata": { + "colab": { + "name": "GraphCast", + "private_outputs": true, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..2d76155 --- /dev/null +++ b/setup.py @@ -0,0 +1,60 @@ +# Copyright 2023 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS-IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module setuptools script.""" + +from setuptools import setup + +description = ( + "GraphCast: Learning skillful medium-range global weather forecasting" +) + +setup( + name="graphcast", + version="0.1", + description=description, + long_description=description, + author="DeepMind", + license="Apache License, Version 2.0", + keywords="GraphCast Weather Prediction", + url="https://github.com/deepmind/graphcast", + packages=["graphcast"], + install_requires=[ + "cartopy", + "chex", + "colabtools", + "dask", + "dm-haiku", + "jax", + "jraph", + "matplotlib", + "numpy", + "pandas", + "rtree", + "scipy", + "tree", + "trimesh", + "typing_extensions", + "xarray", + ], + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Atmospheric Science", + "Topic :: Scientific/Engineering :: Physics", + ], +)