Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dtypes to stable api #1087

Merged
merged 40 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
961fe92
wip
MarcoGorelli Sep 28, 2024
c309cc8
use stable dtypes more
MarcoGorelli Sep 28, 2024
c676d56
Merge remote-tracking branch 'upstream/main' into stable-dtypes
MarcoGorelli Sep 28, 2024
9959b71
go further
MarcoGorelli Sep 28, 2024
e79af26
wip
MarcoGorelli Sep 28, 2024
abfda4f
missing file
MarcoGorelli Sep 28, 2024
2e7bff4
wip
MarcoGorelli Sep 28, 2024
e5025c0
fixup
MarcoGorelli Sep 28, 2024
2e08430
fixup
MarcoGorelli Sep 28, 2024
0ee4b70
keep going
MarcoGorelli Sep 29, 2024
3752fb1
keep going
MarcoGorelli Sep 29, 2024
6eb5cd6
add dtypes attribute
MarcoGorelli Sep 29, 2024
6c8ddd1
add dtypes attribute
MarcoGorelli Sep 29, 2024
aabbe37
add dtypes attribute
MarcoGorelli Sep 29, 2024
07372d7
add dtypes attribute
MarcoGorelli Sep 29, 2024
a75477e
add dtypes attribute
MarcoGorelli Sep 29, 2024
239ddce
add dtypes attribute
MarcoGorelli Sep 29, 2024
4626df9
add dtypes attribute
MarcoGorelli Sep 29, 2024
a2c2490
tests passing
MarcoGorelli Sep 29, 2024
453670b
tests passing
MarcoGorelli Sep 29, 2024
cc9104d
wip
MarcoGorelli Sep 29, 2024
b6fa8ab
wip
MarcoGorelli Sep 29, 2024
e21f1ac
wip
MarcoGorelli Sep 29, 2024
1bb35f5
wip
MarcoGorelli Sep 29, 2024
e92d45f
wip
MarcoGorelli Sep 29, 2024
e8fe720
wip
MarcoGorelli Sep 29, 2024
d0a73a3
wip
MarcoGorelli Sep 29, 2024
9861635
wip
MarcoGorelli Sep 29, 2024
546e1f1
wip
MarcoGorelli Sep 29, 2024
9f8c73a
wip
MarcoGorelli Sep 29, 2024
4848292
wip
MarcoGorelli Sep 29, 2024
d2b747b
wip
MarcoGorelli Sep 29, 2024
5733473
wip
MarcoGorelli Sep 29, 2024
1e7a119
wip
MarcoGorelli Sep 29, 2024
7d04857
wip
MarcoGorelli Sep 29, 2024
254ac7f
wip
MarcoGorelli Sep 29, 2024
4891d6a
wip
MarcoGorelli Sep 29, 2024
60ba76e
wip
MarcoGorelli Sep 29, 2024
239410a
fix docs
MarcoGorelli Sep 29, 2024
d843669
fix from_native
MarcoGorelli Sep 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/how_it_works.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ from narwhals.utils import parse_version
pn = PandasLikeNamespace(
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
)
print(nw.col("a")._call(pn))
```
Expand All @@ -101,13 +102,15 @@ import pandas as pd
pn = PandasLikeNamespace(
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
)

df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df = PandasLikeDataFrame(
df_pd,
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
)
expression = pn.col("a") + 1
result = expression._call(df)
Expand Down Expand Up @@ -196,6 +199,7 @@ import pandas as pd
pn = PandasLikeNamespace(
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
)

df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
Expand All @@ -210,6 +214,7 @@ backend, and it does so by passing a Narwhals-compliant namespace to `nw.Expr._c
pn = PandasLikeNamespace(
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
)
expr = (nw.col("a") + 1)._call(pn)
print(expr)
Expand Down
43 changes: 35 additions & 8 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,27 @@
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import IntoArrowExpr
from narwhals.dtypes import DType
from narwhals.typing import DTypes


class ArrowDataFrame:
# --- not in the spec ---
def __init__(
self, native_dataframe: pa.Table, *, backend_version: tuple[int, ...]
self,
native_dataframe: pa.Table,
*,
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> None:
self._native_frame = native_dataframe
self._implementation = Implementation.PYARROW
self._backend_version = backend_version
self._dtypes = dtypes

def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace

return ArrowNamespace(backend_version=self._backend_version)
return ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes)

def __native_namespace__(self: Self) -> ModuleType:
if self._implementation is Implementation.PYARROW:
Expand All @@ -63,7 +69,9 @@ def __narwhals_lazyframe__(self) -> Self:
return self

def _from_native_frame(self, df: Any) -> Self:
return self.__class__(df, backend_version=self._backend_version)
return self.__class__(
df, backend_version=self._backend_version, dtypes=self._dtypes
)

@property
def shape(self) -> tuple[int, int]:
Expand Down Expand Up @@ -111,6 +119,7 @@ def get_column(self, name: str) -> ArrowSeries:
self._native_frame[name],
name=name,
backend_version=self._backend_version,
dtypes=self._dtypes,
)

