Skip to content

Commit

Permalink
[dagster-pipes] Add key param to PipesS3MessageReader
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Oct 3, 2024
1 parent 4ef06eb commit c45c859
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 29 deletions.
3 changes: 0 additions & 3 deletions python_modules/dagster/dagster/_core/pipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,6 @@ class PipesFileMessageReader(PipesMessageReader):
def __init__(self, path: str):
self._path = check.str_param(path, "path")

def on_launched(self, params: PipesLaunchedData) -> None:
self.launched_payload = params

@contextmanager
def read_messages(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dagster_aws.pipes.message_readers import (
PipesCloudWatchMessageReader,
PipesLambdaLogsMessageReader,
PipesS3LogReader,
PipesS3MessageReader,
)

Expand All @@ -21,6 +22,7 @@
"PipesS3ContextInjector",
"PipesLambdaEventContextInjector",
"PipesS3MessageReader",
"PipesS3LogReader",
"PipesLambdaLogsMessageReader",
"PipesCloudWatchMessageReader",
"PipesEMRServerlessClient",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ class PipesGlueClient(PipesClient, TreatAsResourceParam):
context_injector (Optional[PipesContextInjector]): A context injector to use to inject
context into the Glue job, for example, :py:class:`PipesS3ContextInjector`.
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
from the glue job run. Defaults to :py:class:`PipesCloudWatchsMessageReader`.
from the glue job run. Defaults to :py:class:`PipesCloudWatchMessageReader`.
When provided with :py:class:`PipesCloudWatchMessageReader`,
it will be used to recieve logs and events from the ``.../output/<job-run-id>``
it will be used to receive logs and events from the ``.../output/<job-run-id>``
CloudWatch log stream created by AWS Glue. Note that AWS Glue routes both
``stderr`` and ``stdout`` from the main job process into this LogStream.
client (Optional[boto3.client]): The boto Glue client used to launch the Glue job
Expand Down
Original file line number Diff line number Diff line change
@@ -1,73 +1,222 @@
import base64
import gzip
import os
import random
import string
import sys
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, TextIO, TypedDict
from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterator,
List,
Mapping,
Optional,
TextIO,
TypedDict,
)

import boto3
import dagster._check as check
from botocore.exceptions import ClientError
from dagster._annotations import experimental
from dagster._core.pipes.client import PipesMessageReader, PipesParams
from dagster._core.pipes.client import PipesLaunchedData, PipesMessageReader, PipesParams
from dagster._core.pipes.context import PipesMessageHandler
from dagster._core.pipes.utils import (
PipesBlobStoreMessageReader,
PipesChunkedLogReader,
PipesLogReader,
extract_message_or_forward_to_file,
extract_message_or_forward_to_stdout,
)
from dagster_pipes import PipesDefaultMessageWriter

if TYPE_CHECKING:
from mypy_boto3_s3 import S3Client


def _can_read_from_s3(client: "S3Client", bucket: Optional[str], key: Optional[str]):
if not bucket or not key:
return False
else:
try:
client.head_object(Bucket=bucket, Key=key)
return True
except ClientError:
return False

