diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 59bc8e30..0569ac4b 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -33,8 +33,12 @@ SQLOptions, ) +from .catalog import Catalog, Database, Table + # The following imports are okay to remain as opaque to the user. -from ._internal import Config +from ._internal import Config, LogicalPlan, ExecutionPlan, runtime + +from .record_batch import RecordBatchStream, RecordBatch from .udf import ScalarUDF, AggregateUDF, Accumulator @@ -49,6 +53,8 @@ WindowFrame, ) +from . import functions, object_store, substrait + __version__ = importlib_metadata.version(__name__) __all__ = [ @@ -65,6 +71,20 @@ "column", "literal", "DFSchema", + "runtime", + "Catalog", + "Database", + "Table", + "AggregateUDF", + "LogicalPlan", + "ExecutionPlan", + "RecordBatch", + "RecordBatchStream", + "common", + "expr", + "functions", + "object_store", + "substrait", ] diff --git a/python/datafusion/common.py b/python/datafusion/common.py index 2351845b..225e3330 100644 --- a/python/datafusion/common.py +++ b/python/datafusion/common.py @@ -16,8 +16,34 @@ # under the License. """Common data types used throughout the DataFusion project.""" -from ._internal import common +from ._internal import common as common_internal +# TODO these should all have proper wrapper classes -def __getattr__(name): - return getattr(common, name) +DFSchema = common_internal.DFSchema +DataType = common_internal.DataType +DataTypeMap = common_internal.DataTypeMap +NullTreatment = common_internal.NullTreatment +PythonType = common_internal.PythonType +RexType = common_internal.RexType +SqlFunction = common_internal.SqlFunction +SqlSchema = common_internal.SqlSchema +SqlStatistics = common_internal.SqlStatistics +SqlTable = common_internal.SqlTable +SqlType = common_internal.SqlType +SqlView = common_internal.SqlView + +__all__ = [ + "DFSchema", + "DataType", + "DataTypeMap", + "RexType", + "PythonType", + "SqlType", + "NullTreatment", + "SqlTable", + "SqlSchema", + "SqlView", + "SqlStatistics", + "SqlFunction", +] diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index c04a525a..318b8b9a 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -47,6 +47,7 @@ CrossJoin = expr_internal.CrossJoin Distinct = expr_internal.Distinct DropTable = expr_internal.DropTable +EmptyRelation = expr_internal.EmptyRelation Exists = expr_internal.Exists Explain = expr_internal.Explain Extension = expr_internal.Extension @@ -58,6 +59,7 @@ InSubquery = expr_internal.InSubquery IsFalse = expr_internal.IsFalse IsNotTrue = expr_internal.IsNotTrue +IsNull = expr_internal.IsNull IsTrue = expr_internal.IsTrue IsUnknown = expr_internal.IsUnknown IsNotFalse = expr_internal.IsNotFalse @@ -83,6 +85,70 @@ TableScan = expr_internal.TableScan TryCast = expr_internal.TryCast Union = expr_internal.Union +Unnest = expr_internal.Unnest +Window = expr_internal.Window + +__all__ = [ + "Expr", + "Column", + "Literal", + "BinaryExpr", + "Literal", + "AggregateFunction", + "Not", + "IsNotNull", + "IsNull", + "IsTrue", + "IsFalse", + "IsUnknown", + "IsNotTrue", + "IsNotFalse", + "IsNotUnknown", + "Negative", + "Like", + "ILike", + "SimilarTo", + "ScalarVariable", + "Alias", + "InList", + "Exists", + "Subquery", + "InSubquery", + "ScalarSubquery", + "Placeholder", + "GroupingSet", + "Case", + "CaseBuilder", + "Cast", + "TryCast", + "Between", + "Explain", + "Limit", + "Aggregate", + "Sort", + "Analyze", + "EmptyRelation", + "Join", + "JoinType", + "JoinConstraint", + "CrossJoin", + "Union", + "Unnest", + "Extension", + "Filter", + "Projection", + "TableScan", + "CreateMemoryTable", + "CreateView", + "Distinct", + "SubqueryAlias", + "DropTable", + "Partitioning", + "Repartition", + "Window", + "WindowFrame", + "WindowFrameBound", +] class Expr: @@ -246,6 +312,14 @@ def __lt__(self, rhs: Any) -> Expr: rhs = Expr.literal(rhs) return Expr(self.expr.__lt__(rhs.expr)) + __radd__ = __add__ + __rand__ = __and__ + __rmod__ = __mod__ + __rmul__ = __mul__ + __ror__ = __or__ + __rsub__ = __sub__ + __rtruediv__ = __truediv__ + @staticmethod def literal(value: Any) -> Expr: """Creates a new expression representing a scalar value. diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 46d2a2f0..91ca935d 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -27,6 +27,227 @@ from datafusion.expr import CaseBuilder, Expr, WindowFrame from datafusion.context import SessionContext +__all__ = [ + "abs", + "acos", + "acosh", + "alias", + "approx_distinct", + "approx_median", + "approx_percentile_cont", + "approx_percentile_cont_with_weight", + "array", + "array_agg", + "array_append", + "array_cat", + "array_concat", + "array_dims", + "array_distinct", + "array_element", + "array_except", + "array_extract", + "array_has", + "array_has_all", + "array_has_any", + "array_indexof", + "array_intersect", + "array_join", + "array_length", + "array_ndims", + "array_pop_back", + "array_pop_front", + "array_position", + "array_positions", + "array_prepend", + "array_push_back", + "array_push_front", + "array_remove", + "array_remove_all", + "array_remove_n", + "array_repeat", + "array_replace", + "array_replace_all", + "array_replace_n", + "array_resize", + "array_slice", + "array_sort", + "array_to_string", + "array_union", + "arrow_typeof", + "ascii", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "avg", + "bit_and", + "bit_length", + "bit_or", + "bit_xor", + "bool_and", + "bool_or", + "btrim", + "case", + "cbrt", + "ceil", + "char_length", + "character_length", + "chr", + "coalesce", + "col", + "concat", + "concat_ws", + "corr", + "cos", + "cosh", + "cot", + "count", + "count_star", + "covar", + "covar_pop", + "covar_samp", + "current_date", + "current_time", + "date_bin", + "date_part", + "date_trunc", + "datepart", + "datetrunc", + "decode", + "degrees", + "digest", + "encode", + "ends_with", + "exp", + "factorial", + "find_in_set", + "first_value", + "flatten", + "floor", + "from_unixtime", + "gcd", + "grouping", + "in_list", + "initcap", + "isnan", + "iszero", + "last_value", + "lcm", + "left", + "length", + "levenshtein", + "list_append", + "list_dims", + "list_distinct", + "list_element", + "list_except", + "list_extract", + "list_indexof", + "list_intersect", + "list_join", + "list_length", + "list_ndims", + "list_position", + "list_positions", + "list_prepend", + "list_push_back", + "list_push_front", + "list_remove", + "list_remove_all", + "list_remove_n", + "list_replace", + "list_replace_all", + "list_replace_n", + "list_resize", + "list_slice", + "list_sort", + "list_to_string", + "list_union", + "ln", + "log", + "log10", + "log2", + "lower", + "lpad", + "ltrim", + "make_array", + "make_date", + "max", + "md5", + "mean", + "median", + "min", + "named_struct", + "nanvl", + "now", + "nullif", + "octet_length", + "order_by", + "overlay", + "pi", + "pow", + "power", + "radians", + "random", + "range", + "regexp_like", + "regexp_match", + "regexp_replace", + "regr_avgx", + "regr_avgy", + "regr_count", + "regr_intercept", + "regr_r2", + "regr_slope", + "regr_sxx", + "regr_sxy", + "regr_syy", + "repeat", + "replace", + "reverse", + "right", + "round", + "rpad", + "rtrim", + "sha224", + "sha256", + "sha384", + "sha512", + "signum", + "sin", + "sinh", + "split_part", + "sqrt", + "starts_with", + "stddev", + "stddev_pop", + "stddev_samp", + "strpos", + "struct", + "substr", + "substr_index", + "substring", + "sum", + "tan", + "tanh", + "to_hex", + "to_timestamp", + "to_timestamp_micros", + "to_timestamp_millis", + "to_timestamp_seconds", + "to_unixtime", + "translate", + "trim", + "trunc", + "upper", + "uuid", + "var", + "var_pop", + "var_samp", + "window", +] + def isnan(expr: Expr) -> Expr: """Returns true if a given number is +NaN or -NaN otherwise returns false.""" diff --git a/python/datafusion/object_store.py b/python/datafusion/object_store.py index a9bb83d2..c927e761 100644 --- a/python/datafusion/object_store.py +++ b/python/datafusion/object_store.py @@ -18,6 +18,18 @@ from ._internal import object_store +AmazonS3 = object_store.AmazonS3 +GoogleCloud = object_store.GoogleCloud +LocalFileSystem = object_store.LocalFileSystem +MicrosoftAzure = object_store.MicrosoftAzure + +__all__ = [ + "AmazonS3", + "GoogleCloud", + "LocalFileSystem", + "MicrosoftAzure", +] + def __getattr__(name): return getattr(object_store, name) diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index a199dd73..4b44ad19 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -33,6 +33,13 @@ from datafusion.context import SessionContext from datafusion._internal import LogicalPlan +__all__ = [ + "Plan", + "Consumer", + "Producer", + "Serde", +] + class Plan: """A class representing an encodable substrait plan.""" @@ -73,7 +80,7 @@ def serialize(sql: str, ctx: SessionContext, path: str | pathlib.Path) -> None: ctx: SessionContext to use. path: Path to write the Substrait plan to. """ - return substrait_internal.serde.serialize(sql, ctx.ctx, str(path)) + return substrait_internal.Serde.serialize(sql, ctx.ctx, str(path)) @staticmethod def serialize_to_plan(sql: str, ctx: SessionContext) -> Plan: @@ -86,7 +93,7 @@ def serialize_to_plan(sql: str, ctx: SessionContext) -> Plan: Returns: Substrait plan. """ - return Plan(substrait_internal.serde.serialize_to_plan(sql, ctx.ctx)) + return Plan(substrait_internal.Serde.serialize_to_plan(sql, ctx.ctx)) @staticmethod def serialize_bytes(sql: str, ctx: SessionContext) -> bytes: @@ -99,7 +106,7 @@ def serialize_bytes(sql: str, ctx: SessionContext) -> bytes: Returns: Substrait plan as bytes. """ - return substrait_internal.serde.serialize_bytes(sql, ctx.ctx) + return substrait_internal.Serde.serialize_bytes(sql, ctx.ctx) @staticmethod def deserialize(path: str | pathlib.Path) -> Plan: @@ -111,7 +118,7 @@ def deserialize(path: str | pathlib.Path) -> Plan: Returns: Substrait plan. """ - return Plan(substrait_internal.serde.deserialize(str(path))) + return Plan(substrait_internal.Serde.deserialize(str(path))) @staticmethod def deserialize_bytes(proto_bytes: bytes) -> Plan: @@ -123,7 +130,7 @@ def deserialize_bytes(proto_bytes: bytes) -> Plan: Returns: Substrait plan. """ - return Plan(substrait_internal.serde.deserialize_bytes(proto_bytes)) + return Plan(substrait_internal.Serde.deserialize_bytes(proto_bytes)) @deprecated("Use `Serde` instead.") @@ -148,7 +155,7 @@ def to_substrait_plan(logical_plan: LogicalPlan, ctx: SessionContext) -> Plan: Substrait plan. """ return Plan( - substrait_internal.producer.to_substrait_plan(logical_plan, ctx.ctx) + substrait_internal.Producer.to_substrait_plan(logical_plan, ctx.ctx) ) @@ -173,7 +180,7 @@ def from_substrait_plan(ctx: SessionContext, plan: Plan) -> LogicalPlan: Returns: LogicalPlan. """ - return substrait_internal.consumer.from_substrait_plan( + return substrait_internal.Consumer.from_substrait_plan( ctx.ctx, plan.plan_internal ) diff --git a/python/datafusion/tests/test_wrapper_coverage.py b/python/datafusion/tests/test_wrapper_coverage.py new file mode 100644 index 00000000..44b9ca83 --- /dev/null +++ b/python/datafusion/tests/test_wrapper_coverage.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datafusion +import datafusion.functions +import datafusion.object_store +import datafusion.substrait + + +def missing_exports(internal_obj, wrapped_obj) -> None: + for attr in dir(internal_obj): + assert attr in dir(wrapped_obj) + + internal_attr = getattr(internal_obj, attr) + wrapped_attr = getattr(wrapped_obj, attr) + + assert wrapped_attr is not None if internal_attr is not None else True + + if attr in ["__self__", "__class__"]: + continue + if isinstance(internal_attr, list): + assert isinstance(wrapped_attr, list) + for val in internal_attr: + assert val in wrapped_attr + elif hasattr(internal_attr, "__dict__"): + missing_exports(internal_attr, wrapped_attr) + + +def test_datafusion_missing_exports() -> None: + """Check for any missing pythone exports. + + This test verifies that every exposed class, attribute, and function in + the internal (pyo3) module is also exposed in our python wrappers. + """ + missing_exports(datafusion._internal, datafusion) diff --git a/src/substrait.rs b/src/substrait.rs index 60a52380..f89b6b09 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -59,7 +59,7 @@ impl From for PyPlan { /// A PySubstraitSerializer is a representation of a Serializer that is capable of both serializing /// a `LogicalPlan` instance to Substrait Protobuf bytes and also deserialize Substrait Protobuf bytes /// to a valid `LogicalPlan` instance. -#[pyclass(name = "serde", module = "datafusion.substrait", subclass)] +#[pyclass(name = "Serde", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] pub struct PySubstraitSerializer; @@ -105,7 +105,7 @@ impl PySubstraitSerializer { } } -#[pyclass(name = "producer", module = "datafusion.substrait", subclass)] +#[pyclass(name = "Producer", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] pub struct PySubstraitProducer; @@ -121,7 +121,7 @@ impl PySubstraitProducer { } } -#[pyclass(name = "consumer", module = "datafusion.substrait", subclass)] +#[pyclass(name = "Consumer", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] pub struct PySubstraitConsumer;