Skip to content

Commit

Permalink
full size xg and yg (#246)
Browse files Browse the repository at this point in the history
* add get_xg_yg_from_input code

* refactored some code to reduce repeats

* added testing

* minimalized PR, dropped the sub tile things

* remove tests for now

* add outer

* resolve feedback from @raphaeldussin

* Update xmitgcm/utils.py

Co-authored-by: Ryan Abernathey <[email protected]>

* Update xmitgcm/utils.py

Co-authored-by: Ryan Abernathey <[email protected]>

* Update xmitgcm/utils.py

Co-authored-by: Ryan Abernathey <[email protected]>

* Update xmitgcm/utils.py

Co-authored-by: Ryan Abernathey <[email protected]>

* add tests

* change logic, pass all tests

* feedback

* remove unnescessary outer flags in tests

* cleanup

* CS Grid check

* add precision, endianness to grid test

Co-authored-by: Spencer Jones <[email protected]>
Co-authored-by: Ryan Abernathey <[email protected]>
Co-authored-by: Raphael Dussin <[email protected]>
  • Loading branch information
4 people authored Jul 29, 2021
1 parent c793ecc commit 3b341ad
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 40 deletions.
122 changes: 100 additions & 22 deletions xmitgcm/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,17 +957,22 @@ def test_get_extra_metadata(domain, nx):
em = get_extra_metadata(domain='notinlist', nx=nx)


@pytest.mark.parametrize("outer", [True, False])
@pytest.mark.parametrize("usedask", [True, False])
def test_get_grid_from_input(all_grid_datadirs, usedask):
def test_get_grid_from_input(all_grid_datadirs, usedask, outer):
from xmitgcm.utils import get_grid_from_input, get_extra_metadata
from xmitgcm.utils import read_raw_data
dirname, expected = all_grid_datadirs
md = get_extra_metadata(domain=expected['domain'], nx=expected['nx'])
dtype = np.dtype('{}{}'.format(expected['endianness'], expected['precision']))
ds = get_grid_from_input(dirname + '/' + expected['gridfile'],
geometry=expected['geometry'],
dtype=np.dtype('d'), endian='>',
dtype=np.dtype(expected['precision']),
endian=expected['endianness'],
use_dask=usedask,
extra_metadata=md)
extra_metadata=md,
outer=outer)

# test types
assert type(ds) == xarray.Dataset
assert type(ds['XC']) == xarray.core.dataarray.DataArray
Expand All @@ -980,9 +985,19 @@ def test_get_grid_from_input(all_grid_datadirs, usedask):
'XG', 'YG', 'DXV', 'DYU', 'RAZ',
'DXC', 'DYC', 'RAW', 'RAS', 'DXG', 'DYG']

outerx_vars = ['DXC', 'RAW', 'DYG'] if outer else []
outery_vars = ['DYC', 'RAS', 'DXG'] if outer else []
outerxy_vars = ['XG', 'YG', 'RAZ'] if outer else []

for var in expected_variables:
expected_shape = list(expected['shape'])
if var in outerx_vars or var in outerxy_vars:
expected_shape[-1] = expected_shape[-1] + 1
if var in outery_vars or var in outerxy_vars:
expected_shape[-2] = expected_shape[-2] + 1

assert type(ds[var]) == xarray.core.dataarray.DataArray
assert ds[var].values.shape == expected['shape']
assert ds[var].values.shape == tuple(expected_shape)

# check we don't leave points behind
if expected['geometry'] == 'llc':
Expand All @@ -1009,37 +1024,98 @@ def test_get_grid_from_input(all_grid_datadirs, usedask):
ny4 = int(size4 / sizeofd / nvars / nx)
ny5 = int(size5 / sizeofd / nvars / nx)

