Skip to content

Commit

Permalink
Revert "Improve safe chunk validation (#9527)" (#9558)
Browse files Browse the repository at this point in the history
This reverts commit 2a6212e.
  • Loading branch information
shoyer authored Sep 30, 2024
1 parent ece582d commit 7bdc6d4
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 303 deletions.
4 changes: 1 addition & 3 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ Bug fixes
<https://github.com/spencerkclark>`_.
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Fix the safe_chunks validation option on the to_zarr method
(:issue:`5511`, :pull:`9513`). By `Joseph Nowak
<https://github.com/josephnowak>`_.


Documentation
~~~~~~~~~~~~~
Expand Down
168 changes: 47 additions & 121 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def __getitem__(self, key):
# could possibly have a work-around for 0d data here


def _determine_zarr_chunks(
enc_chunks, var_chunks, ndim, name, safe_chunks, region, mode, shape
):
def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
"""
Given encoding chunks (possibly None or []) and variable chunks
(possibly None or []).
Expand Down Expand Up @@ -165,9 +163,7 @@ def _determine_zarr_chunks(

if len(enc_chunks_tuple) != ndim:
# throw away encoding chunks, start over
return _determine_zarr_chunks(
None, var_chunks, ndim, name, safe_chunks, region, mode, shape
)
return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks)

for x in enc_chunks_tuple:
if not isinstance(x, int):
Expand All @@ -193,59 +189,20 @@ def _determine_zarr_chunks(
# TODO: incorporate synchronizer to allow writes from multiple dask
# threads
if var_chunks and enc_chunks_tuple:
# If it is possible to write on partial chunks then it is not necessary to check
# the last one contained on the region
allow_partial_chunks = mode != "r+"

base_error = (
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r} "
f"on the region {region}. "
f"Writing this array in parallel with dask could lead to corrupted data."
f"Consider either rechunking using `chunk()`, deleting "
f"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
)

for zchunk, dchunks, interval, size in zip(
enc_chunks_tuple, var_chunks, region, shape, strict=True
):
if not safe_chunks:
continue

for dchunk in dchunks[1:-1]:
for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks, strict=True):
for dchunk in dchunks[:-1]:
if dchunk % zchunk:
raise ValueError(base_error)

region_start = interval.start if interval.start else 0

if len(dchunks) > 1:
# The first border size is the amount of data that needs to be updated on the
# first chunk taking into account the region slice.
first_border_size = zchunk
if allow_partial_chunks:
first_border_size = zchunk - region_start % zchunk

if (dchunks[0] - first_border_size) % zchunk:
raise ValueError(base_error)

if not allow_partial_chunks:
chunk_start = sum(dchunks[:-1]) + region_start
if chunk_start % zchunk:
# The last chunk which can also be the only one is a partial chunk
# if it is not aligned at the beginning
raise ValueError(base_error)

region_stop = interval.stop if interval.stop else size

if size - region_stop + 1 < zchunk:
# If the region is covering the last chunk then check
# if the reminder with the default chunk size
# is equal to the size of the last chunk
if dchunks[-1] % zchunk != size % zchunk:
raise ValueError(base_error)
elif dchunks[-1] % zchunk:
raise ValueError(base_error)

base_error = (
f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. "
f"Writing this array in parallel with dask could lead to corrupted data."
)
if safe_chunks:
raise ValueError(
base_error
+ " Consider either rechunking using `chunk()`, deleting "
"or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
)
return enc_chunks_tuple

raise AssertionError("We should never get here. Function logic must be wrong.")
Expand Down Expand Up @@ -286,14 +243,7 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr):


def extract_zarr_variable_encoding(
variable,
raise_on_invalid=False,
name=None,
*,
safe_chunks=True,
region=None,
mode=None,
shape=None,
variable, raise_on_invalid=False, name=None, safe_chunks=True
):
"""
Extract zarr encoding dictionary from xarray Variable
Expand All @@ -302,18 +252,12 @@ def extract_zarr_variable_encoding(
----------
variable : Variable
raise_on_invalid : bool, optional
name: str | Hashable, optional
safe_chunks: bool, optional
region: tuple[slice, ...], optional
mode: str, optional
shape: tuple[int, ...], optional
Returns
-------
encoding : dict
Zarr encoding for `variable`
"""

shape = shape if shape else variable.shape
encoding = variable.encoding.copy()

safe_to_drop = {"source", "original_shape"}
Expand Down Expand Up @@ -341,14 +285,7 @@ def extract_zarr_variable_encoding(
del encoding[k]

chunks = _determine_zarr_chunks(
enc_chunks=encoding.get("chunks"),
var_chunks=variable.chunks,
ndim=variable.ndim,
name=name,
safe_chunks=safe_chunks,
region=region,
mode=mode,
shape=shape,
encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks
)
encoding["chunks"] = chunks
return encoding
Expand Down Expand Up @@ -825,10 +762,16 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
if v.encoding == {"_FillValue": None} and fill_value is None:
v.encoding = {}

zarr_array = None
zarr_shape = None
write_region = self._write_region if self._write_region is not None else {}
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}
# We need to do this for both new and existing variables to ensure we're not
# writing to a partial chunk, even though we don't use the `encoding` value
# when writing to an existing variable. See
# https://github.com/pydata/xarray/issues/8371 for details.
encoding = extract_zarr_variable_encoding(
v,
raise_on_invalid=vn in check_encoding_set,
name=vn,
safe_chunks=self._safe_chunks,
)

if name in existing_keys:
# existing variable
Expand Down Expand Up @@ -858,40 +801,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
)
else:
zarr_array = self.zarr_group[name]

if self._append_dim is not None and self._append_dim in dims:
# resize existing variable
append_axis = dims.index(self._append_dim)
assert write_region[self._append_dim] == slice(None)
write_region[self._append_dim] = slice(
zarr_array.shape[append_axis], None
)

new_shape = list(zarr_array.shape)
new_shape[append_axis] += v.shape[append_axis]
zarr_array.resize(new_shape)

zarr_shape = zarr_array.shape

region = tuple(write_region[dim] for dim in dims)

# We need to do this for both new and existing variables to ensure we're not
# writing to a partial chunk, even though we don't use the `encoding` value
# when writing to an existing variable. See
# https://github.com/pydata/xarray/issues/8371 for details.
# Note: Ideally there should be two functions, one for validating the chunks and
# another one for extracting the encoding.
encoding = extract_zarr_variable_encoding(
v,
raise_on_invalid=vn in check_encoding_set,
name=vn,
safe_chunks=self._safe_chunks,
region=region,
mode=self._mode,
shape=zarr_shape,
)

if name not in existing_keys:
else:
# new variable
encoded_attrs = {}
# the magic for storing the hidden dimension data
Expand Down Expand Up @@ -923,6 +833,22 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
)
zarr_array = _put_attrs(zarr_array, encoded_attrs)

write_region = self._write_region if self._write_region is not None else {}
write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}

if self._append_dim is not None and self._append_dim in dims:
# resize existing variable
append_axis = dims.index(self._append_dim)
assert write_region[self._append_dim] == slice(None)
write_region[self._append_dim] = slice(
zarr_array.shape[append_axis], None
)

new_shape = list(zarr_array.shape)
new_shape[append_axis] += v.shape[append_axis]
zarr_array.resize(new_shape)

region = tuple(write_region[dim] for dim in dims)
writer.add(v.data, zarr_array, region)

def close(self) -> None:
Expand Down Expand Up @@ -971,9 +897,9 @@ def _validate_and_autodetect_region(self, ds) -> None:
if not isinstance(region, dict):
raise TypeError(f"``region`` must be a dict, got {type(region)}")
if any(v == "auto" for v in region.values()):
if self._mode not in ["r+", "a"]:
if self._mode != "r+":
raise ValueError(
f"``mode`` must be 'r+' or 'a' when using ``region='auto'``, got {self._mode!r}"
f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}"
)
region = self._auto_detect_regions(ds, region)

Expand Down
8 changes: 0 additions & 8 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4316,14 +4316,6 @@ def to_zarr(
if Zarr arrays are written in parallel. This option may be useful in combination
with ``compute=False`` to initialize a Zarr store from an existing
DataArray with arbitrary chunk structure.
In addition to the many-to-one relationship validation, it also detects partial
chunks writes when using the region parameter,
these partial chunks are considered unsafe in the mode "r+" but safe in
the mode "a".
Note: Even with these validations it can still be unsafe to write
two or more chunked arrays in the same location in parallel if they are
not writing in independent regions, for those cases it is better to use
a synchronizer.
storage_options : dict, optional
Any additional parameters for the storage backend (ignored for local
paths).
Expand Down
8 changes: 0 additions & 8 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2509,14 +2509,6 @@ def to_zarr(
if Zarr arrays are written in parallel. This option may be useful in combination
with ``compute=False`` to initialize a Zarr from an existing
Dataset with arbitrary chunk structure.
In addition to the many-to-one relationship validation, it also detects partial
chunks writes when using the region parameter,
these partial chunks are considered unsafe in the mode "r+" but safe in
the mode "a".
Note: Even with these validations it can still be unsafe to write
two or more chunked arrays in the same location in parallel if they are
not writing in independent regions, for those cases it is better to use
a synchronizer.
storage_options : dict, optional
Any additional parameters for the storage backend (ignored for local
paths).
Expand Down
Loading

0 comments on commit 7bdc6d4

Please sign in to comment.