Skip to content

Commit

Permalink
feat: add top level cast
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Oct 13, 2024
1 parent 35656f1 commit bded1c4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
13 changes: 13 additions & 0 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,19 @@ def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
exprs_raw = [sort_or_default(expr) for expr in exprs]
return DataFrame(self.df.sort(*exprs_raw))

def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
"""Cast all or a subset of columns to new dtype.
Args:
mapping (dict[str, pa.DataType[Any]]): Mapped with column as key and column
dtype as value.
Returns:
DataFrame after casting columns
"""
exprs = [Expr.column(col).cast(dtype) for col, dtype in mapping.items()]
return self.with_columns(exprs)

def limit(self, count: int, offset: int = 0) -> DataFrame:
"""Return a new :py:class:`DataFrame` with a limited number of rows.
Expand Down
9 changes: 9 additions & 0 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,15 @@ def test_with_columns(df):
assert result.column(6) == pa.array([5, 7, 9])


def test_cast(df):
df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
expected = pa.schema(
[("a", pa.float16()), ("b", pa.list_(pa.uint32())), ("c", pa.int64())]
)

assert df.schema() == expected


def test_with_column_renamed(df):
df = df.with_column("c", column("a") + column("b")).with_column_renamed("c", "sum")

Expand Down

0 comments on commit bded1c4

Please sign in to comment.