Skip to content

Commit

Permalink
Merge branch 'main' into feat--cast-dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
timsaucer authored Oct 21, 2024
2 parents b39a5f0 + 7cca028 commit c698c97
Show file tree
Hide file tree
Showing 11 changed files with 270 additions and 54 deletions.
20 changes: 10 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 4 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ build-backend = "maturin"
name = "datafusion"
description = "Build and run queries against data"
readme = "README.md"
license = {file = "LICENSE.txt"}
requires-python = ">=3.6"
license = { file = "LICENSE.txt" }
requires-python = ">=3.7"
keywords = ["datafusion", "dataframe", "rust", "query-engine"]
classifier = [
"Development Status :: 2 - Pre-Alpha",
Expand All @@ -42,10 +42,7 @@ classifier = [
"Programming Language :: Python",
"Programming Language :: Rust",
]
dependencies = [
"pyarrow>=11.0.0",
"typing-extensions;python_version<'3.13'",
]
dependencies = ["pyarrow>=11.0.0", "typing-extensions;python_version<'3.13'"]

[project.urls]
homepage = "https://datafusion.apache.org/python"
Expand All @@ -58,9 +55,7 @@ profile = "black"
[tool.maturin]
python-source = "python"
module-name = "datafusion._internal"
include = [
{ path = "Cargo.lock", format = "sdist" }
]
include = [{ path = "Cargo.lock", format = "sdist" }]
exclude = [".github/**", "ci/**", ".asf.yaml"]
# Require Cargo.lock is up to date
locked = true
Expand Down
34 changes: 31 additions & 3 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from datafusion.record_batch import RecordBatchStream
from datafusion.udf import ScalarUDF, AggregateUDF, WindowUDF

from typing import Any, TYPE_CHECKING
from typing import Any, TYPE_CHECKING, Protocol
from typing_extensions import deprecated

if TYPE_CHECKING:
Expand All @@ -41,6 +41,28 @@
from datafusion.plan import LogicalPlan, ExecutionPlan


class ArrowStreamExportable(Protocol):
"""Type hint for object exporting Arrow C Stream via Arrow PyCapsule Interface.
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
"""

def __arrow_c_stream__( # noqa: D105
self, requested_schema: object | None = None
) -> object: ...


class ArrowArrayExportable(Protocol):
"""Type hint for object exporting Arrow C Array via Arrow PyCapsule Interface.
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
"""

def __arrow_c_array__( # noqa: D105
self, requested_schema: object | None = None
) -> tuple[object, object]: ...


class SessionConfig:
"""Session configuration options."""

Expand Down Expand Up @@ -592,12 +614,18 @@ def from_pydict(
"""
return DataFrame(self.ctx.from_pydict(data, name))

def from_arrow(self, data: Any, name: str | None = None) -> DataFrame:
def from_arrow(
self,
data: ArrowStreamExportable | ArrowArrayExportable,
name: str | None = None,
) -> DataFrame:
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from an Arrow source.
The Arrow data source can be any object that implements either
``__arrow_c_stream__`` or ``__arrow_c_array__``. For the latter, it must return
a struct array. Common examples of sources from pyarrow include
a struct array.
Arrow data can be Polars, Pandas, Pyarrow etc.
Args:
data: Arrow data source.
Expand Down
56 changes: 51 additions & 5 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from __future__ import annotations

from typing import Any, Iterable, List, TYPE_CHECKING

from typing import Any, Iterable, List, Literal, TYPE_CHECKING
from datafusion.record_batch import RecordBatchStream
from typing_extensions import deprecated
from datafusion.plan import LogicalPlan, ExecutionPlan
Expand Down Expand Up @@ -129,6 +130,17 @@ def select(self, *exprs: Expr | str) -> DataFrame:
]
return DataFrame(self.df.select(*exprs_internal))

def drop(self, *columns: str) -> DataFrame:
"""Drop arbitrary amount of columns.
Args:
columns: Column names to drop from the dataframe.
Returns:
DataFrame with those columns removed in the projection.
"""
return DataFrame(self.df.drop(*columns))

def filter(self, *predicates: Expr) -> DataFrame:
"""Return a DataFrame for which ``predicate`` evaluates to ``True``.
Expand Down Expand Up @@ -163,14 +175,25 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
def with_columns(
self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr
) -> DataFrame:
"""Add an additional column to the DataFrame.
"""Add columns to the DataFrame.
By passing expressions, iteratables of expressions, or named expressions. To
pass named expressions use the form name=Expr.
Example usage: The following will add 4 columns labeled a, b, c, and d::
df = df.with_columns(
lit(0).alias('a'),
[lit(1).alias('b'), lit(2).alias('c')],
d=lit(3)
)
Args:
*exprs: Name of the column to add.
**named_exprs: Expression to compute the column.
exprs: Either a single expression or an iterable of expressions to add.
named_exprs: Named expressions in the form of ``name=expr``
Returns:
DataFrame with the new column.
DataFrame with the new columns added.
"""

def _simplify_expression(
Expand Down Expand Up @@ -339,6 +362,29 @@ def join(
"""
return DataFrame(self.df.join(right.df, join_keys, how))

def join_on(
self,
right: DataFrame,
*on_exprs: Expr,
how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner",
) -> DataFrame:
"""Join two :py:class:`DataFrame`using the specified expressions.
On expressions are used to support in-equality predicates. Equality
predicates are correctly optimized
Args:
right: Other DataFrame to join with.
on_exprs: single or multiple (in)-equality predicates.
how: Type of join to perform. Supported types are "inner", "left",
"right", "full", "semi", "anti".
Returns:
DataFrame after join.
"""
exprs = [expr.expr for expr in on_exprs]
return DataFrame(self.df.join_on(right.df, exprs, how))

def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame:
"""Return a DataFrame with the explanation of its plan so far.
Expand Down
12 changes: 12 additions & 0 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,18 @@ def is_not_null(self) -> Expr:
"""Returns ``True`` if this expression is not null."""
return Expr(self.expr.is_not_null())

def fill_nan(self, value: Any | Expr | None = None) -> Expr:
"""Fill NaN values with a provided value."""
if not isinstance(value, Expr):
value = Expr.literal(value)
return Expr(functions_internal.nanvl(self.expr, value.expr))

def fill_null(self, value: Any | Expr | None = None) -> Expr:
"""Fill NULL values with a provided value."""
if not isinstance(value, Expr):
value = Expr.literal(value)
return Expr(functions_internal.nvl(self.expr, value.expr))

_to_pyarrow_types = {
float: pa.float64(),
int: pa.int64(),
Expand Down
6 changes: 6 additions & 0 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@
"min",
"named_struct",
"nanvl",
"nvl",
"now",
"nth_value",
"nullif",
Expand Down Expand Up @@ -673,6 +674,11 @@ def nanvl(x: Expr, y: Expr) -> Expr:
return Expr(f.nanvl(x.expr, y.expr))


def nvl(x: Expr, y: Expr) -> Expr:
"""Returns ``x`` if ``x`` is not ``NULL``. Otherwise returns ``y``."""
return Expr(f.nvl(x.expr, y.expr))


def octet_length(arg: Expr) -> Expr:
"""Returns the number of bytes of a string."""
return Expr(f.octet_length(arg.expr))
Expand Down
47 changes: 47 additions & 0 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,17 @@ def test_sort(df):
assert table.to_pydict() == expected


def test_drop(df):
df = df.drop("c")

# execute and collect the first (and only) batch
result = df.collect()[0]

assert df.schema().names == ["a", "b"]
assert result.column(0) == pa.array([1, 2, 3])
assert result.column(1) == pa.array([4, 5, 6])


def test_limit(df):
df = df.limit(1)

Expand Down Expand Up @@ -299,6 +310,42 @@ def test_join():
assert table.to_pydict() == expected


def test_join_on():
ctx = SessionContext()

batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]], "l")

batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2]), pa.array([-8, 10])],
names=["a", "c"],
)
df1 = ctx.create_dataframe([[batch]], "r")

df2 = df.join_on(df1, column("l.a").__eq__(column("r.a")), how="inner")
df2.show()
df2 = df2.sort(column("l.a"))
table = pa.Table.from_batches(df2.collect())

expected = {"a": [1, 2], "c": [-8, 10], "b": [4, 5]}
assert table.to_pydict() == expected

df3 = df.join_on(
df1,
column("l.a").__eq__(column("r.a")),
column("l.a").__lt__(column("r.c")),
how="inner",
)
df3.show()
df3 = df3.sort(column("l.a"))
table = pa.Table.from_batches(df3.collect())
expected = {"a": [2], "c": [10], "b": [5]}
assert table.to_pydict() == expected


def test_distinct():
ctx = SessionContext()

Expand Down
33 changes: 30 additions & 3 deletions python/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

import pyarrow
import pyarrow as pa
import pytest
from datafusion import SessionContext, col
from datafusion.expr import (
Expand Down Expand Up @@ -125,8 +125,8 @@ def test_sort(test_ctx):
def test_relational_expr(test_ctx):
ctx = SessionContext()

batch = pyarrow.RecordBatch.from_arrays(
[pyarrow.array([1, 2, 3]), pyarrow.array(["alpha", "beta", "gamma"])],
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array(["alpha", "beta", "gamma"])],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]], name="batch_array")
Expand Down Expand Up @@ -216,3 +216,30 @@ def test_display_name_deprecation():
# returns appropriate result
assert name == expr.schema_name()
assert name == "foo"


@pytest.fixture
def df():
ctx = SessionContext()

# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, None]), pa.array([4, None, 6]), pa.array([None, None, 8])],
names=["a", "b", "c"],
)

return ctx.from_arrow(batch)


def test_fill_null(df):
df = df.select(
col("a").fill_null(100).alias("a"),
col("b").fill_null(25).alias("b"),
col("c").fill_null(1234).alias("c"),
)
df.show()
result = df.collect()[0]

assert result.column(0) == pa.array([1, 2, 100])
assert result.column(1) == pa.array([4, 25, 6])
assert result.column(2) == pa.array([1234, 1234, 8])
Loading

0 comments on commit c698c97

Please sign in to comment.