Skip to content

Commit

Permalink
[dagster-pipes] Add key param to PipesS3MessageReader
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Oct 1, 2024
1 parent 2ebe652 commit 1bd1bc8
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dagster_aws.pipes.message_readers import (
PipesCloudWatchMessageReader,
PipesLambdaLogsMessageReader,
PipesS3LogReader,
PipesS3MessageReader,
)

Expand All @@ -21,6 +22,7 @@
"PipesS3ContextInjector",
"PipesLambdaEventContextInjector",
"PipesS3MessageReader",
"PipesS3LogReader",
"PipesLambdaLogsMessageReader",
"PipesCloudWatchMessageReader",
"PipesEMRServerlessClient",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ class PipesGlueClient(PipesClient, TreatAsResourceParam):
context_injector (Optional[PipesContextInjector]): A context injector to use to inject
context into the Glue job, for example, :py:class:`PipesS3ContextInjector`.
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
from the glue job run. Defaults to :py:class:`PipesCloudWatchsMessageReader`.
from the glue job run. Defaults to :py:class:`PipesCloudWatchMessageReader`.
When provided with :py:class:`PipesCloudWatchMessageReader`,
it will be used to recieve logs and events from the ``.../output/<job-run-id>``
it will be used to receive logs and events from the ``.../output/<job-run-id>``
CloudWatch log stream created by AWS Glue. Note that AWS Glue routes both
``stderr`` and ``stdout`` from the main job process into this LogStream.
client (Optional[boto3.client]): The boto Glue client used to launch the Glue job
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,21 @@
import random
import string
import sys
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, TextIO, TypedDict
from typing import (
IO,
TYPE_CHECKING,
Any,
Dict,
Generator,
Iterator,
List,
Mapping,
Optional,
TextIO,
TypedDict,
)

import boto3
import dagster._check as check
Expand All @@ -13,61 +26,183 @@
from dagster._core.pipes.context import PipesMessageHandler
from dagster._core.pipes.utils import (
PipesBlobStoreMessageReader,
PipesChunkedLogReader,
PipesLogReader,
extract_message_or_forward_to_file,
extract_message_or_forward_to_stdout,
)
from dagster_pipes import PipesDefaultMessageWriter

if TYPE_CHECKING:
from mypy_boto3_s3 import S3Client


def _can_read_from_s3(client: "S3Client", bucket: Optional[str], key: Optional[str]):
if not bucket or not key:
return False
else:
try:
client.head_object(Bucket=bucket, Key=key)
return True
except ClientError:
return False

