diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 680433565..6e4fcb156 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -332,6 +332,7 @@ async def save( *args: NDArrayLike, zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, + mode: AccessModeLiteral | None = None, path: str | None = None, **kwargs: Any, # TODO: type kwargs as valid args to save ) -> None: @@ -345,6 +346,11 @@ async def save( NumPy arrays with data to save. zarr_format : {2, 3, None}, optional The zarr format to use when saving. + mode : {'r', 'r+', 'a', 'w', 'w-'}, optional + Persistence mode: 'r' means read only (must exist); 'r+' means + read/write (must exist); 'a' means read/write (create if doesn't + exist); 'w' means create (overwrite if exists); 'w-' means create + (fail if exists). path : str or None, optional The path within the group where the arrays will be saved. **kwargs @@ -352,12 +358,19 @@ async def save( """ zarr_format = _handle_zarr_version_or_format(zarr_version=zarr_version, zarr_format=zarr_format) + for arg in args: + if not isinstance(arg, np.ndarray): + raise TypeError("All arguments must be numpy arrays") + for k, v in kwargs.items(): + if not isinstance(v, np.ndarray): + raise TypeError(f"Keyword argument '{k}' must be a numpy array") + if len(args) == 0 and len(kwargs) == 0: raise ValueError("at least one array must be provided") if len(args) == 1 and len(kwargs) == 0: - await save_array(store, args[0], zarr_format=zarr_format, path=path) + await save_array(store, args[0], zarr_format=zarr_format, mode=mode, path=path) else: - await save_group(store, *args, zarr_format=zarr_format, path=path, **kwargs) + await save_group(store, *args, zarr_format=zarr_format, mode=mode, path=path, **kwargs) async def save_array( @@ -366,6 +379,7 @@ async def save_array( *, zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, + mode: AccessModeLiteral | None = None, path: str | None = None, storage_options: dict[str, Any] | None = None, **kwargs: Any, # TODO: type kwargs as valid args to create @@ -381,6 +395,11 @@ async def save_array( NumPy array with data to save. zarr_format : {2, 3, None}, optional The zarr format to use when saving. + mode : {'r', 'r+', 'a', 'w', 'w-'}, optional + Persistence mode: 'r' means read only (must exist); 'r+' means + read/write (must exist); 'a' means read/write (create if doesn't + exist); 'w' means create (overwrite if exists); 'w-' means create + (fail if exists). path : str or None, optional The path within the store where the array will be saved. storage_options : dict @@ -394,7 +413,6 @@ async def save_array( or _default_zarr_version() ) - mode = kwargs.pop("mode", None) store_path = await make_store_path(store, path=path, mode=mode, storage_options=storage_options) new = await AsyncArray.create( store_path, @@ -412,6 +430,7 @@ async def save_group( *args: NDArrayLike, zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, + mode: AccessModeLiteral | None = None, path: str | None = None, storage_options: dict[str, Any] | None = None, **kwargs: NDArrayLike, @@ -427,6 +446,11 @@ async def save_group( NumPy arrays with data to save. zarr_format : {2, 3, None}, optional The zarr format to use when saving. + mode : {'r', 'r+', 'a', 'w', 'w-'}, optional + Persistence mode: 'r' means read only (must exist); 'r+' means + read/write (must exist); 'a' means read/write (create if doesn't + exist); 'w' means create (overwrite if exists); 'w-' means create + (fail if exists). path : str or None, optional Path within the store where the group will be saved. storage_options : dict @@ -452,6 +476,7 @@ async def save_group( store, arr, zarr_format=zarr_format, + mode=mode, path=f"{path}/arr_{i}", storage_options=storage_options, ) @@ -460,7 +485,12 @@ async def save_group( _path = f"{path}/{k}" if path is not None else k aws.append( save_array( - store, arr, zarr_format=zarr_format, path=_path, storage_options=storage_options + store, + arr, + zarr_format=zarr_format, + mode=mode, + path=_path, + storage_options=storage_options, ) ) await asyncio.gather(*aws) diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index 9dcd6fe2d..d9d35ca8b 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -101,12 +101,19 @@ def save( *args: NDArrayLike, zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, + mode: AccessModeLiteral | None = None, path: str | None = None, **kwargs: Any, # TODO: type kwargs as valid args to async_api.save ) -> None: return sync( async_api.save( - store, *args, zarr_version=zarr_version, zarr_format=zarr_format, path=path, **kwargs + store, + *args, + zarr_version=zarr_version, + zarr_format=zarr_format, + mode=mode, + path=path, + **kwargs, ) ) @@ -118,6 +125,7 @@ def save_array( *, zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, + mode: AccessModeLiteral | None = None, path: str | None = None, **kwargs: Any, # TODO: type kwargs as valid args to async_api.save_array ) -> None: @@ -127,6 +135,7 @@ def save_array( arr=arr, zarr_version=zarr_version, zarr_format=zarr_format, + mode=mode, path=path, **kwargs, ) @@ -138,6 +147,7 @@ def save_group( *args: NDArrayLike, zarr_version: ZarrFormat | None = None, # deprecated zarr_format: ZarrFormat | None = None, + mode: AccessModeLiteral | None = None, path: str | None = None, storage_options: dict[str, Any] | None = None, **kwargs: NDArrayLike, @@ -148,6 +158,7 @@ def save_group( *args, zarr_version=zarr_version, zarr_format=zarr_format, + mode=mode, path=path, storage_options=storage_options, **kwargs, diff --git a/tests/test_api.py b/tests/test_api.py index 5b62e3a2f..6b5249fe7 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -23,6 +23,7 @@ ) from zarr.core.common import MemoryOrder, ZarrFormat from zarr.errors import MetadataValidationError +from zarr.storage import StorePath from zarr.storage._utils import normalize_path from zarr.storage.memory import MemoryStore @@ -999,3 +1000,10 @@ async def test_metadata_validation_error() -> None: match="Invalid value for 'zarr_format'. Expected '2, 3, or None'. Got '3.0'.", ): await zarr.api.asynchronous.open_array(shape=(1,), zarr_format="3.0") # type: ignore[arg-type] + + +@pytest.mark.parametrize("store", ["local"], indirect=["store"]) +def test_zarr_save(store: Store) -> None: + a = np.arange(1000).reshape(10, 10, 10) + zarr.save(StorePath(store), a, mode="w") + assert_array_equal(zarr.load(store)[...], a) # type: ignore[index]