class PipesS3MessageReader(PipesBlobStoreMessageReader):
"""Message reader that reads messages by periodically reading message chunks from a specified S3
bucket.

If `log_readers` is passed, this reader will also start the passed readers
def default_log_decode_fn(contents: bytes) -> str:
return contents.decode("utf-8")


def gzip_log_decode_fn(contents: bytes) -> str:
return gzip.decompress(contents).decode("utf-8")


class PipesS3LogReader(PipesChunkedLogReader):
def __init__(
self,
*,
bucket: str,
key: str,
client=None,
interval: float = 10,
target_stream: Optional[IO[str]] = None,
# TODO: maybe move this parameter to a different scope
decode_fn: Optional[Callable[[bytes], str]] = None,
):
self.bucket = bucket
self.key = key
self.client: "S3Client" = client or boto3.client("s3")
self.decode_fn = decode_fn or default_log_decode_fn

self.log_position = 0

super().__init__(interval=interval, target_stream=target_stream or sys.stdout)

@property
def name(self) -> str:
return f"PipesS3LogReader(s3://{os.path.join(self.bucket, self.key)})"

def can_start(self, params: PipesParams) -> bool:
return _can_read_from_s3(
client=self.client,
bucket=self.bucket,
key=self.key,
)

def download_log_chunk(self, params: PipesParams) -> Optional[str]:
text = self.decode_fn(
self.client.get_object(Bucket=self.bucket, Key=self.key)["Body"].read()
)
current_position = self.log_position
self.log_position += len(text)

return text[current_position:]


class PipesS3MessageReader(PipesBlobStoreMessageReader):
"""Message reader that reads messages from S3. Can operate in two modes: reading messages from objects
created by py:class:`dagster_pipes.PipesS3MessageWriter`, or reading messages from a specific S3 object
(typically a normal log file). If `log_readers` is passed, this reader will also start the passed readers
when the first message is received from the external process.
- if `expect_s3_message_writer` is set to `True` (default), a py:class:`dagster_pipes.PipesS3MessageWriter`
is expected to be used in the external process. The writer will write messages to a random S3 prefix in chunks,
and this reader will read them in order.
- if `expect_s3_message_writer` is set to `False`, this reader will read messages from a specific S3 object,
typically created by an external service (for example, by dumping stdout/stderr containing Pipes messages).
The object key can either be passed via corresponding constructor argument,
or with `on_launched` method (if not known in advance).
Args:
interval (float): interval in seconds between attempts to download a chunk
bucket (str): The S3 bucket to read from.
client (WorkspaceClient): A boto3 client.
log_readers (Optional[Sequence[PipesLogReader]]): A set of readers for logs on S3.
expect_s3_message_writer (bool): Whether to expect a PipesS3MessageWriter to be used in the external process.
key (Optional[str]): The S3 key to read from. If not set, the key must be passed via `on_launched`.
log_readers (Optional[Mapping[str, PipesLogReader]]): A mapping of arbitrary names to log readers.
"""

def __init__(
self,
*,
interval: float = 10,
bucket: str,
client: boto3.client, # pyright: ignore (reportGeneralTypeIssues)
log_readers: Optional[Sequence[PipesLogReader]] = None,
client=None,
expect_s3_message_writer: bool = True,
key: Optional[str] = None,
interval: float = 10,
log_readers: Optional[Mapping[str, "PipesLogReader"]] = None,
):
if isinstance(log_readers, Sequence):
# backcompat conversion for the older Sequence type
log_readers = {str(i): lr for i, lr in enumerate(log_readers)} # type: ignore

super().__init__(
interval=interval,
log_readers=log_readers,
)
self.bucket = check.str_param(bucket, "bucket")
self.client = client
self.client: "S3Client" = client or boto3.client("s3")
self.expect_s3_message_writer = expect_s3_message_writer
self.key = key

self.offset = 0

if expect_s3_message_writer and key is not None:
raise ValueError("key should not be set if expect_s3_message_writer is True")

def on_launched(self, launched_payload: PipesLaunchedData) -> None:
if not self.expect_s3_message_writer:
self.key = launched_payload["extras"].get("key")

self.launched_payload = launched_payload

def can_start(self, params: PipesParams) -> bool:
if self.expect_s3_message_writer:
# we are supposed to be reading from {i}.json chunks created by the MessageWriter
return _can_read_from_s3(
client=self.client,
bucket=params.get("bucket") or self.bucket,
key=f"{params['key_prefix']}/1.json",
)

else:
return _can_read_from_s3(client=self.client, bucket=self.bucket, key=self.key)

@contextmanager
def get_params(self) -> Iterator[PipesParams]:
key_prefix = "".join(random.choices(string.ascii_letters, k=30))
yield {"bucket": self.bucket, "key_prefix": key_prefix}
if self.expect_s3_message_writer:
key_prefix = "".join(random.choices(string.ascii_letters, k=30))
yield {"bucket": self.bucket, "key_prefix": key_prefix}
else:
yield {PipesDefaultMessageWriter.STDIO_KEY: PipesDefaultMessageWriter.STDOUT}

def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]:
key = f"{params['key_prefix']}/{index}.json"
try:
obj = self.client.get_object(Bucket=self.bucket, Key=key)
return obj["Body"].read().decode("utf-8")
except ClientError:
return None
if self.expect_s3_message_writer:
try:
obj = self.client.get_object(
Bucket=self.bucket, Key=f"{params['key_prefix']}/{index}.json"
)
return obj["Body"].read().decode("utf-8")
except ClientError:
return None
else:
# we will be reading the same S3 object again and again
key = params.get("key") or self.key
try:
text = (
self.client.get_object(Bucket=self.bucket, Key=key)["Body"]
.read()
.decode("utf-8")
)
next_text = text[self.offset :]
self.offset = len(text)
return next_text
except ClientError:
return None

def no_messages_debug_text(self) -> str:
return (
f"Attempted to read messages from S3 bucket {self.bucket}. Expected"
" PipesS3MessageWriter to be explicitly passed to open_dagster_pipes in the external"
" process."
)
if self.expect_s3_message_writer:
return f"Attempted to read messages from S3 bucket {self.bucket}. Expected PipesS3MessageWriter to be explicitly passed to open_dagster_pipes in the external process."
else:
message = f"Attempted to read messages from S3 bucket {self.bucket}. The key is not set (yet)."

if self.key is not None:
message += f" Expected to read messages from S3 key {self.key}."
else:
message += " The `key` parameter was not set, should be provided via `on_launched`."

message += " Please check if the object exists and is accessible."

return message


class PipesLambdaLogsMessageReader(PipesMessageReader):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
from threading import Event
from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal, Tuple
from uuid import uuid4

Expand Down Expand Up @@ -38,6 +39,7 @@
PipesLambdaClient,
PipesLambdaLogsMessageReader,
PipesS3ContextInjector,
PipesS3LogReader,
PipesS3MessageReader,
)
from dagster_aws_tests.pipes_tests.fake_ecs import LocalECSMockClient
Expand Down Expand Up @@ -152,6 +154,63 @@ def s3_client(moto_server):
return client


def test_s3_log_reader(s3_client, capsys):
key = str(uuid4())
log_reader = PipesS3LogReader(client=s3_client, bucket=_S3_TEST_BUCKET, key=key)
is_session_closed = Event()

assert not log_reader.can_start({})

s3_client.put_object(Bucket=_S3_TEST_BUCKET, Key=key, Body=b"Line 0\nLine 1")

assert log_reader.can_start({})

log_reader.start({}, is_session_closed)
assert log_reader.is_running()

s3_client.put_object(Bucket=_S3_TEST_BUCKET, Key=key, Body=b"Line 0\nLine 1\nLine 2")

is_session_closed.set()

log_reader.stop()

assert not log_reader.is_running()

captured = capsys.readouterr()

assert captured.out == "Line 0\nLine 1\nLine 2"

assert sys.stdout is not None


def test_s3_message_reader(s3_client):
message_reader = PipesS3MessageReader(
client=s3_client, bucket=_S3_TEST_BUCKET, expect_s3_message_writer=False
)

key = str(uuid4())

@asset
def my_asset(context: AssetExecutionContext):
with open_pipes_session(
context=context,
message_reader=message_reader,
context_injector=PipesEnvContextInjector(),
) as session:
assert not message_reader.can_start({})
params = {"key": key}
session.report_launched({"extras": params})
assert not message_reader.can_start(params)

s3_client.put_object(Bucket=_S3_TEST_BUCKET, Key=key, Body=b"hello world")

assert message_reader.can_start(params)

return session.get_results()

materialize([my_asset])


def test_s3_pipes_components(
capsys,
tmpdir,
Expand Down

0 comments on commit c45c859

Please sign in to comment.