def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray:
Expand Down Expand Up @@ -151,6 +160,7 @@ def __getitem__(
self._native_frame[item],
name=item,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
elif (
isinstance(item, tuple)
Expand Down Expand Up @@ -191,12 +201,14 @@ def __getitem__(
self._native_frame[col_name],
name=col_name,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
selected_rows = select_rows(self._native_frame, item[0])
return ArrowSeries(
selected_rows[col_name],
name=col_name,
backend_version=self._backend_version,
dtypes=self._dtypes,
)

elif isinstance(item, slice):
Expand Down Expand Up @@ -234,7 +246,7 @@ def __getitem__(
def schema(self) -> dict[str, DType]:
schema = self._native_frame.schema
return {
name: native_to_narwhals_dtype(dtype)
name: native_to_narwhals_dtype(dtype, self._dtypes)
for name, dtype in zip(schema.names, schema.types)
}

Expand Down Expand Up @@ -410,7 +422,12 @@ def to_dict(self, *, as_series: bool) -> Any:
from narwhals._arrow.series import ArrowSeries

return {
name: ArrowSeries(col, name=name, backend_version=self._backend_version)
name: ArrowSeries(
col,
name=name,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
for name, col in names_and_values
}
else:
Expand Down Expand Up @@ -471,7 +488,9 @@ def lazy(self) -> Self:
return self

def collect(self) -> ArrowDataFrame:
return ArrowDataFrame(self._native_frame, backend_version=self._backend_version)
return ArrowDataFrame(
self._native_frame, backend_version=self._backend_version, dtypes=self._dtypes
)

def clone(self) -> Self:
msg = "clone is not yet supported on PyArrow tables"
Expand Down Expand Up @@ -541,7 +560,12 @@ def is_duplicated(self: Self) -> ArrowSeries:
).column(f"{col_token}_count"),
1,
)
return ArrowSeries(is_duplicated, name="", backend_version=self._backend_version)
return ArrowSeries(
is_duplicated,
name="",
backend_version=self._backend_version,
dtypes=self._dtypes,
)

def is_unique(self: Self) -> ArrowSeries:
import pyarrow.compute as pc # ignore-banned-import()
Expand All @@ -551,7 +575,10 @@ def is_unique(self: Self) -> ArrowSeries:
is_duplicated = self.is_duplicated()._native_series

return ArrowSeries(
pc.invert(is_duplicated), name="", backend_version=self._backend_version
pc.invert(is_duplicated),
name="",
backend_version=self._backend_version,
dtypes=self._dtypes,
)

def unique(
Expand Down
27 changes: 24 additions & 3 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import IntoArrowExpr
from narwhals.dtypes import DType
from narwhals.typing import DTypes


class ArrowExpr:
Expand All @@ -29,6 +30,7 @@ def __init__(
root_names: list[str] | None,
output_names: list[str] | None,
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> None:
self._call = call
self._depth = depth
Expand All @@ -38,6 +40,7 @@ def __init__(
self._output_names = output_names
self._implementation = Implementation.PYARROW
self._backend_version = backend_version
self._dtypes = dtypes

def __repr__(self) -> str: # pragma: no cover
return (
Expand All @@ -50,7 +53,10 @@ def __repr__(self) -> str: # pragma: no cover

@classmethod
def from_column_names(
cls: type[Self], *column_names: str, backend_version: tuple[int, ...]
cls: type[Self],
*column_names: str,
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> Self:
from narwhals._arrow.series import ArrowSeries

Expand All @@ -60,6 +66,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
df._native_frame[column_name],
name=column_name,
backend_version=df._backend_version,
dtypes=df._dtypes,
)
for column_name in column_names
]
Expand All @@ -71,11 +78,15 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
root_names=list(column_names),
output_names=list(column_names),
backend_version=backend_version,
dtypes=dtypes,
)

@classmethod
def from_column_indices(
cls: type[Self], *column_indices: int, backend_version: tuple[int, ...]
cls: type[Self],
*column_indices: int,
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> Self:
from narwhals._arrow.series import ArrowSeries

Expand All @@ -85,6 +96,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
df._native_frame[column_index],
name=df._native_frame.column_names[column_index],
backend_version=df._backend_version,
dtypes=df._dtypes,
)
for column_index in column_indices
]
Expand All @@ -96,12 +108,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
root_names=None,
output_names=None,
backend_version=backend_version,
dtypes=dtypes,
)

def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace

return ArrowNamespace(backend_version=self._backend_version)
return ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes)

def __narwhals_expr__(self) -> None: ...

Expand Down Expand Up @@ -246,6 +259,7 @@ def alias(self, name: str) -> Self:
root_names=self._root_names,
output_names=[name],
backend_version=self._backend_version,
dtypes=self._dtypes,
)

def null_count(self) -> Self:
Expand Down Expand Up @@ -352,6 +366,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
root_names=self._root_names,
output_names=self._output_names,
backend_version=self._backend_version,
dtypes=self._dtypes,
)

def mode(self: Self) -> Self:
Expand Down Expand Up @@ -573,6 +588,7 @@ def keep(self: Self) -> ArrowExpr:
root_names=root_names,
output_names=root_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
)

def map(self: Self, function: Callable[[str], str]) -> ArrowExpr:
Expand All @@ -598,6 +614,7 @@ def map(self: Self, function: Callable[[str], str]) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
)

def prefix(self: Self, prefix: str) -> ArrowExpr:
Expand All @@ -621,6 +638,7 @@ def prefix(self: Self, prefix: str) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
)

def suffix(self: Self, suffix: str) -> ArrowExpr:
Expand All @@ -645,6 +663,7 @@ def suffix(self: Self, suffix: str) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
)

def to_lowercase(self: Self) -> ArrowExpr:
Expand All @@ -669,6 +688,7 @@ def to_lowercase(self: Self) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
)

def to_uppercase(self: Self) -> ArrowExpr:
Expand All @@ -693,4 +713,5 @@ def to_uppercase(self: Self) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
)
Loading
Loading