Skip to content

Commit

Permalink
modified open_mdsdataset to add extra_variables (#205)
Browse files Browse the repository at this point in the history
* modified open_mdsdataset to add extra_variables

* Backwards compatible

* modified open_mdsdataset to add extra_variables

* Backwards compatible

* add tests and extend docstring

* only copy file if input is available.

Co-authored-by: Spencer Jones <[email protected]>
Co-authored-by: Aaron Schneider <[email protected]>
  • Loading branch information
3 people authored May 17, 2021
1 parent 7db1c32 commit 09f8f95
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 8 deletions.
41 changes: 33 additions & 8 deletions xmitgcm/mds_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def open_mdsdataset(data_dir, grid_dir=None,
endian=">", chunks=None,
ignore_unknown_vars=False, default_dtype=None,
nx=None, ny=None, nz=None,
llc_method="smallchunks", extra_metadata=None):
llc_method="smallchunks", extra_metadata=None,
extra_variables=None):
"""Open MITgcm-style mds (.data / .meta) file output as xarray datset.
Parameters
Expand Down Expand Up @@ -131,6 +132,23 @@ def open_mdsdataset(data_dir, grid_dir=None,
For global llc grids, no extra metadata is required and code
will set up to global llc default configuration.
extra_variables : dict, optional
Allow to pass variables not listed in the variables.py
or in available_diagnostics.log.
extra_variables must be a dict containing the variable names as keys with
the corresponging values being a dict with the keys being dims and attrs.
Syntax:
extra_variables = dict(varname = dict(dims=list_of_dims, attrs=dict(optional_attrs)))
where optional_attrs can contain standard_name, long_name, units as keys
Example:
extra_variables = dict(
ADJtheta = dict(dims=['k','j','i'], attrs=dict(
standard_name='Sensitivity_to_theta',
long_name='Sensitivity of cost function to theta', units='[J]/degC'))
)
Returns
-------
Expand All @@ -141,7 +159,7 @@ def open_mdsdataset(data_dir, grid_dir=None,
----------
.. [1] http://cfconventions.org/Data/cf-conventions/cf-conventions-1.7/build/ch04s04.html
"""

# get frame info for history
frame = inspect.currentframe()
_, _, _, arg_values = inspect.getargvalues(frame)
Expand Down Expand Up @@ -216,7 +234,8 @@ def open_mdsdataset(data_dir, grid_dir=None,
ignore_unknown_vars=ignore_unknown_vars,
default_dtype=default_dtype,
nx=nx, ny=ny, nz=nz, llc_method=llc_method,
levels=levels, extra_metadata=extra_metadata)
levels=levels, extra_metadata=extra_metadata,
extra_variables=extra_variables)
datasets = [open_mdsdataset(
data_dir, iters=iternum, read_grid=False, **kwargs)
for iternum in iters]
Expand Down Expand Up @@ -255,7 +274,9 @@ def open_mdsdataset(data_dir, grid_dir=None,
ignore_unknown_vars=ignore_unknown_vars,
default_dtype=default_dtype,
nx=nx, ny=ny, nz=nz, llc_method=llc_method,
levels=levels, extra_metadata=extra_metadata)
levels=levels, extra_metadata=extra_metadata,
extra_variables=extra_variables)

ds = xr.Dataset.load_store(store)
if swap_dims:
ds = _swap_dimensions(ds, geometry)
Expand Down Expand Up @@ -338,7 +359,8 @@ def __init__(self, data_dir, grid_dir=None,
endian='>', ignore_unknown_vars=False,
default_dtype=np.dtype('f4'),
nx=None, ny=None, nz=None, llc_method="smallchunks",
levels=None, extra_metadata=None):
levels=None, extra_metadata=None,
extra_variables=None):
"""
This is not a user-facing class. See open_mdsdataset for argument
documentation. The only ones which are distinct are.
Expand All @@ -362,6 +384,7 @@ def __init__(self, data_dir, grid_dir=None,
# the directory where the files live
self.data_dir = data_dir
self.grid_dir = grid_dir if (grid_dir is not None) else data_dir
self.extra_variables = extra_variables
self._ignore_unknown_vars = ignore_unknown_vars

# The endianness of the files
Expand Down Expand Up @@ -537,7 +560,8 @@ def __init__(self, data_dir, grid_dir=None,
self.layers)
self._all_data_variables = _get_all_data_variables(self.data_dir,
self.grid_dir,
self.layers)
self.layers,
self.extra_variables)

# The rest of the data has to be read from disk.
# The list `prefixes` specifies file prefixes from which to infer
Expand Down Expand Up @@ -866,11 +890,12 @@ def _recursively_replace(item, search, replace):
return item


def _get_all_data_variables(data_dir, grid_dir, layers):
def _get_all_data_variables(data_dir, grid_dir, layers, extra_variables):
""""Put all the relevant data metadata into one big dictionary."""
allvars = [state_variables]
allvars.append(package_state_variables)

if extra_variables is not None:
allvars.append(extra_variables)
# add others from available_diagnostics.log
# search in the data dir
fnameD = os.path.join(data_dir, 'available_diagnostics.log')
Expand Down
43 changes: 43 additions & 0 deletions xmitgcm/test/test_mds_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,49 @@ def test_drc_length(all_mds_datadirs):
assert len(ds.drC) == (len(ds.drF)+1)


def test_extra_variables(all_mds_datadirs):
"""Test that open_mdsdataset reads extra_variables correctly"""
dirname, expected = all_mds_datadirs

extra_variable_data = dict(
# use U,V to test with attrs (including mate)
testdataU=dict(dims=['k','j','i_g'], attrs=dict(
standard_name='sea_water_x_velocity', mate='testdataV',
long_name='Zonal Component of Velocity', units='m s-1')),
testdataV=dict(dims=['k','j_g','i'], attrs=dict(
standard_name='sea_water_y_velocity', mate='testdataU',
long_name='Meridional Component of Velocity', units='m s-1')),
# use T to test without attrs
testdataT=dict(dims=['k','j','i'], attrs=dict())
)

for var in ["U","V","T"]:
input_dir = os.path.join(dirname, '{}.{:010d}'.format(var, expected['test_iternum']))
test_dir = os.path.join(dirname, 'testdata{}.{:010d}'.format(var, expected['test_iternum']))

if input_dir+".meta" not in os.listdir() or input_dir+".data" not in os.listdir():
return

copyfile(input_dir+".meta",test_dir+".meta")
copyfile(input_dir+".data",test_dir+".data")

assert test_dir + ".meta" not in os.listdir(), f"{var} did not copy meta!"
assert test_dir + ".data" not in os.listdir(), f"{var} did not copy data!"

ds = xmitgcm.open_mdsdataset(
dirname,
read_grid=False,
iters=expected['test_iternum'],
geometry=expected['geometry'],
prefix=list(extra_variable_data.keys()),
extra_variables=extra_variable_data)

for var in extra_variable_data.keys():
assert var in ds
if 'mate' in ds[var].attrs:
mate = ds[var].attrs['mate']
assert ds[mate].attrs['mate'] == var

def test_mask_values(all_mds_datadirs):
"""Test that open_mdsdataset generates binary masks with correct values"""

Expand Down

0 comments on commit 09f8f95

Please sign in to comment.