class PipesS3MessageReader(PipesBlobStoreMessageReader):
"""Message reader that reads messages by periodically reading message chunks from a specified S3
bucket.

If `log_readers` is passed, this reader will also start the passed readers
class PipesS3LogReader(PipesChunkedLogReader):
def __init__(
self,
*,
bucket: Optional[str] = None,
key: Optional[str] = None,
client=None,
interval: float = 10,
target_stream: Optional[IO[str]] = None,
):
self.bucket = bucket
self.key = key
self.client: "S3Client" = client or boto3.client("s3")

self.log_position = 0

super().__init__(interval=interval, target_stream=target_stream or sys.stdout)

def can_start(self, params: PipesParams) -> bool:
return _can_read_from_s3(
client=self.client,
bucket=params.get("bucket") or self.bucket,
key=params.get("key") or self.key,
)

def download_log_chunk(self, params: PipesParams) -> Optional[str]:
try:
bucket = params.get("bucket") or self.bucket
key = params.get("key") or self.key

if not bucket or not key:
return None

text = self.client.get_object(Bucket=bucket, Key=key)["Body"].read().decode("utf-8")
current_position = self.log_position
self.log_position += len(text)

return text[current_position:]
except:
return None


class PipesS3MessageReader(PipesBlobStoreMessageReader):
"""Message reader that reads messages from S3. Can operate in two modes: reading messages from objects
created by py:class:`dagster_pipes.PipesS3MessageWriter`, or reading messages from a specific S3 object
(typically a normal log file). If `log_readers` is passed, this reader will also start the passed readers
when the first message is received from the external process.
- if `expect_s3_message_writer` is set to `True` (default), a py:class:`dagster_pipes.PipesS3MessageWriter`
is expected to be used in the external process. The writer will write messages to a random S3 prefix in chunks,
and this reader will read them in order.
- if `expect_s3_message_writer` is set to `False`, this reader will read messages from a specific S3 object,
typically created by an external service (for example, by dumping stdout/stderr containing Pipes messages).
The object key can either be passed via corresponding constructor argument,
or with `update_params` method (if not known in advance).
Args:
interval (float): interval in seconds between attempts to download a chunk
bucket (str): The S3 bucket to read from.
client (WorkspaceClient): A boto3 client.
log_readers (Optional[Sequence[PipesLogReader]]): A set of readers for logs on S3.
expect_s3_message_writer (bool): Whether to expect a PipesS3MessageWriter to be used in the external process.
key (Optional[str]): The S3 key to read from. If not set, the key must be passed via `update_params`.
log_readers (Optional[Mapping[str, PipesLogReader]]): A mapping of arbitrary names to log readers.
"""

def __init__(
self,
*,
interval: float = 10,
bucket: str,
client: boto3.client, # pyright: ignore (reportGeneralTypeIssues)
log_readers: Optional[Sequence[PipesLogReader]] = None,
client=None,
expect_s3_message_writer: bool = True,
key: Optional[str] = None,
interval: float = 10,
log_readers: Optional[Mapping[str, "PipesLogReader"]] = None,
):
if isinstance(log_readers, Sequence):
# backcompat conversion for the older Sequence type
log_readers = {str(i): lr for i, lr in enumerate(log_readers)} # type: ignore

super().__init__(
interval=interval,
log_readers=log_readers,
)
self.bucket = check.str_param(bucket, "bucket")
self.client = client
self.client: "S3Client" = client or boto3.client("s3")
self.expect_s3_message_writer = expect_s3_message_writer
self.key = key

self.offset = 0

if expect_s3_message_writer and key is not None:
raise ValueError("key should not be set if expect_s3_message_writer is True")

def can_start(self, params: PipesParams) -> bool:
if self.expect_s3_message_writer:
# we are supposed to be reading from {i}.json chunks created by the MessageWriter
return _can_read_from_s3(
client=self.client,
bucket=params.get("bucket") or self.bucket,
key=f"{params['key_prefix']}/1.json",
)

else:
key = params.get("key") or self.key
return _can_read_from_s3(client=self.client, bucket=self.bucket, key=key)

@contextmanager
def get_params(self) -> Iterator[PipesParams]:
key_prefix = "".join(random.choices(string.ascii_letters, k=30))
yield {"bucket": self.bucket, "key_prefix": key_prefix}
if self.expect_s3_message_writer:
key_prefix = "".join(random.choices(string.ascii_letters, k=30))
yield {"bucket": self.bucket, "key_prefix": key_prefix}
else:
yield {PipesDefaultMessageWriter.STDIO_KEY: PipesDefaultMessageWriter.STDOUT}

def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]:
key = f"{params['key_prefix']}/{index}.json"
try:
obj = self.client.get_object(Bucket=self.bucket, Key=key)
return obj["Body"].read().decode("utf-8")
except ClientError:
return None
if self.expect_s3_message_writer:
try:
obj = self.client.get_object(
Bucket=self.bucket, Key=f"{params['key_prefix']}/{index}.json"
)
return obj["Body"].read().decode("utf-8")
except ClientError:
return None
else:
# we will be reading the same S3 object again and again
key = params.get("key") or self.key
try:
text = (
self.client.get_object(Bucket=self.bucket, Key=key)["Body"]
.read()
.decode("utf-8")
)
next_text = text[self.offset :]
self.offset = len(text)
return next_text
except ClientError:
return None

def no_messages_debug_text(self) -> str:
return (
f"Attempted to read messages from S3 bucket {self.bucket}. Expected"
" PipesS3MessageWriter to be explicitly passed to open_dagster_pipes in the external"
" process."
)
if self.expect_s3_message_writer:
return f"Attempted to read messages from S3 bucket {self.bucket}. Expected PipesS3MessageWriter to be explicitly passed to open_dagster_pipes in the external process."
else:
message = f"Attempted to read messages from S3 bucket {self.bucket}. The key is not set (yet)."

if self.key is not None:
message += f" Expected to read messages from S3 key {self.key}."
else:
message += (
" The `key` parameter was not set, should be provided via `update_params`."
)

message += " Please check if the object exists and is accessible."

return message


class PipesLambdaLogsMessageReader(PipesMessageReader):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
from threading import Event
from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal, Tuple
from uuid import uuid4

Expand Down Expand Up @@ -38,6 +39,7 @@
PipesLambdaClient,
PipesLambdaLogsMessageReader,
PipesS3ContextInjector,
PipesS3LogReader,
PipesS3MessageReader,
)
from dagster_aws_tests.pipes_tests.fake_ecs import LocalECSMockClient
Expand Down Expand Up @@ -152,6 +154,63 @@ def s3_client(moto_server):
return client


def test_s3_log_reader(s3_client, capsys):
key = str(uuid4())
log_reader = PipesS3LogReader(client=s3_client, bucket=_S3_TEST_BUCKET, key=key)
is_session_closed = Event()

assert not log_reader.can_start({})

s3_client.put_object(Bucket=_S3_TEST_BUCKET, Key=key, Body=b"Line 0\nLine 1")

assert log_reader.can_start({})

log_reader.start({}, is_session_closed)
assert log_reader.is_running()

s3_client.put_object(Bucket=_S3_TEST_BUCKET, Key=key, Body=b"Line 0\nLine 1\nLine 2")

is_session_closed.set()

log_reader.stop()

assert not log_reader.is_running()

captured = capsys.readouterr()

assert captured.out == "Line 0\nLine 1\nLine 2"

assert sys.stdout is not None


def test_s3_message_reader(s3_client):
message_reader = PipesS3MessageReader(
client=s3_client, bucket=_S3_TEST_BUCKET, expect_s3_message_writer=False
)

key = str(uuid4())

@asset
def my_asset(context: AssetExecutionContext):
with open_pipes_session(
context=context,
message_reader=message_reader,
context_injector=PipesEnvContextInjector(),
) as session:
assert not message_reader.can_start({})
params = {"key": key}
message_reader.update_params(params)
assert not message_reader.can_start(params)

s3_client.put_object(Bucket=_S3_TEST_BUCKET, Key=key, Body=b"hello world")

assert message_reader.can_start(params)

return session.get_results()

materialize([my_asset])


def test_s3_pipes_components(
capsys,
tmpdir,
Expand Down

0 comments on commit 1bd1bc8

Please sign in to comment.