xc1 = read_raw_data(grid1, dtype=np.dtype('>d'), shape=(ny1, nx),
xc1 = read_raw_data(grid1, dtype=dtype, shape=(ny1, nx),
partial_read=True)
xc2 = read_raw_data(grid2, dtype=np.dtype('>d'), shape=(ny2, nx),
xc2 = read_raw_data(grid2, dtype=dtype, shape=(ny2, nx),
partial_read=True)
xc3 = read_raw_data(grid3, dtype=np.dtype('>d'), shape=(ny3, nx),
xc3 = read_raw_data(grid3, dtype=dtype, shape=(ny3, nx),
partial_read=True)
xc4 = read_raw_data(grid4, dtype=np.dtype('>d'), shape=(ny4, nx),
xc4 = read_raw_data(grid4, dtype=dtype, shape=(ny4, nx),
order='F', partial_read=True)
xc5 = read_raw_data(grid5, dtype=np.dtype('>d'), shape=(ny5, nx),
xc5 = read_raw_data(grid5, dtype=dtype, shape=(ny5, nx),
order='F', partial_read=True)

yc1 = read_raw_data(grid1, dtype=np.dtype('>d'), shape=(ny1, nx),
yc1 = read_raw_data(grid1, dtype=dtype, shape=(ny1, nx),
partial_read=True, offset=nx*ny1*sizeofd)
yc2 = read_raw_data(grid2, dtype=np.dtype('>d'), shape=(ny2, nx),
yc2 = read_raw_data(grid2, dtype=dtype, shape=(ny2, nx),
partial_read=True, offset=nx*ny2*sizeofd)
yc3 = read_raw_data(grid3, dtype=np.dtype('>d'), shape=(ny3, nx),
yc3 = read_raw_data(grid3, dtype=dtype, shape=(ny3, nx),
partial_read=True, offset=nx*ny3*sizeofd)
yc4 = read_raw_data(grid4, dtype=np.dtype('>d'), shape=(ny4, nx),
yc4 = read_raw_data(grid4, dtype=dtype, shape=(ny4, nx),
order='F', partial_read=True,
offset=nx*ny4*sizeofd)
yc5 = read_raw_data(grid5, dtype=np.dtype('>d'), shape=(ny5, nx),
yc5 = read_raw_data(grid5, dtype=dtype, shape=(ny5, nx),
order='F', partial_read=True,
offset=nx*ny5*sizeofd)

xc = np.concatenate([xc1[:-1, :-1].flatten(), xc2[:-1, :-1].flatten(),
xc3[:-1, :-1].flatten(), xc4[:-1, :-1].flatten(),
xc5[:-1, :-1].flatten()])
xc = np.concatenate([xc1.flatten(), xc2.flatten(),
xc3.flatten(), xc4.flatten(),
xc5.flatten()])

yc = np.concatenate([yc1.flatten(), yc2.flatten(),
yc3.flatten(), yc4.flatten(),
yc5.flatten()])

xc_from_ds = ds['XC'].values.flatten()
yc_from_ds = ds['YC'].values.flatten()

assert xc.min() == xc_from_ds.min()
assert xc.max() == xc_from_ds.max()
assert yc.min() == yc_from_ds.min()
assert yc.max() == yc_from_ds.max()

if expected['geometry'] == 'cs':
nx = expected['nx'] + 1
sizeofd = 8

grid = expected['gridfile']
grid1 = dirname + '/' + grid.replace('<NFACET>', '001')
grid2 = dirname + '/' + grid.replace('<NFACET>', '002')
grid3 = dirname + '/' + grid.replace('<NFACET>', '003')
grid4 = dirname + '/' + grid.replace('<NFACET>', '004')
grid5 = dirname + '/' + grid.replace('<NFACET>', '005')
grid6 = dirname + '/' + grid.replace('<NFACET>', '006')


xc1 = read_raw_data(grid1, dtype=dtype, shape=(nx, nx),
order='F', partial_read=True)
xc2 = read_raw_data(grid2, dtype=dtype, shape=(nx, nx),
order='F', partial_read=True)
xc3 = read_raw_data(grid3, dtype=dtype, shape=(nx, nx),
order='F', partial_read=True)
xc4 = read_raw_data(grid4, dtype=dtype, shape=(nx, nx),
order='F', partial_read=True)
xc5 = read_raw_data(grid5, dtype=dtype, shape=(nx, nx),
order='F', partial_read=True)
xc6 = read_raw_data(grid6, dtype=dtype, shape=(nx, nx),
order='F', partial_read=True)

yc1 = read_raw_data(grid1, dtype=dtype, shape=(nx, nx),
order='F', partial_read=True,
offset=nx * nx * sizeofd)
yc2 = read_raw_data(grid2, dtype=dtype, shape=(nx, nx),
order='F', partial_read=True,
offset=nx * nx * sizeofd)
yc3 = read_raw_data(grid3, dtype=dtype, shape=(nx, nx),
order='F', partial_read=True,
offset=nx * nx * sizeofd)
yc4 = read_raw_data(grid4, dtype=dtype, shape=(nx, nx),
order='F', partial_read=True,
offset=nx * nx * sizeofd)
yc5 = read_raw_data(grid5, dtype=dtype, shape=(nx, nx),
order='F', partial_read=True,
offset=nx * nx * sizeofd)
yc6 = read_raw_data(grid6, dtype=dtype, shape=(nx, nx),
order='F', partial_read=True,
offset=nx * nx * sizeofd)

xc = np.concatenate([xc1.flatten(), xc2.flatten(),
xc3.flatten(), xc4.flatten(),
xc5.flatten(), xc6.flatten()])

yc = np.concatenate([yc1[:-1, :-1].flatten(), yc2[:-1, :-1].flatten(),
yc3[:-1, :-1].flatten(), yc4[:-1, :-1].flatten(),
yc5[:-1, :-1].flatten()])
yc = np.concatenate([yc1.flatten(), yc2.flatten(),
yc3.flatten(), yc4.flatten(),
yc5.flatten(), yc6.flatten()])

xc_from_ds = ds['XC'].values.flatten()
yc_from_ds = ds['YC'].values.flatten()
Expand All @@ -1054,9 +1130,11 @@ def test_get_grid_from_input(all_grid_datadirs, usedask):
with pytest.raises(ValueError):
ds = get_grid_from_input(dirname + '/' + expected['gridfile'],
geometry=expected['geometry'],
dtype=np.dtype('d'), endian='>',
dtype=np.dtype(expected['precision']),
endian=expected['endianness'],
use_dask=False,
extra_metadata=None)
extra_metadata=None,
outer=outer)


@pytest.mark.parametrize("dtype", [np.dtype('d'), np.dtype('f')])
Expand Down
9 changes: 6 additions & 3 deletions xmitgcm/test/test_xmitgcm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,19 +177,22 @@ def hide_file(origdir, *basenames):
'dlink': dlroot + '14072594',
'md5': 'f66c3195a62790d539debe6ca8f3a851',
'gridfile': 'tile<NFACET>.mitgrid',
'nx': 90, 'shape': (13, 90, 90)},
'nx': 90, 'shape': (13, 90, 90),
'endianness': '>', 'precision': 'd'},

'grid_aste270': {'geometry': 'llc', 'domain': 'aste',
'dlink': dlroot + '14072591',
'md5': '92b28c65e0dfb54b253bfcd0a249359b',
'gridfile': 'tile<NFACET>.mitgrid',
'nx': 270, 'shape': (6, 270, 270)},
'nx': 270, 'shape': (6, 270, 270),
'endianness': '>', 'precision': 'd'},

'grid_cs32': {'geometry': 'cs', 'domain': 'cs',
'dlink': dlroot + '14072597',
'md5': '848cd5b6daab5b069e96a0cff67d4b57',
'gridfile': 'grid_cs32.face<NFACET>.bin',
'nx': 32, 'shape': (6, 32, 32)}
'nx': 32, 'shape': (6, 32, 32),
'endianness': '>', 'precision': 'd'},
}


Expand Down
55 changes: 40 additions & 15 deletions xmitgcm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,7 @@ def _pad_array(data, file_metadata, face=0):


def get_extra_metadata(domain='llc', nx=90):
"""
"""
Return the extra_metadata dictionay for selected domains
PARAMETERS
Expand Down Expand Up @@ -1308,9 +1308,9 @@ def get_extra_metadata(domain='llc', nx=90):


def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
dtype=np.dtype('d'), endian='>', use_dask=False,
dtype=np.dtype('d'), endian='>', use_dask=False, outer=False,
extra_metadata=None):
"""
"""
Read grid variables from grid input files, this is especially useful
for llc and cube sphere configurations used with land tiles
elimination. Reading the input grid files (e.g. tile00[1-5].mitgrid)
Expand All @@ -1332,11 +1332,13 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
endianness of input data
use_dask : bool
use dask or not
outer : bool
include outer boundary or not
extra_metadata : dict
dictionary of extra metadata, needed for llc configurations
RETURNS
-------
-------
grid : xarray.Dataset
all grid variables
"""
Expand All @@ -1347,6 +1349,10 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
'XG', 'YG', 'DXV', 'DYU', 'RAZ',
'DXC', 'DYC', 'RAW', 'RAS', 'DXG', 'DYG']

outerx_vars = ['DXC', 'RAW', 'DYG'] if outer else []
outery_vars = ['DYC', 'RAS', 'DXG'] if outer else []
outerxy_vars = ['XG', 'YG', 'RAZ'] if outer else []

file_metadata['vars'] = file_metadata['fldList']
dims_vars_list = []
for var in file_metadata['fldList']:
Expand Down Expand Up @@ -1399,6 +1405,7 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
nxgrid = file_metadata['ny_facets'][kfacet] + 1
nygrid = file_metadata['nx'] + 1


grid_metadata.update({'nx': nxgrid, 'ny': nygrid,
'has_faces': False})

Expand All @@ -1412,8 +1419,9 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
{file_metadata['fldList'][kfield]: raw[kfield]})

for field in file_metadata['fldList']:
# symetrize
tmp = rawfields[field][:, :, :-1, :-1].squeeze()

# get the full array
tmp = rawfields[field].squeeze()
# transpose
if grid_metadata['facet_orders'][kfacet] == 'F':
tmp = tmp.transpose()
Expand All @@ -1423,15 +1431,30 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
if grid_metadata['face_facets'][face] == kfacet:
# get offset of face from facet
offset = file_metadata['face_offsets'][face]
nx = file_metadata['nx']

nx = file_metadata['nx'] + 1
nxm1 = file_metadata['nx']
pad_metadata = file_metadata.copy()
pad_metadata['nx'] = file_metadata['nx'] + 1
# pad data, if needed (would trigger eager data eval)
# needs a new array not to pad multiple times
padded = _pad_array(tmp, file_metadata, face=face)
padded = _pad_array(tmp, pad_metadata, face=face)
# extract the data
dataface = padded[offset*nx:(offset+1)*nx, :]
dataface = padded[offset*nxm1:offset*nxm1 + nx, :]
# transpose, if needed
if file_metadata['transpose_face'][face]:
dataface = dataface.transpose()

# remove irrelevant data
if field in outerx_vars:
dataface = dataface[..., :-1, :].squeeze()
elif field in outery_vars:
dataface = dataface[..., :-1].squeeze()
elif field in outerxy_vars:
dataface = dataface.squeeze()
else:
dataface = dataface[..., :-1, :-1].squeeze()

# assign values
dataface = dsa.stack([dataface], axis=0)
if face == 0:
Expand All @@ -1441,6 +1464,7 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
[gridfields[field], dataface], axis=0)

# create the dataset
nxouter = file_metadata['nx'] + 1 if outer else file_metadata['nx']
if geometry == 'llc':
grid = xr.Dataset({'XC': (['face', 'j', 'i'], gridfields['XC']),
'YC': (['face', 'j', 'i'], gridfields['YC']),
Expand All @@ -1462,9 +1486,9 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
coords={'i': (['i'], np.arange(file_metadata['nx'])),
'j': (['j'], np.arange(file_metadata['nx'])),
'i_g': (['i_g'],
np.arange(file_metadata['nx'])),
np.arange(nxouter)),
'j_g': (['j_g'],
np.arange(file_metadata['nx'])),
np.arange(nxouter)),
'face': (['face'], np.arange(nfaces))
}
)
Expand All @@ -1489,13 +1513,14 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
coords={'i': (['i'], np.arange(file_metadata['nx'])),
'j': (['j'], np.arange(file_metadata['nx'])),
'i_g': (['i_g'],
np.arange(file_metadata['nx'])),
np.arange(nxouter)),
'j_g': (['j_g'],
np.arange(file_metadata['nx'])),
np.arange(nxouter)),
'face': (['face'], np.arange(nfaces))
}
)
else: # pragma: no cover
nyouter = file_metadata['ny'] + 1 if outer else file_metadata['ny']
grid = xr.Dataset({'XC': (['j', 'i'], gridfields['XC']),
'YC': (['j', 'i'], gridfields['YC']),
'DXF': (['j', 'i'], gridfields['DXF']),
Expand All @@ -1516,9 +1541,9 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
coords={'i': (['i'], np.arange(file_metadata['nx'])),
'j': (['j'], np.arange(file_metadata['ny'])),
'i_g': (['i_g'],
np.arange(file_metadata['nx'])),
np.arange(nxouter)),
'j_g': (['j_g'],
np.arange(file_metadata['ny']))
np.arange(nyouter))
}
)

Expand Down

0 comments on commit 3b341ad

Please sign in to comment.