From 11c7229a10e43b5de8f44a029af9546de42dc485 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Mon, 18 Apr 2022 02:43:28 -0700 Subject: [PATCH] Start to multiband on the dask side - Think it's piped through in `to_dask` and readers - Needs tests - Should `bands` list even go in the asset table? It's kinda redundant. But maybe no worse than inlining it in every task. - Custom chunks still feels awkward. `prepare` obvs doesn't actually figure out band counts from STAC yet. `to_coords` will also need to be updated to handle this. --- .flake8 | 5 +- stackstac/nodata.py | 34 ++++++++ stackstac/nodata_reader.py | 65 --------------- stackstac/prepare.py | 45 +++++++++- stackstac/reader_protocol.py | 27 ++++-- stackstac/rio_reader.py | 97 ++++++++++++++++++---- stackstac/stack.py | 5 +- stackstac/tests/test_to_dask.py | 56 ++++++++++++- stackstac/to_dask.py | 141 +++++++++++++++++++++++++++++--- 9 files changed, 368 insertions(+), 107 deletions(-) create mode 100644 stackstac/nodata.py delete mode 100644 stackstac/nodata_reader.py diff --git a/.flake8 b/.flake8 index eabd35d..b808521 100644 --- a/.flake8 +++ b/.flake8 @@ -2,4 +2,7 @@ max-line-length = 120 exclude = .pyi - typings \ No newline at end of file + typings +ignore = + E203 # whitespace before ':' + W503 # line break before binary operator diff --git a/stackstac/nodata.py b/stackstac/nodata.py new file mode 100644 index 0000000..4a0a711 --- /dev/null +++ b/stackstac/nodata.py @@ -0,0 +1,34 @@ +from typing import Tuple, Union +import re + +import numpy as np +from rasterio.windows import Window + +State = Tuple[np.dtype, Union[int, float]] + + +def nodata_for_window( + ndim: int, window: Window, fill_value: Union[int, float], dtype: np.dtype +): + return np.full((ndim, window.height, window.width), fill_value, dtype) + + +def exception_matches(e: Exception, patterns: Tuple[Exception, ...]) -> bool: + """ + Whether an exception matches one of the pattern exceptions + + Parameters + ---------- + e: + The exception to check + patterns: + Instances of an Exception type to catch, where ``str(exception_pattern)`` + is a regex pattern to match against ``str(e)``. + """ + e_type = type(e) + e_msg = str(e) + for pattern in patterns: + if issubclass(e_type, type(pattern)): + if re.match(str(pattern), e_msg): + return True + return False diff --git a/stackstac/nodata_reader.py b/stackstac/nodata_reader.py deleted file mode 100644 index 8aab7f1..0000000 --- a/stackstac/nodata_reader.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Tuple, Type, Union -import re - -import numpy as np -from rasterio.windows import Window - -from .reader_protocol import Reader - -State = Tuple[np.dtype, Union[int, float]] - - -class NodataReader: - "Reader that returns a constant (nodata) value for all reads" - scale_offset = (1.0, 0.0) - - def __init__( - self, - *, - dtype: np.dtype, - fill_value: Union[int, float], - **kwargs, - ) -> None: - self.dtype = dtype - self.fill_value = fill_value - - def read(self, window: Window, **kwargs) -> np.ndarray: - return nodata_for_window(window, self.fill_value, self.dtype) - - def close(self) -> None: - pass - - def __getstate__(self) -> State: - return (self.dtype, self.fill_value) - - def __setstate__(self, state: State) -> None: - self.dtype, self.fill_value = state - - -def nodata_for_window(window: Window, fill_value: Union[int, float], dtype: np.dtype): - return np.full((window.height, window.width), fill_value, dtype) - - -def exception_matches(e: Exception, patterns: Tuple[Exception, ...]) -> bool: - """ - Whether an exception matches one of the pattern exceptions - - Parameters - ---------- - e: - The exception to check - patterns: - Instances of an Exception type to catch, where ``str(exception_pattern)`` - is a regex pattern to match against ``str(e)``. - """ - e_type = type(e) - e_msg = str(e) - for pattern in patterns: - if issubclass(e_type, type(pattern)): - if re.match(str(pattern), e_msg): - return True - return False - - -# Type assertion -_: Type[Reader] = NodataReader diff --git a/stackstac/prepare.py b/stackstac/prepare.py index faa00ef..55fa913 100644 --- a/stackstac/prepare.py +++ b/stackstac/prepare.py @@ -27,7 +27,16 @@ from .stac_types import ItemSequence from . import accumulate_metadata, geom_utils -ASSET_TABLE_DT = np.dtype([("url", object), ("bounds", "float64", 4)]) +ASSET_TABLE_DT = np.dtype( + [("url", object), ("bounds", "float64", 4), ("bands", object)] +) +# ^ NOTE: `bands` should be a `Sequence[int]` of _1-indexed_ bands to fetch from the asset. +# We support specifying a sequence of band indices (rather than just the number of bands, +# and always doing a `read()` of all bands) for future optimizations to support fetching +# (and possibly reordering?) a subset of bands per asset. This could be done either via +# another argument to `stack` (please no!) or a custom Dask optimization, akin to column +# projection for DataFrames. +# But at the moment, `bands == list(range(1, ds.count + 1))`. class Mimetype(NamedTuple): @@ -64,7 +73,7 @@ def prepare_items( bounds: Optional[Bbox] = None, bounds_latlon: Optional[Bbox] = None, snap_bounds: bool = True, -) -> Tuple[np.ndarray, RasterSpec, List[str], ItemSequence]: +) -> Tuple[np.ndarray, RasterSpec, List[str], ItemSequence, tuple[int, ...]]: if bounds is not None and bounds_latlon is not None: raise ValueError( @@ -119,6 +128,7 @@ def prepare_items( asset_ids = assets asset_table = np.full((len(items), len(asset_ids)), None, dtype=ASSET_TABLE_DT) + nbands_per_asset: list[int | None] = [None] * len(asset_ids) # TODO support item-assets https://github.com/radiantearth/stac-spec/tree/master/extensions/item-assets @@ -321,7 +331,25 @@ def prepare_items( ) # Phew, we figured out all the spatial stuff! Now actually store the information we care about. - asset_table[item_i, asset_i] = (asset["href"], asset_bbox_proj) + + bands: Optional[Sequence[int]] = None + # ^ TODO actually determine this from `eo:bands` or `raster:bands` + # https://github.com/gjoseph92/stackstac/issues/62 + + nbands = 1 if bands is None else len(bands) + prev_nbands = nbands_per_asset[asset_i] + if prev_nbands is None: + nbands_per_asset[asset_i] = nbands + else: + if prev_nbands != nbands: + raise ValueError( + f"The asset {id!r} has {nbands} band(s) on item {item_i} {item['id']!r}, " + f"but on all previous items, it had {prev_nbands}." + # TODO improve this error message with something actionable + # (it's probably a data provider issue), once multi-band is actually supported. + ) + + asset_table[item_i, asset_i] = (asset["href"], asset_bbox_proj, bands) # ^ NOTE: If `asset_bbox_proj` is None, NumPy automatically converts it to NaNs # At this point, everything has been set (or there was as error) @@ -346,9 +374,18 @@ def prepare_items( if item_isnan.any() or asset_id_isnan.any(): asset_table = asset_table[np.ix_(~item_isnan, ~asset_id_isnan)] asset_ids = [id for id, isnan in zip(asset_ids, asset_id_isnan) if not isnan] + nbands_per_asset = [ + id for id, isnan in zip(nbands_per_asset, asset_id_isnan) if not isnan + ] items = [item for item, isnan in zip(items, item_isnan) if not isnan] - return asset_table, spec, asset_ids, items + # Being for the benefit of mr. typechecker + nbpa = tuple(x for x in nbands_per_asset if x is not None) + assert len(nbpa) == len( + nbands_per_asset + ), f"Some `nbands_per_asset` are None: {nbands_per_asset}" + + return asset_table, spec, asset_ids, items, nbpa def to_coords( diff --git a/stackstac/reader_protocol.py b/stackstac/reader_protocol.py index 81f7da4..c621b1f 100644 --- a/stackstac/reader_protocol.py +++ b/stackstac/reader_protocol.py @@ -1,5 +1,14 @@ from __future__ import annotations -from typing import Optional, Protocol, Tuple, Type, TYPE_CHECKING, TypeVar, Union +from typing import ( + Optional, + Protocol, + Sequence, + Tuple, + Type, + TYPE_CHECKING, + TypeVar, + Union, +) import numpy as np @@ -30,6 +39,7 @@ def __init__( self, *, url: str, + bands: Optional[Sequence[int]], spec: RasterSpec, resampling: Resampling, dtype: np.dtype, @@ -45,6 +55,9 @@ def __init__( ---------- url: Fetch data from the asset at this URL. + bands: + List of (one-indexed!) band indices to read, or None for all bands. + If None, the asset must have exactly one band. spec: Reproject data to match this georeferencing information. resampling: @@ -69,7 +82,6 @@ def __init__( where ``str(exception_pattern)`` is a regex pattern to match against ``str(raised_exception)``. """ - # TODO colormaps? def read(self, window: Window) -> np.ndarray: """ @@ -87,7 +99,7 @@ def read(self, window: Window) -> np.ndarray: Returns ------- - array: The window of data read + array: The window of data read from all bands, as a 3D array """ ... @@ -113,11 +125,16 @@ class FakeReader: or inherent to the dask graph. """ - def __init__(self, *, dtype: np.dtype, **kwargs) -> None: + def __init__( + self, *, bands: Optional[Sequence[int]], dtype: np.dtype, **kwargs + ) -> None: self.dtype = dtype + self.ndim = len(bands) if bands is not None else 1 def read(self, window: Window, **kwargs) -> np.ndarray: - return np.random.random((window.height, window.width)).astype(self.dtype) + return np.random.random((self.ndim, window.height, window.width)).astype( + self.dtype + ) def close(self) -> None: pass diff --git a/stackstac/rio_reader.py b/stackstac/rio_reader.py index 73e6bf2..edaca63 100644 --- a/stackstac/rio_reader.py +++ b/stackstac/rio_reader.py @@ -3,7 +3,16 @@ import logging import threading import warnings -from typing import TYPE_CHECKING, Optional, Protocol, Tuple, Type, TypedDict, Union +from typing import ( + TYPE_CHECKING, + Optional, + Protocol, + Sequence, + Tuple, + Type, + TypedDict, + Union, +) import numpy as np import rasterio as rio @@ -13,7 +22,7 @@ from .timer import time from .reader_protocol import Reader from .raster_spec import RasterSpec -from .nodata_reader import NodataReader, exception_matches, nodata_for_window +from .nodata import exception_matches, nodata_for_window if TYPE_CHECKING: from rasterio.enums import Resampling @@ -70,7 +79,7 @@ def _curthread(): class ThreadsafeRioDataset(Protocol): scale_offset: Tuple[float, float] - def read(self, window: Window, **kwargs) -> np.ndarray: + def read(self, bands: Sequence[int], window: Window, **kwargs) -> np.ndarray: ... def close(self) -> None: @@ -99,11 +108,11 @@ def __init__( self._lock = threading.Lock() - def read(self, window: Window, **kwargs) -> np.ndarray: + def read(self, bands: Sequence[int], window: Window, **kwargs) -> np.ndarray: "Acquire the lock, then read from the dataset" reader = self.vrt or self.ds with self._lock, self.env.read: - return reader.read(1, window=window, **kwargs) + return reader.read(bands, window=window, **kwargs) def close(self) -> None: "Acquire the lock, then close the dataset" @@ -220,11 +229,11 @@ def dataset(self) -> Union[SelfCleaningDatasetReader, WarpedVRT]: except AttributeError: return self._open() - def read(self, window: Window, **kwargs) -> np.ndarray: + def read(self, bands: Sequence[int], window: Window, **kwargs) -> np.ndarray: "Read from the current thread's dataset, opening a new copy of the dataset on first access from each thread." with time(f"Read {self._url!r} in {_curthread()}: {{t}}"): with self._env.read: - return self.dataset.read(1, window=window, **kwargs) + return self.dataset.read(bands, window=window, **kwargs) def close(self) -> None: """ @@ -274,8 +283,29 @@ def __del__(self): self.close() +class Nodataset: + "`ThreadsafeRioDataset` that returns a constant (nodata) value for all reads" + scale_offset = (1.0, 0.0) + + def __init__( + self, + *, + dtype: np.dtype, + fill_value: Union[int, float], + ) -> None: + self.dtype = dtype + self.fill_value = fill_value + + def read(self, bands: Sequence[int], window: Window, **kwargs) -> np.ndarray: + return nodata_for_window(len(bands), window, self.fill_value, self.dtype) + + def close(self) -> None: + pass + + class PickleState(TypedDict): url: str + bands: Optional[Sequence[int]] spec: RasterSpec resampling: Resampling dtype: np.dtype @@ -295,10 +325,22 @@ class AutoParallelRioReader: for non-thread-safe drivers. """ + url: str + bands: Sequence[int] + exactly_one_band: bool + spec: RasterSpec + resampling: Resampling + dtype: np.dtype + fill_value: Union[int, float] + rescale: bool + gdal_env: LayeredEnv + errors_as_nodata: Tuple[Exception, ...] + def __init__( self, *, url: str, + bands: Optional[Sequence[int]], spec: RasterSpec, resampling: Resampling, dtype: np.dtype, @@ -308,6 +350,8 @@ def __init__( errors_as_nodata: Tuple[Exception, ...] = (), ) -> None: self.url = url + self.bands = bands if bands is not None else (1,) + self.exactly_one_band = bands is None self.spec = spec self.resampling = resampling self.dtype = dtype @@ -330,17 +374,34 @@ def _open(self) -> ThreadsafeRioDataset: msg = f"Error opening {self.url!r}: {e!r}" if exception_matches(e, self.errors_as_nodata): warnings.warn(msg) - return NodataReader( - dtype=self.dtype, fill_value=self.fill_value + return Nodataset( + dtype=self.dtype, + fill_value=self.fill_value, ) raise RuntimeError(msg) from e - if ds.count != 1: + + if self.exactly_one_band: + # Unknown band count. If the asset actually has 3 bands, we don't want to + # silently read just the first one. + if ds.count != 1: + ds.close() + raise RuntimeError( + f"Assets must have exactly 1 band, but file {self.url!r} has {ds.count}. " + "We can't currently handle multi-band rasters (each band has to be " + "a separate STAC asset), so you'll need to exclude this asset from your analysis." + # TODO change this error message once we actually determine band counts from STAC metadata. + # Then, this should mention that the asset was missing `eo:bands` and `raster:bands` metadata, + # so the expected band count was unknown and defaults to 1. + # Alternatively, we could get rid of this bands==None codepath entirely, and always require + # STAC metadata to specify `eo:bands` or `raster:bands` (allowing you to explicitly provide + # values for them if they're missing?). + ) + elif ds.count < len(self.bands): ds.close() raise RuntimeError( - f"Assets must have exactly 1 band, but file {self.url!r} has {ds.count}. " - "We can't currently handle multi-band rasters (each band has to be " - "a separate STAC asset), so you'll need to exclude this asset from your analysis." + f"Expected to read {len(self.bands)} {tuple(self.bands)}, but there are only " + f"{ds.count} band(s) in the asset at {self.url!r}." ) # Only make a VRT if the dataset doesn't match the spatial spec we want @@ -375,7 +436,7 @@ def _open(self) -> ThreadsafeRioDataset: return SingleThreadedRioDataset(self.gdal_env, ds, vrt=vrt) @property - def dataset(self): + def dataset(self) -> ThreadsafeRioDataset: with self._dataset_lock: if self._dataset is None: self._dataset = self._open() @@ -385,6 +446,7 @@ def read(self, window: Window, **kwargs) -> np.ndarray: reader = self.dataset try: result = reader.read( + self.bands, window=window, masked=True, # ^ NOTE: we always do a masked array, so we can safely apply scales and offsets @@ -395,10 +457,14 @@ def read(self, window: Window, **kwargs) -> np.ndarray: msg = f"Error reading {window} from {self.url!r}: {e!r}" if exception_matches(e, self.errors_as_nodata): warnings.warn(msg) - return nodata_for_window(window, self.fill_value, self.dtype) + return nodata_for_window( + len(self.bands), window, self.fill_value, self.dtype + ) raise RuntimeError(msg) from e + # TODO scale and offset might not apply to all bands. + # Should probably just remove this. if self.rescale: scale, offset = reader.scale_offset if scale != 1 and offset != 0: @@ -430,6 +496,7 @@ def __getstate__( ) -> PickleState: return { "url": self.url, + "bands": None if self.exactly_one_band else self.bands, "spec": self.spec, "resampling": self.resampling, "dtype": self.dtype, diff --git a/stackstac/stack.py b/stackstac/stack.py index 6ab51e0..85815ef 100644 --- a/stackstac/stack.py +++ b/stackstac/stack.py @@ -276,7 +276,7 @@ def stack( reverse=sortby_date == "desc", ) - asset_table, spec, asset_ids, plain_items = prepare_items( + asset_table, spec, asset_ids, plain_items, nbands_per_asset = prepare_items( plain_items, assets=assets, epsg=epsg, @@ -289,6 +289,7 @@ def stack( asset_table, spec, chunksize=chunksize, + nbands_per_asset=nbands_per_asset, dtype=dtype, resampling=resampling, fill_value=fill_value, @@ -309,5 +310,5 @@ def stack( band_coords=band_coords, ), attrs=to_attrs(spec), - name="stackstac-" + dask.base.tokenize(arr) + name="stackstac-" + dask.base.tokenize(arr), ) diff --git a/stackstac/tests/test_to_dask.py b/stackstac/tests/test_to_dask.py index 1a85bfb..37f6089 100644 --- a/stackstac/tests/test_to_dask.py +++ b/stackstac/tests/test_to_dask.py @@ -1,10 +1,12 @@ from __future__ import annotations +import itertools from threading import Lock from typing import ClassVar -from hypothesis import given, settings, strategies as st +from hypothesis import given, note, settings, strategies as st import hypothesis.extra.numpy as st_np import numpy as np +import pytest from rasterio import windows import dask.core import dask.threaded @@ -16,6 +18,7 @@ ChunksParam, items_to_dask, normalize_chunks, + process_multiband_chunks, window_from_bounds, ) from stackstac.testing import strategies as st_stc @@ -194,9 +197,58 @@ def __setstate__(self, state): def test_normalize_chunks( chunksize: ChunksParam, shape: tuple[int, int, int, int], dtype: np.dtype ): - chunks = normalize_chunks(chunksize, shape, dtype) + nbands_per_asset = (1,) * shape[1] # not testing this here, keep it simple + chunks, asset_table_band_chunks = normalize_chunks( + chunksize, shape, nbands_per_asset, dtype + ) numblocks = tuple(map(len, chunks)) assert len(numblocks) == 4 assert all(x >= 1 for t in chunks for x in t) if isinstance(chunksize, int) or isinstance(chunks, tuple) and len(chunks) == 2: assert numblocks[:2] == shape[:2] + + +@given(st.data(), st.lists(st.integers(1, 5), max_size=5).map(tuple)) +def test_process_multiband_chunks( + data: st.DataObject, nbands_per_asset: tuple[int, ...] +): + total_bands = sum(nbands_per_asset) + chunks: list[int] = [] + remaining = total_bands + while remaining: + c = data.draw(st.integers(1, remaining)) + remaining -= c + assert remaining >= 0 + chunks.append(c) + + note(f"{nbands_per_asset=}") + note(f" {chunks=}") + + # Expand chunks form into 1-elem-per-band form. This is a simpler but less efficient way to validate. + # Ex: [2, 4, 1, 1] -> [0, 0, 1, 1, 1, 1, 2, 3] + physical_layout = [ + x for i, n in enumerate(nbands_per_asset) for x in itertools.repeat(i, n) + ] + requested_layout = [x for i, n in enumerate(chunks) for x in itertools.repeat(i, n)] + assert len(physical_layout) == len(requested_layout) + + invalid = False + for i in range(1, len(requested_layout)): + if requested_layout[i - 1] != requested_layout[i]: + # Wherever the asset we're pulling from changes in the requested layout, + # it must also change in the physical layout. + if physical_layout[i - 1] == physical_layout[i]: + invalid = True + break + + note(f" {physical_layout=}") + note(f"{requested_layout=}") + + if invalid: + with pytest.raises(NotImplementedError): + process_multiband_chunks(tuple(chunks), nbands_per_asset) + else: + asset_table_band_chunks = process_multiband_chunks( + tuple(chunks), nbands_per_asset + ) + assert len(asset_table_band_chunks) == len(chunks) diff --git a/stackstac/to_dask.py b/stackstac/to_dask.py index 513a1c0..97b83f6 100644 --- a/stackstac/to_dask.py +++ b/stackstac/to_dask.py @@ -18,13 +18,17 @@ from .reader_protocol import Reader ChunkVal = Union[int, Literal["auto"], str, None] -ChunksParam = Union[ChunkVal, Tuple[ChunkVal, ...], Dict[int, ChunkVal]] +ChunksParam = Union[ + ChunkVal, Tuple[Union[ChunkVal, Tuple[ChunkVal, ...]], ...], Dict[int, ChunkVal] +] +TBYXChunks = Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]] def items_to_dask( asset_table: np.ndarray, spec: RasterSpec, chunksize: ChunksParam, + nbands_per_asset: tuple[int, ...], resampling: Resampling = Resampling.nearest, dtype: np.dtype = np.dtype("float64"), fill_value: Union[int, float] = np.nan, @@ -42,8 +46,11 @@ def items_to_dask( f"Either use `dtype={np.array(fill_value).dtype.name!r}`, or pick a different `fill_value`." ) - chunks = normalize_chunks(chunksize, asset_table.shape + spec.shape, dtype) - chunks_tb, chunks_yx = chunks[:2], chunks[2:] + chunks, asset_table_band_chunks = normalize_chunks( + chunksize, asset_table.shape + spec.shape, nbands_per_asset, dtype + ) + chunks_tb = chunks[:1] + asset_table_band_chunks + chunks_yx = chunks[2:] # The overall strategy in this function is to materialize the outer two dimensions (items, assets) # as one dask array (the "asset table"), then map a function over it which opens each URL as a `Reader` @@ -56,7 +63,7 @@ def items_to_dask( # make URLs into dask array, chunked as requested for the time,band dimensions asset_table_dask = da.from_array( asset_table, - chunks=chunks_tb, + chunks=chunks_tb, # type: ignore inline_array=True, name="asset-table-" + dask.base.tokenize(asset_table), ) @@ -97,6 +104,8 @@ def items_to_dask( None, fill_value, None, + nbands_per_asset, + None, numblocks={reader_table.name: reader_table.numblocks}, # ugh ) dsk = HighLevelGraph.from_collections(name, lyr, [reader_table]) @@ -136,6 +145,7 @@ def asset_table_to_reader_and_window( entry: ReaderTableEntry = ( reader( url=url, + bands=asset_entry["bands"], spec=spec, resampling=resampling, dtype=dtype, @@ -155,24 +165,34 @@ def fetch_raster_window( slices: Tuple[slice, slice], dtype: np.dtype, fill_value: Union[int, float], + nbands_per_asset: tuple[int, ...], ) -> np.ndarray: "Do a spatially-windowed read of raster data from all the Readers in the table." assert len(slices) == 2, slices current_window = windows.Window.from_slices(*slices) assert reader_table.size, f"Empty reader_table: {reader_table.shape=}" + assert ( + len(nbands_per_asset) == reader_table.shape[1] + ), f"{nbands_per_asset=}, {reader_table.shape[1]=}" # Start with an empty output array, using the broadcast trick for even fewer memz. # If none of the assets end up actually existing, or overlapping the current window, # or containing data, we'll just return this 1-element array that's been broadcast # to look like a full-size array. output = np.broadcast_to( np.array(fill_value, dtype), - reader_table.shape + (current_window.height, current_window.width), + ( + reader_table.shape[0], + sum(nbands_per_asset), + current_window.height, + current_window.width, + ), ) + asset_i_to_band = np.cumsum(nbands_per_asset) all_empty: bool = True entry: ReaderTableEntry - for index, entry in np.ndenumerate(reader_table): + for (time_i, asset_i), entry in np.ndenumerate(reader_table): if entry: reader, asset_window = entry # Only read if the window we're fetching actually overlaps with the asset @@ -183,6 +203,9 @@ def fetch_raster_window( # TODO when the Reader won't be rescaling, support passing `output` to avoid the copy? data = reader.read(current_window) + assert ( + data.shape[0] == nbands_per_asset[asset_i] + ), f"Band count mismatch: {nbands_per_asset[asset_i]=}, {data.shape[0]=}" if all_empty: # Turn `output` from a broadcast-trick array to a real array, so it's writeable @@ -196,36 +219,128 @@ def fetch_raster_window( output = np.array(output) all_empty = False - output[index] = data + band_i = asset_i_to_band[asset_i] + output[time_i, band_i : band_i + data.shape[0]] = data return output def normalize_chunks( - chunks: ChunksParam, shape: Tuple[int, int, int, int], dtype: np.dtype -) -> Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]: + chunks: ChunksParam, + shape: Tuple[int, int, int, int], + nbands_per_asset: tuple[int, ...], + dtype: np.dtype, +) -> tuple[TBYXChunks, tuple[int, ...]]: """ Normalize chunks to tuple of tuples, assuming 1D and 2D chunks only apply to spatial coordinates If only 1 or 2 chunks are given, assume they're for the ``y, x`` coordinates, and that the ``time, band`` coordinates should be chunksize 1. + + If "auto" is given for bands, uses ``nbands_per_asset``. + + Returns + ------- + chunks: + Normalized chunks + asset_table_band_chunks: + Band chunks to apply to the asset table (see `process_multiband_chunks`) """ # TODO implement our own auto-chunking that makes the time,band coordinates # >1 if the spatial chunking would create too many tasks? if isinstance(chunks, int): - chunks = (1, 1, chunks, chunks) + chunks = (1, nbands_per_asset, chunks, chunks) elif isinstance(chunks, tuple) and len(chunks) == 2: - chunks = (1, 1) + chunks + chunks = (1, nbands_per_asset) + chunks + elif isinstance(chunks, tuple) and len(chunks) == 4 and chunks[1] == "auto": + chunks = (chunks[0], nbands_per_asset, chunks[2], chunks[3]) - return da.core.normalize_chunks( + norm: TBYXChunks = da.core.normalize_chunks( chunks, shape, dtype=dtype, - previous_chunks=((1,) * shape[0], (1,) * shape[1], (shape[2],), (shape[3],)), + previous_chunks=((1,) * shape[0], nbands_per_asset, (shape[2],), (shape[3],)), # ^ Give dask some hint of the physical layout of the data, so it prefers widening # the spatial chunks over bundling together items/assets. This isn't totally accurate. ) + # Ensure we aren't trying to split apart multi-band assets. This would require rewriting + # the asset table (adding duplicate columns) and is generally not what you want, assuming + # that in multi-band assets, the bands are stored interleaved, so reading one requires reading + # them all anyway. + asset_table_band_chunks = process_multiband_chunks(norm[1], nbands_per_asset) + return norm, asset_table_band_chunks + + +def process_multiband_chunks( + chunks: tuple[int, ...], nbands_per_asset: tuple[int, ...] +) -> tuple[int, ...]: + """ + Validate that the bands chunks don't try to split apart any multi-band assets. + + Returns + ------- + asset_table_band_chunks: + Band chunks to apply to the asset table (so that assets are combined into single chunks as necessary). + ``len(asset_table_band_chunks) == len(chunks)``. In other words, along the bands, we'll have the same + ``numblocks`` in the asset table as ``numblocks`` in the final array. But each block in the final array + may be longer (have more bands) than the number of assets (when they're multi-band assets). + """ + n_chunks = len(chunks) + n_assets = len(nbands_per_asset) + + final_msg = ( + f"Requested bands chunks: {chunks}\n" + f"Physical bands chunks: {nbands_per_asset}\n" + "This would entail splitting apart multi-band assets. This typically (but not always) has " + "much worse performance, since GeoTIFF bands are generally interleaved (so reading one " + "band from a file requires reading them all).\n" + "If you have a use-case for this, please discuss on https://github.com/gjoseph92/stackstac/issues." + ) + + if n_chunks > n_assets: + raise NotImplementedError( + f"Refusing to make {n_chunks} chunk(s) for the bands when there are only {n_assets} bands asset(s).\n" + + final_msg + ) + elif n_chunks == n_assets: + if chunks != nbands_per_asset: + raise NotImplementedError(final_msg) + return chunks + else: + # Trying to combine multiple assets into one chunk; must be whole multiples. + # n_chunks < n_assets + asset_table_band_chunks: list[int] = [] + i = nbands_so_far = nbands_requested = n_assets_so_far = 0 + for nb in nbands_per_asset: + if nbands_requested == 0: + if i == n_chunks: + raise ValueError( + f"Invalid chunks for {sum(nbands_per_asset)} band(s): only {sum(chunks)} band(s) used.\n" + f"Requested bands chunks: {chunks}\n" + f"Physical bands chunks: {nbands_per_asset}\n" + ) + nbands_requested = chunks[i] + + nbands_so_far += nb + n_assets_so_far += 1 + if nbands_so_far < nbands_requested: + continue + elif nbands_so_far == nbands_requested: + # nailed it + i += 1 + nbands_so_far = 0 + nbands_requested = 0 + asset_table_band_chunks.append(n_assets_so_far) + n_assets_so_far = 0 + else: + # `nbands_so_far > nbands_requested` + raise NotImplementedError( + f"Specified chunks do not evenly combine multi-band assets: chunk {i} would split one apart.\n" + + final_msg + ) + return tuple(asset_table_band_chunks) + # FIXME remove this once rasterio bugs are fixed def window_from_bounds(bounds: Bbox, transform: Affine) -> windows.Window: