Skip to content

Commit

Permalink
[BUG] Fix multi-output tasks in RayRunner (#2291)
Browse files Browse the repository at this point in the history
Fixes a user-reported bug around `into_partitions` failing on the Ray
runner for a specific corner-case: when **some** input partition wasn't
split, the RayRunner was performing an overly-aggressive assert that
should be relaxed.

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored May 22, 2024
1 parent df9aa15 commit f5b8dfa
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
5 changes: 4 additions & 1 deletion daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,9 @@ def place_in_queue(item):
# If it is a no-op task, just run it locally immediately.
elif len(next_step.instructions) == 0:
logger.debug("Running task synchronously in main thread: %s", next_step)
assert isinstance(next_step, SingleOutputPartitionTask)
assert (
len(next_step.partial_metadatas) == 1
), "No-op tasks must have one output by definition, since there are no instructions to run"
[single_partial] = next_step.partial_metadatas
if single_partial.num_rows is None:
[single_meta] = ray.get(get_metas.remote(next_step.inputs))
Expand All @@ -577,6 +579,7 @@ def place_in_queue(item):
)
]
)

next_step.set_result(
[RayMaterializedResult(partition, accessor, 0) for partition in next_step.inputs]
)
Expand Down
12 changes: 12 additions & 0 deletions tests/dataframe/test_repartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,15 @@ def test_into_partitions_coalesce(make_df) -> None:
data = {"foo": list(range(100))}
df = make_df(data).into_partitions(20).into_partitions(1).collect()
assert df.to_pydict() == data


def test_into_partitions_some_no_split(make_df) -> None:
data = {"foo": [1, 2, 3]}

# Materialize as 3 partitions
df = make_df(data).into_partitions(3).collect()

# Attempt to split into 4 partitions, so only 1 split occurs
df = df.into_partitions(4).collect()

assert df.to_pydict() == data

0 comments on commit f5b8dfa

Please sign in to comment.