Skip to content

Commit

Permalink
feat: allow stream to be None
Browse files Browse the repository at this point in the history
This will allow the client to decide whether to stream or not.
  • Loading branch information
gadomski committed Oct 8, 2024
1 parent 6bbc73f commit 40b6c71
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 22 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased]

### Changed

- `stream` can now be `None`, with each client deciding what its preferred mode is ([#226](https://github.com/stac-utils/stac-asset/pull/226))

### Removed

- Python 3.9 ([#224](https://github.com/stac-utils/stac-asset/pull/224))
Expand Down
27 changes: 18 additions & 9 deletions src/stac_asset/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import click_logging
import tabulate
import tqdm
from click import Choice
from pystac import Asset, Item, ItemCollection

from . import Config, ErrorStrategy, _functions
Expand Down Expand Up @@ -147,10 +148,10 @@ def cli() -> None:
default=_functions.DEFAULT_MAX_CONCURRENT_DOWNLOADS,
)
@click.option(
"--no-stream",
help="Disable chunked reading of assets",
default=False,
is_flag=True,
"--stream",
help="Enable, disable, or defer to the client for chunked reading of assets",
default="defer",
type=Choice(["true", "false", "defer"]),
show_default=True,
)
# TODO add option to disable content type checking
Expand All @@ -172,7 +173,7 @@ def download(
fail_fast: bool,
overwrite: bool,
max_concurrent_downloads: int,
no_stream: bool,
stream: str,
) -> None:
"""Download STAC assets from an item or item collection.
Expand All @@ -193,6 +194,14 @@ def download(
$ stac-asset download -i asset-key-to-include item.json
"""
if stream == "true":
stream_actual = True
elif stream == "false":
stream_actual = False
elif stream == "defer":
stream_actual = None
else:
raise NotImplementedError(f"Unknonwn stream value: {stream}")
asyncio.run(
download_async(
href,
Expand All @@ -212,7 +221,7 @@ def download(
fail_fast=fail_fast,
overwrite=overwrite,
max_concurrent_downloads=max_concurrent_downloads,
no_stream=no_stream,
stream=stream_actual,
)
)

Expand All @@ -235,7 +244,7 @@ async def download_async(
fail_fast: bool,
overwrite: bool,
max_concurrent_downloads: int,
no_stream: bool,
stream: bool | None,
) -> None:
config = Config(
alternate_assets=alternate_assets,
Expand Down Expand Up @@ -282,7 +291,7 @@ async def download() -> Item | ItemCollection:
config=config,
messages=messages,
max_concurrent_downloads=max_concurrent_downloads,
stream=not no_stream,
stream=stream,
)

elif type_ == "FeatureCollection":
Expand All @@ -297,7 +306,7 @@ async def download() -> Item | ItemCollection:
config=config,
messages=messages,
max_concurrent_downloads=max_concurrent_downloads,
stream=not no_stream,
stream=stream,
)

else:
Expand Down
15 changes: 9 additions & 6 deletions src/stac_asset/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Download:
config: Config

async def download(
self, messages: MessageQueue | None, stream: bool = True
self, messages: MessageQueue | None, stream: bool | None = None
) -> Download | WrappedError:
if not os.path.exists(self.path) or self.config.overwrite:
try:
Expand Down Expand Up @@ -135,7 +135,7 @@ async def add(
stac_object.assets = assets

async def download(
self, messages: MessageQueue | None, stream: bool = True
self, messages: MessageQueue | None, stream: bool | None = None
) -> None:
tasks: set[Task[Download | WrappedError]] = set()
for download in self.downloads:
Expand Down Expand Up @@ -173,7 +173,10 @@ async def download(
raise DownloadError(exceptions)

async def download_with_lock(
self, download: Download, messages: MessageQueue | None, stream: bool = True
self,
download: Download,
messages: MessageQueue | None,
stream: bool | None = None,
) -> Download | WrappedError:
await self.semaphore.acquire()
try:
Expand Down Expand Up @@ -212,7 +215,7 @@ async def download_item(
clients: list[Client] | None = None,
keep_non_downloaded: bool = False,
max_concurrent_downloads: int = DEFAULT_MAX_CONCURRENT_DOWNLOADS,
stream: bool = True,
stream: bool | None = None,
) -> Item:
"""Downloads an item to the local filesystem.
Expand Down Expand Up @@ -326,7 +329,7 @@ async def download_item_collection(
clients: list[Client] | None = None,
keep_non_downloaded: bool = False,
max_concurrent_downloads: int = DEFAULT_MAX_CONCURRENT_DOWNLOADS,
stream: bool = True,
stream: bool | None = None,
) -> ItemCollection:
"""Downloads an item collection to the local filesystem.
Expand Down Expand Up @@ -385,7 +388,7 @@ async def download_asset(
config: Config,
messages: MessageQueue | None = None,
clients: Clients | None = None,
stream: bool = True,
stream: bool | None = None,
) -> Asset:
"""Downloads an asset.
Expand Down
6 changes: 3 additions & 3 deletions src/stac_asset/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def open_url(
url: URL,
content_type: str | None = None,
messages: MessageQueue | None = None,
stream: bool = True,
stream: bool | None = None,
) -> AsyncIterator[bytes]:
"""Opens a url and yields an iterator over its bytes.
Expand All @@ -68,7 +68,7 @@ async def open_href(
href: str,
content_type: str | None = None,
messages: MessageQueue | None = None,
stream: bool = True,
stream: bool | None = None,
) -> AsyncIterator[bytes]:
"""Opens a href and yields an iterator over its bytes.
Expand All @@ -94,7 +94,7 @@ async def download_href(
clean: bool = True,
content_type: str | None = None,
messages: MessageQueue | None = None,
stream: bool = True,
stream: bool | None = None,
) -> None:
"""Downloads a file to the local filesystem.
Expand Down
4 changes: 3 additions & 1 deletion src/stac_asset/filesystem_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def open_url(
url: URL,
content_type: str | None = None,
messages: MessageQueue | None = None,
stream: bool = True,
stream: bool | None = None,
) -> AsyncIterator[bytes]:
"""Iterates over data from a local url.
Expand All @@ -42,6 +42,8 @@ async def open_url(
ValueError: Raised if the url has a scheme. This behavior will
change if/when we support Windows paths.
"""
if stream is None:
stream = False
if url.scheme:
raise ValueError(
"cannot read a file with the filesystem client if it has a url scheme: "
Expand Down
4 changes: 3 additions & 1 deletion src/stac_asset/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def open_url(
url: URL,
content_type: str | None = None,
messages: MessageQueue | None = None,
stream: bool = True,
stream: bool | None = None,
) -> AsyncIterator[bytes]:
"""Opens a url with this client's session and iterates over its bytes.
Expand All @@ -157,6 +157,8 @@ async def open_url(
Raises:
:py:class:`aiohttp.ClientResponseError`: Raised if the response is not OK
"""
if stream is None:
stream = True
async with self.session.get(url, allow_redirects=True) as response:
response.raise_for_status()
if self.check_content_type and content_type:
Expand Down
2 changes: 1 addition & 1 deletion src/stac_asset/planetary_computer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def open_url(
url: URL,
content_type: str | None = None,
messages: Queue[Any] | None = None,
stream: bool = True,
stream: bool | None = None,
) -> AsyncIterator[bytes]:
"""Opens a url and iterates over its bytes.
Expand Down
4 changes: 3 additions & 1 deletion src/stac_asset/s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def open_url(
url: URL,
content_type: str | None = None,
messages: MessageQueue | None = None,
stream: bool = True,
stream: bool | None = None,
) -> AsyncIterator[bytes]:
"""Opens an s3 url and iterates over its bytes.
Expand All @@ -105,6 +105,8 @@ async def open_url(
Raises:
SchemeError: Raised if the url's scheme is not ``s3``
"""
if stream is None:
stream = True
async with self._create_client() as client:
response = await client.get_object(**self._params(url))
if content_type:
Expand Down

0 comments on commit 40b6c71

Please sign in to comment.