Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle reading multi-band datasetes #146

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@
max-line-length = 120
exclude =
.pyi
typings
typings
ignore =
E203 # whitespace before ':'
W503 # line break before binary operator
34 changes: 34 additions & 0 deletions stackstac/nodata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Tuple, Union
import re

import numpy as np
from rasterio.windows import Window

State = Tuple[np.dtype, Union[int, float]]


def nodata_for_window(
ndim: int, window: Window, fill_value: Union[int, float], dtype: np.dtype
):
return np.full((ndim, window.height, window.width), fill_value, dtype)


def exception_matches(e: Exception, patterns: Tuple[Exception, ...]) -> bool:
"""
Whether an exception matches one of the pattern exceptions

Parameters
----------
e:
The exception to check
patterns:
Instances of an Exception type to catch, where ``str(exception_pattern)``
is a regex pattern to match against ``str(e)``.
"""
e_type = type(e)
e_msg = str(e)
for pattern in patterns:
if issubclass(e_type, type(pattern)):
if re.match(str(pattern), e_msg):
return True
return False
65 changes: 0 additions & 65 deletions stackstac/nodata_reader.py

This file was deleted.

45 changes: 41 additions & 4 deletions stackstac/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,16 @@
from .stac_types import ItemSequence
from . import accumulate_metadata, geom_utils

ASSET_TABLE_DT = np.dtype([("url", object), ("bounds", "float64", 4)])
ASSET_TABLE_DT = np.dtype(
[("url", object), ("bounds", "float64", 4), ("bands", object)]
)
# ^ NOTE: `bands` should be a `Sequence[int]` of _1-indexed_ bands to fetch from the asset.
# We support specifying a sequence of band indices (rather than just the number of bands,
# and always doing a `read()` of all bands) for future optimizations to support fetching
# (and possibly reordering?) a subset of bands per asset. This could be done either via
# another argument to `stack` (please no!) or a custom Dask optimization, akin to column
# projection for DataFrames.
# But at the moment, `bands == list(range(1, ds.count + 1))`.


class Mimetype(NamedTuple):
Expand Down Expand Up @@ -64,7 +73,7 @@ def prepare_items(
bounds: Optional[Bbox] = None,
bounds_latlon: Optional[Bbox] = None,
snap_bounds: bool = True,
) -> Tuple[np.ndarray, RasterSpec, List[str], ItemSequence]:
) -> Tuple[np.ndarray, RasterSpec, List[str], ItemSequence, tuple[int, ...]]:

if bounds is not None and bounds_latlon is not None:
raise ValueError(
Expand Down Expand Up @@ -119,6 +128,7 @@ def prepare_items(
asset_ids = assets

asset_table = np.full((len(items), len(asset_ids)), None, dtype=ASSET_TABLE_DT)
nbands_per_asset: list[int | None] = [None] * len(asset_ids)

# TODO support item-assets https://github.com/radiantearth/stac-spec/tree/master/extensions/item-assets

Expand Down Expand Up @@ -321,7 +331,25 @@ def prepare_items(
)

# Phew, we figured out all the spatial stuff! Now actually store the information we care about.
asset_table[item_i, asset_i] = (asset["href"], asset_bbox_proj)

bands: Optional[Sequence[int]] = None
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, this is where we'd actually figure out band counts from STAC metadata

# ^ TODO actually determine this from `eo:bands` or `raster:bands`
# https://github.com/gjoseph92/stackstac/issues/62

nbands = 1 if bands is None else len(bands)
prev_nbands = nbands_per_asset[asset_i]
if prev_nbands is None:
nbands_per_asset[asset_i] = nbands
else:
if prev_nbands != nbands:
raise ValueError(
f"The asset {id!r} has {nbands} band(s) on item {item_i} {item['id']!r}, "
f"but on all previous items, it had {prev_nbands}."
# TODO improve this error message with something actionable
# (it's probably a data provider issue), once multi-band is actually supported.
)

asset_table[item_i, asset_i] = (asset["href"], asset_bbox_proj, bands)
# ^ NOTE: If `asset_bbox_proj` is None, NumPy automatically converts it to NaNs

# At this point, everything has been set (or there was as error)
Expand All @@ -346,9 +374,18 @@ def prepare_items(
if item_isnan.any() or asset_id_isnan.any():
asset_table = asset_table[np.ix_(~item_isnan, ~asset_id_isnan)]
asset_ids = [id for id, isnan in zip(asset_ids, asset_id_isnan) if not isnan]
nbands_per_asset = [
id for id, isnan in zip(nbands_per_asset, asset_id_isnan) if not isnan
]
items = [item for item, isnan in zip(items, item_isnan) if not isnan]

return asset_table, spec, asset_ids, items
# Being for the benefit of mr. typechecker
nbpa = tuple(x for x in nbands_per_asset if x is not None)
assert len(nbpa) == len(
nbands_per_asset
), f"Some `nbands_per_asset` are None: {nbands_per_asset}"

return asset_table, spec, asset_ids, items, nbpa


def to_coords(
Expand Down
27 changes: 22 additions & 5 deletions stackstac/reader_protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from __future__ import annotations
from typing import Optional, Protocol, Tuple, Type, TYPE_CHECKING, TypeVar, Union
from typing import (
Optional,
Protocol,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)

import numpy as np

Expand Down Expand Up @@ -30,6 +39,7 @@ def __init__(
self,
*,
url: str,
bands: Optional[Sequence[int]],
spec: RasterSpec,
resampling: Resampling,
dtype: np.dtype,
Expand All @@ -45,6 +55,9 @@ def __init__(
----------
url:
Fetch data from the asset at this URL.
bands:
List of (one-indexed!) band indices to read, or None for all bands.
If None, the asset must have exactly one band.
spec:
Reproject data to match this georeferencing information.
resampling:
Expand All @@ -69,7 +82,6 @@ def __init__(
where ``str(exception_pattern)`` is a regex pattern to match against
``str(raised_exception)``.
"""
# TODO colormaps?

def read(self, window: Window) -> np.ndarray:
"""
Expand All @@ -87,7 +99,7 @@ def read(self, window: Window) -> np.ndarray:

Returns
-------
array: The window of data read
array: The window of data read from all bands, as a 3D array
"""
...

Expand All @@ -113,11 +125,16 @@ class FakeReader:
or inherent to the dask graph.
"""

def __init__(self, *, dtype: np.dtype, **kwargs) -> None:
def __init__(
self, *, bands: Optional[Sequence[int]], dtype: np.dtype, **kwargs
) -> None:
self.dtype = dtype
self.ndim = len(bands) if bands is not None else 1
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof, ndim is not the right name for this, since this is just the length of one dimension... it was a late night.


def read(self, window: Window, **kwargs) -> np.ndarray:
return np.random.random((window.height, window.width)).astype(self.dtype)
return np.random.random((self.ndim, window.height, window.width)).astype(
self.dtype
)

def close(self) -> None:
pass
Expand Down
Loading