diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 350be46d..a37abe53 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -89,6 +89,10 @@ jobs: name: python-wheel-license path: . + # To remove once https://github.com/MaterializeInc/rust-protobuf-native/issues/20 is resolved + - name: Install gtest + uses: MarkusJx/googletest-installer@v1.1 + - name: Install Protoc uses: arduino/setup-protoc@v1 with: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 4f47dc98..c9a365bb 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -55,6 +55,10 @@ jobs: version: '3.20.2' repo-token: ${{ secrets.GITHUB_TOKEN }} + # To remove once https://github.com/MaterializeInc/rust-protobuf-native/issues/20 is resolved + - name: Install gtest + uses: MarkusJx/googletest-installer@v1.1 + - name: Setup Python uses: actions/setup-python@v5 with: diff --git a/benchmarks/db-benchmark/join-datafusion.py b/benchmarks/db-benchmark/join-datafusion.py index 4d59c7dc..811ad870 100755 --- a/benchmarks/db-benchmark/join-datafusion.py +++ b/benchmarks/db-benchmark/join-datafusion.py @@ -74,7 +74,8 @@ def ans_shape(batches): ctx = df.SessionContext() print(ctx) -# TODO we should be applying projections to these table reads to crete relations of different sizes +# TODO we should be applying projections to these table reads to create relations +# of different sizes x_data = pacsv.read_csv( src_jn_x, convert_options=pacsv.ConvertOptions(auto_dict_encode=True) diff --git a/conda/recipes/meta.yaml b/conda/recipes/meta.yaml index 72ac7f50..b0784253 100644 --- a/conda/recipes/meta.yaml +++ b/conda/recipes/meta.yaml @@ -51,6 +51,7 @@ requirements: run: - python - pyarrow >=11.0.0 + - typing_extensions test: imports: diff --git a/docs/source/api/functions.rst b/docs/source/api/functions.rst index 958606df..6f10d826 100644 --- a/docs/source/api/functions.rst +++ b/docs/source/api/functions.rst @@ -24,4 +24,4 @@ Functions .. autosummary:: :toctree: ../generated/ - functions.functions + functions diff --git a/docs/source/conf.py b/docs/source/conf.py index c0da8b2c..308069b6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +"""Documenation generation.""" + # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full @@ -78,6 +80,25 @@ autosummary_generate = True + +def autodoc_skip_member(app, what, name, obj, skip, options): + exclude_functions = "__init__" + exclude_classes = ("Expr", "DataFrame") + + class_name = "" + if hasattr(obj, "__qualname__"): + if obj.__qualname__ is not None: + class_name = obj.__qualname__.split(".")[0] + + should_exclude = name in exclude_functions and class_name in exclude_classes + + return True if should_exclude else None + + +def setup(app): + app.connect("autodoc-skip-member", autodoc_skip_member) + + # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for diff --git a/examples/substrait.py b/examples/substrait.py index 23cd7464..fd4d0f9c 100644 --- a/examples/substrait.py +++ b/examples/substrait.py @@ -18,16 +18,13 @@ from datafusion import SessionContext from datafusion import substrait as ss - # Create a DataFusion context ctx = SessionContext() # Register table with context ctx.register_csv("aggregate_test_data", "./testing/data/csv/aggregate_test_100.csv") -substrait_plan = ss.substrait.serde.serialize_to_plan( - "SELECT * FROM aggregate_test_data", ctx -) +substrait_plan = ss.Serde.serialize_to_plan("SELECT * FROM aggregate_test_data", ctx) # type(substrait_plan) -> # Encode it to bytes @@ -38,17 +35,15 @@ # Alternative serialization approaches # type(substrait_bytes) -> , at this point the bytes can be distributed to file, network, etc safely # where they could subsequently be deserialized on the receiving end. -substrait_bytes = ss.substrait.serde.serialize_bytes( - "SELECT * FROM aggregate_test_data", ctx -) +substrait_bytes = ss.Serde.serialize_bytes("SELECT * FROM aggregate_test_data", ctx) # Imagine here bytes would be read from network, file, etc ... for example brevity this is omitted and variable is simply reused # type(substrait_plan) -> -substrait_plan = ss.substrait.serde.deserialize_bytes(substrait_bytes) +substrait_plan = ss.Serde.deserialize_bytes(substrait_bytes) # type(df_logical_plan) -> -df_logical_plan = ss.substrait.consumer.from_substrait_plan(ctx, substrait_plan) +df_logical_plan = ss.Consumer.from_substrait_plan(ctx, substrait_plan) # Back to Substrait Plan just for demonstration purposes # type(substrait_plan) -> -substrait_plan = ss.substrait.producer.to_substrait_plan(df_logical_plan) +substrait_plan = ss.Producer.to_substrait_plan(df_logical_plan) diff --git a/examples/tpch/_tests.py b/examples/tpch/_tests.py index 3f973d9f..903b5354 100644 --- a/examples/tpch/_tests.py +++ b/examples/tpch/_tests.py @@ -21,6 +21,7 @@ from datafusion import col, lit, functions as F from util import get_answer_file + def df_selection(col_name, col_type): if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type): return F.round(col(col_name), lit(2)).alias(col_name) @@ -29,6 +30,7 @@ def df_selection(col_name, col_type): else: return col(col_name) + def load_schema(col_name, col_type): if col_type == pa.int64() or col_type == pa.int32(): return col_name, pa.string() @@ -36,7 +38,8 @@ def load_schema(col_name, col_type): return col_name, pa.float64() else: return col_name, col_type - + + def expected_selection(col_name, col_type): if col_type == pa.int64() or col_type == pa.int32(): return F.trim(col(col_name)).cast(col_type).alias(col_name) @@ -45,20 +48,23 @@ def expected_selection(col_name, col_type): else: return col(col_name) + def selections_and_schema(original_schema): - columns = [ (c, original_schema.field(c).type) for c in original_schema.names ] + columns = [(c, original_schema.field(c).type) for c in original_schema.names] - df_selections = [ df_selection(c, t) for (c, t) in columns] - expected_schema = [ load_schema(c, t) for (c, t) in columns] - expected_selections = [ expected_selection(c, t) for (c, t) in columns] + df_selections = [df_selection(c, t) for (c, t) in columns] + expected_schema = [load_schema(c, t) for (c, t) in columns] + expected_selections = [expected_selection(c, t) for (c, t) in columns] return (df_selections, expected_schema, expected_selections) + def check_q17(df): raw_value = float(df.collect()[0]["avg_yearly"][0].as_py()) value = round(raw_value, 2) assert abs(value - 348406.05) < 0.001 + @pytest.mark.parametrize( ("query_code", "answer_file"), [ @@ -72,9 +78,7 @@ def check_q17(df): ("q08_market_share", "q8"), ("q09_product_type_profit_measure", "q9"), ("q10_returned_item_reporting", "q10"), - pytest.param( - "q11_important_stock_identification", "q11", - ), + ("q11_important_stock_identification", "q11"), ("q12_ship_mode_order_priority", "q12"), ("q13_customer_distribution", "q13"), ("q14_promotion_effect", "q14"), @@ -92,18 +96,26 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str): module = import_module(query_code) df = module.df - # Treat q17 as a special case. The answer file does not match the spec. Running at - # scale factor 1, we have manually verified this result does match the expected value. + # Treat q17 as a special case. The answer file does not match the spec. + # Running at scale factor 1, we have manually verified this result does + # match the expected value. if answer_file == "q17": return check_q17(df) - (df_selections, expected_schema, expected_selections) = selections_and_schema(df.schema()) + (df_selections, expected_schema, expected_selections) = selections_and_schema( + df.schema() + ) df = df.select(*df_selections) read_schema = pa.schema(expected_schema) - df_expected = module.ctx.read_csv(get_answer_file(answer_file), schema=read_schema, delimiter="|", file_extension=".out") + df_expected = module.ctx.read_csv( + get_answer_file(answer_file), + schema=read_schema, + delimiter="|", + file_extension=".out", + ) df_expected = df_expected.select(*expected_selections) diff --git a/examples/tpch/convert_data_to_parquet.py b/examples/tpch/convert_data_to_parquet.py index d81ec290..a8091a70 100644 --- a/examples/tpch/convert_data_to_parquet.py +++ b/examples/tpch/convert_data_to_parquet.py @@ -117,7 +117,6 @@ curr_dir = os.path.dirname(os.path.abspath(__file__)) for filename, curr_schema in all_schemas.items(): - # For convenience, go ahead and convert the schema column names to lowercase curr_schema = [(s[0].lower(), s[1]) for s in curr_schema] @@ -125,7 +124,7 @@ # in to handle the trailing | in the file output_cols = [r[0] for r in curr_schema] - curr_schema = [ pyarrow.field(r[0], r[1], nullable=False) for r in curr_schema] + curr_schema = [pyarrow.field(r[0], r[1], nullable=False) for r in curr_schema] # Trailing | requires extra field for in processing curr_schema.append(("some_null", pyarrow.null())) diff --git a/examples/tpch/q08_market_share.py b/examples/tpch/q08_market_share.py index d13a71df..cd6bc1fa 100644 --- a/examples/tpch/q08_market_share.py +++ b/examples/tpch/q08_market_share.py @@ -47,7 +47,9 @@ ctx = SessionContext() -df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_type") +df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns( + "p_partkey", "p_type" +) df_supplier = ctx.read_parquet(get_data_path("supplier.parquet")).select_columns( "s_suppkey", "s_nationkey" ) diff --git a/examples/tpch/q09_product_type_profit_measure.py b/examples/tpch/q09_product_type_profit_measure.py index 29ffceed..b4a7369f 100644 --- a/examples/tpch/q09_product_type_profit_measure.py +++ b/examples/tpch/q09_product_type_profit_measure.py @@ -39,7 +39,9 @@ ctx = SessionContext() -df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_name") +df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns( + "p_partkey", "p_name" +) df_supplier = ctx.read_parquet(get_data_path("supplier.parquet")).select_columns( "s_suppkey", "s_nationkey" ) diff --git a/examples/tpch/q13_customer_distribution.py b/examples/tpch/q13_customer_distribution.py index 2b6e7e20..bc0a5bd1 100644 --- a/examples/tpch/q13_customer_distribution.py +++ b/examples/tpch/q13_customer_distribution.py @@ -41,7 +41,9 @@ df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns( "o_custkey", "o_comment" ) -df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns("c_custkey") +df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns( + "c_custkey" +) # Use a regex to remove special cases df_orders = df_orders.filter( diff --git a/examples/tpch/q14_promotion_effect.py b/examples/tpch/q14_promotion_effect.py index 75fa363a..8cb1e4c5 100644 --- a/examples/tpch/q14_promotion_effect.py +++ b/examples/tpch/q14_promotion_effect.py @@ -44,7 +44,9 @@ df_lineitem = ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns( "l_partkey", "l_shipdate", "l_extendedprice", "l_discount" ) -df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_type") +df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns( + "p_partkey", "p_type" +) # Check part type begins with PROMO diff --git a/examples/tpch/q16_part_supplier_relationship.py b/examples/tpch/q16_part_supplier_relationship.py index 0db2d1b8..fdcb5b4d 100644 --- a/examples/tpch/q16_part_supplier_relationship.py +++ b/examples/tpch/q16_part_supplier_relationship.py @@ -62,7 +62,8 @@ # Select the parts we are interested in df_part = df_part.filter(col("p_brand") != lit(BRAND)) df_part = df_part.filter( - F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) != lit(TYPE_TO_IGNORE) + F.substring(col("p_type"), lit(0), lit(len(TYPE_TO_IGNORE) + 1)) + != lit(TYPE_TO_IGNORE) ) # Python conversion of integer to literal casts it to int64 but the data for diff --git a/examples/tpch/q17_small_quantity_order.py b/examples/tpch/q17_small_quantity_order.py index 5880e7ed..e0ee8bb9 100644 --- a/examples/tpch/q17_small_quantity_order.py +++ b/examples/tpch/q17_small_quantity_order.py @@ -56,7 +56,13 @@ # Find the average quantity window_frame = WindowFrame("rows", None, None) df = df.with_column( - "avg_quantity", F.window("avg", [col("l_quantity")], window_frame=window_frame, partition_by=[col("l_partkey")]) + "avg_quantity", + F.window( + "avg", + [col("l_quantity")], + window_frame=window_frame, + partition_by=[col("l_partkey")], + ), ) df = df.filter(col("l_quantity") < lit(0.2) * col("avg_quantity")) diff --git a/examples/tpch/q20_potential_part_promotion.py b/examples/tpch/q20_potential_part_promotion.py index 85e7226f..05a26745 100644 --- a/examples/tpch/q20_potential_part_promotion.py +++ b/examples/tpch/q20_potential_part_promotion.py @@ -40,7 +40,9 @@ ctx = SessionContext() -df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns("p_partkey", "p_name") +df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns( + "p_partkey", "p_name" +) df_lineitem = ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns( "l_shipdate", "l_partkey", "l_suppkey", "l_quantity" ) diff --git a/examples/tpch/q22_global_sales_opportunity.py b/examples/tpch/q22_global_sales_opportunity.py index dfde19cb..622c1429 100644 --- a/examples/tpch/q22_global_sales_opportunity.py +++ b/examples/tpch/q22_global_sales_opportunity.py @@ -38,7 +38,9 @@ df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns( "c_phone", "c_acctbal", "c_custkey" ) -df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns("o_custkey") +df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns( + "o_custkey" +) # The nation code is a two digit number, but we need to convert it to a string literal nation_codes = F.make_array(*[lit(str(n)) for n in NATION_CODES]) diff --git a/examples/tpch/util.py b/examples/tpch/util.py index 191fa609..7e3d659d 100644 --- a/examples/tpch/util.py +++ b/examples/tpch/util.py @@ -20,14 +20,17 @@ """ import os -from pathlib import Path + def get_data_path(filename: str) -> str: path = os.path.dirname(os.path.abspath(__file__)) return os.path.join(path, "data", filename) + def get_answer_file(answer_file: str) -> str: path = os.path.dirname(os.path.abspath(__file__)) - return os.path.join(path, "../../benchmarks/tpch/data/answers", f"{answer_file}.out") + return os.path.join( + path, "../../benchmarks/tpch/data/answers", f"{answer_file}.out" + ) diff --git a/pyproject.toml b/pyproject.toml index b706065a..a18ef0e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,3 +64,21 @@ exclude = [".github/**", "ci/**", ".asf.yaml"] # Require Cargo.lock is up to date locked = true features = ["substrait"] + +# Enable docstring linting using the google style guide +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "D", "W"] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.pycodestyle] +max-doc-length = 88 + +# Disable docstring checking for these directories +[tool.ruff.lint.per-file-ignores] +"python/datafusion/tests/*" = ["D"] +"examples/*" = ["D", "W505"] +"dev/*" = ["D"] +"benchmarks/*" = ["D", "F"] +"docs/*" = ["D"] diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 846b1a45..59bc8e30 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -15,80 +15,44 @@ # specific language governing permissions and limitations # under the License. -from abc import ABCMeta, abstractmethod -from typing import List +"""DataFusion python package. + +This is a Python library that binds to Apache Arrow in-memory query engine DataFusion. +See https://datafusion.apache.org/python for more information. +""" try: import importlib.metadata as importlib_metadata except ImportError: import importlib_metadata -import pyarrow as pa - -from ._internal import ( - AggregateUDF, - Config, - DataFrame, +from .context import ( SessionContext, SessionConfig, RuntimeConfig, - ScalarUDF, SQLOptions, ) +# The following imports are okay to remain as opaque to the user. +from ._internal import Config + +from .udf import ScalarUDF, AggregateUDF, Accumulator + from .common import ( DFSchema, ) +from .dataframe import DataFrame + from .expr import ( - Alias, - Analyze, Expr, - Filter, - Limit, - Like, - ILike, - Projection, - SimilarTo, - ScalarVariable, - Sort, - TableScan, - Not, - IsNotNull, - IsTrue, - IsFalse, - IsUnknown, - IsNotTrue, - IsNotFalse, - IsNotUnknown, - Negative, - InList, - Exists, - Subquery, - InSubquery, - ScalarSubquery, - GroupingSet, - Placeholder, - Case, - Cast, - TryCast, - Between, - Explain, - CreateMemoryTable, - SubqueryAlias, - Extension, - CreateView, - Distinct, - DropTable, - Repartition, - Partitioning, - Window, WindowFrame, ) __version__ = importlib_metadata.version(__name__) __all__ = [ + "Accumulator", "Config", "DataFrame", "SessionContext", @@ -96,78 +60,16 @@ "SQLOptions", "RuntimeConfig", "Expr", - "AggregateUDF", "ScalarUDF", - "Window", "WindowFrame", "column", "literal", - "TableScan", - "Projection", "DFSchema", - "DFField", - "Analyze", - "Sort", - "Limit", - "Filter", - "Like", - "ILike", - "SimilarTo", - "ScalarVariable", - "Alias", - "Not", - "IsNotNull", - "IsTrue", - "IsFalse", - "IsUnknown", - "IsNotTrue", - "IsNotFalse", - "IsNotUnknown", - "Negative", - "ScalarFunction", - "BuiltinScalarFunction", - "InList", - "Exists", - "Subquery", - "InSubquery", - "ScalarSubquery", - "GroupingSet", - "Placeholder", - "Case", - "Cast", - "TryCast", - "Between", - "Explain", - "SubqueryAlias", - "Extension", - "CreateMemoryTable", - "CreateView", - "Distinct", - "DropTable", - "Repartition", - "Partitioning", ] -class Accumulator(metaclass=ABCMeta): - @abstractmethod - def state(self) -> List[pa.Scalar]: - pass - - @abstractmethod - def update(self, values: pa.Array) -> None: - pass - - @abstractmethod - def merge(self, states: pa.Array) -> None: - pass - - @abstractmethod - def evaluate(self) -> pa.Scalar: - pass - - -def column(value): +def column(value: str): + """Create a column expression.""" return Expr.column(value) @@ -175,46 +77,12 @@ def column(value): def literal(value): - if not isinstance(value, pa.Scalar): - value = pa.scalar(value) + """Create a literal expression.""" return Expr.literal(value) lit = literal +udf = ScalarUDF.udf -def udf(func, input_types, return_type, volatility, name=None): - """ - Create a new User Defined Function - """ - if not callable(func): - raise TypeError("`func` argument must be callable") - if name is None: - name = func.__qualname__.lower() - return ScalarUDF( - name=name, - func=func, - input_types=input_types, - return_type=return_type, - volatility=volatility, - ) - - -def udaf(accum, input_type, return_type, state_type, volatility, name=None): - """ - Create a new User Defined Aggregate Function - """ - if not issubclass(accum, Accumulator): - raise TypeError("`accum` must implement the abstract base class Accumulator") - if name is None: - name = accum.__qualname__.lower() - if isinstance(input_type, pa.lib.DataType): - input_type = [input_type] - return AggregateUDF( - name=name, - accumulator=accum, - input_type=input_type, - return_type=return_type, - state_type=state_type, - volatility=volatility, - ) +udaf = AggregateUDF.udaf diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py new file mode 100644 index 00000000..cec0be76 --- /dev/null +++ b/python/datafusion/catalog.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Data catalog providers.""" + +from __future__ import annotations + +import datafusion._internal as df_internal + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import pyarrow + + +class Catalog: + """DataFusion data catalog.""" + + def __init__(self, catalog: df_internal.Catalog) -> None: + """This constructor is not typically called by the end user.""" + self.catalog = catalog + + def names(self) -> list[str]: + """Returns the list of databases in this catalog.""" + return self.catalog.names() + + def database(self, name: str = "public") -> Database: + """Returns the database with the given `name` from this catalog.""" + return Database(self.catalog.database(name)) + + +class Database: + """DataFusion Database.""" + + def __init__(self, db: df_internal.Database) -> None: + """This constructor is not typically called by the end user.""" + self.db = db + + def names(self) -> set[str]: + """Returns the list of all tables in this database.""" + return self.db.names() + + def table(self, name: str) -> Table: + """Return the table with the given `name` from this database.""" + return Table(self.db.table(name)) + + +class Table: + """DataFusion table.""" + + def __init__(self, table: df_internal.Table) -> None: + """This constructor is not typically called by the end user.""" + self.table = table + + def schema(self) -> pyarrow.Schema: + """Returns the schema associated with this table.""" + return self.table.schema() + + @property + def kind(self) -> str: + """Returns the kind of table.""" + return self.table.kind() diff --git a/python/datafusion/common.py b/python/datafusion/common.py index dd56640a..2351845b 100644 --- a/python/datafusion/common.py +++ b/python/datafusion/common.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Common data types used throughout the DataFusion project.""" from ._internal import common diff --git a/python/datafusion/context.py b/python/datafusion/context.py new file mode 100644 index 00000000..a717db10 --- /dev/null +++ b/python/datafusion/context.py @@ -0,0 +1,1003 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Session Context and it's associated configuration.""" + +from __future__ import annotations + +from ._internal import SessionConfig as SessionConfigInternal +from ._internal import RuntimeConfig as RuntimeConfigInternal +from ._internal import SQLOptions as SQLOptionsInternal +from ._internal import SessionContext as SessionContextInternal +from ._internal import LogicalPlan, ExecutionPlan + +from datafusion._internal import AggregateUDF +from datafusion.catalog import Catalog, Table +from datafusion.dataframe import DataFrame +from datafusion.expr import Expr +from datafusion.record_batch import RecordBatchStream +from datafusion.udf import ScalarUDF + +from typing import Any, TYPE_CHECKING +from typing_extensions import deprecated + +if TYPE_CHECKING: + import pyarrow + import pandas + import polars + import pathlib + + +class SessionConfig: + """Session configuration options.""" + + def __init__(self, config_options: dict[str, str] | None = None) -> None: + """Create a new `SessionConfig` with the given configuration options. + + Args: + config_options: Configuration options. + """ + self.config_internal = SessionConfigInternal(config_options) + + def with_create_default_catalog_and_schema( + self, enabled: bool = True + ) -> SessionConfig: + """Control if the default catalog and schema will be automatically created. + + Args: + enabled: Whether the default catalog and schema will be + automatically created. + + Returns: + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = ( + self.config_internal.with_create_default_catalog_and_schema(enabled) + ) + return self + + def with_default_catalog_and_schema( + self, catalog: str, schema: str + ) -> SessionConfig: + """Select a name for the default catalog and shcema. + + Args: + catalog: Catalog name. + schema: Schema name. + + Returns: + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_default_catalog_and_schema( + catalog, schema + ) + return self + + def with_information_schema(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the inclusion of `information_schema` virtual tables. + + Args: + enabled: Whether to include `information_schema` virtual tables. + + Returns: + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_information_schema(enabled) + return self + + def with_batch_size(self, batch_size: int) -> SessionConfig: + """Customize batch size. + + Args: + batch_size: Batch size. + + Returns: + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_batch_size(batch_size) + return self + + def with_target_partitions(self, target_partitions: int) -> SessionConfig: + """Customize the number of target partitions for query execution. + + Increasing partitions can increase concurrency. + + Args: + target_partitions: Number of target partitions. + + Returns: + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_target_partitions( + target_partitions + ) + return self + + def with_repartition_aggregations(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the use of repartitioning for aggregations. + + Enabling this improves parallelism. + + Args: + enabled: Whether to use repartitioning for aggregations. + + Returns: + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_repartition_aggregations( + enabled + ) + return self + + def with_repartition_joins(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the use of repartitioning for joins to improve parallelism. + + Args: + enabled: Whether to use repartitioning for joins. + + Returns: + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_repartition_joins(enabled) + return self + + def with_repartition_windows(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the use of repartitioning for window functions. + + This may improve parallelism. + + Args: + enabled: Whether to use repartitioning for window functions. + + Returns: + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_repartition_windows(enabled) + return self + + def with_repartition_sorts(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the use of repartitioning for window functions. + + This may improve parallelism. + + Args: + enabled: Whether to use repartitioning for window functions. + + Returns: + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_repartition_sorts(enabled) + return self + + def with_repartition_file_scans(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the use of repartitioning for file scans. + + Args: + enabled: Whether to use repartitioning for file scans. + + Returns: + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_repartition_file_scans(enabled) + return self + + def with_repartition_file_min_size(self, size: int) -> SessionConfig: + """Set minimum file range size for repartitioning scans. + + Args: + size: Minimum file range size. + + Returns: + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_repartition_file_min_size(size) + return self + + def with_parquet_pruning(self, enabled: bool = True) -> SessionConfig: + """Enable or disable the use of pruning predicate for parquet readers. + + Pruning predicates will enable the reader to skip row groups. + + Args: + enabled: Whether to use pruning predicate for parquet readers. + + Returns: + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_parquet_pruning(enabled) + return self + + def set(self, key: str, value: str) -> SessionConfig: + """Set a configuration option. + + Args: + key: Option key. + value: Option value. + + Returns: + A new `SessionConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.set(key, value) + return self + + +class RuntimeConfig: + """Runtime configuration options.""" + + def __init__(self) -> None: + """Create a new `RuntimeConfig` with default values.""" + self.config_internal = RuntimeConfigInternal() + + def with_disk_manager_disabled(self) -> RuntimeConfig: + """Disable the disk manager, attempts to create temporary files will error. + + Returns: + A new `RuntimeConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_disk_manager_disabled() + return self + + def with_disk_manager_os(self) -> RuntimeConfig: + """Use the operating system's temporary directory for disk manager. + + Returns: + A new `RuntimeConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_disk_manager_os() + return self + + def with_disk_manager_specified(self, *paths: str | pathlib.Path) -> RuntimeConfig: + """Use the specified paths for the disk manager's temporary files. + + Args: + paths: Paths to use for the disk manager's temporary files. + + Returns: + A new `RuntimeConfig` object with the updated setting. + """ + paths = [str(p) for p in paths] + self.config_internal = self.config_internal.with_disk_manager_specified(paths) + return self + + def with_unbounded_memory_pool(self) -> RuntimeConfig: + """Use an unbounded memory pool. + + Returns: + A new `RuntimeConfig` object with the updated setting. + """ + self.config_internal = self.config_internal.with_unbounded_memory_pool() + return self + + def with_fair_spill_pool(self, size: int) -> RuntimeConfig: + """Use a fair spill pool with the specified size. + + This pool works best when you know beforehand the query has multiple spillable + operators that will likely all need to spill. Sometimes it will cause spills + even when there was sufficient memory (reserved for other operators) to avoid + doing so:: + + ┌───────────────────────z──────────────────────z───────────────┐ + │ z z │ + │ z z │ + │ Spillable z Unspillable z Free │ + │ Memory z Memory z Memory │ + │ z z │ + │ z z │ + └───────────────────────z──────────────────────z───────────────┘ + + Args: + size: Size of the memory pool in bytes. + + Returns: + A new ``RuntimeConfig`` object with the updated setting. + + Examples usage:: + + config = RuntimeConfig().with_fair_spill_pool(1024) + """ + self.config_internal = self.config_internal.with_fair_spill_pool(size) + return self + + def with_greedy_memory_pool(self, size: int) -> RuntimeConfig: + """Use a greedy memory pool with the specified size. + + This pool works well for queries that do not need to spill or have a single + spillable operator. See `RuntimeConfig.with_fair_spill_pool` if there are + multiple spillable operators that all will spill. + + Args: + size: Size of the memory pool in bytes. + + Returns: + A new `RuntimeConfig` object with the updated setting. + + Example usage:: + + config = RuntimeConfig().with_greedy_memory_pool(1024) + """ + self.config_internal = self.config_internal.with_greedy_memory_pool(size) + return self + + def with_temp_file_path(self, path: str | pathlib.Path) -> RuntimeConfig: + """Use the specified path to create any needed temporary files. + + Args: + path: Path to use for temporary files. + + Returns: + A new `RuntimeConfig` object with the updated setting. + + Example usage:: + + config = RuntimeConfig().with_temp_file_path("/tmp") + """ + self.config_internal = self.config_internal.with_temp_file_path(str(path)) + return self + + +class SQLOptions: + """Options to be used when performing SQL queries on the ``SessionContext``.""" + + def __init__(self) -> None: + """Create a new `SQLOptions` with default values. + + The default values are: + - DDL commands are allowed + - DML commands are allowed + - Statements are allowed + """ + self.options_internal = SQLOptionsInternal() + + def with_allow_ddl(self, allow: bool = True) -> SQLOptions: + """Should DDL (Data Definition Language) commands be run? + + Examples of DDL commands include `CREATE TABLE` and `DROP TABLE`. + + Args: + allow: Allow DDL commands to be run. + + Returns: + A new `SQLOptions` object with the updated setting. + + Example usage:: + + options = SQLOptions().with_allow_ddl(True) + """ + self.options_internal = self.options_internal.with_allow_ddl(allow) + return self + + def with_allow_dml(self, allow: bool = True) -> SQLOptions: + """Should DML (Data Manipulation Language) commands be run? + + Examples of DML commands include `INSERT INTO` and `DELETE`. + + Args: + allow: Allow DML commands to be run. + + Returns: + A new `SQLOptions` object with the updated setting. + + Example usage:: + + options = SQLOptions().with_allow_dml(True) + """ + self.options_internal = self.options_internal.with_allow_dml(allow) + return self + + def with_allow_statements(self, allow: bool = True) -> SQLOptions: + """Should statements such as `SET VARIABLE` and `BEGIN TRANSACTION` be run? + + Args: + allow: Allow statements to be run. + + Returns: + A new `SQLOptions` object with the updated setting. + + Example usage:: + + options = SQLOptions().with_allow_statements(True) + """ + self.options_internal = self.options_internal.with_allow_statements(allow) + return self + + +class SessionContext: + """This is the main interface for executing queries and creating DataFrames. + + See https://datafusion.apache.org/python/user-guide/basics.html for + additional information. + """ + + def __init__( + self, config: SessionConfig | None = None, runtime: RuntimeConfig | None = None + ) -> None: + """Main interface for executing queries with DataFusion. + + Maintains the state of the connection between a user and an instance + of the connection between a user and an instance of the DataFusion + engine. + + Args: + config: Session configuration options. + runtime: Runtime configuration options. + + Example usage: + + The following example demostrates how to use the context to execute + a query against a CSV data source using the ``DataFrame`` API:: + + from datafusion import SessionContext + + ctx = SessionContext() + df = ctx.read_csv("data.csv") + """ + config = config.config_internal if config is not None else None + runtime = runtime.config_internal if config is not None else None + + self.ctx = SessionContextInternal(config, runtime) + + def register_object_store(self, schema: str, store: Any, host: str | None) -> None: + """Add a new object store into the session. + + Args: + schema: The data source schema. + store: The `ObjectStore` to register. + host: URL for the host. + """ + self.ctx.register_object_store(schema, store, host) + + def register_listing_table( + self, + name: str, + path: str | pathlib.Path, + table_partition_cols: list[tuple[str, str]] | None = None, + file_extension: str = ".parquet", + schema: pyarrow.Schema | None = None, + file_sort_order: list[list[Expr]] | None = None, + ) -> None: + """Register multiple files as a single table. + + Registers a `Table` that can assemble multiple files from locations in + an `ObjectStore` instance. + + Args: + name: Name of the resultant table. + path: Path to the file to register. + table_partition_cols: Partition columns. + file_extension: File extension of the provided table. + schema: The data source schema. + file_sort_order: Sort order for the file. + """ + if table_partition_cols is None: + table_partition_cols = [] + if file_sort_order is not None: + file_sort_order = [[x.expr for x in xs] for xs in file_sort_order] + self.ctx.register_listing_table( + name, + str(path), + table_partition_cols, + file_extension, + schema, + file_sort_order, + ) + + def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: + """Create a `DataFrame` from SQL query text. + + Note: This API implements DDL statements such as `CREATE TABLE` and + `CREATE VIEW` and DML statements such as `INSERT INTO` with in-memory + default implementation. See `SessionContext.sql_with_options`. + + Args: + query: SQL query text. + options: If provided, the query will be validated against these options. + + Returns: + DataFrame representation of the SQL query. + """ + if options is None: + return DataFrame(self.ctx.sql(query)) + return DataFrame(self.ctx.sql_with_options(query, options.options_internal)) + + def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: + """Create a `DataFrame` from SQL query text. + + This function will first validating that the query is allowed by the + provided options. + + Args: + query: SQL query text. + options: SQL options. + + Returns: + DataFrame representation of the SQL query. + """ + return self.sql(query, options) + + def create_dataframe( + self, + partitions: list[list[pyarrow.RecordBatch]], + name: str | None = None, + schema: pyarrow.Schema | None = None, + ) -> DataFrame: + """Create and return a dataframe using the provided partitions. + + Args: + partitions: `RecordBatch` partitions to register. + name: Resultant dataframe name. + schema: Schema for the partitions. + + Returns: + DataFrame representation of the SQL query. + """ + return DataFrame(self.ctx.create_dataframe(partitions, name, schema)) + + def create_dataframe_from_logical_plan(self, plan: LogicalPlan) -> DataFrame: + """Create a `DataFrame` from an existing logical plan. + + Args: + plan: Logical plan. + + Returns: + DataFrame representation of the logical plan. + """ + return DataFrame(self.ctx.create_dataframe_from_logical_plan(plan)) + + def from_pylist( + self, data: list[dict[str, Any]], name: str | None = None + ) -> DataFrame: + """Create a `DataFrame` from a list of dictionaries. + + Args: + data: List of dictionaries. + name: Name of the DataFrame. + + Returns: + DataFrame representation of the list of dictionaries. + """ + return DataFrame(self.ctx.from_pylist(data, name)) + + def from_pydict( + self, data: dict[str, list[Any]], name: str | None = None + ) -> DataFrame: + """Create a `DataFrame` from a dictionary of lists. + + Args: + data: Dictionary of lists. + name: Name of the DataFrame. + + Returns: + DataFrame representation of the dictionary of lists. + """ + return DataFrame(self.ctx.from_pydict(data, name)) + + def from_arrow_table( + self, data: pyarrow.Table, name: str | None = None + ) -> DataFrame: + """Create a `DataFrame` from an Arrow table. + + Args: + data: Arrow table. + name: Name of the DataFrame. + + Returns: + DataFrame representation of the Arrow table. + """ + return DataFrame(self.ctx.from_arrow_table(data, name)) + + def from_pandas(self, data: pandas.DataFrame, name: str | None = None) -> DataFrame: + """Create a `DataFrame` from a Pandas DataFrame. + + Args: + data: Pandas DataFrame. + name: Name of the DataFrame. + + Returns: + DataFrame representation of the Pandas DataFrame. + """ + return DataFrame(self.ctx.from_pandas(data, name)) + + def from_polars(self, data: polars.DataFrame, name: str | None = None) -> DataFrame: + """Create a `DataFrame` from a Polars DataFrame. + + Args: + data: Polars DataFrame. + name: Name of the DataFrame. + + Returns: + DataFrame representation of the Polars DataFrame. + """ + return DataFrame(self.ctx.from_polars(data, name)) + + def register_table(self, name: str, table: pyarrow.Table) -> None: + """Register a table with the given name into the session. + + Args: + name: Name of the resultant table. + table: PyArrow table to add to the session context. + """ + self.ctx.register_table(name, table) + + def deregister_table(self, name: str) -> None: + """Remove a table from the session.""" + self.ctx.deregister_table(name) + + def register_record_batches( + self, name: str, partitions: list[list[pyarrow.RecordBatch]] + ) -> None: + """Register record batches as a table. + + This function will convert the provided partitions into a table and + register it into the session using the given name. + + Args: + name: Name of the resultant table. + partitions: Record batches to register as a table. + """ + self.ctx.register_record_batches(name, partitions) + + def register_parquet( + self, + name: str, + path: str | pathlib.Path, + table_partition_cols: list[tuple[str, str]] | None = None, + parquet_pruning: bool = True, + file_extension: str = ".parquet", + skip_metadata: bool = True, + schema: pyarrow.Schema | None = None, + file_sort_order: list[list[Expr]] | None = None, + ) -> None: + """Register a Parquet file as a table. + + The registered table can be referenced from SQL statement executed + against this context. + + Args: + name: Name of the table to register. + path: Path to the Parquet file. + table_partition_cols: Partition columns. + parquet_pruning: Whether the parquet reader should use the + predicate to prune row groups. + file_extension: File extension; only files with this extension are + selected for data input. + skip_metadata: Whether the parquet reader should skip any metadata + that may be in the file schema. This can help avoid schema + conflicts due to metadata. + schema: The data source schema. + file_sort_order: Sort order for the file. + """ + if table_partition_cols is None: + table_partition_cols = [] + self.ctx.register_parquet( + name, + str(path), + table_partition_cols, + parquet_pruning, + file_extension, + skip_metadata, + schema, + file_sort_order, + ) + + def register_csv( + self, + name: str, + path: str | pathlib.Path, + schema: pyarrow.Schema | None = None, + has_header: bool = True, + delimiter: str = ",", + schema_infer_max_records: int = 1000, + file_extension: str = ".csv", + file_compression_type: str | None = None, + ) -> None: + """Register a CSV file as a table. + + The registered table can be referenced from SQL statement executed against. + + Args: + name: Name of the table to register. + path: Path to the CSV file. + schema: An optional schema representing the CSV file. If None, the + CSV reader will try to infer it based on data in file. + has_header: Whether the CSV file have a header. If schema inference + is run on a file with no headers, default column names are + created. + delimiter: An optional column delimiter. + schema_infer_max_records: Maximum number of rows to read from CSV + files for schema inference if needed. + file_extension: File extension; only files with this extension are + selected for data input. + file_compression_type: File compression type. + """ + self.ctx.register_csv( + name, + str(path), + schema, + has_header, + delimiter, + schema_infer_max_records, + file_extension, + file_compression_type, + ) + + def register_json( + self, + name: str, + path: str | pathlib.Path, + schema: pyarrow.Schema | None = None, + schema_infer_max_records: int = 1000, + file_extension: str = ".json", + table_partition_cols: list[tuple[str, str]] | None = None, + file_compression_type: str | None = None, + ) -> None: + """Register a JSON file as a table. + + The registered table can be referenced from SQL statement executed + against this context. + + Args: + name: Name of the table to register. + path: Path to the JSON file. + schema: The data source schema. + schema_infer_max_records: Maximum number of rows to read from JSON + files for schema inference if needed. + file_extension: File extension; only files with this extension are + selected for data input. + table_partition_cols: Partition columns. + file_compression_type: File compression type. + """ + if table_partition_cols is None: + table_partition_cols = [] + self.ctx.register_json( + name, + str(path), + schema, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, + ) + + def register_avro( + self, + name: str, + path: str | pathlib.Path, + schema: pyarrow.Schema | None = None, + file_extension: str = ".avro", + table_partition_cols: list[tuple[str, str]] | None = None, + ) -> None: + """Register an Avro file as a table. + + The registered table can be referenced from SQL statement executed against + this context. + + Args: + name: Name of the table to register. + path: Path to the Avro file. + schema: The data source schema. + file_extension: File extension to select. + table_partition_cols: Partition columns. + """ + if table_partition_cols is None: + table_partition_cols = [] + self.ctx.register_avro( + name, str(path), schema, file_extension, table_partition_cols + ) + + def register_dataset(self, name: str, dataset: pyarrow.dataset.Dataset) -> None: + """Register a `pyarrow.dataset.Dataset` as a table. + + Args: + name: Name of the table to register. + dataset: PyArrow dataset. + """ + self.ctx.register_dataset(name, dataset) + + def register_udf(self, udf: ScalarUDF) -> None: + """Register a user-defined function (UDF) with the context.""" + self.ctx.register_udf(udf.udf) + + def register_udaf(self, udaf: AggregateUDF) -> None: + """Register a user-defined aggregation function (UDAF) with the context.""" + self.ctx.register_udaf(udaf) + + def catalog(self, name: str = "datafusion") -> Catalog: + """Retrieve a catalog by name.""" + return self.ctx.catalog(name) + + @deprecated( + "Use the catalog provider interface `SessionContext.catalog` to " + "examine available catalogs, schemas and tables" + ) + def tables(self) -> set[str]: + """Deprecated.""" + return self.ctx.tables() + + def table(self, name: str) -> DataFrame: + """Retrieve a `DataFrame` representing a previously registered table.""" + return DataFrame(self.ctx.table(name)) + + def table_exist(self, name: str) -> bool: + """Return whether a table with the given name exists.""" + return self.ctx.table_exist(name) + + def empty_table(self) -> DataFrame: + """Create an empty `DataFrame`.""" + return DataFrame(self.ctx.empty_table()) + + def session_id(self) -> str: + """Retrun an id that uniquely identifies this `SessionContext`.""" + return self.ctx.session_id() + + def read_json( + self, + path: str | pathlib.Path, + schema: pyarrow.Schema | None = None, + schema_infer_max_records: int = 1000, + file_extension: str = ".json", + table_partition_cols: list[tuple[str, str]] | None = None, + file_compression_type: str | None = None, + ) -> DataFrame: + """Create a `DataFrame` for reading a line-delimited JSON data source. + + Args: + path: Path to the JSON file. + schema: The data source schema. + schema_infer_max_records: Maximum number of rows to read from JSON + files for schema inference if needed. + file_extension: File extension; only files with this extension are + selected for data input. + table_partition_cols: Partition columns. + file_compression_type: File compression type. + + Returns: + DataFrame representation of the read JSON files. + """ + if table_partition_cols is None: + table_partition_cols = [] + return DataFrame( + self.ctx.read_json( + str(path), + schema, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, + ) + ) + + def read_csv( + self, + path: str | pathlib.Path, + schema: pyarrow.Schema | None = None, + has_header: bool = True, + delimiter: str = ",", + schema_infer_max_records: int = 1000, + file_extension: str = ".csv", + table_partition_cols: list[tuple[str, str]] | None = None, + file_compression_type: str | None = None, + ) -> DataFrame: + """Create a `DataFrame` for reading a CSV data source. + + Args: + path: Path to the CSV file + schema: An optional schema representing the CSV files. If None, the + CSV reader will try to infer it based on data in file. + has_header: Whether the CSV file have a header. If schema inference + is run on a file with no headers, default column names are + created. + delimiter: An optional column delimiter. + schema_infer_max_records: Maximum number of rows to read from CSV + files for schema inference if needed. + file_extension: File extension; only files with this extension are + selected for data input. + table_partition_cols: Partition columns. + file_compression_type: File compression type. + + Returns: + DataFrame representation of the read CSV files + """ + if table_partition_cols is None: + table_partition_cols = [] + return DataFrame( + self.ctx.read_csv( + str(path), + schema, + has_header, + delimiter, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, + ) + ) + + def read_parquet( + self, + path: str | pathlib.Path, + table_partition_cols: list[tuple[str, str]] | None = None, + parquet_pruning: bool = True, + file_extension: str = ".parquet", + skip_metadata: bool = True, + schema: pyarrow.Schema | None = None, + file_sort_order: list[list[Expr]] | None = None, + ) -> DataFrame: + """Create a `DataFrame` for reading Parquet data source. + + Args: + path: Path to the Parquet file. + table_partition_cols: Partition columns. + parquet_pruning: Whether the parquet reader should use the predicate + to prune row groups. + file_extension: File extension; only files with this extension are + selected for data input. + skip_metadata: Whether the parquet reader should skip any metadata + that may be in the file schema. This can help avoid schema + conflicts due to metadata. + schema: An optional schema representing the parquet files. If None, + the parquet reader will try to infer it based on data in the + file. + file_sort_order: Sort order for the file. + + Returns: + DataFrame representation of the read Parquet files + """ + if table_partition_cols is None: + table_partition_cols = [] + return DataFrame( + self.ctx.read_parquet( + str(path), + table_partition_cols, + parquet_pruning, + file_extension, + skip_metadata, + schema, + file_sort_order, + ) + ) + + def read_avro( + self, + path: str | pathlib.Path, + schema: pyarrow.Schema | None = None, + file_partition_cols: list[tuple[str, str]] | None = None, + file_extension: str = ".avro", + ) -> DataFrame: + """Create a ``DataFrame`` for reading Avro data source. + + Args: + path: Path to the Avro file. + schema: The data source schema. + file_partition_cols: Partition columns. + file_extension: File extension to select. + + Returns: + DataFrame representation of the read Avro file + """ + if file_partition_cols is None: + file_partition_cols = [] + return DataFrame( + self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) + ) + + def read_table(self, table: Table) -> DataFrame: + """Creates a ``DataFrame`` for a ``Table`` such as a ``ListingTable``.""" + return DataFrame(self.ctx.read_table(table)) + + def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: + """Execute the `plan` and return the results.""" + return RecordBatchStream(self.ctx.execute(plan, partitions)) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py new file mode 100644 index 00000000..68e6298f --- /dev/null +++ b/python/datafusion/dataframe.py @@ -0,0 +1,527 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DataFrame is one of the core concepts in DataFusion. + +See https://datafusion.apache.org/python/user-guide/basics.html for more +information. +""" + +from __future__ import annotations + +from typing import Any, List, TYPE_CHECKING +from datafusion.record_batch import RecordBatchStream +from typing_extensions import deprecated + +if TYPE_CHECKING: + import pyarrow as pa + import pandas as pd + import polars as pl + import pathlib + +from datafusion._internal import DataFrame as DataFrameInternal +from datafusion.expr import Expr +from datafusion._internal import ( + LogicalPlan, + ExecutionPlan, +) + + +class DataFrame: + """Two dimensional table representation of data. + + See https://datafusion.apache.org/python/user-guide/basics.html for more + information. + """ + + def __init__(self, df: DataFrameInternal) -> None: + """This constructor is not to be used by the end user. + + See ``SessionContext`` for methods to create DataFrames. + """ + self.df = df + + def __getitem__(self, key: str | List[str]) -> DataFrame: + """Return a new `DataFrame` with the specified column or columns. + + Args: + key: Column name or list of column names to select. + + Returns: + DataFrame with the specified column or columns. + """ + return DataFrame(self.df.__getitem__(key)) + + def __repr__(self) -> str: + """Return a string representation of the DataFrame. + + Returns: + String representation of the DataFrame. + """ + return self.df.__repr__() + + def describe(self) -> DataFrame: + """Return a new `DataFrame` that has statistics for a DataFrame. + + Only summarized numeric datatypes at the moments and returns nulls + for non-numeric datatypes. + + The output format is modeled after pandas. + + Returns: + A summary DataFrame containing statistics. + """ + return DataFrame(self.df.describe()) + + def schema(self) -> pa.Schema: + """Return the `pyarrow.Schema` describing the output of this DataFrame. + + The output schema contains information on the name, data type, and + nullability for each column. + + Returns: + Describing schema of the DataFrame + """ + return self.df.schema() + + def select_columns(self, *args: str) -> DataFrame: + """Filter the DataFrame by columns. + + Returns: + DataFrame only containing the specified columns. + """ + return self.select(*args) + + def select(self, *exprs: Expr | str) -> DataFrame: + """Project arbitrary expressions into a new `DataFrame`. + + Args: + exprs: Either column names or `Expr` to select. + + Returns: + DataFrame after projection. It has one column for each expression. + + Example usage: + + The following example will return 3 columns from the original dataframe. + The first two columns will be the original column `a` and `b` since the + string "a" is assumed to refer to column selection. Also a duplicate of + column `a` will be returned with the column name `alternate_a`:: + + df = df.select("a", col("b"), col("a").alias("alternate_a")) + + """ + exprs = [ + arg.expr if isinstance(arg, Expr) else Expr.column(arg).expr + for arg in exprs + ] + return DataFrame(self.df.select(*exprs)) + + def filter(self, *predicates: Expr) -> DataFrame: + """Return a DataFrame for which `predicate` evaluates to `True`. + + Rows for which `predicate` evaluates to `False` or `None` are filtered + out. If more than one predicate is provided, these predicates will be + combined as a logical AND. If more complex logic is required, see the + logical operations in `datafusion.functions`. + + Args: + predicates: Predicate expression(s) to filter the DataFrame. + + Returns: + DataFrame after filtering. + """ + df = self.df + for p in predicates: + df = df.filter(p.expr) + return DataFrame(df) + + def with_column(self, name: str, expr: Expr) -> DataFrame: + """Add an additional column to the DataFrame. + + Args: + name: Name of the column to add. + expr: Expression to compute the column. + + Returns: + DataFrame with the new column. + """ + return DataFrame(self.df.with_column(name, expr.expr)) + + def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame: + """Rename one column by applying a new projection. + + This is a no-op if the column to be renamed does not exist. + + The method supports case sensitive rename with wrapping column name + into one the following symbols (" or ' or `). + + Args: + old_name: Old column name. + new_name: New column name. + + Returns: + DataFrame with the column renamed. + """ + return DataFrame(self.df.with_column_renamed(old_name, new_name)) + + def aggregate(self, group_by: list[Expr], aggs: list[Expr]) -> DataFrame: + """Aggregates the rows of the current DataFrame. + + Args: + group_by: List of expressions to group by. + aggs: List of expressions to aggregate. + + Returns: + DataFrame after aggregation. + """ + group_by = [e.expr for e in group_by] + aggs = [e.expr for e in aggs] + return DataFrame(self.df.aggregate(group_by, aggs)) + + def sort(self, *exprs: Expr) -> DataFrame: + """Sort the DataFrame by the specified sorting expressions. + + Note that any expression can be turned into a sort expression by + calling its `sort` method. + + Args: + exprs: Sort expressions, applied in order. + + Returns: + DataFrame after sorting. + """ + exprs = [expr.expr for expr in exprs] + return DataFrame(self.df.sort(*exprs)) + + def limit(self, count: int, offset: int = 0) -> DataFrame: + """Return a new `DataFrame` with a limited number of rows. + + Args: + count: Number of rows to limit the DataFrame to. + offset: Number of rows to skip. + + Returns: + DataFrame after limiting. + """ + return DataFrame(self.df.limit(count, offset)) + + def collect(self) -> list[pa.RecordBatch]: + """Execute this `DataFrame` and collect results into memory. + + Prior to calling `collect`, modifying a DataFrme simply updates a plan + (no actual computation is performed). Calling `collect` triggers the + computation. + + Returns: + List of `pyarrow.RecordBatch`es collected from the DataFrame. + """ + return self.df.collect() + + def cache(self) -> DataFrame: + """Cache the DataFrame as a memory table. + + Returns: + Cached DataFrame. + """ + return DataFrame(self.df.cache()) + + def collect_partitioned(self) -> list[list[pa.RecordBatch]]: + """Execute this DataFrame and collect all partitioned results. + + This operation returns ``RecordBatch`` maintaining the input + partitioning. + + Returns: + List of list of ``RecordBatch`` collected from the + DataFrame. + """ + return self.df.collect_partitioned() + + def show(self, num: int = 20) -> None: + """Execute the DataFrame and print the result to the console. + + Args: + num: Number of lines to show. + """ + self.df.show(num) + + def distinct(self) -> DataFrame: + """Return a new `DataFrame` with all duplicated rows removed. + + Returns: + DataFrame after removing duplicates. + """ + return DataFrame(self.df.distinct()) + + def join( + self, + right: DataFrame, + join_keys: tuple[list[str], list[str]], + how: str, + ) -> DataFrame: + """Join this `DataFrame` with another `DataFrame`. + + Join keys are a pair of lists of column names in the left and right + dataframes, respectively. These lists must have the same length. + + Args: + right: Other DataFrame to join with. + join_keys: Tuple of two lists of column names to join on. + how: Type of join to perform. Supported types are "inner", "left", + "right", "full", "semi", "anti". + + Returns: + DataFrame after join. + """ + return DataFrame(self.df.join(right.df, join_keys, how)) + + def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame: + """Return a DataFrame with the explanation of its plan so far. + + If `analyze` is specified, runs the plan and reports metrics. + + Args: + verbose: If `True`, more details will be included. + analyze: If `True`, the plan will run and metrics reported. + + Returns: + DataFrame with the explanation of its plan. + """ + return DataFrame(self.df.explain(verbose, analyze)) + + def logical_plan(self) -> LogicalPlan: + """Return the unoptimized `LogicalPlan` that comprises this `DataFrame`. + + Returns: + Unoptimized logical plan. + """ + return self.df.logical_plan() + + def optimized_logical_plan(self) -> LogicalPlan: + """Return the optimized `LogicalPlan` that comprises this `DataFrame`. + + Returns: + Optimized logical plan. + """ + return self.df.optimized_logical_plan() + + def execution_plan(self) -> ExecutionPlan: + """Return the execution/physical plan that comprises this `DataFrame`. + + Returns: + Execution plan. + """ + return self.df.execution_plan() + + def repartition(self, num: int) -> DataFrame: + """Repartition a DataFrame into `num` partitions. + + The batches allocation uses a round-robin algorithm. + + Args: + num: Number of partitions to repartition the DataFrame into. + + Returns: + Repartitioned DataFrame. + """ + return DataFrame(self.df.repartition(num)) + + def repartition_by_hash(self, *exprs: Expr, num: int) -> DataFrame: + """Repartition a DataFrame using a hash partitioning scheme. + + Args: + exprs: Expressions to evaluate and perform hashing on. + num: Number of partitions to repartition the DataFrame into. + + Returns: + Repartitioned DataFrame. + """ + exprs = [expr.expr for expr in exprs] + return DataFrame(self.df.repartition_by_hash(*exprs, num=num)) + + def union(self, other: DataFrame, distinct: bool = False) -> DataFrame: + """Calculate the union of two `DataFrame`s. + + The two `DataFrame`s must have exactly the same schema. + + Args: + other: DataFrame to union with. + distinct: If `True`, duplicate rows will be removed. + + Returns: + DataFrame after union. + """ + return DataFrame(self.df.union(other.df, distinct)) + + def union_distinct(self, other: DataFrame) -> DataFrame: + """Calculate the distinct union of two `DataFrame`s. + + The two `DataFrame`s must have exactly the same schema. + Any duplicate rows are discarded. + + Args: + other: DataFrame to union with. + + Returns: + DataFrame after union. + """ + return DataFrame(self.df.union_distinct(other.df)) + + def intersect(self, other: DataFrame) -> DataFrame: + """Calculate the intersection of two `DataFrame`s. + + The two `DataFrame`s must have exactly the same schema. + + Args: + other: DataFrame to intersect with. + + Returns: + DataFrame after intersection. + """ + return DataFrame(self.df.intersect(other.df)) + + def except_all(self, other: DataFrame) -> DataFrame: + """Calculate the exception of two `DataFrame`s. + + The two `DataFrame`s must have exactly the same schema. + + Args: + other: DataFrame to calculate exception with. + + Returns: + DataFrame after exception. + """ + return DataFrame(self.df.except_all(other.df)) + + def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None: + """Execute the `DataFrame` and write the results to a CSV file. + + Args: + path: Path of the CSV file to write. + with_header: If true, output the CSV header row. + """ + self.df.write_csv(str(path), with_header) + + def write_parquet( + self, + path: str | pathlib.Path, + compression: str = "uncompressed", + compression_level: int | None = None, + ) -> None: + """Execute the `DataFrame` and write the results to a Parquet file. + + Args: + path: Path of the Parquet file to write. + compression: Compression type to use. + compression_level: Compression level to use. + """ + self.df.write_parquet(str(path), compression, compression_level) + + def write_json(self, path: str | pathlib.Path) -> None: + """Execute the `DataFrame` and write the results to a JSON file. + + Args: + path: Path of the JSON file to write. + """ + self.df.write_json(str(path)) + + def to_arrow_table(self) -> pa.Table: + """Execute the `DataFrame` and convert it into an Arrow Table. + + Returns: + Arrow Table. + """ + return self.df.to_arrow_table() + + def execute_stream(self) -> RecordBatchStream: + """Executes this DataFrame and returns a stream over a single partition. + + Returns: + Record Batch Stream over a single partition. + """ + return RecordBatchStream(self.df.execute_stream()) + + def execute_stream_partitioned(self) -> list[RecordBatchStream]: + """Executes this DataFrame and returns a stream for each partition. + + Returns: + One record batch stream per partition. + """ + streams = self.df.execute_stream_partitioned() + return [RecordBatchStream(rbs) for rbs in streams] + + def to_pandas(self) -> pd.DataFrame: + """Execute the `DataFrame` and convert it into a Pandas DataFrame. + + Returns: + Pandas DataFrame. + """ + return self.df.to_pandas() + + def to_pylist(self) -> list[dict[str, Any]]: + """Execute the `DataFrame` and convert it into a list of dictionaries. + + Returns: + List of dictionaries. + """ + return self.df.to_pylist() + + def to_pydict(self) -> dict[str, list[Any]]: + """Execute the `DataFrame` and convert it into a dictionary of lists. + + Returns: + Dictionary of lists. + """ + return self.df.to_pydict() + + def to_polars(self) -> pl.DataFrame: + """Execute the `DataFrame` and convert it into a Polars DataFrame. + + Returns: + Polars DataFrame. + """ + return self.df.to_polars() + + def count(self) -> int: + """Return the total number of rows in this `DataFrame`. + + Note that this method will actually run a plan to calculate the + count, which may be slow for large or complicated DataFrames. + + Returns: + Number of rows in the DataFrame. + """ + return self.df.count() + + @deprecated("Use :func:`unnest_columns` instead.") + def unnest_column(self, column: str, preserve_nulls: bool = True) -> DataFrame: + """See ``unnest_columns``.""" + return DataFrame(self.df.unnest_column(column, preserve_nulls=preserve_nulls)) + + def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFrame: + """Expand columns of arrays into a single row per array element. + + Args: + columns: Column names to perform unnest operation on. + preserve_nulls: If False, rows with null entries will not be + returned. + + Returns: + A DataFrame with the columns expanded. + """ + columns = [c for c in columns] + return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls)) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index e914b85d..c04a525a 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -15,9 +15,417 @@ # specific language governing permissions and limitations # under the License. +"""This module supports expressions, one of the core concepts in DataFusion. -from ._internal import expr +See ``Expr`` for more details. +""" +from __future__ import annotations -def __getattr__(name): - return getattr(expr, name) +from ._internal import expr as expr_internal, LogicalPlan +from datafusion.common import RexType, DataTypeMap +from typing import Any +import pyarrow as pa + +# The following are imported from the internal representation. We may choose to +# give these all proper wrappers, or to simply leave as is. These were added +# in order to support passing the `test_imports` unit test. +# Tim Saucer note: It is not clear to me what the use case is for exposing +# these definitions to the end user. + +Alias = expr_internal.Alias +Analyze = expr_internal.Analyze +Aggregate = expr_internal.Aggregate +AggregateFunction = expr_internal.AggregateFunction +Between = expr_internal.Between +BinaryExpr = expr_internal.BinaryExpr +Case = expr_internal.Case +Cast = expr_internal.Cast +Column = expr_internal.Column +CreateMemoryTable = expr_internal.CreateMemoryTable +CreateView = expr_internal.CreateView +CrossJoin = expr_internal.CrossJoin +Distinct = expr_internal.Distinct +DropTable = expr_internal.DropTable +Exists = expr_internal.Exists +Explain = expr_internal.Explain +Extension = expr_internal.Extension +Filter = expr_internal.Filter +GroupingSet = expr_internal.GroupingSet +Join = expr_internal.Join +ILike = expr_internal.ILike +InList = expr_internal.InList +InSubquery = expr_internal.InSubquery +IsFalse = expr_internal.IsFalse +IsNotTrue = expr_internal.IsNotTrue +IsTrue = expr_internal.IsTrue +IsUnknown = expr_internal.IsUnknown +IsNotFalse = expr_internal.IsNotFalse +IsNotNull = expr_internal.IsNotNull +IsNotUnknown = expr_internal.IsNotUnknown +JoinConstraint = expr_internal.JoinConstraint +JoinType = expr_internal.JoinType +Like = expr_internal.Like +Limit = expr_internal.Limit +Literal = expr_internal.Literal +Negative = expr_internal.Negative +Not = expr_internal.Not +Partitioning = expr_internal.Partitioning +Placeholder = expr_internal.Placeholder +Projection = expr_internal.Projection +Repartition = expr_internal.Repartition +ScalarSubquery = expr_internal.ScalarSubquery +ScalarVariable = expr_internal.ScalarVariable +SimilarTo = expr_internal.SimilarTo +Sort = expr_internal.Sort +Subquery = expr_internal.Subquery +SubqueryAlias = expr_internal.SubqueryAlias +TableScan = expr_internal.TableScan +TryCast = expr_internal.TryCast +Union = expr_internal.Union + + +class Expr: + """Expression object. + + Expressions are one of the core concepts in DataFusion. See + https://datafusion.apache.org/python/user-guide/common-operations/expressions.html + for more information. + """ + + def __init__(self, expr: expr_internal.Expr) -> None: + """This constructor should not be called by the end user.""" + self.expr = expr + + def to_variant(self) -> Any: + """Convert this expression into a python object if possible.""" + return self.expr.to_variant() + + def display_name(self) -> str: + """Returns the name of this expression as it should appear in a schema. + + This name will not include any CAST expressions. + """ + return self.expr.display_name() + + def canonical_name(self) -> str: + """Returns a complete string representation of this expression.""" + return self.expr.canonical_name() + + def variant_name(self) -> str: + """Returns the name of the Expr variant. + + Ex: ``IsNotNull``, ``Literal``, ``BinaryExpr``, etc + """ + return self.expr.variant_name() + + def __richcmp__(self, other: Expr, op: int) -> Expr: + """Comparison operator.""" + return Expr(self.expr.__richcmp__(other, op)) + + def __repr__(self) -> str: + """Generate a string representation of this expression.""" + return self.expr.__repr__() + + def __add__(self, rhs: Any) -> Expr: + """Addition operator. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) + return Expr(self.expr.__add__(rhs.expr)) + + def __sub__(self, rhs: Any) -> Expr: + """Subtraction operator. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) + return Expr(self.expr.__sub__(rhs.expr)) + + def __truediv__(self, rhs: Any) -> Expr: + """Division operator. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) + return Expr(self.expr.__truediv__(rhs.expr)) + + def __mul__(self, rhs: Any) -> Expr: + """Multiplication operator. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) + return Expr(self.expr.__mul__(rhs.expr)) + + def __mod__(self, rhs: Any) -> Expr: + """Modulo operator (%). + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) + return Expr(self.expr.__mod__(rhs.expr)) + + def __and__(self, rhs: Expr) -> Expr: + """Logical AND.""" + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) + return Expr(self.expr.__and__(rhs.expr)) + + def __or__(self, rhs: Expr) -> Expr: + """Logical OR.""" + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) + return Expr(self.expr.__or__(rhs.expr)) + + def __invert__(self) -> Expr: + """Binary not (~).""" + return Expr(self.expr.__invert__()) + + def __getitem__(self, key: str) -> Expr: + """For struct data types, return the field indicated by ``key``.""" + return Expr(self.expr.__getitem__(key)) + + def __eq__(self, rhs: Any) -> Expr: + """Equal to. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) + return Expr(self.expr.__eq__(rhs.expr)) + + def __ne__(self, rhs: Any) -> Expr: + """Not equal to. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) + return Expr(self.expr.__ne__(rhs.expr)) + + def __ge__(self, rhs: Any) -> Expr: + """Greater than or equal to. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) + return Expr(self.expr.__ge__(rhs.expr)) + + def __gt__(self, rhs: Any) -> Expr: + """Greater than. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) + return Expr(self.expr.__gt__(rhs.expr)) + + def __le__(self, rhs: Any) -> Expr: + """Less than or equal to. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) + return Expr(self.expr.__le__(rhs.expr)) + + def __lt__(self, rhs: Any) -> Expr: + """Less than. + + Accepts either an expression or any valid PyArrow scalar literal value. + """ + if not isinstance(rhs, Expr): + rhs = Expr.literal(rhs) + return Expr(self.expr.__lt__(rhs.expr)) + + @staticmethod + def literal(value: Any) -> Expr: + """Creates a new expression representing a scalar value. + + `value` must be a valid PyArrow scalar value or easily castable to one. + """ + if not isinstance(value, pa.Scalar): + value = pa.scalar(value) + return Expr(expr_internal.Expr.literal(value)) + + @staticmethod + def column(value: str) -> Expr: + """Creates a new expression representing a column in a ``DataFrame``.""" + return Expr(expr_internal.Expr.column(value)) + + def alias(self, name: str) -> Expr: + """Assign a name to the expression.""" + return Expr(self.expr.alias(name)) + + def sort(self, ascending: bool = True, nulls_first: bool = True) -> Expr: + """Creates a sort ``Expr`` from an existing ``Expr``. + + Args: + ascending: If true, sort in ascending order. + nulls_first: Return null values first. + """ + return Expr(self.expr.sort(ascending=ascending, nulls_first=nulls_first)) + + def is_null(self) -> Expr: + """Returns ``True`` if this expression is null.""" + return Expr(self.expr.is_null()) + + def cast(self, to: pa.DataType[Any]) -> Expr: + """Cast to a new data type.""" + return Expr(self.expr.cast(to)) + + def rex_type(self) -> RexType: + """Return the Rex Type of this expression. + + A Rex (Row Expression) specifies a single row of data.That specification + could include user defined functions or types. RexType identifies the + row as one of the possible valid ``RexType``(s). + """ + return self.expr.rex_type() + + def types(self) -> DataTypeMap: + """Return the ``DataTypeMap``. + + Returns: + DataTypeMap which represents the PythonType, Arrow DataType, and + SqlType Enum which this expression represents. + """ + return self.expr.types() + + def python_value(self) -> Any: + """Extracts the Expr value into a PyObject. + + This is only valid for literal expressions. + + Returns: + Python object representing literal value of the expression. + """ + return self.expr.python_value() + + def rex_call_operands(self) -> list[Expr]: + """Return the operands of the expression based on it's variant type. + + Row expressions, Rex(s), operate on the concept of operands. Different + variants of Expressions, Expr(s), store those operands in different + datastructures. This function examines the Expr variant and returns + the operands to the calling logic. + """ + return [Expr(e) for e in self.expr.rex_call_operands()] + + def rex_call_operator(self) -> str: + """Extracts the operator associated with a row expression type call.""" + return self.expr.rex_call_operator() + + def column_name(self, plan: LogicalPlan) -> str: + """Compute the output column name based on the provided logical plan.""" + return self.expr.column_name(plan) + + +class WindowFrame: + """Defines a window frame for performing window operations.""" + + def __init__( + self, units: str, start_bound: int | None, end_bound: int | None + ) -> None: + """Construct a window frame using the given parameters. + + Args: + units: Should be one of `rows`, `range`, or `groups`. + start_bound: Sets the preceeding bound. Must be >= 0. If none, this + will be set to unbounded. If unit type is `groups`, this + parameter must be set. + end_bound: Sets the following bound. Must be >= 0. If none, this + will be set to unbounded. If unit type is `groups`, this + parameter must be set. + """ + self.window_frame = expr_internal.WindowFrame(units, start_bound, end_bound) + + def get_frame_units(self) -> str: + """Returns the window frame units for the bounds.""" + return self.window_frame.get_frame_units() + + def get_lower_bound(self) -> WindowFrameBound: + """Returns starting bound.""" + return WindowFrameBound(self.window_frame.get_lower_bound()) + + def get_upper_bound(self): + """Returns end bound.""" + return WindowFrameBound(self.window_frame.get_upper_bound()) + + +class WindowFrameBound: + """Defines a single window frame bound. + + ``WindowFrame`` typically requires a start and end bound. + """ + + def __init__(self, frame_bound: expr_internal.WindowFrameBound) -> None: + """Constructs a window frame bound.""" + self.frame_bound = frame_bound + + def get_offset(self) -> int | None: + """Returns the offset of the window frame.""" + return self.frame_bound.get_offset() + + def is_current_row(self) -> bool: + """Returns if the frame bound is current row.""" + return self.frame_bound.is_current_row() + + def is_following(self) -> bool: + """Returns if the frame bound is following.""" + return self.frame_bound.is_following() + + def is_preceding(self) -> bool: + """Returns if the frame bound is preceding.""" + return self.frame_bound.is_preceding() + + def is_unbounded(self) -> bool: + """Returns if the frame bound is unbounded.""" + return self.frame_bound.is_unbounded() + + +class CaseBuilder: + """Builder class for constructing case statements. + + An example usage would be as follows:: + + import datafusion.functions as f + from datafusion import lit, col + df.select( + f.case(col("column_a") + .when(lit(1), lit("One")) + .when(lit(2), lit("Two")) + .otherwise(lit("Unknown")) + ) + """ + + def __init__(self, case_builder: expr_internal.CaseBuilder) -> None: + """Constructs a case builder. + + This is not typically called by the end user directly. See + ``datafusion.functions.case`` instead. + """ + self.case_builder = case_builder + + def when(self, when_expr: Expr, then_expr: Expr) -> CaseBuilder: + """Add a case to match against.""" + return CaseBuilder(self.case_builder.when(when_expr.expr, then_expr.expr)) + + def otherwise(self, else_expr: Expr) -> Expr: + """Set a default value for the case statement.""" + return Expr(self.case_builder.otherwise(else_expr.expr)) + + def end(self) -> Expr: + """Finish building a case statement. + + Any non-matching cases will end in a `null` value. + """ + return Expr(self.case_builder.end()) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 782ecba2..ad77712e 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -14,10 +14,1475 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""This module contains the user functions for operating on ``Expr``.""" +from __future__ import annotations -from ._internal import functions +# from datafusion._internal.context import SessionContext +# from datafusion._internal.expr import Expr +# from datafusion._internal.expr.conditional_expr import CaseBuilder +# from datafusion._internal.expr.window import WindowFrame +from datafusion._internal import functions as f, common +from datafusion.expr import CaseBuilder, Expr, WindowFrame +from datafusion.context import SessionContext -def __getattr__(name): - return getattr(functions, name) + +def isnan(expr: Expr) -> Expr: + """Returns true if a given number is +NaN or -NaN otherwise returns false.""" + return Expr(f.isnan(expr.expr)) + + +def nullif(expr1: Expr, expr2: Expr) -> Expr: + """Returns NULL if expr1 equals expr2; otherwise it returns expr1. + + This can be used to perform the inverse operation of the COALESCE expression. + """ + return Expr(f.nullif(expr1.expr, expr2.expr)) + + +def encode(input: Expr, encoding: Expr) -> Expr: + """Encode the `input`, using the `encoding`. encoding can be base64 or hex.""" + return Expr(f.encode(input.expr, encoding.expr)) + + +def decode(input: Expr, encoding: Expr) -> Expr: + """Decode the `input`, using the `encoding`. encoding can be base64 or hex.""" + return Expr(f.decode(input.expr, encoding.expr)) + + +def array_to_string(expr: Expr, delimiter: Expr) -> Expr: + """Converts each element to its text representation.""" + return Expr(f.array_to_string(expr.expr, delimiter.expr)) + + +def array_join(expr: Expr, delimiter: Expr) -> Expr: + """Converts each element to its text representation. + + This is an alias for :func:`array_to_string`. + """ + return array_to_string(expr, delimiter) + + +def list_to_string(expr: Expr, delimiter: Expr) -> Expr: + """Converts each element to its text representation. + + This is an alias for :func:`array_to_string`. + """ + return array_to_string(expr, delimiter) + + +def list_join(expr: Expr, delimiter: Expr) -> Expr: + """Converts each element to its text representation. + + This is an alias for :func:`array_to_string`. + """ + return array_to_string(expr, delimiter) + + +def in_list(arg: Expr, values: list[Expr], negated: bool = False) -> Expr: + """Returns whether the argument is contained within the list `values`.""" + values = [v.expr for v in values] + return Expr(f.in_list(arg.expr, values, negated)) + + +def digest(value: Expr, method: Expr) -> Expr: + """Computes the binary hash of an expression using the specified algorithm. + + Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, + blake2b, and blake3. + """ + return Expr(f.digest(value.expr, method.expr)) + + +def concat(*args: Expr) -> Expr: + """Concatenates the text representations of all the arguments. + + NULL arguments are ignored. + """ + args = [arg.expr for arg in args] + return Expr(f.concat(*args)) + + +def concat_ws(separator: str, *args: Expr) -> Expr: + """Concatenates the list `args` with the separator. + + `NULL` arugments are ignored. `separator` should not be `NULL`. + """ + args = [arg.expr for arg in args] + return Expr(f.concat_ws(separator, *args)) + + +def order_by(expr: Expr, ascending: bool = True, nulls_first: bool = True) -> Expr: + """Creates a new sort expression.""" + return Expr(f.order_by(expr.expr, ascending, nulls_first)) + + +def alias(expr: Expr, name: str) -> Expr: + """Creates an alias expression.""" + return Expr(f.alias(expr.expr, name)) + + +def col(name: str) -> Expr: + """Creates a column reference expression.""" + return Expr(f.col(name)) + + +def count_star() -> Expr: + """Create a COUNT(1) aggregate expression.""" + return Expr(f.count_star()) + + +def case(expr: Expr) -> CaseBuilder: + """Create a ``CaseBuilder`` to match cases for the expression ``expr``. + + See ``datafusion.expr.CaseBuilder`` for detailed usage of ``CaseBuilder``. + """ + return CaseBuilder(f.case(expr.expr)) + + +def window( + name: str, + args: list[Expr], + partition_by: list[Expr] | None = None, + order_by: list[Expr] | None = None, + window_frame: WindowFrame | None = None, + ctx: SessionContext | None = None, +) -> Expr: + """Creates a new Window function expression.""" + args = [a.expr for a in args] + partition_by = [e.expr for e in partition_by] if partition_by is not None else None + order_by = [o.expr for o in order_by] if order_by is not None else None + window_frame = window_frame.window_frame if window_frame is not None else None + return Expr(f.window(name, args, partition_by, order_by, window_frame, ctx)) + + +# scalar functions +def abs(arg: Expr) -> Expr: + """Return the absolute value of a given number. + + Returns: + -------- + Expr + A new expression representing the absolute value of the input expression. + """ + return Expr(f.abs(arg.expr)) + + +def acos(arg: Expr) -> Expr: + """Returns the arc cosine or inverse cosine of a number. + + Returns: + -------- + Expr + A new expression representing the arc cosine of the input expression. + """ + return Expr(f.acos(arg.expr)) + + +def acosh(arg: Expr) -> Expr: + """Returns inverse hyperbolic cosine.""" + return Expr(f.acosh(arg.expr)) + + +def ascii(arg: Expr) -> Expr: + """Returns the numeric code of the first character of the argument.""" + return Expr(f.ascii(arg.expr)) + + +def asin(arg: Expr) -> Expr: + """Returns the arc sine or inverse sine of a number.""" + return Expr(f.asin(arg.expr)) + + +def asinh(arg: Expr) -> Expr: + """Returns inverse hyperbolic sine.""" + return Expr(f.asinh(arg.expr)) + + +def atan(arg: Expr) -> Expr: + """Returns inverse tangent of a number.""" + return Expr(f.atan(arg.expr)) + + +def atanh(arg: Expr) -> Expr: + """Returns inverse hyperbolic tangent.""" + return Expr(f.atanh(arg.expr)) + + +def atan2(y: Expr, x: Expr) -> Expr: + """Returns inverse tangent of a division given in the argument.""" + return Expr(f.atan2(y.expr, x.expr)) + + +def bit_length(arg: Expr) -> Expr: + """Returns the number of bits in the string argument.""" + return Expr(f.bit_length(arg.expr)) + + +def btrim(arg: Expr) -> Expr: + """Removes all characters, spaces by default, from both sides of a string.""" + return Expr(f.btrim(arg.expr)) + + +def cbrt(arg: Expr) -> Expr: + """Returns the cube root of a number.""" + return Expr(f.cbrt(arg.expr)) + + +def ceil(arg: Expr) -> Expr: + """Returns the nearest integer greater than or equal to argument.""" + return Expr(f.ceil(arg.expr)) + + +def character_length(arg: Expr) -> Expr: + """Returns the number of characters in the argument.""" + return Expr(f.character_length(arg.expr)) + + +def length(string: Expr) -> Expr: + """The number of characters in the `string`.""" + return Expr(f.length(string.expr)) + + +def char_length(string: Expr) -> Expr: + """The number of characters in the `string`.""" + return Expr(f.char_length(string.expr)) + + +def chr(arg: Expr) -> Expr: + """Converts the Unicode code point to a UTF8 character.""" + return Expr(f.chr(arg.expr)) + + +def coalesce(*args: Expr) -> Expr: + """Returns the value of the first expr in `args` which is not NULL.""" + args = [arg.expr for arg in args] + return Expr(f.coalesce(*args)) + + +def cos(arg: Expr) -> Expr: + """Returns the cosine of the argument.""" + return Expr(f.cos(arg.expr)) + + +def cosh(arg: Expr) -> Expr: + """Returns the hyperbolic cosine of the argument.""" + return Expr(f.cosh(arg.expr)) + + +def cot(arg: Expr) -> Expr: + """Returns the cotangent of the argument.""" + return Expr(f.cot(arg.expr)) + + +def degrees(arg: Expr) -> Expr: + """Converts the argument from radians to degrees.""" + return Expr(f.degrees(arg.expr)) + + +def ends_with(arg: Expr, suffix: Expr) -> Expr: + """Returns true if the `string` ends with the `suffix`, false otherwise.""" + return Expr(f.ends_with(arg.expr, suffix.expr)) + + +def exp(arg: Expr) -> Expr: + """Returns the exponential of the arugment.""" + return Expr(f.exp(arg.expr)) + + +def factorial(arg: Expr) -> Expr: + """Returns the factorial of the argument.""" + return Expr(f.factorial(arg.expr)) + + +def find_in_set(string: Expr, string_list: Expr) -> Expr: + """Find a string in a list of strings. + + Returns a value in the range of 1 to N if the string is in the string list + `string_list` consisting of N substrings. + + The string list is a string composed of substrings separated by `,` characters. + """ + return Expr(f.find_in_set(string.expr, string_list.expr)) + + +def floor(arg: Expr) -> Expr: + """Returns the nearest integer less than or equal to the argument.""" + return Expr(f.floor(arg.expr)) + + +def gcd(x: Expr, y: Expr) -> Expr: + """Returns the greatest common divisor.""" + return Expr(f.gcd(x.expr, y.expr)) + + +def initcap(string: Expr) -> Expr: + """Set the initial letter of each word to capital. + + Converts the first letter of each word in `string` to uppercase and the remaining + characters to lowercase. + """ + return Expr(f.initcap(string.expr)) + + +def instr(string: Expr, substring: Expr) -> Expr: + """Finds the position from where the `substring` matches the `string`. + + This is an alias for :func:`strpos`. + """ + return strpos(string, substring) + + +def iszero(arg: Expr) -> Expr: + """Returns true if a given number is +0.0 or -0.0 otherwise returns false.""" + return Expr(f.iszero(arg.expr)) + + +def lcm(x: Expr, y: Expr) -> Expr: + """Returns the least common multiple.""" + return Expr(f.lcm(x.expr, y.expr)) + + +def left(string: Expr, n: Expr) -> Expr: + """Returns the first `n` characters in the `string`.""" + return Expr(f.left(string.expr, n.expr)) + + +def levenshtein(string1: Expr, string2: Expr) -> Expr: + """Returns the Levenshtein distance between the two given strings.""" + return Expr(f.levenshtein(string1.expr, string2.expr)) + + +def ln(arg: Expr) -> Expr: + """Returns the natural logarithm (base e) of the argument.""" + return Expr(f.ln(arg.expr)) + + +def log(base: Expr, num: Expr) -> Expr: + """Returns the logarithm of a number for a particular `base`.""" + return Expr(f.log(base.expr, num.expr)) + + +def log10(arg: Expr) -> Expr: + """Base 10 logarithm of the argument.""" + return Expr(f.log10(arg.expr)) + + +def log2(arg: Expr) -> Expr: + """Base 2 logarithm of the argument.""" + return Expr(f.log2(arg.expr)) + + +def lower(arg: Expr) -> Expr: + """Converts a string to lowercase.""" + return Expr(f.lower(arg.expr)) + + +def lpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: + """Add left padding to a string. + + Extends the string to length length by prepending the characters fill (a + space by default). If the string is already longer than length then it is + truncated (on the right). + """ + characters = characters if characters is not None else Expr.literal(" ") + return Expr(f.lpad(string.expr, count.expr, characters.expr)) + + +def ltrim(arg: Expr) -> Expr: + """Removes all characters, spaces by default, from the beginning of a string.""" + return Expr(f.ltrim(arg.expr)) + + +def md5(arg: Expr) -> Expr: + """Computes an MD5 128-bit checksum for a string expression.""" + return Expr(f.md5(arg.expr)) + + +def nanvl(x: Expr, y: Expr) -> Expr: + """Returns `x` if `x` is not `NaN`. Otherwise returns `y`.""" + return Expr(f.nanvl(x.expr, y.expr)) + + +def octet_length(arg: Expr) -> Expr: + """Returns the number of bytes of a string.""" + return Expr(f.octet_length(arg.expr)) + + +def overlay( + string: Expr, substring: Expr, start: Expr, length: Expr | None = None +) -> Expr: + """Replace a substring with a new substring. + + Replace the substring of string that starts at the `start`'th character and + extends for `length` characters with new substring. + """ + if length is None: + return Expr(f.overlay(string.expr, substring.expr, start.expr)) + return Expr(f.overlay(string.expr, substring.expr, start.expr, length.expr)) + + +def pi() -> Expr: + """Returns an approximate value of π.""" + return Expr(f.pi()) + + +def position(string: Expr, substring: Expr) -> Expr: + """Finds the position from where the `substring` matches the `string`. + + This is an alias for :func:`strpos`. + """ + return strpos(string, substring) + + +def power(base: Expr, exponent: Expr) -> Expr: + """Returns `base` raised to the power of `exponent`.""" + return Expr(f.power(base.expr, exponent.expr)) + + +def pow(base: Expr, exponent: Expr) -> Expr: + """Returns `base` raised to the power of `exponent`. + + This is an alias of `power`. + """ + return power(base, exponent) + + +def radians(arg: Expr) -> Expr: + """Converts the argument from degrees to radians.""" + return Expr(f.radians(arg.expr)) + + +def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: + """Find if any regular expression (regex) matches exist. + + Tests a string using a regular expression returning true if at least one match, + false otherwise. + """ + if flags is not None: + flags = flags.expr + return Expr(f.regexp_like(string.expr, regex.expr, flags)) + + +def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr: + """Perform regular expression (regex) matching. + + Returns an array with each element containing the leftmost-first match of the + corresponding index in `regex` to string in `string`. + """ + if flags is not None: + flags = flags.expr + return Expr(f.regexp_match(string.expr, regex.expr, flags)) + + +def regexp_replace( + string: Expr, pattern: Expr, replacement: Expr, flags: Expr | None = None +) -> Expr: + """Replaces substring(s) matching a PCRE-like regular expression. + + The full list of supported features and syntax can be found at + + + Supported flags with the addition of 'g' can be found at + + """ + if flags is not None: + flags = flags.expr + return Expr(f.regexp_replace(string.expr, pattern.expr, replacement.expr, flags)) + + +def repeat(string: Expr, n: Expr) -> Expr: + """Repeats the `string` to `n` times.""" + return Expr(f.repeat(string.expr, n.expr)) + + +def replace(string: Expr, from_val: Expr, to_val: Expr) -> Expr: + """Replaces all occurrences of `from` with `to` in the `string`.""" + return Expr(f.replace(string.expr, from_val.expr, to_val.expr)) + + +def reverse(arg: Expr) -> Expr: + """Reverse the string argument.""" + return Expr(f.reverse(arg.expr)) + + +def right(string: Expr, n: Expr) -> Expr: + """Returns the last `n` characters in the `string`.""" + return Expr(f.right(string.expr, n.expr)) + + +def round(value: Expr, decimal_places: Expr = Expr.literal(0)) -> Expr: + """Round the argument to the nearest integer. + + If the optional ``decimal_places`` is specified, round to the nearest number of + decimal places. You can specify a negative number of decimal places. For example + `round(lit(125.2345), lit(-2))` would yield a value of `100.0`. + """ + return Expr(f.round(value.expr, decimal_places.expr)) + + +def rpad(string: Expr, count: Expr, characters: Expr | None = None) -> Expr: + """Add right padding to a string. + + Extends the string to length length by appending the characters fill (a space + by default). If the string is already longer than length then it is truncated. + """ + characters = characters if characters is not None else Expr.literal(" ") + return Expr(f.rpad(string.expr, count.expr, characters.expr)) + + +def rtrim(arg: Expr) -> Expr: + """Removes all characters, spaces by default, from the end of a string.""" + return Expr(f.rtrim(arg.expr)) + + +def sha224(arg: Expr) -> Expr: + """Computes the SHA-224 hash of a binary string.""" + return Expr(f.sha224(arg.expr)) + + +def sha256(arg: Expr) -> Expr: + """Computes the SHA-256 hash of a binary string.""" + return Expr(f.sha256(arg.expr)) + + +def sha384(arg: Expr) -> Expr: + """Computes the SHA-384 hash of a binary string.""" + return Expr(f.sha384(arg.expr)) + + +def sha512(arg: Expr) -> Expr: + """Computes the SHA-512 hash of a binary string.""" + return Expr(f.sha512(arg.expr)) + + +def signum(arg: Expr) -> Expr: + """Returns the sign of the argument (-1, 0, +1).""" + return Expr(f.signum(arg.expr)) + + +def sin(arg: Expr) -> Expr: + """Returns the sine of the argument.""" + return Expr(f.sin(arg.expr)) + + +def sinh(arg: Expr) -> Expr: + """Returns the hyperbolic sine of the argument.""" + return Expr(f.sinh(arg.expr)) + + +def split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr: + """Split a string and return one part. + + Splits a string based on a delimiter and picks out the desired field based + on the index. + """ + return Expr(f.split_part(string.expr, delimiter.expr, index.expr)) + + +def sqrt(arg: Expr) -> Expr: + """Returns the square root of the argument.""" + return Expr(f.sqrt(arg.expr)) + + +def starts_with(string: Expr, prefix: Expr) -> Expr: + """Returns true if string starts with prefix.""" + return Expr(f.starts_with(string.expr, prefix.expr)) + + +def strpos(string: Expr, substring: Expr) -> Expr: + """Finds the position from where the `substring` matches the `string`.""" + return Expr(f.strpos(string.expr, substring.expr)) + + +def substr(string: Expr, position: Expr) -> Expr: + """Substring from the `position` to the end.""" + return Expr(f.substr(string.expr, position.expr)) + + +def substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr: + """Returns the substring from `string` before `count` occurrences of `delimiter`.""" + return Expr(f.substr_index(string.expr, delimiter.expr, count.expr)) + + +def substring(string: Expr, position: Expr, length: Expr) -> Expr: + """Substring from the `position` with `length` characters.""" + return Expr(f.substring(string.expr, position.expr, length.expr)) + + +def tan(arg: Expr) -> Expr: + """Returns the tangent of the argument.""" + return Expr(f.tan(arg.expr)) + + +def tanh(arg: Expr) -> Expr: + """Returns the hyperbolic tangent of the argument.""" + return Expr(f.tanh(arg.expr)) + + +def to_hex(arg: Expr) -> Expr: + """Converts an integer to a hexadecimal string.""" + return Expr(f.to_hex(arg.expr)) + + +def now() -> Expr: + """Returns the current timestamp in nanoseconds. + + This will use the same value for all instances of now() in same statement. + """ + return Expr(f.now()) + + +def to_timestamp(arg: Expr, *formatters: Expr) -> Expr: + """Converts a string and optional formats to a `Timestamp` in nanoseconds. + + For usage of ``formatters`` see the rust chrono package ``strftime`` package. + + [Documentation here.](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) + """ + if formatters is None: + return f.to_timestamp(arg.expr) + + formatters = [f.expr for f in formatters] + return Expr(f.to_timestamp(arg.expr, *formatters)) + + +def to_timestamp_millis(arg: Expr, *formatters: Expr) -> Expr: + """Converts a string and optional formats to a `Timestamp` in milliseconds. + + See `to_timestamp` for a description on how to use formatters. + """ + return Expr(f.to_timestamp_millis(arg.expr, *formatters)) + + +def to_timestamp_micros(arg: Expr, *formatters: Expr) -> Expr: + """Converts a string and optional formats to a `Timestamp` in microseconds. + + See `to_timestamp` for a description on how to use formatters. + """ + return Expr(f.to_timestamp_micros(arg.expr, *formatters)) + + +def to_timestamp_nanos(arg: Expr, *formatters: Expr) -> Expr: + """Converts a string and optional formats to a `Timestamp` in nanoseconds. + + See `to_timestamp` for a description on how to use formatters. + """ + return Expr(f.to_timestamp_nanos(arg.expr, *formatters)) + + +def to_timestamp_seconds(arg: Expr, *formatters: Expr) -> Expr: + """Converts a string and optional formats to a `Timestamp` in seconds. + + See `to_timestamp` for a description on how to use formatters. + """ + return Expr(f.to_timestamp_seconds(arg.expr, *formatters)) + + +def to_unixtime(string: Expr, *format_arguments: Expr) -> Expr: + """Converts a string and optional formats to a Unixtime.""" + args = [f.expr for f in format_arguments] + return Expr(f.to_unixtime(string.expr, *args)) + + +def current_date() -> Expr: + """Returns current UTC date as a Date32 value.""" + return Expr(f.current_date()) + + +def current_time() -> Expr: + """Returns current UTC time as a Time64 value.""" + return Expr(f.current_time()) + + +def datepart(part: Expr, date: Expr) -> Expr: + """Return a specified part of a date. + + This is an alias for `date_part`. + """ + return date_part(part, date) + + +def date_part(part: Expr, date: Expr) -> Expr: + """Extracts a subfield from the date.""" + return Expr(f.date_part(part.expr, date.expr)) + + +def date_trunc(part: Expr, date: Expr) -> Expr: + """Truncates the date to a specified level of precision.""" + return Expr(f.date_trunc(part.expr, date.expr)) + + +def datetrunc(part: Expr, date: Expr) -> Expr: + """Truncates the date to a specified level of precision. + + This is an alias for `date_trunc`. + """ + return date_trunc(part, date) + + +def date_bin(stride: Expr, source: Expr, origin: Expr) -> Expr: + """Coerces an arbitrary timestamp to the start of the nearest specified interval.""" + return Expr(f.date_bin(stride.expr, source.expr, origin.expr)) + + +def make_date(year: Expr, month: Expr, day: Expr) -> Expr: + """Make a date from year, month and day component parts.""" + return Expr(f.make_date(year.expr, month.expr, day.expr)) + + +def translate(string: Expr, from_val: Expr, to_val: Expr) -> Expr: + """Replaces the characters in `from_val` with the counterpart in `to_val`.""" + return Expr(f.translate(string.expr, from_val.expr, to_val.expr)) + + +def trim(arg: Expr) -> Expr: + """Removes all characters, spaces by default, from both sides of a string.""" + return Expr(f.trim(arg.expr)) + + +def trunc(num: Expr, precision: Expr | None = None) -> Expr: + """Truncate the number toward zero with optional precision.""" + if precision is not None: + return Expr(f.trunc(num.expr, precision.expr)) + return Expr(f.trunc(num.expr)) + + +def upper(arg: Expr) -> Expr: + """Converts a string to uppercase.""" + return Expr(f.upper(arg.expr)) + + +def make_array(*args: Expr) -> Expr: + """Returns an array using the specified input expressions.""" + args = [arg.expr for arg in args] + return Expr(f.make_array(*args)) + + +def array(*args: Expr) -> Expr: + """Returns an array using the specified input expressions. + + This is an alias for `make_array`. + """ + return make_array(args) + + +def range(start: Expr, stop: Expr, step: Expr) -> Expr: + """Create a list of values in the range between start and stop.""" + return Expr(f.range(start.expr, stop.expr, step.expr)) + + +def uuid(arg: Expr) -> Expr: + """Returns uuid v4 as a string value.""" + return Expr(f.uuid(arg.expr)) + + +def struct(*args: Expr) -> Expr: + """Returns a struct with the given arguments.""" + args = [arg.expr for arg in args] + return Expr(f.struct(*args)) + + +def named_struct(name_pairs: list[(str, Expr)]) -> Expr: + """Returns a struct with the given names and arguments pairs.""" + name_pairs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs] + + # flatten + name_pairs = [x.expr for xs in name_pairs for x in xs] + return Expr(f.named_struct(*name_pairs)) + + +def from_unixtime(arg: Expr) -> Expr: + """Converts an integer to RFC3339 timestamp format string.""" + return Expr(f.from_unixtime(arg.expr)) + + +def arrow_typeof(arg: Expr) -> Expr: + """Returns the Arrow type of the expression.""" + return Expr(f.arrow_typeof(arg.expr)) + + +def random() -> Expr: + """Returns a random value in the range `0.0 <= x < 1.0`.""" + return Expr(f.random()) + + +def array_append(array: Expr, element: Expr) -> Expr: + """Appends an element to the end of an array.""" + return Expr(f.array_append(array.expr, element.expr)) + + +def array_push_back(array: Expr, element: Expr) -> Expr: + """Appends an element to the end of an array. + + This is an alias for `array_append`. + """ + return array_append(array, element) + + +def list_append(array: Expr, element: Expr) -> Expr: + """Appends an element to the end of an array. + + This is an alias for `array_append`. + """ + return array_append(array, element) + + +def list_push_back(array: Expr, element: Expr) -> Expr: + """Appends an element to the end of an array. + + This is an alias for `array_append`. + """ + return array_append(array, element) + + +def array_concat(*args: Expr) -> Expr: + """Concatenates the input arrays.""" + args = [arg.expr for arg in args] + return Expr(f.array_concat(*args)) + + +def array_cat(*args: Expr) -> Expr: + """Concatenates the input arrays. + + This is an alias for `array_concat`. + """ + return array_concat(*args) + + +def array_dims(array: Expr) -> Expr: + """Returns an array of the array's dimensions.""" + return Expr(f.array_dims(array.expr)) + + +def array_distinct(array: Expr) -> Expr: + """Returns distinct values from the array after removing duplicates.""" + return Expr(f.array_distinct(array.expr)) + + +def list_distinct(array: Expr) -> Expr: + """Returns distinct values from the array after removing duplicates. + + This is an alias for `array_distinct`. + """ + return array_distinct(array) + + +def list_dims(array: Expr) -> Expr: + """Returns an array of the array's dimensions. + + This is an alias for `array_dims`. + """ + return array_dims(array) + + +def array_element(array: Expr, n: Expr) -> Expr: + """Extracts the element with the index n from the array.""" + return Expr(f.array_element(array.expr, n.expr)) + + +def array_extract(array: Expr, n: Expr) -> Expr: + """Extracts the element with the index n from the array. + + This is an alias for `array_element`. + """ + return array_element(array, n) + + +def list_element(array: Expr, n: Expr) -> Expr: + """Extracts the element with the index n from the array. + + This is an alias for `array_element`. + """ + return array_element(array, n) + + +def list_extract(array: Expr, n: Expr) -> Expr: + """Extracts the element with the index n from the array. + + This is an alias for `array_element`. + """ + return array_element(array, n) + + +def array_length(array: Expr) -> Expr: + """Returns the length of the array.""" + return Expr(f.array_length(array.expr)) + + +def list_length(array: Expr) -> Expr: + """Returns the length of the array. + + This is an alias for `array_length`. + """ + return array_length(array) + + +def array_has(first_array: Expr, second_array: Expr) -> Expr: + """Returns true if the element appears in the first array, otherwise false.""" + return Expr(f.array_has(first_array.expr, second_array.expr)) + + +def array_has_all(first_array: Expr, second_array: Expr) -> Expr: + """Determines if there is complete overlap ``second_array`` in ``first_array``. + + Returns true if each element of the second array appears in the first array. + Otherwise, it returns false. + """ + return Expr(f.array_has_all(first_array.expr, second_array.expr)) + + +def array_has_any(first_array: Expr, second_array: Expr) -> Expr: + """Determine if there is an overlap between ``first_array`` and ``second_array``. + + Returns true if at least one element of the second array appears in the first + array. Otherwise, it returns false. + """ + return Expr(f.array_has_any(first_array.expr, second_array.expr)) + + +def array_position(array: Expr, element: Expr, index: int | None = 1) -> Expr: + """Return the position of the first occurrence of ``element`` in ``array``.""" + return Expr(f.array_position(array.expr, element.expr, index)) + + +def array_indexof(array: Expr, element: Expr, index: int | None = 1) -> Expr: + """Return the position of the first occurrence of ``element`` in ``array``. + + This is an alias for `array_position`. + """ + return array_position(array, element, index) + + +def list_position(array: Expr, element: Expr, index: int | None = 1) -> Expr: + """Return the position of the first occurrence of ``element`` in ``array``. + + This is an alias for `array_position`. + """ + return array_position(array, element, index) + + +def list_indexof(array: Expr, element: Expr, index: int | None = 1) -> Expr: + """Return the position of the first occurrence of ``element`` in ``array``. + + This is an alias for `array_position`. + """ + return array_position(array, element, index) + + +def array_positions(array: Expr, element: Expr) -> Expr: + """Searches for an element in the array and returns all occurrences.""" + return Expr(f.array_positions(array.expr, element.expr)) + + +def list_positions(array: Expr, element: Expr) -> Expr: + """Searches for an element in the array and returns all occurrences. + + This is an alias for `array_positions`. + """ + return array_positions(array, element) + + +def array_ndims(array: Expr) -> Expr: + """Returns the number of dimensions of the array.""" + return Expr(f.array_ndims(array.expr)) + + +def list_ndims(array: Expr) -> Expr: + """Returns the number of dimensions of the array. + + This is an alias for `array_ndims`. + """ + return array_ndims(array) + + +def array_prepend(element: Expr, array: Expr) -> Expr: + """Prepends an element to the beginning of an array.""" + return Expr(f.array_prepend(element.expr, array.expr)) + + +def array_push_front(element: Expr, array: Expr) -> Expr: + """Prepends an element to the beginning of an array. + + This is an alias for `array_prepend`. + """ + return array_prepend(element, array) + + +def list_prepend(element: Expr, array: Expr) -> Expr: + """Prepends an element to the beginning of an array. + + This is an alias for `array_prepend`. + """ + return array_prepend(element, array) + + +def list_push_front(element: Expr, array: Expr) -> Expr: + """Prepends an element to the beginning of an array. + + This is an alias for `array_prepend`. + """ + return array_prepend(element, array) + + +def array_pop_back(array: Expr) -> Expr: + """Returns the array without the last element.""" + return Expr(f.array_pop_back(array.expr)) + + +def array_pop_front(array: Expr) -> Expr: + """Returns the array without the first element.""" + return Expr(f.array_pop_front(array.expr)) + + +def array_remove(array: Expr, element: Expr) -> Expr: + """Removes the first element from the array equal to the given value.""" + return Expr(f.array_remove(array.expr, element.expr)) + + +def list_remove(array: Expr, element: Expr) -> Expr: + """Removes the first element from the array equal to the given value. + + This is an alias for `array_remove`. + """ + return array_remove(array, element) + + +def array_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: + """Removes the first `max` elements from the array equal to the given value.""" + return Expr(f.array_remove_n(array.expr, element.expr, max.expr)) + + +def list_remove_n(array: Expr, element: Expr, max: Expr) -> Expr: + """Removes the first `max` elements from the array equal to the given value. + + This is an alias for `array_remove_n`. + """ + return array_remove_n(array, element, max) + + +def array_remove_all(array: Expr, element: Expr) -> Expr: + """Removes all elements from the array equal to the given value.""" + return Expr(f.array_remove_all(array.expr, element.expr)) + + +def list_remove_all(array: Expr, element: Expr) -> Expr: + """Removes all elements from the array equal to the given value. + + This is an alias for `array_remove_all`. + """ + return array_remove_all(array, element) + + +def array_repeat(element: Expr, count: Expr) -> Expr: + """Returns an array containing `element` `count` times.""" + return Expr(f.array_repeat(element.expr, count.expr)) + + +def array_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: + """Replaces the first occurrence of ``from_val`` with ``to_val``.""" + return Expr(f.array_replace(array.expr, from_val.expr, to_val.expr)) + + +def list_replace(array: Expr, from_val: Expr, to_val: Expr) -> Expr: + """Replaces the first occurrence of ``from_val`` with ``to_val``. + + This is an alias for `array_replace`. + """ + return array_replace(array, from_val, to_val) + + +def array_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: + """Replace `n` occurrences of ``from_val`` with ``to_val``. + + Replaces the first `max` occurrences of the specified element with another + specified element. + """ + return Expr(f.array_replace_n(array.expr, from_val.expr, to_val.expr, max.expr)) + + +def list_replace_n(array: Expr, from_val: Expr, to_val: Expr, max: Expr) -> Expr: + """Replace `n` occurrences of ``from_val`` with ``to_val``. + + Replaces the first `max` occurrences of the specified element with another + specified element. + + This is an alias for `array_replace_n`. + """ + return array_replace_n(array, from_val, to_val, max) + + +def array_replace_all(array: Expr, from_val: Expr, to_val: Expr) -> Expr: + """Replaces all occurrences of ``from_val`` with ``to_val``.""" + return Expr(f.array_replace_all(array.expr, from_val.expr, to_val.expr)) + + +def list_replace_all(array: Expr, from_val: Expr, to_val: Expr) -> Expr: + """Replaces all occurrences of ``from_val`` with ``to_val``. + + This is an alias for `array_replace_all`. + """ + return array_replace_all(array, from_val, to_val) + + +def array_slice( + array: Expr, begin: Expr, end: Expr, stride: Expr | None = None +) -> Expr: + """Returns a slice of the array.""" + if stride is not None: + stride = stride.expr + return Expr(f.array_slice(array.expr, begin.expr, end.expr, stride)) + + +def list_slice(array: Expr, begin: Expr, end: Expr, stride: Expr | None = None) -> Expr: + """Returns a slice of the array. + + This is an alias for `array_slice`. + """ + return array_slice(array, begin, end, stride) + + +def array_intersect(array1: Expr, array2: Expr) -> Expr: + """Returns an array of the elements in the intersection of array1 and array2.""" + return Expr(f.array_intersect(array1.expr, array2.expr)) + + +def list_intersect(array1: Expr, array2: Expr) -> Expr: + """Returns an array of the elements in the intersection of `array1` and `array2`. + + This is an alias for `array_intersect`. + """ + return array_intersect(array1, array2) + + +def array_union(array1: Expr, array2: Expr) -> Expr: + """Returns an array of the elements in the union of array1 and array2. + + Duplicate rows will not be returned. + """ + return Expr(f.array_union(array1.expr, array2.expr)) + + +def list_union(array1: Expr, array2: Expr) -> Expr: + """Returns an array of the elements in the union of array1 and array2. + + Duplicate rows will not be returned. + + This is an alias for `array_union`. + """ + return array_union(array1, array2) + + +def array_except(array1: Expr, array2: Expr) -> Expr: + """Returns an array of the elements that appear in `array1` but not in `array2`.""" + return Expr(f.array_except(array1.expr, array2.expr)) + + +def list_except(array1: Expr, array2: Expr) -> Expr: + """Returns an array of the elements that appear in `array1` but not in the `array2`. + + This is an alias for `array_except`. + """ + return array_except(array1, array2) + + +def array_resize(array: Expr, size: Expr, value: Expr) -> Expr: + """Returns an array with the specified size filled. + + If `size` is greater than the `array` length, the additional entries will be filled + with the given `value`. + """ + return Expr(f.array_resize(array.expr, size.expr, value.expr)) + + +def list_resize(array: Expr, size: Expr, value: Expr) -> Expr: + """Returns an array with the specified size filled. + + If `size` is greater than the `array` length, the additional entries will be + filled with the given `value`. This is an alias for `array_resize`. + """ + return array_resize(array, size, value) + + +def flatten(array: Expr) -> Expr: + """Flattens an array of arrays into a single array.""" + return Expr(f.flatten(array.expr)) + + +# aggregate functions +def approx_distinct(arg: Expr) -> Expr: + """Returns the approximate number of distinct values.""" + return Expr(f.approx_distinct(arg.expr, distinct=True)) + + +def approx_median(arg: Expr, distinct: bool = False) -> Expr: + """Returns the approximate median value.""" + return Expr(f.approx_median(arg.expr, distinct=distinct)) + + +def approx_percentile_cont( + expr: Expr, + percentile: Expr, + num_centroids: int | None = None, + distinct: bool = False, +) -> Expr: + """Returns the value that is approximately at a given percentile of ``expr``.""" + if num_centroids is None: + return Expr( + f.approx_percentile_cont(expr.expr, percentile.expr, distinct=distinct) + ) + + return Expr( + f.approx_percentile_cont( + expr.expr, percentile.expr, num_centroids, distinct=distinct + ) + ) + + +def approx_percentile_cont_with_weight( + arg: Expr, weight: Expr, percentile: Expr, distinct: bool = False +) -> Expr: + """Returns the value of the approximate percentile. + + This function is similar to ``approx_percentile_cont`` except that it uses + the associated associated weights. + """ + return Expr( + f.approx_percentile_cont_with_weight( + arg.expr, weight.expr, percentile.expr, distinct=distinct + ) + ) + + +def array_agg(arg: Expr, distinct: bool = False) -> Expr: + """Aggregate values into an array.""" + return Expr(f.array_agg(arg.expr, distinct=distinct)) + + +def avg(arg: Expr, distinct: bool = False) -> Expr: + """Returns the average value.""" + return Expr(f.avg(arg.expr, distinct=distinct)) + + +def corr(value1: Expr, value2: Expr, distinct: bool = False) -> Expr: + """Returns the correlation coefficient between `value1` and `value2`.""" + return Expr(f.corr(value1.expr, value2.expr, distinct=distinct)) + + +def count(args: Expr | list[Expr] | None = None, distinct: bool = False) -> Expr: + """Returns the number of rows that match the given arguments.""" + if isinstance(args, list): + args = [arg.expr for arg in args] + elif isinstance(args, Expr): + args = [args.expr] + return Expr(f.count(*args, distinct=distinct)) + + +def covar(y: Expr, x: Expr) -> Expr: + """Computes the sample covariance. + + This is an alias for `covar_samp`. + """ + return Expr(f.covar(y.expr, x.expr)) + + +def covar_pop(y: Expr, x: Expr) -> Expr: + """Computes the population covariance.""" + return Expr(f.covar_pop(y.expr, x.expr)) + + +def covar_samp(y: Expr, x: Expr) -> Expr: + """Computes the sample covariance.""" + return Expr(f.covar_samp(y.expr, x.expr)) + + +def grouping(arg: Expr, distinct: bool = False) -> Expr: + """Indicates if the expression is aggregated or not. + + Returns 1 if the value of the argument is aggregated, 0 if not. + """ + return Expr(f.grouping([arg.expr], distinct=distinct)) + + +def max(arg: Expr, distinct: bool = False) -> Expr: + """Returns the maximum value of the arugment.""" + return Expr(f.max(arg.expr, distinct=distinct)) + + +def mean(arg: Expr, distinct: bool = False) -> Expr: + """Returns the average (mean) value of the argument. + + This is an alias for `avg`. + """ + return avg(arg, distinct) + + +def median(arg: Expr) -> Expr: + """Computes the median of a set of numbers.""" + return Expr(f.median(arg.expr)) + + +def min(arg: Expr, distinct: bool = False) -> Expr: + """Returns the minimum value of the argument.""" + return Expr(f.min(arg.expr, distinct=distinct)) + + +def sum(arg: Expr) -> Expr: + """Computes the sum of a set of numbers.""" + return Expr(f.sum(arg.expr)) + + +def stddev(arg: Expr, distinct: bool = False) -> Expr: + """Computes the standard deviation of the argument.""" + return Expr(f.stddev(arg.expr, distinct=distinct)) + + +def stddev_pop(arg: Expr, distinct: bool = False) -> Expr: + """Computes the population standard deviation of the argument.""" + return Expr(f.stddev_pop(arg.expr, distinct=distinct)) + + +def stddev_samp(arg: Expr, distinct: bool = False) -> Expr: + """Computes the sample standard deviation of the argument. + + This is an alias for `stddev`. + """ + return stddev(arg, distinct) + + +def var(arg: Expr) -> Expr: + """Computes the sample variance of the argument. + + This is an alias for `var_samp`. + """ + return var_samp(arg) + + +def var_pop(arg: Expr, distinct: bool = False) -> Expr: + """Computes the population variance of the argument.""" + return Expr(f.var_pop(arg.expr, distinct=distinct)) + + +def var_samp(arg: Expr) -> Expr: + """Computes the sample variance of the argument.""" + return Expr(f.var_samp(arg.expr)) + + +def regr_avgx(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """Computes the average of the independent variable `x`. + + Only non-null pairs of the inputs are evaluated. + """ + return Expr(f.regr_avgx[y.expr, x.expr], distinct) + + +def regr_avgy(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """Computes the average of the dependent variable ``y``. + + Only non-null pairs of the inputs are evaluated. + """ + return Expr(f.regr_avgy[y.expr, x.expr], distinct) + + +def regr_count(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """Counts the number of rows in which both expressions are not null.""" + return Expr(f.regr_count[y.expr, x.expr], distinct) + + +def regr_intercept(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """Computes the intercept from the linear regression.""" + return Expr(f.regr_intercept[y.expr, x.expr], distinct) + + +def regr_r2(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """Computes the R-squared value from linear regression.""" + return Expr(f.regr_r2[y.expr, x.expr], distinct) + + +def regr_slope(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """Computes the slope from linear regression.""" + return Expr(f.regr_slope[y.expr, x.expr], distinct) + + +def regr_sxx(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """Computes the sum of squares of the independent variable `x`.""" + return Expr(f.regr_sxx[y.expr, x.expr], distinct) + + +def regr_sxy(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """Computes the sum of products of pairs of numbers.""" + return Expr(f.regr_sxy[y.expr, x.expr], distinct) + + +def regr_syy(y: Expr, x: Expr, distinct: bool = False) -> Expr: + """Computes the sum of squares of the dependent variable `y`.""" + return Expr(f.regr_syy[y.expr, x.expr], distinct) + + +def first_value( + arg: Expr, + distinct: bool = False, + filter: bool = None, + order_by: Expr | None = None, + null_treatment: common.NullTreatment | None = None, +) -> Expr: + """Returns the first value in a group of values.""" + return Expr( + f.first_value( + arg.expr, + distinct=distinct, + filter=filter, + order_by=order_by, + null_treatment=null_treatment, + ) + ) + + +def last_value( + arg: Expr, + distinct: bool = False, + filter: bool = None, + order_by: Expr | None = None, + null_treatment: common.NullTreatment | None = None, +) -> Expr: + """Returns the last value in a group of values.""" + return Expr( + f.last_value( + arg.expr, + distinct=distinct, + filter=filter, + order_by=order_by, + null_treatment=null_treatment, + ) + ) + + +def bit_and(*args: Expr, distinct: bool = False) -> Expr: + """Computes the bitwise AND of the argument.""" + args = [arg.expr for arg in args] + return Expr(f.bit_and(*args, distinct=distinct)) + + +def bit_or(*args: Expr, distinct: bool = False) -> Expr: + """Computes the bitwise OR of the argument.""" + args = [arg.expr for arg in args] + return Expr(f.bit_or(*args, distinct=distinct)) + + +def bit_xor(*args: Expr, distinct: bool = False) -> Expr: + """Computes the bitwise XOR of the argument.""" + args = [arg.expr for arg in args] + return Expr(f.bit_xor(*args, distinct=distinct)) + + +def bool_and(*args: Expr, distinct: bool = False) -> Expr: + """Computes the boolean AND of the arugment.""" + args = [arg.expr for arg in args] + return Expr(f.bool_and(*args, distinct=distinct)) + + +def bool_or(*args: Expr, distinct: bool = False) -> Expr: + """Computes the boolean OR of the arguement.""" + args = [arg.expr for arg in args] + return Expr(f.bool_or(*args, distinct=distinct)) diff --git a/python/datafusion/input/__init__.py b/python/datafusion/input/__init__.py index 27e39b8c..f85ce21f 100644 --- a/python/datafusion/input/__init__.py +++ b/python/datafusion/input/__init__.py @@ -15,6 +15,11 @@ # specific language governing permissions and limitations # under the License. +"""This package provides for input sources. + +The primary class used within DataFusion is ``LocationInputPlugin``. +""" + from .location import LocationInputPlugin __all__ = [ diff --git a/python/datafusion/input/base.py b/python/datafusion/input/base.py index efcaf769..4eba1978 100644 --- a/python/datafusion/input/base.py +++ b/python/datafusion/input/base.py @@ -15,6 +15,11 @@ # specific language governing permissions and limitations # under the License. +"""This module provides ``BaseInputSource``. + +A user can extend this to provide a custom input source. +""" + from abc import ABC, abstractmethod from typing import Any @@ -22,18 +27,22 @@ class BaseInputSource(ABC): - """ - If a consuming library would like to provider their own InputSource - this is the class they should extend to write their own. Once - completed the Plugin InputSource can be registered with the + """Base Input Source class. + + If a consuming library would like to provider their own InputSource this is + the class they should extend to write their own. + + Once completed the Plugin InputSource can be registered with the SessionContext to ensure that it will be used in order to obtain the SqlTable information from the custom datasource. """ @abstractmethod def is_correct_input(self, input_item: Any, table_name: str, **kwargs) -> bool: + """Returns `True` if the input is valid.""" pass @abstractmethod def build_table(self, input_item: Any, table_name: str, **kwarg) -> SqlTable: + """Create a table from the input source.""" pass diff --git a/python/datafusion/input/location.py b/python/datafusion/input/location.py index 16e632d1..566a63da 100644 --- a/python/datafusion/input/location.py +++ b/python/datafusion/input/location.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +"""The default input source for DataFusion.""" + import os import glob from typing import Any @@ -24,12 +26,13 @@ class LocationInputPlugin(BaseInputSource): - """ - Input Plugin for everything, which can be read - in from a file (on disk, remote etc.) + """Input Plugin for everything. + + This can be read in from a file (on disk, remote etc.). """ def is_correct_input(self, input_item: Any, table_name: str, **kwargs): + """Returns `True` if the input is valid.""" return isinstance(input_item, str) def build_table( @@ -38,6 +41,7 @@ def build_table( table_name: str, **kwargs, ) -> SqlTable: + """Create a table from the input source.""" _, extension = os.path.splitext(input_file) format = extension.lstrip(".").lower() num_rows = 0 # Total number of rows in the file. Used for statistics diff --git a/python/datafusion/object_store.py b/python/datafusion/object_store.py index 70ecbd2b..a9bb83d2 100644 --- a/python/datafusion/object_store.py +++ b/python/datafusion/object_store.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Object store functionality.""" from ._internal import object_store diff --git a/python/datafusion/py.typed b/python/datafusion/py.typed new file mode 100644 index 00000000..d216be4d --- /dev/null +++ b/python/datafusion/py.typed @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. \ No newline at end of file diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py new file mode 100644 index 00000000..dcfd5548 --- /dev/null +++ b/python/datafusion/record_batch.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module provides the classes for handling record batches. + +These are typically the result of dataframe `execute_stream` operations. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import pyarrow + import datafusion._internal as df_internal + import typing_extensions + + +class RecordBatch: + """This class is essentially a wrapper for ``pyarrow.RecordBatch``.""" + + def __init__(self, record_batch: df_internal.RecordBatch) -> None: + """This constructor is generally not called by the end user. + + See the ``RecordBatchStream`` iterator for generating this class. + """ + self.record_batch = record_batch + + def to_pyarrow(self) -> pyarrow.RecordBatch: + """Convert to pyarrow ``RecordBatch``.""" + return self.record_batch.to_pyarrow() + + +class RecordBatchStream: + """This class represents a stream of record batches. + + These are typically the result of a ``DataFrame::execute_stream`` operation. + """ + + def __init__(self, record_batch_stream: df_internal.RecordBatchStream) -> None: + """This constructor is typically not called by the end user.""" + self.rbs = record_batch_stream + + def next(self) -> RecordBatch | None: + """See ``__next__`` for the iterator function.""" + try: + next_batch = next(self) + except StopIteration: + return None + + return next_batch + + def __next__(self) -> RecordBatch: + """Iterator function.""" + next_batch = next(self.rbs) + return RecordBatch(next_batch) + + def __iter__(self) -> typing_extensions.Self: + """Iterator function.""" + return self diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index eff809a0..a199dd73 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -15,9 +15,171 @@ # specific language governing permissions and limitations # under the License. +"""This module provides support for using substrait with datafusion. -from ._internal import substrait +For additional information about substrait, see https://substrait.io/ for more +information about substrait. +""" +from __future__ import annotations -def __getattr__(name): - return getattr(substrait, name) +from ._internal import substrait as substrait_internal + +from typing import TYPE_CHECKING +from typing_extensions import deprecated +import pathlib + +if TYPE_CHECKING: + from datafusion.context import SessionContext + from datafusion._internal import LogicalPlan + + +class Plan: + """A class representing an encodable substrait plan.""" + + def __init__(self, plan: substrait_internal.Plan) -> None: + """Create a substrait plan. + + The user should not have to call this constructor directly. Rather, it + should be created via `Serde` or `Producer` classes in this module. + """ + self.plan_internal = plan + + def encode(self) -> bytes: + """Encode the plan to bytes. + + Returns: + Encoded plan. + """ + return self.plan_internal.encode() + + +@deprecated("Use `Plan` instead.") +class plan(Plan): + """See `Plan`.""" + + pass + + +class Serde: + """Provides the ``Substrait`` serialization and deserialization.""" + + @staticmethod + def serialize(sql: str, ctx: SessionContext, path: str | pathlib.Path) -> None: + """Serialize a SQL query to a Substrait plan and write it to a file. + + Args: + sql:SQL query to serialize. + ctx: SessionContext to use. + path: Path to write the Substrait plan to. + """ + return substrait_internal.serde.serialize(sql, ctx.ctx, str(path)) + + @staticmethod + def serialize_to_plan(sql: str, ctx: SessionContext) -> Plan: + """Serialize a SQL query to a Substrait plan. + + Args: + sql: SQL query to serialize. + ctx: SessionContext to use. + + Returns: + Substrait plan. + """ + return Plan(substrait_internal.serde.serialize_to_plan(sql, ctx.ctx)) + + @staticmethod + def serialize_bytes(sql: str, ctx: SessionContext) -> bytes: + """Serialize a SQL query to a Substrait plan as bytes. + + Args: + sql: SQL query to serialize. + ctx: SessionContext to use. + + Returns: + Substrait plan as bytes. + """ + return substrait_internal.serde.serialize_bytes(sql, ctx.ctx) + + @staticmethod + def deserialize(path: str | pathlib.Path) -> Plan: + """Deserialize a Substrait plan from a file. + + Args: + path: Path to read the Substrait plan from. + + Returns: + Substrait plan. + """ + return Plan(substrait_internal.serde.deserialize(str(path))) + + @staticmethod + def deserialize_bytes(proto_bytes: bytes) -> Plan: + """Deserialize a Substrait plan from bytes. + + Args: + proto_bytes: Bytes to read the Substrait plan from. + + Returns: + Substrait plan. + """ + return Plan(substrait_internal.serde.deserialize_bytes(proto_bytes)) + + +@deprecated("Use `Serde` instead.") +class serde(Serde): + """See `Serde` instead.""" + + pass + + +class Producer: + """Generates substrait plans from a logical plan.""" + + @staticmethod + def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> Plan: + """Convert a DataFusion LogicalPlan to a Substrait plan. + + Args: + logical_plan: LogicalPlan to convert. + ctx: SessionContext to use. + + Returns: + Substrait plan. + """ + return Plan( + substrait_internal.producer.to_substrait_plan(logical_plan, ctx.ctx) + ) + + +@deprecated("Use `Producer` instead.") +class producer(Producer): + """Use `Producer` instead.""" + + pass + + +class Consumer: + """Generates a logical plan from a substrait plan.""" + + @staticmethod + def from_substrait_plan(ctx: SessionContext, plan: Plan) -> LogicalPlan: + """Convert a Substrait plan to a DataFusion LogicalPlan. + + Args: + ctx: SessionContext to use. + plan: Substrait plan to convert. + + Returns: + LogicalPlan. + """ + return substrait_internal.consumer.from_substrait_plan( + ctx.ctx, plan.plan_internal + ) + + +@deprecated("Use `Consumer` instead.") +class consumer(Consumer): + """Use `Consumer` instead.""" + + pass diff --git a/python/datafusion/tests/conftest.py b/python/datafusion/tests/conftest.py index a4eec41e..1cc07e50 100644 --- a/python/datafusion/tests/conftest.py +++ b/python/datafusion/tests/conftest.py @@ -18,6 +18,7 @@ import pytest from datafusion import SessionContext import pyarrow as pa +from pyarrow.csv import write_csv @pytest.fixture @@ -37,7 +38,7 @@ def database(ctx, tmp_path): ], names=["int", "str", "float"], ) - pa.csv.write_csv(table, path) + write_csv(table, path) ctx.register_csv("csv", path) ctx.register_csv("csv1", str(path)) diff --git a/python/datafusion/tests/test_context.py b/python/datafusion/tests/test_context.py index abc324db..8373659b 100644 --- a/python/datafusion/tests/test_context.py +++ b/python/datafusion/tests/test_context.py @@ -17,6 +17,7 @@ import gzip import os import datetime as dt +import pathlib import pyarrow as pa import pyarrow.dataset as ds @@ -37,6 +38,36 @@ def test_create_context_no_args(): SessionContext() +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_runtime_configs(tmp_path, path_to_str): + path1 = tmp_path / "dir1" + path2 = tmp_path / "dir2" + + path1 = str(path1) if path_to_str else path1 + path2 = str(path2) if path_to_str else path2 + + runtime = RuntimeConfig().with_disk_manager_specified(path1, path2) + config = SessionConfig().with_default_catalog_and_schema("foo", "bar") + ctx = SessionContext(config, runtime) + assert ctx is not None + + db = ctx.catalog("foo").database("bar") + assert db is not None + + +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_temporary_files(tmp_path, path_to_str): + path = str(tmp_path) if path_to_str else tmp_path + + runtime = RuntimeConfig().with_temp_file_path(path) + config = SessionConfig().with_default_catalog_and_schema("foo", "bar") + ctx = SessionContext(config, runtime) + assert ctx is not None + + db = ctx.catalog("foo").database("bar") + assert db is not None + + def test_create_context_with_all_valid_args(): runtime = RuntimeConfig().with_disk_manager_os().with_fair_spill_pool(10000000) config = ( @@ -68,7 +99,7 @@ def test_register_record_batches(ctx): ctx.register_record_batches("t", [[batch]]) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} result = ctx.sql("SELECT a+b, a-b FROM t").collect() @@ -84,7 +115,7 @@ def test_create_dataframe_registers_unique_table_name(ctx): ) df = ctx.create_dataframe([[batch]]) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 @@ -104,7 +135,7 @@ def test_create_dataframe_registers_with_defined_table_name(ctx): ) df = ctx.create_dataframe([[batch]], name="tbl") - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 @@ -118,11 +149,11 @@ def test_from_arrow_table(ctx): # convert to DataFrame df = ctx.from_arrow_table(table) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 - assert type(df) == DataFrame + assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 @@ -134,7 +165,7 @@ def test_from_arrow_table_with_name(ctx): # convert to DataFrame with optional name df = ctx.from_arrow_table(table, name="tbl") - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert tables[0] == "tbl" @@ -147,7 +178,7 @@ def test_from_arrow_table_empty(ctx): # convert to DataFrame df = ctx.from_arrow_table(table) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 @@ -162,7 +193,7 @@ def test_from_arrow_table_empty_no_schema(ctx): # convert to DataFrame df = ctx.from_arrow_table(table) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 @@ -180,11 +211,11 @@ def test_from_pylist(ctx): ] df = ctx.from_pylist(data) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 - assert type(df) == DataFrame + assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 @@ -194,11 +225,11 @@ def test_from_pydict(ctx): data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = ctx.from_pydict(data) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 - assert type(df) == DataFrame + assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 @@ -210,11 +241,11 @@ def test_from_pandas(ctx): pandas_df = pd.DataFrame(data) df = ctx.from_pandas(pandas_df) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 - assert type(df) == DataFrame + assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 @@ -226,11 +257,11 @@ def test_from_polars(ctx): polars_df = pd.DataFrame(data) df = ctx.from_polars(polars_df) - tables = list(ctx.tables()) + tables = list(ctx.catalog().database().names()) assert df assert len(tables) == 1 - assert type(df) == DataFrame + assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 @@ -273,7 +304,7 @@ def test_register_dataset(ctx): dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} result = ctx.sql("SELECT a+b, a-b FROM t").collect() @@ -290,7 +321,7 @@ def test_dataset_filter(ctx, capfd): dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} df = ctx.sql("SELECT a+b, a-b FROM t WHERE a BETWEEN 2 and 3 AND b > 5") # Make sure the filter was pushed down in Physical Plan @@ -370,7 +401,7 @@ def test_dataset_filter_nested_data(ctx): dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} df = ctx.table("t") @@ -468,13 +499,23 @@ def test_read_csv_compressed(ctx, tmp_path): def test_read_parquet(ctx): - csv_df = ctx.read_parquet(path="parquet/data/alltypes_plain.parquet") - csv_df.show() + parquet_df = ctx.read_parquet(path="parquet/data/alltypes_plain.parquet") + parquet_df.show() + assert parquet_df is not None + + path = pathlib.Path.cwd() / "parquet/data/alltypes_plain.parquet" + parquet_df = ctx.read_parquet(path=path) + assert parquet_df is not None def test_read_avro(ctx): - csv_df = ctx.read_avro(path="testing/data/avro/alltypes_plain.avro") - csv_df.show() + avro_df = ctx.read_avro(path="testing/data/avro/alltypes_plain.avro") + avro_df.show() + assert avro_df is not None + + path = pathlib.Path.cwd() / "testing/data/avro/alltypes_plain.avro" + avro_df = ctx.read_avro(path=path) + assert avro_df is not None def test_create_sql_options(): diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index 2f6a818e..25875da7 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -17,6 +17,7 @@ import os import pyarrow as pa +from pyarrow.csv import write_csv import pyarrow.parquet as pq import pytest @@ -96,6 +97,16 @@ def test_select(df): assert result.column(1) == pa.array([-3, -3, -3]) +def test_select_mixed_expr_string(df): + df = df.select_columns(column("b"), "a") + + # execute and collect the first (and only) batch + result = df.collect()[0] + + assert result.column(0) == pa.array([4, 5, 6]) + assert result.column(1) == pa.array([1, 2, 3]) + + def test_select_columns(df): df = df.select_columns("b", "a") @@ -107,17 +118,29 @@ def test_select_columns(df): def test_filter(df): - df = df.filter(column("a") > literal(2)).select( + df1 = df.filter(column("a") > literal(2)).select( column("a") + column("b"), column("a") - column("b"), ) # execute and collect the first (and only) batch - result = df.collect()[0] + result = df1.collect()[0] assert result.column(0) == pa.array([9]) assert result.column(1) == pa.array([-3]) + df.show() + # verify that if there is no filter applied, internal dataframe is unchanged + df2 = df.filter() + assert df.df == df2.df + + df3 = df.filter(column("a") > literal(1), column("b") != literal(6)) + result = df3.collect()[0] + + assert result.column(0) == pa.array([2]) + assert result.column(1) == pa.array([5]) + assert result.column(2) == pa.array([5]) + def test_sort(df): df = df.sort(column("b").sort(ascending=False)) @@ -175,7 +198,7 @@ def test_with_column_renamed(df): def test_unnest(nested_df): - nested_df = nested_df.unnest_column("a") + nested_df = nested_df.unnest_columns("a") # execute and collect the first (and only) batch result = nested_df.collect()[0] @@ -185,7 +208,7 @@ def test_unnest(nested_df): def test_unnest_without_nulls(nested_df): - nested_df = nested_df.unnest_column("a", preserve_nulls=False) + nested_df = nested_df.unnest_columns("a", preserve_nulls=False) # execute and collect the first (and only) batch result = nested_df.collect()[0] @@ -379,7 +402,7 @@ def test_get_dataframe(tmp_path): ], names=["int", "str", "float"], ) - pa.csv.write_csv(table, path) + write_csv(table, path) ctx.register_csv("csv", path) @@ -611,7 +634,7 @@ def test_to_pandas(df): # Convert datafusion dataframe to pandas dataframe pandas_df = df.to_pandas() - assert type(pandas_df) == pd.DataFrame + assert isinstance(pandas_df, pd.DataFrame) assert pandas_df.shape == (3, 3) assert set(pandas_df.columns) == {"a", "b", "c"} @@ -622,7 +645,7 @@ def test_empty_to_pandas(df): # Convert empty datafusion dataframe to pandas dataframe pandas_df = df.limit(0).to_pandas() - assert type(pandas_df) == pd.DataFrame + assert isinstance(pandas_df, pd.DataFrame) assert pandas_df.shape == (0, 3) assert set(pandas_df.columns) == {"a", "b", "c"} @@ -633,7 +656,7 @@ def test_to_polars(df): # Convert datafusion dataframe to polars dataframe polars_df = df.to_polars() - assert type(polars_df) == pl.DataFrame + assert isinstance(polars_df, pl.DataFrame) assert polars_df.shape == (3, 3) assert set(polars_df.columns) == {"a", "b", "c"} @@ -644,7 +667,7 @@ def test_empty_to_polars(df): # Convert empty datafusion dataframe to polars dataframe polars_df = df.limit(0).to_polars() - assert type(polars_df) == pl.DataFrame + assert isinstance(polars_df, pl.DataFrame) assert polars_df.shape == (0, 3) assert set(polars_df.columns) == {"a", "b", "c"} @@ -652,13 +675,15 @@ def test_empty_to_polars(df): def test_to_arrow_table(df): # Convert datafusion dataframe to pyarrow Table pyarrow_table = df.to_arrow_table() - assert type(pyarrow_table) == pa.Table + assert isinstance(pyarrow_table, pa.Table) assert pyarrow_table.shape == (3, 3) assert set(pyarrow_table.column_names) == {"a", "b", "c"} def test_execute_stream(df): stream = df.execute_stream() + for s in stream: + print(type(s)) assert all(batch is not None for batch in stream) assert not list(stream) # after one iteration the generator must be exhausted @@ -690,7 +715,7 @@ def test_execute_stream_partitioned(df): def test_empty_to_arrow_table(df): # Convert empty datafusion dataframe to pyarrow Table pyarrow_table = df.limit(0).to_arrow_table() - assert type(pyarrow_table) == pa.Table + assert isinstance(pyarrow_table, pa.Table) assert pyarrow_table.shape == (0, 3) assert set(pyarrow_table.column_names) == {"a", "b", "c"} @@ -736,8 +761,35 @@ def test_describe(df): } -def test_write_parquet(df, tmp_path): - path = tmp_path +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_write_csv(ctx, df, tmp_path, path_to_str): + path = str(tmp_path) if path_to_str else tmp_path + + df.write_csv(path, with_header=True) + + ctx.register_csv("csv", path) + result = ctx.table("csv").to_pydict() + expected = df.to_pydict() + + assert result == expected + + +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_write_json(ctx, df, tmp_path, path_to_str): + path = str(tmp_path) if path_to_str else tmp_path + + df.write_json(path) + + ctx.register_json("json", path) + result = ctx.table("json").to_pydict() + expected = df.to_pydict() + + assert result == expected + + +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_write_parquet(df, tmp_path, path_to_str): + path = str(tmp_path) if path_to_str else tmp_path df.write_parquet(str(path)) result = pq.read_table(str(path)).to_pydict() @@ -795,3 +847,15 @@ def test_write_compressed_parquet_missing_compression_level(df, tmp_path, compre with pytest.raises(ValueError): df.write_parquet(str(path), compression=compression) + + +# ctx = SessionContext() + +# # create a RecordBatch and a new DataFrame from it +# batch = pa.RecordBatch.from_arrays( +# [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])], +# names=["a", "b", "c"], +# ) + +# df = ctx.create_dataframe([[batch]]) +# test_execute_stream(df) diff --git a/python/datafusion/tests/test_expr.py b/python/datafusion/tests/test_expr.py index 73f7d087..c9f0e98d 100644 --- a/python/datafusion/tests/test_expr.py +++ b/python/datafusion/tests/test_expr.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from datafusion import SessionContext +from datafusion import SessionContext, col from datafusion.expr import Column, Literal, BinaryExpr, AggregateFunction from datafusion.expr import ( Projection, @@ -25,6 +25,7 @@ Sort, TableScan, ) +import pyarrow import pytest @@ -116,3 +117,25 @@ def test_sort(test_ctx): plan = plan.to_variant() assert isinstance(plan, Sort) + + +def test_relational_expr(test_ctx): + ctx = SessionContext() + + batch = pyarrow.RecordBatch.from_arrays( + [pyarrow.array([1, 2, 3]), pyarrow.array(["alpha", "beta", "gamma"])], + names=["a", "b"], + ) + df = ctx.create_dataframe([[batch]], name="batch_array") + + assert df.filter(col("a") == 1).count() == 1 + assert df.filter(col("a") != 1).count() == 2 + assert df.filter(col("a") >= 1).count() == 3 + assert df.filter(col("a") > 1).count() == 2 + assert df.filter(col("a") <= 3).count() == 3 + assert df.filter(col("a") < 3).count() == 2 + + assert df.filter(col("b") == "beta").count() == 1 + assert df.filter(col("b") != "beta").count() == 2 + + assert df.filter(col("a") == "beta").count() == 0 diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 449f706c..2384b6ab 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -54,12 +54,11 @@ def test_named_struct(df): df = df.with_column( "d", f.named_struct( - literal("a"), - column("a"), - literal("b"), - column("b"), - literal("c"), - column("c"), + [ + ("a", column("a")), + ("b", column("b")), + ("c", column("c")), + ] ), ) @@ -97,9 +96,7 @@ def test_literal(df): def test_lit_arith(df): - """ - Test literals with arithmetic operations - """ + """Test literals with arithmetic operations""" df = df.select(literal(1) + column("b"), f.concat(column("a"), literal("!"))) result = df.collect() assert len(result) == 1 @@ -140,6 +137,7 @@ def test_math_functions(): f.power(col_v, literal(pa.scalar(3))), f.pow(col_v, literal(pa.scalar(4))), f.round(col_v), + f.round(col_v, literal(pa.scalar(3))), f.sqrt(col_v), f.signum(col_v), f.trunc(col_v), @@ -183,29 +181,30 @@ def test_math_functions(): np.testing.assert_array_almost_equal(result.column(15), np.power(values, 3)) np.testing.assert_array_almost_equal(result.column(16), np.power(values, 4)) np.testing.assert_array_almost_equal(result.column(17), np.round(values)) - np.testing.assert_array_almost_equal(result.column(18), np.sqrt(values)) - np.testing.assert_array_almost_equal(result.column(19), np.sign(values)) - np.testing.assert_array_almost_equal(result.column(20), np.trunc(values)) - np.testing.assert_array_almost_equal(result.column(21), np.arcsinh(values)) - np.testing.assert_array_almost_equal(result.column(22), np.arccosh(values)) - np.testing.assert_array_almost_equal(result.column(23), np.arctanh(values)) - np.testing.assert_array_almost_equal(result.column(24), np.cbrt(values)) - np.testing.assert_array_almost_equal(result.column(25), np.cosh(values)) - np.testing.assert_array_almost_equal(result.column(26), np.degrees(values)) - np.testing.assert_array_almost_equal(result.column(27), np.gcd(9, 3)) - np.testing.assert_array_almost_equal(result.column(28), np.lcm(6, 4)) + np.testing.assert_array_almost_equal(result.column(18), np.round(values, 3)) + np.testing.assert_array_almost_equal(result.column(19), np.sqrt(values)) + np.testing.assert_array_almost_equal(result.column(20), np.sign(values)) + np.testing.assert_array_almost_equal(result.column(21), np.trunc(values)) + np.testing.assert_array_almost_equal(result.column(22), np.arcsinh(values)) + np.testing.assert_array_almost_equal(result.column(23), np.arccosh(values)) + np.testing.assert_array_almost_equal(result.column(24), np.arctanh(values)) + np.testing.assert_array_almost_equal(result.column(25), np.cbrt(values)) + np.testing.assert_array_almost_equal(result.column(26), np.cosh(values)) + np.testing.assert_array_almost_equal(result.column(27), np.degrees(values)) + np.testing.assert_array_almost_equal(result.column(28), np.gcd(9, 3)) + np.testing.assert_array_almost_equal(result.column(29), np.lcm(6, 4)) np.testing.assert_array_almost_equal( - result.column(29), np.where(np.isnan(na_values), 5, na_values) + result.column(30), np.where(np.isnan(na_values), 5, na_values) ) - np.testing.assert_array_almost_equal(result.column(30), np.pi) - np.testing.assert_array_almost_equal(result.column(31), np.radians(values)) - np.testing.assert_array_almost_equal(result.column(32), np.sinh(values)) - np.testing.assert_array_almost_equal(result.column(33), np.tanh(values)) - np.testing.assert_array_almost_equal(result.column(34), math.factorial(6)) - np.testing.assert_array_almost_equal(result.column(35), np.isnan(na_values)) - np.testing.assert_array_almost_equal(result.column(36), na_values == 0) + np.testing.assert_array_almost_equal(result.column(31), np.pi) + np.testing.assert_array_almost_equal(result.column(32), np.radians(values)) + np.testing.assert_array_almost_equal(result.column(33), np.sinh(values)) + np.testing.assert_array_almost_equal(result.column(34), np.tanh(values)) + np.testing.assert_array_almost_equal(result.column(35), math.factorial(6)) + np.testing.assert_array_almost_equal(result.column(36), np.isnan(na_values)) + np.testing.assert_array_almost_equal(result.column(37), na_values == 0) np.testing.assert_array_almost_equal( - result.column(37), np.emath.logn(3, values + 1.0) + result.column(38), np.emath.logn(3, values + 1.0) ) @@ -591,7 +590,12 @@ def test_string_functions(df): f.trim(column("c")), f.upper(column("c")), f.ends_with(column("a"), literal("llo")), + f.overlay(column("a"), literal("--"), literal(2)), + f.regexp_like(column("a"), literal("(ell|orl)")), + f.regexp_match(column("a"), literal("(ell|orl)")), + f.regexp_replace(column("a"), literal("(ell|orl)"), literal("-")), ) + result = df.collect() assert len(result) == 1 result = result[0] @@ -632,6 +636,10 @@ def test_string_functions(df): assert result.column(26) == pa.array(["hello", "world", "!"]) assert result.column(27) == pa.array(["HELLO ", " WORLD ", " !"]) assert result.column(28) == pa.array([True, False, False]) + assert result.column(29) == pa.array(["H--lo", "W--ld", "--"]) + assert result.column(30) == pa.array([True, True, False]) + assert result.column(31) == pa.array([["ell"], ["orl"], None]) + assert result.column(32) == pa.array(["H-o", "W-d", "!"]) def test_hash_functions(df): diff --git a/python/datafusion/tests/test_imports.py b/python/datafusion/tests/test_imports.py index bd4e7c31..3d324fb6 100644 --- a/python/datafusion/tests/test_imports.py +++ b/python/datafusion/tests/test_imports.py @@ -94,13 +94,24 @@ def test_datafusion_python_version(): def test_class_module_is_datafusion(): + # context for klass in [ SessionContext, + ]: + assert klass.__module__ == "datafusion.context" + + # dataframe + for klass in [ DataFrame, - ScalarUDF, + ]: + assert klass.__module__ == "datafusion.dataframe" + + # udf + for klass in [ AggregateUDF, + ScalarUDF, ]: - assert klass.__module__ == "datafusion" + assert klass.__module__ == "datafusion.udf" # expressions for klass in [Expr, Column, Literal, BinaryExpr, AggregateFunction]: diff --git a/python/datafusion/tests/test_sql.py b/python/datafusion/tests/test_sql.py index 8ec2ffb1..d85f380e 100644 --- a/python/datafusion/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -19,6 +19,7 @@ import numpy as np import pyarrow as pa +from pyarrow.csv import write_csv import pyarrow.dataset as ds import pytest from datafusion.object_store import LocalFileSystem @@ -45,7 +46,7 @@ def test_register_csv(ctx, tmp_path): ], names=["int", "str", "float"], ) - pa.csv.write_csv(table, path) + write_csv(table, path) with open(path, "rb") as csv_file: with gzip.open(gzip_path, "wb") as gzipped_file: @@ -76,7 +77,13 @@ def test_register_csv(ctx, tmp_path): ) ctx.register_csv("csv3", path, schema=alternative_schema) - assert ctx.tables() == {"csv", "csv1", "csv2", "csv3", "csv_gzip"} + assert ctx.catalog().database().names() == { + "csv", + "csv1", + "csv2", + "csv3", + "csv_gzip", + } for table in ["csv", "csv1", "csv2", "csv_gzip"]: result = ctx.sql(f"SELECT COUNT(int) AS cnt FROM {table}").collect() @@ -100,14 +107,16 @@ def test_register_csv(ctx, tmp_path): def test_register_parquet(ctx, tmp_path): path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) ctx.register_parquet("t", path) - assert ctx.tables() == {"t"} + ctx.register_parquet("t1", str(path)) + assert ctx.catalog().database().names() == {"t", "t1"} result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() result = pa.Table.from_batches(result) assert result.to_pydict() == {"cnt": [100]} -def test_register_parquet_partitioned(ctx, tmp_path): +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_register_parquet_partitioned(ctx, tmp_path, path_to_str): dir_root = tmp_path / "dataset_parquet_partitioned" dir_root.mkdir(exist_ok=False) (dir_root / "grp=a").mkdir(exist_ok=False) @@ -124,14 +133,16 @@ def test_register_parquet_partitioned(ctx, tmp_path): pa.parquet.write_table(table.slice(0, 3), dir_root / "grp=a/file.parquet") pa.parquet.write_table(table.slice(3, 4), dir_root / "grp=b/file.parquet") + dir_root = str(dir_root) if path_to_str else dir_root + ctx.register_parquet( "datapp", - str(dir_root), + dir_root, table_partition_cols=[("grp", "string")], parquet_pruning=True, file_extension=".parquet", ) - assert ctx.tables() == {"datapp"} + assert ctx.catalog().database().names() == {"datapp"} result = ctx.sql("SELECT grp, COUNT(*) AS cnt FROM datapp GROUP BY grp").collect() result = pa.Table.from_batches(result) @@ -140,12 +151,14 @@ def test_register_parquet_partitioned(ctx, tmp_path): assert dict(zip(rd["grp"], rd["cnt"])) == {"a": 3, "b": 1} -def test_register_dataset(ctx, tmp_path): +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_register_dataset(ctx, tmp_path, path_to_str): path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) + path = str(path) if path_to_str else path dataset = ds.dataset(path, format="parquet") ctx.register_dataset("t", dataset) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect() result = pa.Table.from_batches(result) @@ -174,6 +187,12 @@ def test_register_json(ctx, tmp_path): file_extension="gz", file_compression_type="gzip", ) + ctx.register_json( + "json_gzip1", + str(gzip_path), + file_extension="gz", + file_compression_type="gzip", + ) alternative_schema = pa.schema( [ @@ -184,7 +203,14 @@ def test_register_json(ctx, tmp_path): ) ctx.register_json("json3", path, schema=alternative_schema) - assert ctx.tables() == {"json", "json1", "json2", "json3", "json_gzip"} + assert ctx.catalog().database().names() == { + "json", + "json1", + "json2", + "json3", + "json_gzip", + "json_gzip1", + } for table in ["json", "json1", "json2", "json_gzip"]: result = ctx.sql(f'SELECT COUNT("B") AS cnt FROM {table}').collect() @@ -234,7 +260,7 @@ def test_execute(ctx, tmp_path): path = helpers.write_parquet(tmp_path / "a.parquet", pa.array(data)) ctx.register_parquet("t", path) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} # count result = ctx.sql("SELECT COUNT(a) AS cnt FROM t WHERE a IS NOT NULL").collect() @@ -280,9 +306,7 @@ def test_execute(ctx, tmp_path): def test_cast(ctx, tmp_path): - """ - Verify that we can cast - """ + """Verify that we can cast""" path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) ctx.register_parquet("t", path) @@ -379,7 +403,10 @@ def test_simple_select(ctx, tmp_path, arr): @pytest.mark.parametrize("file_sort_order", (None, [[col("int").sort(True, True)]])) @pytest.mark.parametrize("pass_schema", (True, False)) -def test_register_listing_table(ctx, tmp_path, pass_schema, file_sort_order): +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_register_listing_table( + ctx, tmp_path, pass_schema, file_sort_order, path_to_str +): dir_root = tmp_path / "dataset_parquet_partitioned" dir_root.mkdir(exist_ok=False) (dir_root / "grp=a/date_id=20201005").mkdir(exist_ok=False, parents=True) @@ -404,16 +431,18 @@ def test_register_listing_table(ctx, tmp_path, pass_schema, file_sort_order): table.slice(5, 10), dir_root / "grp=b/date_id=20201005/file.parquet" ) + dir_root = f"file://{dir_root}/" if path_to_str else dir_root + ctx.register_object_store("file://local", LocalFileSystem(), None) ctx.register_listing_table( "my_table", - f"file://{dir_root}/", + dir_root, table_partition_cols=[("grp", "string"), ("date_id", "int")], file_extension=".parquet", schema=table.schema if pass_schema else None, file_sort_order=file_sort_order, ) - assert ctx.tables() == {"my_table"} + assert ctx.catalog().database().names() == {"my_table"} result = ctx.sql( "SELECT grp, COUNT(*) AS count FROM my_table GROUP BY grp" diff --git a/python/datafusion/tests/test_substrait.py b/python/datafusion/tests/test_substrait.py index 62f6413a..2071c8f3 100644 --- a/python/datafusion/tests/test_substrait.py +++ b/python/datafusion/tests/test_substrait.py @@ -35,17 +35,43 @@ def test_substrait_serialization(ctx): ctx.register_record_batches("t", [[batch]]) - assert ctx.tables() == {"t"} + assert ctx.catalog().database().names() == {"t"} # For now just make sure the method calls blow up - substrait_plan = ss.substrait.serde.serialize_to_plan("SELECT * FROM t", ctx) + substrait_plan = ss.Serde.serialize_to_plan("SELECT * FROM t", ctx) substrait_bytes = substrait_plan.encode() assert isinstance(substrait_bytes, bytes) - substrait_bytes = ss.substrait.serde.serialize_bytes("SELECT * FROM t", ctx) - substrait_plan = ss.substrait.serde.deserialize_bytes(substrait_bytes) - logical_plan = ss.substrait.consumer.from_substrait_plan(ctx, substrait_plan) + substrait_bytes = ss.Serde.serialize_bytes("SELECT * FROM t", ctx) + substrait_plan = ss.Serde.deserialize_bytes(substrait_bytes) + logical_plan = ss.Consumer.from_substrait_plan(ctx, substrait_plan) # demonstrate how to create a DataFrame from a deserialized logical plan df = ctx.create_dataframe_from_logical_plan(logical_plan) - substrait_plan = ss.substrait.producer.to_substrait_plan(df.logical_plan(), ctx) + substrait_plan = ss.Producer.to_substrait_plan(df.logical_plan(), ctx) + + +@pytest.mark.parametrize("path_to_str", (True, False)) +def test_substrait_file_serialization(ctx, tmp_path, path_to_str): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + + ctx.register_record_batches("t", [[batch]]) + + assert ctx.catalog().database().names() == {"t"} + + path = tmp_path / "substrait_plan" + path = str(path) if path_to_str else path + + sql_command = "SELECT * FROM T" + ss.Serde.serialize(sql_command, ctx, path) + + expected_plan = ss.Serde.serialize_to_plan(sql_command, ctx) + actual_plan = ss.Serde.deserialize(path) + + expected_logical_plan = ss.Consumer.from_substrait_plan(ctx, expected_plan) + expected_actual_plan = ss.Consumer.from_substrait_plan(ctx, actual_plan) + + assert str(expected_logical_plan) == str(expected_actual_plan) diff --git a/python/datafusion/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py index c2b29d19..81194927 100644 --- a/python/datafusion/tests/test_udaf.py +++ b/python/datafusion/tests/test_udaf.py @@ -25,9 +25,7 @@ class Summarize(Accumulator): - """ - Interface of a user-defined accumulation. - """ + """Interface of a user-defined accumulation.""" def __init__(self): self._sum = pa.scalar(0.0) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py new file mode 100644 index 00000000..4bfbabe6 --- /dev/null +++ b/python/datafusion/udf.py @@ -0,0 +1,248 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Provides the user defined functions for evaluation of dataframes.""" + +from __future__ import annotations + +import datafusion._internal as df_internal +from datafusion.expr import Expr +from typing import Callable, TYPE_CHECKING, TypeVar +from abc import ABCMeta, abstractmethod +from typing import List +from enum import Enum +import pyarrow + +if TYPE_CHECKING: + _R = TypeVar("_R", bound=pyarrow.DataType) + + +class Volatility(Enum): + """Defines how stable or volatile a function is. + + When setting the volatility of a function, you can either pass this + enumeration or a `str`. The `str` equivalent is the lower case value of the + name (`"immutable"`, `"stable"`, or `"volatile"`). + """ + + Immutable = 1 + """An immutable function will always return the same output when given the + same input. + + DataFusion will attempt to inline immutable functions during planning. + """ + + Stable = 2 + """ + Returns the same value for a given input within a single queries. + + A stable function may return different values given the same input across + different queries but must return the same value for a given input within a + query. An example of this is the `Now` function. DataFusion will attempt to + inline `Stable` functions during planning, when possible. For query + `select col1, now() from t1`, it might take a while to execute but `now()` + column will be the same for each output row, which is evaluated during + planning. + """ + + Volatile = 3 + """A volatile function may change the return value from evaluation to + evaluation. + + Multiple invocations of a volatile function may return different results + when used in the same query. An example of this is the random() function. + DataFusion can not evaluate such functions during planning. In the query + `select col1, random() from t1`, `random()` function will be evaluated + for each output row, resulting in a unique random value for each row. + """ + + def __str__(self): + """Returns the string equivalent.""" + return self.name.lower() + + +class ScalarUDF: + """Class for performing scalar user defined functions (UDF). + + Scalar UDFs operate on a row by row basis. See also ``AggregateUDF`` for + operating on a group of rows. + """ + + def __init__( + self, + name: str | None, + func: Callable[..., _R], + input_types: list[pyarrow.DataType], + return_type: _R, + volatility: Volatility | str, + ) -> None: + """Instantiate a scalar user defined function (UDF). + + See helper method ``udf`` for argument details. + """ + self.udf = df_internal.ScalarUDF( + name, func, input_types, return_type, str(volatility) + ) + + def __call__(self, *args: Expr) -> Expr: + """Execute the UDF. + + This function is not typically called by an end user. These calls will + occur during the evaluation of the dataframe. + """ + args = [arg.expr for arg in args] + return Expr(self.udf.__call__(*args)) + + @staticmethod + def udf( + func: Callable[..., _R], + input_types: list[pyarrow.DataType], + return_type: _R, + volatility: Volatility | str, + name: str | None = None, + ) -> ScalarUDF: + """Create a new User Defined Function. + + Args: + func: A callable python function. + input_types: The data types of the arguments to `func`. This list + must be of the same length as the number of arguments. + return_type: The data type of the return value from the python + function. + volatility: See ``Volatility`` for allowed values. + name: A descriptive name for the function. + + Returns: + A user defined aggregate function, which can be used in either data + aggregation or window function calls. + """ + if not callable(func): + raise TypeError("`func` argument must be callable") + if name is None: + name = func.__qualname__.lower() + return ScalarUDF( + name=name, + func=func, + input_types=input_types, + return_type=return_type, + volatility=volatility, + ) + + +class Accumulator(metaclass=ABCMeta): + """Defines how an `AggregateUDF` accumulates values during an evaluation.""" + + @abstractmethod + def state(self) -> List[pyarrow.Scalar]: + """Return the current state.""" + pass + + @abstractmethod + def update(self, values: pyarrow.Array) -> None: + """Evalute an array of values and update state.""" + pass + + @abstractmethod + def merge(self, states: pyarrow.Array) -> None: + """Merge a set of states.""" + pass + + @abstractmethod + def evaluate(self) -> pyarrow.Scalar: + """Return the resultant value.""" + pass + + +if TYPE_CHECKING: + _A = TypeVar("_A", bound=(Callable[..., _R], Accumulator)) + + +class AggregateUDF: + """Class for performing scalar user defined functions (UDF). + + Aggregate UDFs operate on a group of rows and return a single value. See + also ``ScalarUDF`` for operating on a row by row basis. + """ + + def __init__( + self, + name: str | None, + accumulator: _A, + input_types: list[pyarrow.DataType], + return_type: _R, + state_type: list[pyarrow.DataType], + volatility: Volatility | str, + ) -> None: + """Instantiate a user defined aggregate function (UDAF). + + See ``Aggregate::udaf`` for a convenience function and arugment + descriptions. + """ + self.udf = df_internal.AggregateUDF( + name, accumulator, input_types, return_type, state_type, str(volatility) + ) + + def __call__(self, *args: Expr) -> Expr: + """Execute the UDAF. + + This function is not typically called by an end user. These calls will + occur during the evaluation of the dataframe. + """ + args = [arg.expr for arg in args] + return Expr(self.udf.__call__(*args)) + + @staticmethod + def udaf( + accum: _A, + input_types: list[pyarrow.DataType], + return_type: _R, + state_type: list[pyarrow.DataType], + volatility: Volatility | str, + name: str | None = None, + ) -> AggregateUDF: + """Create a new User Defined Aggregate Function. + + The accumulator function must be callable and implement `Accumulator`. + + Args: + accum: The accumulator python function. + input_types: The data types of the arguments to `accum`. + return_type: The data type of the return value. + state_type: The data types of the intermediate accumulation. + volatility: See `Volatility` for allowed values. + name: A descriptive name for the function. + + Returns: + A user defined aggregate function, which can be used in either data + aggregation or window function calls. + """ + if not issubclass(accum, Accumulator): + raise TypeError( + "`accum` must implement the abstract base class Accumulator" + ) + if name is None: + name = accum.__qualname__.lower() + if isinstance(input_types, pyarrow.lib.DataType): + input_types = [input_types] + return AggregateUDF( + name=name, + accumulator=accum, + input_types=input_types, + return_type=return_type, + state_type=state_type, + volatility=volatility, + ) diff --git a/src/common.rs b/src/common.rs index 094e70c0..453bf67a 100644 --- a/src/common.rs +++ b/src/common.rs @@ -27,6 +27,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/common/data_type.rs b/src/common/data_type.rs index 313318fc..3299a46f 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -764,7 +764,7 @@ pub enum SqlType { #[allow(non_camel_case_types)] #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(name = "PythonType", module = "datafusion.common")] +#[pyclass(name = "NullTreatment", module = "datafusion.common")] pub enum NullTreatment { IGNORE_NULLS, RESPECT_NULLS, diff --git a/src/dataframe.rs b/src/dataframe.rs index 9e36be2c..4db59d4f 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::arrow::util::pretty; -use datafusion::config::TableParquetOptions; +use datafusion::config::{CsvOptions, TableParquetOptions}; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::execution::SendableRecordBatchStream; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; @@ -320,6 +320,18 @@ impl PyDataFrame { Ok(Self::new(df)) } + #[pyo3(signature = (columns, preserve_nulls=true))] + fn unnest_columns(&self, columns: Vec, preserve_nulls: bool) -> PyResult { + let unnest_options = UnnestOptions { preserve_nulls }; + let cols = columns.iter().map(|s| s.as_ref()).collect::>(); + let df = self + .df + .as_ref() + .clone() + .unnest_columns_with_options(&cols, unnest_options)?; + Ok(Self::new(df)) + } + /// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema fn intersect(&self, py_df: PyDataFrame) -> PyResult { let new_df = self @@ -337,13 +349,18 @@ impl PyDataFrame { } /// Write a `DataFrame` to a CSV file. - fn write_csv(&self, path: &str, py: Python) -> PyResult<()> { + fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyResult<()> { + let csv_options = CsvOptions { + has_header: Some(with_header), + ..Default::default() + }; wait_for_future( py, - self.df - .as_ref() - .clone() - .write_csv(path, DataFrameWriteOptions::new(), None), + self.df.as_ref().clone().write_csv( + path, + DataFrameWriteOptions::new(), + Some(csv_options), + ), )?; Ok(()) } diff --git a/src/expr.rs b/src/expr.rs index dc1de669..aab0daa6 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -583,6 +583,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/functions.rs b/src/functions.rs index b39d98b3..d2f3c7ed 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -232,6 +232,12 @@ fn concat_ws(sep: String, args: Vec) -> PyResult { Ok(functions::string::expr_fn::concat_ws(lit(sep), args).into()) } +#[pyfunction] +#[pyo3(signature = (values, regex, flags = None))] +fn regexp_like(values: PyExpr, regex: PyExpr, flags: Option) -> PyResult { + Ok(functions::expr_fn::regexp_like(values.expr, regex.expr, flags.map(|x| x.expr)).into()) +} + #[pyfunction] #[pyo3(signature = (values, regex, flags = None))] fn regexp_match(values: PyExpr, regex: PyExpr, flags: Option) -> PyResult { @@ -256,12 +262,12 @@ fn regexp_replace( } /// Creates a new Sort Expr #[pyfunction] -fn order_by(expr: PyExpr, asc: Option, nulls_first: Option) -> PyResult { +fn order_by(expr: PyExpr, asc: bool, nulls_first: bool) -> PyResult { Ok(PyExpr { expr: datafusion_expr::Expr::Sort(Sort { expr: Box::new(expr.expr), - asc: asc.unwrap_or(true), - nulls_first: nulls_first.unwrap_or(true), + asc, + nulls_first, }), }) } @@ -488,6 +494,7 @@ expr_fn!(chr, arg, "Returns the character with the given code."); expr_fn_vec!(coalesce); expr_fn!(cos, num); expr_fn!(cosh, num); +expr_fn!(cot, num); expr_fn!(degrees, num); expr_fn!(decode, input encoding); expr_fn!(encode, input encoding); @@ -499,6 +506,7 @@ expr_fn!(gcd, x y); expr_fn!(initcap, string, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters."); expr_fn!(isnan, num); expr_fn!(iszero, num); +expr_fn!(levenshtein, string1 string2); expr_fn!(lcm, x y); expr_fn!(left, string n, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters."); expr_fn!(ln, num); @@ -520,6 +528,7 @@ expr_fn!( ); expr_fn!(nullif, arg_1 arg_2); expr_fn!(octet_length, args, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces."); +expr_fn_vec!(overlay); expr_fn!(pi); expr_fn!(power, base exponent); expr_fn!(pow, power, base exponent); @@ -555,7 +564,9 @@ expr_fn!(sqrt, num); expr_fn!(starts_with, string prefix, "Returns true if string starts with prefix."); expr_fn!(strpos, string substring, "Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)"); expr_fn!(substr, string position); +expr_fn!(substr_index, string delimiter count); expr_fn!(substring, string position length); +expr_fn!(find_in_set, string string_list); expr_fn!(tan, num); expr_fn!(tanh, num); expr_fn!( @@ -568,6 +579,7 @@ expr_fn_vec!(to_timestamp); expr_fn_vec!(to_timestamp_millis); expr_fn_vec!(to_timestamp_micros); expr_fn_vec!(to_timestamp_seconds); +expr_fn_vec!(to_unixtime); expr_fn!(current_date); expr_fn!(current_time); expr_fn!(date_part, part date); @@ -575,6 +587,7 @@ expr_fn!(datepart, date_part, part date); expr_fn!(date_trunc, part date); expr_fn!(datetrunc, date_trunc, part date); expr_fn!(date_bin, stride source origin); +expr_fn!(make_date, year month day); expr_fn!(translate, string from to, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted."); expr_fn_vec!(trim, "Removes the longest string containing only characters in characters (a space by default) from the start, end, or both ends (BOTH is the default) of string."); @@ -712,6 +725,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(corr))?; m.add_wrapped(wrap_pyfunction!(cos))?; m.add_wrapped(wrap_pyfunction!(cosh))?; + m.add_wrapped(wrap_pyfunction!(cot))?; m.add_wrapped(wrap_pyfunction!(count))?; m.add_wrapped(wrap_pyfunction!(count_star))?; m.add_wrapped(wrap_pyfunction!(covar))?; @@ -725,6 +739,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(date_part))?; m.add_wrapped(wrap_pyfunction!(datetrunc))?; m.add_wrapped(wrap_pyfunction!(date_trunc))?; + m.add_wrapped(wrap_pyfunction!(make_date))?; m.add_wrapped(wrap_pyfunction!(digest))?; m.add_wrapped(wrap_pyfunction!(ends_with))?; m.add_wrapped(wrap_pyfunction!(exp))?; @@ -737,6 +752,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(initcap))?; m.add_wrapped(wrap_pyfunction!(isnan))?; m.add_wrapped(wrap_pyfunction!(iszero))?; + m.add_wrapped(wrap_pyfunction!(levenshtein))?; m.add_wrapped(wrap_pyfunction!(lcm))?; m.add_wrapped(wrap_pyfunction!(left))?; m.add_wrapped(wrap_pyfunction!(length))?; @@ -759,11 +775,13 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(nullif))?; m.add_wrapped(wrap_pyfunction!(octet_length))?; m.add_wrapped(wrap_pyfunction!(order_by))?; + m.add_wrapped(wrap_pyfunction!(overlay))?; m.add_wrapped(wrap_pyfunction!(pi))?; m.add_wrapped(wrap_pyfunction!(power))?; m.add_wrapped(wrap_pyfunction!(pow))?; m.add_wrapped(wrap_pyfunction!(radians))?; m.add_wrapped(wrap_pyfunction!(random))?; + m.add_wrapped(wrap_pyfunction!(regexp_like))?; m.add_wrapped(wrap_pyfunction!(regexp_match))?; m.add_wrapped(wrap_pyfunction!(regexp_replace))?; m.add_wrapped(wrap_pyfunction!(repeat))?; @@ -789,7 +807,9 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(strpos))?; m.add_wrapped(wrap_pyfunction!(r#struct))?; // Use raw identifier since struct is a keyword m.add_wrapped(wrap_pyfunction!(substr))?; + m.add_wrapped(wrap_pyfunction!(substr_index))?; m.add_wrapped(wrap_pyfunction!(substring))?; + m.add_wrapped(wrap_pyfunction!(find_in_set))?; m.add_wrapped(wrap_pyfunction!(sum))?; m.add_wrapped(wrap_pyfunction!(tan))?; m.add_wrapped(wrap_pyfunction!(tanh))?; @@ -798,6 +818,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(to_timestamp_millis))?; m.add_wrapped(wrap_pyfunction!(to_timestamp_micros))?; m.add_wrapped(wrap_pyfunction!(to_timestamp_seconds))?; + m.add_wrapped(wrap_pyfunction!(to_unixtime))?; m.add_wrapped(wrap_pyfunction!(translate))?; m.add_wrapped(wrap_pyfunction!(trim))?; m.add_wrapped(wrap_pyfunction!(trunc))?; diff --git a/src/lib.rs b/src/lib.rs index 71c27e1a..357eaacd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,6 +92,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; // Register `common` as a submodule. Matching `datafusion-common` https://docs.rs/datafusion-common/latest/datafusion_common/ let common = PyModule::new_bound(py, "common")?; diff --git a/src/substrait.rs b/src/substrait.rs index 1e9e16c7..60a52380 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -27,7 +27,7 @@ use datafusion_substrait::serializer; use datafusion_substrait::substrait::proto::Plan; use prost::Message; -#[pyclass(name = "plan", module = "datafusion.substrait", subclass)] +#[pyclass(name = "Plan", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] pub struct PyPlan { pub plan: Plan,