diff --git a/datagen/pom.xml b/datagen/pom.xml index 20b3403d3e1..9bdf897cfd7 100644 --- a/datagen/pom.xml +++ b/datagen/pom.xml @@ -33,6 +33,7 @@ **/* package + ${project.build.outputDirectory}/datagen-version-info.properties diff --git a/df_udf/README.md b/df_udf/README.md new file mode 100644 index 00000000000..0226c365a42 --- /dev/null +++ b/df_udf/README.md @@ -0,0 +1,90 @@ +# Scala / Java UDFS implemented using data frame + +User Defined Functions (UDFs) are used for a number of reasons in Apache Spark. Much of the time it is to implement +logic that is either very difficult or impossible to implement using existing SQL/Dataframe APIs directly. But they +are also used as a way to standardize processing logic across an organization or for code reused. + +But UDFs come with some downsides. The biggest one is visibility into the processing being done. SQL is a language that +can be highly optimized. But a UDF in most cases is a black box, that the SQL optimizer cannot do anything about. +This can result in less than ideal query planning. Additionally, accelerated execution environments, like the +RAPIDS Accelerator for Apache Spark have no easy way to replace UDFs with accelerated versions, which can result in +slow performance. + +This attempts to add visibility to the code reuse use case by providing a way to implement a UDF in terms of dataframe +commands. + +## Setup + +To do this include com.nvidia:df_udf_plugin as a dependency for your project and also include it on the +classpath for your Apache Spark environment. Then include `com.nvidia.spark.DFUDFPlugin` in the config +`spark.sql.extensions`. Now you can implement a UDF in terms of Dataframe operations. + +## Usage + +```scala +import com.nvidia.spark.functions._ + +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions._ + +val sum_array = df_udf((longArray: Column) => + aggregate(longArray, + lit(0L), + (a, b) => coalesce(a, lit(0L)) + coalesce(b, lit(0L)), + a => a)) +spark.udf.register("sum_array", sum_array) +``` + +You can then use `sum_array` however you would have used any other UDF. This allows you to provide a drop in replacement +implementation of an existing UDF. + +```scala +Seq(Array(1L, 2L, 3L)).toDF("data").selectExpr("sum_array(data) as result").show() + ++------+ +|result| ++------+ +| 6| ++------+ +``` + +## Type Checks + +DataFrame APIs do not provide type safety when writing the code and that is the same here. There are no builtin type +checks for inputs yet. Also, because of how types are resolved in Spark there is no way to adjust the query based on +the types passed in. Type checks are handled by the SQL planner/optimizer after the UDF has been replaced. This means +that the final SQL will not violate any type safety, but it also means that the errors might be confusing. For example, +if I passed in an `ARRAY` to `sum_array` instead of an `ARRAY` I would get an error like + +```scala +Seq(Array(1.0, 2.0, 3.0)).toDF("data").selectExpr("sum_array(data) as result").show() +org.apache.spark.sql.AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "aggregate(data, 0, lambdafunction((coalesce(namedlambdavariable(), 0) + coalesce(namedlambdavariable(), 0)), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))" due to data type mismatch: Parameter 3 requires the "BIGINT" type, however "lambdafunction((coalesce(namedlambdavariable(), 0) + coalesce(namedlambdavariable(), 0)), namedlambdavariable(), namedlambdavariable())" has the type "DOUBLE".; line 1 pos 0; +Project [aggregate(data#46, 0, lambdafunction((cast(coalesce(lambda x_9#49L, 0) as double) + coalesce(lambda y_10#50, cast(0 as double))), lambda x_9#49L, lambda y_10#50, false), lambdafunction(lambda x_11#51L, lambda x_11#51L, false)) AS result#48L] ++- Project [value#43 AS data#46] + +- LocalRelation [value#43] + + at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.dataTypeMismatch(package.scala:73) + at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5(CheckAnalysis.scala:269) + at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5$adapted(CheckAnalysis.scala:256) +``` + +Which is not as simple to understand as a normal UDF. + +```scala +val sum_array = udf((a: Array[Long]) => a.sum) + +spark.udf.register("sum_array", sum_array) + +Seq(Array(1.0, 2.0, 3.0)).toDF("data").selectExpr("sum_array(data) as result").show() +org.apache.spark.sql.AnalysisException: [CANNOT_UP_CAST_DATATYPE] Cannot up cast array element from "DOUBLE" to "BIGINT". + The type path of the target object is: +- array element class: "long" +- root class: "[J" +You can either add an explicit cast to the input data or choose a higher precision type of the field in the target object +at org.apache.spark.sql.errors.QueryCompilationErrors$.upCastFailureError(QueryCompilationErrors.scala:285) +at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$.org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveUpCast$$fail(Analyzer.scala:3646) +at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$$anonfun$apply$57$$anonfun$applyOrElse$234.applyOrElse(Analyzer.scala:3677) +at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$$anonfun$apply$57$$anonfun$applyOrElse$234.applyOrElse(Analyzer.scala:3654) +``` + +We hope to add optional type checks in the future. \ No newline at end of file diff --git a/df_udf/pom.xml b/df_udf/pom.xml new file mode 100644 index 00000000000..39f33880f34 --- /dev/null +++ b/df_udf/pom.xml @@ -0,0 +1,88 @@ + + + + 4.0.0 + + com.nvidia + rapids-4-spark-shim-deps-parent_2.12 + 24.12.0-SNAPSHOT + ../shim-deps/pom.xml + + df_udf_plugin_2.12 + UDFs implemented in SQL/Dataframe + UDFs for Apache Spark implemented in SQL/Dataframe + 24.12.0-SNAPSHOT + + + df_udf + + **/* + package + ${project.build.outputDirectory}/df_udf-version-info.properties + + + + + org.scala-lang + scala-library + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.test.version} + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + net.alchim31.maven + scala-maven-plugin + + + org.scalatest + scalatest-maven-plugin + + + org.apache.rat + apache-rat-plugin + + + + + + + ${project.build.directory}/extra-resources + + + + diff --git a/df_udf/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala b/df_udf/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala new file mode 100644 index 00000000000..7e1c0451c8a --- /dev/null +++ b/df_udf/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark + +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +class DFUDFPlugin extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectResolutionRule(logicalPlanRules) + } + + def logicalPlanRules(sparkSession: SparkSession): Rule[LogicalPlan] = { + org.apache.spark.sql.nvidia.LogicalPlanRules() + } +} \ No newline at end of file diff --git a/df_udf/src/main/scala/com/nvidia/spark/functions.scala b/df_udf/src/main/scala/com/nvidia/spark/functions.scala new file mode 100644 index 00000000000..8c8eef3f825 --- /dev/null +++ b/df_udf/src/main/scala/com/nvidia/spark/functions.scala @@ -0,0 +1,232 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark + +import org.apache.spark.sql.Column +import org.apache.spark.sql.api.java.{UDF0, UDF1, UDF10, UDF2, UDF3, UDF4, UDF5, UDF6, UDF7, UDF8, UDF9} +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.nvidia._ +import org.apache.spark.sql.types.LongType + +// scalastyle:off +object functions { +// scalastyle:on + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function0[Column]): UserDefinedFunction = + udf(DFUDF0(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function1[Column, Column]): UserDefinedFunction = + udf(DFUDF1(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function2[Column, Column, Column]): UserDefinedFunction = + udf(DFUDF2(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function3[Column, Column, Column, Column]): UserDefinedFunction = + udf(DFUDF3(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function4[Column, Column, Column, Column, Column]): UserDefinedFunction = + udf(DFUDF4(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function5[Column, Column, Column, Column, Column, Column]): UserDefinedFunction = + udf(DFUDF5(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function6[Column, Column, Column, Column, Column, Column, + Column]): UserDefinedFunction = + udf(DFUDF6(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function7[Column, Column, Column, Column, Column, Column, + Column, Column]): UserDefinedFunction = + udf(DFUDF7(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function8[Column, Column, Column, Column, Column, Column, + Column, Column, Column]): UserDefinedFunction = + udf(DFUDF8(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function9[Column, Column, Column, Column, Column, Column, + Column, Column, Column, Column]): UserDefinedFunction = + udf(DFUDF9(f), LongType) + + /** + * Defines a Scala closure of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: Function10[Column, Column, Column, Column, Column, Column, + Column, Column, Column, Column, Column]): UserDefinedFunction = + udf(DFUDF10(f), LongType) + + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Java UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF0[Column]): UserDefinedFunction = { + udf(JDFUDF0(f), LongType) + } + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF1[Column, Column]): UserDefinedFunction = { + udf(JDFUDF1(f), LongType) + } + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF2[Column, Column, Column]): UserDefinedFunction = { + udf(JDFUDF2(f), LongType) + } + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF3[Column, Column, Column, Column]): UserDefinedFunction = { + udf(JDFUDF3(f), LongType) + } + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF4[Column, Column, Column, Column, Column]): UserDefinedFunction = { + udf(JDFUDF4(f), LongType) + } + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF5[Column, Column, Column, Column, Column, Column]): UserDefinedFunction = { + udf(JDFUDF5(f), LongType) + } + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF6[Column, Column, Column, Column, Column, Column, + Column]): UserDefinedFunction = { + udf(JDFUDF6(f), LongType) + } + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF7[Column, Column, Column, Column, Column, Column, + Column, Column]): UserDefinedFunction = { + udf(JDFUDF7(f), LongType) + } + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF8[Column, Column, Column, Column, Column, Column, + Column, Column, Column]): UserDefinedFunction = { + udf(JDFUDF8(f), LongType) + } + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF9[Column, Column, Column, Column, Column, Column, + Column, Column, Column, Column]): UserDefinedFunction = { + udf(JDFUDF9(f), LongType) + } + + /** + * Defines a Java UDF instance of Columns as user-defined function (UDF). + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + */ + def df_udf(f: UDF10[Column, Column, Column, Column, Column, Column, + Column, Column, Column, Column, Column]): UserDefinedFunction = { + udf(JDFUDF10(f), LongType) + } + +} \ No newline at end of file diff --git a/df_udf/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala b/df_udf/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala new file mode 100644 index 00000000000..24a123016d6 --- /dev/null +++ b/df_udf/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.nvidia + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +case class LogicalPlanRules() extends Rule[LogicalPlan] with Logging { + val replacePartialFunc: PartialFunction[Expression, Expression] = { + case f: ScalaUDF if DFUDF.getDFUDF(f.function).isDefined => + DFUDF.getDFUDF(f.function).map { + dfudf => DFUDFShims.columnToExpr( + dfudf(f.children.map(DFUDFShims.exprToColumn(_)).toArray)) + }.getOrElse{ + throw new IllegalStateException("Inconsistent results when extracting df_udf") + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = + plan.transformExpressions(replacePartialFunc) +} diff --git a/df_udf/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala b/df_udf/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala new file mode 100644 index 00000000000..79f71ba4ca0 --- /dev/null +++ b/df_udf/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala @@ -0,0 +1,340 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.nvidia + +import java.lang.invoke.SerializedLambda + +import org.apache.spark.sql.Column +import org.apache.spark.sql.api.java._ +import org.apache.spark.util.Utils + +trait DFUDF { + def apply(input: Array[Column]): Column +} + +case class DFUDF0(f: Function0[Column]) + extends UDF0[Any] with DFUDF { + override def call(): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 0) + f() + } +} + +case class DFUDF1(f: Function1[Column, Column]) + extends UDF1[Any, Any] with DFUDF { + override def call(t1: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 1) + f(input(0)) + } +} + +case class DFUDF2(f: Function2[Column, Column, Column]) + extends UDF2[Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 2) + f(input(0), input(1)) + } +} + +case class DFUDF3(f: Function3[Column, Column, Column, Column]) + extends UDF3[Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 3) + f(input(0), input(1), input(2)) + } +} + +case class DFUDF4(f: Function4[Column, Column, Column, Column, Column]) + extends UDF4[Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 4) + f(input(0), input(1), input(2), input(3)) + } +} + +case class DFUDF5(f: Function5[Column, Column, Column, Column, Column, Column]) + extends UDF5[Any, Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 5) + f(input(0), input(1), input(2), input(3), input(4)) + } +} + +case class DFUDF6(f: Function6[Column, Column, Column, Column, Column, Column, Column]) + extends UDF6[Any, Any, Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 6) + f(input(0), input(1), input(2), input(3), input(4), input(5)) + } +} + +case class DFUDF7(f: Function7[Column, Column, Column, Column, Column, Column, Column, Column]) + extends UDF7[Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 7) + f(input(0), input(1), input(2), input(3), input(4), input(5), input(6)) + } +} + +case class DFUDF8(f: Function8[Column, Column, Column, Column, Column, Column, Column, Column, + Column]) + extends UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 8) + f(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7)) + } +} + +case class DFUDF9(f: Function9[Column, Column, Column, Column, Column, Column, Column, Column, + Column, Column]) + extends UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any, + t9: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 9) + f(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7), input(8)) + } +} + +case class DFUDF10(f: Function10[Column, Column, Column, Column, Column, Column, Column, Column, + Column, Column, Column]) + extends UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any, + t9: Any, t10: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 10) + f(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7), input(8), + input(9)) + } +} + +case class JDFUDF0(f: UDF0[Column]) + extends UDF0[Any] with DFUDF { + override def call(): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 0) + f.call() + } +} + +case class JDFUDF1(f: UDF1[Column, Column]) + extends UDF1[Any, Any] with DFUDF { + override def call(t1: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 1) + f.call(input(0)) + } +} + +case class JDFUDF2(f: UDF2[Column, Column, Column]) + extends UDF2[Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 2) + f.call(input(0), input(1)) + } +} + +case class JDFUDF3(f: UDF3[Column, Column, Column, Column]) + extends UDF3[Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 3) + f.call(input(0), input(1), input(2)) + } +} + +case class JDFUDF4(f: UDF4[Column, Column, Column, Column, Column]) + extends UDF4[Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 4) + f.call(input(0), input(1), input(2), input(3)) + } +} + +case class JDFUDF5(f: UDF5[Column, Column, Column, Column, Column, Column]) + extends UDF5[Any, Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 5) + f.call(input(0), input(1), input(2), input(3), input(4)) + } +} + +case class JDFUDF6(f: UDF6[Column, Column, Column, Column, Column, Column, Column]) + extends UDF6[Any, Any, Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 6) + f.call(input(0), input(1), input(2), input(3), input(4), input(5)) + } +} + +case class JDFUDF7(f: UDF7[Column, Column, Column, Column, Column, Column, Column, Column]) + extends UDF7[Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 7) + f.call(input(0), input(1), input(2), input(3), input(4), input(5), input(6)) + } +} + +case class JDFUDF8(f: UDF8[Column, Column, Column, Column, Column, Column, Column, Column, + Column]) + extends UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 8) + f.call(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7)) + } +} + +case class JDFUDF9(f: UDF9[Column, Column, Column, Column, Column, Column, Column, Column, + Column, Column]) + extends UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any, + t9: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 9) + f.call(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7), input(8)) + } +} + +case class JDFUDF10(f: UDF10[Column, Column, Column, Column, Column, Column, Column, Column, + Column, Column, Column]) + extends UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF { + override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any, + t9: Any, t10: Any): Any = { + throw new IllegalStateException("TODO better error message. This should have been replaced") + } + + override def apply(input: Array[Column]): Column = { + assert(input.length == 10) + f.call(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7), input(8), + input(9)) + } +} + +object DFUDF { + /** + * Determine if the UDF function implements the DFUDF. + */ + def getDFUDF(function: AnyRef): Option[DFUDF] = { + function match { + case f: DFUDF => Some(f) + case f => + try { + // This may be a lambda that Spark's UDFRegistration wrapped around a Java UDF instance. + val clazz = f.getClass + if (Utils.getSimpleName(clazz).toLowerCase().contains("lambda")) { + // Try to find a `writeReplace` method, further indicating it is likely a lambda + // instance, and invoke it to serialize the lambda. Once serialized, captured arguments + // can be examined to locate the Java UDF instance. + // Note this relies on implementation details of Spark's UDFRegistration class. + val writeReplace = clazz.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + val serializedLambda = writeReplace.invoke(f).asInstanceOf[SerializedLambda] + if (serializedLambda.getCapturedArgCount == 1) { + serializedLambda.getCapturedArg(0) match { + case c: DFUDF => Some(c) + case _ => None + } + } else { + None + } + } else { + None + } + } catch { + case _: ClassCastException | _: NoSuchMethodException | _: SecurityException => None + } + } + } +} diff --git a/df_udf/src/main/spark320/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala b/df_udf/src/main/spark320/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala new file mode 100644 index 00000000000..5b51aeeb991 --- /dev/null +++ b/df_udf/src/main/spark320/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "320"} +{"spark": "321"} +{"spark": "321cdh"} +{"spark": "322"} +{"spark": "323"} +{"spark": "324"} +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "330db"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332cdh"} +{"spark": "332db"} +{"spark": "333"} +{"spark": "334"} +{"spark": "340"} +{"spark": "341"} +{"spark": "341db"} +{"spark": "342"} +{"spark": "343"} +{"spark": "350"} +{"spark": "351"} +{"spark": "352"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.nvidia + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.Expression + +object DFUDFShims { + def columnToExpr(c: Column): Expression = c.expr + def exprToColumn(e: Expression): Column = Column(e) +} diff --git a/df_udf/src/main/spark400/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala b/df_udf/src/main/spark400/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala new file mode 100644 index 00000000000..e67dfb450d8 --- /dev/null +++ b/df_udf/src/main/spark400/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.nvidia + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} + +object DFUDFShims { + def columnToExpr(c: Column): Expression = c + def exprToColumn(e: Expression): Column = e +} diff --git a/df_udf/src/test/scala/com/nvidia/spark/functionsSuite.scala b/df_udf/src/test/scala/com/nvidia/spark/functionsSuite.scala new file mode 100644 index 00000000000..ae6d46aefdf --- /dev/null +++ b/df_udf/src/test/scala/com/nvidia/spark/functionsSuite.scala @@ -0,0 +1,443 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark + +import com.nvidia.spark.functions._ + +import org.apache.spark.sql.{Column, Row} +import org.apache.spark.sql.api.java._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.nvidia.SparkTestBase +import org.apache.spark.sql.types._ + +class functionsSuite extends SparkTestBase { + test("basic 0 arg df_udf") { + val zero = df_udf(() => lit(0)) + withSparkSession{ spark => + spark.udf.register("zero", zero) + assertSame(Array( + Row(0L, 0), + Row(1L, 0)), + spark.range(2).selectExpr("id", "zero()").collect()) + assertSame(Array( + Row(0L, 0), + Row(1L, 0)), + spark.range(2).select(col("id"), zero()).collect()) + } + } + + test("basic 1 arg df_udf") { + val inc = df_udf((input: Column) => input + 1) + withSparkSession { spark => + spark.udf.register("inc", inc) + assertSame(Array( + Row(0L, 1L), + Row(1L, 2L)), + spark.range(2).selectExpr("id", "inc(id)").collect()) + assertSame(Array( + Row(0L, 1L), + Row(1L, 2L)), + spark.range(2).select(col("id"), inc(col("id"))).collect()) + } + } + + + test("basic 2 arg df_udf") { + val add = df_udf((a: Column, b:Column) => a + b) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 0L), + Row(1L, 2L)), + spark.range(2).selectExpr("id", "add(id, id)").collect()) + assertSame(Array( + Row(0L, 0L), + Row(1L, 2L)), + spark.range(2).select(col("id"), add(col("id"), col("id"))).collect()) + } + } + + test("basic 3 arg df_udf") { + val add = df_udf((a: Column, b:Column, c:Column) => a + b + c) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 0L), + Row(1L, 3L)), + spark.range(2).selectExpr("id", "add(id, id, id)").collect()) + assertSame(Array( + Row(0L, 0L), + Row(1L, 3L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), col("id"))).collect()) + } + } + + test("basic 4 arg df_udf") { + val add = df_udf((a: Column, b:Column, c:Column, d:Column) => a + b + c + d) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 1L), + Row(1L, 4L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id)").collect()) + assertSame(Array( + Row(0L, 1L), + Row(1L, 4L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), col("id"))).collect()) + } + } + + test("basic 5 arg df_udf") { + val add = df_udf((a: Column, b:Column, c:Column, d:Column, e:Column) => + a + b + c + d + e) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 2L), + Row(1L, 5L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id, 1)").collect()) + assertSame(Array( + Row(0L, 2L), + Row(1L, 5L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), + col("id"), lit(1))).collect()) + } + } + + test("basic 6 arg df_udf") { + val add = df_udf((a: Column, b:Column, c:Column, d:Column, e:Column, f:Column) => + a + b + c + d + e + f) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 2L), + Row(1L, 6L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id)").collect()) + assertSame(Array( + Row(0L, 2L), + Row(1L, 6L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), + col("id"), lit(1), col("id"))).collect()) + } + } + + test("basic 7 arg df_udf") { + val add = df_udf((a: Column, b:Column, c:Column, d:Column, e:Column, + f:Column, g:Column) => a + b + c + d + e + f + g) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 2L), + Row(1L, 7L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id)").collect()) + assertSame(Array( + Row(0L, 2L), + Row(1L, 7L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), + col("id"), lit(1), col("id"), col("id"))).collect()) + } + } + + test("basic 8 arg df_udf") { + val add = df_udf((a: Column, b:Column, c:Column, d:Column, e:Column, + f:Column, g:Column, h:Column) => a + b + c + d + e + f + g + h) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 4L), + Row(1L, 9L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id, 2)").collect()) + assertSame(Array( + Row(0L, 4L), + Row(1L, 9L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), + col("id"), lit(1), col("id"), col("id"), lit(2))).collect()) + } + } + + test("basic 9 arg df_udf") { + val add = df_udf((a: Column, b:Column, c:Column, d:Column, e:Column, + f:Column, g:Column, h:Column, i:Column) => + a + b + c + d + e + f + g + h + i) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 4L), + Row(1L, 10L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id, 2, id)").collect()) + assertSame(Array( + Row(0L, 4L), + Row(1L, 10L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), + col("id"), lit(1), col("id"), col("id"), lit(2), col("id"))).collect()) + } + } + + test("basic 10 arg df_udf") { + val add = df_udf((a: Column, b:Column, c:Column, d:Column, e:Column, + f:Column, g:Column, h:Column, i:Column, j:Column) => + a + b + c + d + e + f + g + h + i + j) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 4L), + Row(1L, 11L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id, 2, id, id)").collect()) + assertSame(Array( + Row(0L, 4L), + Row(1L, 11L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), + col("id"), lit(1), col("id"), col("id"), lit(2), col("id"), col("id"))).collect()) + } + } + + test("nested df_udf") { + val add = df_udf((a: Column, b:Column) => a + b) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 22L), + Row(1L, 25L)), + spark.range(2).selectExpr("id", "add(add(id, 12), add(add(id, id), 10))").collect()) + } + } + + test("complex df_udf") { + val extractor = df_udf((json: Column) => { + val schema = StructType(Seq(StructField("values", ArrayType(LongType)))) + val extracted_json = from_json(json, schema, Map.empty[String, String]) + aggregate(extracted_json("values"), + lit(0L), + (a, b) => coalesce(a, lit(0L)) + coalesce(b, lit(0L)), + a => a) + }) + withSparkSession { spark => + import spark.implicits._ + spark.udf.register("extractor", extractor) + assertSame(Array( + Row(6L), + Row(3L)), + Seq("""{"values":[1,2,3]}""", + """{"values":[1, null, null, 2]}""").toDF("json").selectExpr("extractor(json)").collect()) + } + } + + test("j basic 0 arg df_udf") { + val zero = df_udf(new UDF0[Column] { + override def call(): Column = lit(0) + }) + withSparkSession{ spark => + spark.udf.register("zero", zero) + assertSame(Array( + Row(0L, 0), + Row(1L, 0)), + spark.range(2).selectExpr("id", "zero()").collect()) + assertSame(Array( + Row(0L, 0), + Row(1L, 0)), + spark.range(2).select(col("id"), zero()).collect()) + } + } + + test("jbasic 1 arg df_udf") { + val inc = df_udf(new UDF1[Column, Column] { + override def call(a: Column): Column = a + 1 + }) + withSparkSession { spark => + spark.udf.register("inc", inc) + assertSame(Array( + Row(0L, 1L), + Row(1L, 2L)), + spark.range(2).selectExpr("id", "inc(id)").collect()) + assertSame(Array( + Row(0L, 1L), + Row(1L, 2L)), + spark.range(2).select(col("id"), inc(col("id"))).collect()) + } + } + + test("jbasic 2 arg df_udf") { + val add = df_udf(new UDF2[Column, Column, Column] { + override def call(a: Column, b:Column): Column = a + b + }) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 0L), + Row(1L, 2L)), + spark.range(2).selectExpr("id", "add(id, id)").collect()) + assertSame(Array( + Row(0L, 0L), + Row(1L, 2L)), + spark.range(2).select(col("id"), add(col("id"), col("id"))).collect()) + } + } + + test("jbasic 3 arg df_udf") { + val add = df_udf(new UDF3[Column, Column, Column, Column] { + override def call(a: Column, b: Column, c: Column): Column = a + b + c + }) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 0L), + Row(1L, 3L)), + spark.range(2).selectExpr("id", "add(id, id, id)").collect()) + assertSame(Array( + Row(0L, 0L), + Row(1L, 3L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), col("id"))).collect()) + } + } + + test("jbasic 4 arg df_udf") { + val add = df_udf(new UDF4[Column, Column, Column, Column, Column] { + override def call(a: Column, b:Column, c:Column, d:Column): Column = a + b + c + d + }) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 1L), + Row(1L, 4L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id)").collect()) + assertSame(Array( + Row(0L, 1L), + Row(1L, 4L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), col("id"))).collect()) + } + } + + test("jbasic 5 arg df_udf") { + val add = df_udf(new UDF5[Column, Column, Column, Column, Column, Column] { + override def call(a: Column, b: Column, c: Column, d: Column, e: Column): Column = + a + b + c + d + e + }) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 2L), + Row(1L, 5L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id, 1)").collect()) + assertSame(Array( + Row(0L, 2L), + Row(1L, 5L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), + col("id"), lit(1))).collect()) + } + } + + test("jbasic 6 arg df_udf") { + val add = df_udf(new UDF6[Column, Column, Column, Column, Column, Column, Column] { + override def call(a: Column, b:Column, c:Column, d:Column, e:Column, f:Column) = + a + b + c + d + e + f + }) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 2L), + Row(1L, 6L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id)").collect()) + assertSame(Array( + Row(0L, 2L), + Row(1L, 6L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), + col("id"), lit(1), col("id"))).collect()) + } + } + + test("jbasic 7 arg df_udf") { + val add = df_udf(new UDF7[Column, Column, Column, Column, Column, Column, Column, + Column] { + override def call(a: Column, b:Column, c:Column, d:Column, e:Column, + f:Column, g:Column): Column = a + b + c + d + e + f + g + }) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 2L), + Row(1L, 7L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id)").collect()) + assertSame(Array( + Row(0L, 2L), + Row(1L, 7L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), + col("id"), lit(1), col("id"), col("id"))).collect()) + } + } + + test("jbasic 8 arg df_udf") { + val add = df_udf(new UDF8[Column, Column, Column, Column, Column, Column, Column, + Column, Column] { + override def call(a: Column, b: Column, c: Column, d: Column, e: Column, + f: Column, g: Column, h: Column): Column = a + b + c + d + e + f + g + h + }) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 4L), + Row(1L, 9L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id, 2)").collect()) + assertSame(Array( + Row(0L, 4L), + Row(1L, 9L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), + col("id"), lit(1), col("id"), col("id"), lit(2))).collect()) + } + } + + test("jbasic 9 arg df_udf") { + val add = df_udf(new UDF9[Column, Column, Column, Column, Column, Column, Column, + Column, Column, Column] { + override def call(a: Column, b:Column, c:Column, d:Column, e:Column, + f:Column, g:Column, h:Column, i:Column): Column = + a + b + c + d + e + f + g + h + i + }) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 4L), + Row(1L, 10L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id, 2, id)").collect()) + assertSame(Array( + Row(0L, 4L), + Row(1L, 10L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), + col("id"), lit(1), col("id"), col("id"), lit(2), col("id"))).collect()) + } + } + + test("jbasic 10 arg df_udf") { + val add = df_udf(new UDF10[Column, Column, Column, Column, Column, Column, Column, + Column, Column, Column, Column] { + override def call(a: Column, b:Column, c:Column, d:Column, e:Column, + f:Column, g:Column, h:Column, i:Column, j:Column): Column = + a + b + c + d + e + f + g + h + i + j + }) + withSparkSession { spark => + spark.udf.register("add", add) + assertSame(Array( + Row(0L, 4L), + Row(1L, 11L)), + spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id, 2, id, id)").collect()) + assertSame(Array( + Row(0L, 4L), + Row(1L, 11L)), + spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), + col("id"), lit(1), col("id"), col("id"), lit(2), col("id"), col("id"))).collect()) + } + } +} \ No newline at end of file diff --git a/df_udf/src/test/scala/org/apache/spark/sql/nvidia/SparkTestBase.scala b/df_udf/src/test/scala/org/apache/spark/sql/nvidia/SparkTestBase.scala new file mode 100644 index 00000000000..2bd6697ffad --- /dev/null +++ b/df_udf/src/test/scala/org/apache/spark/sql/nvidia/SparkTestBase.scala @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.nvidia + +import java.io.File +import java.nio.file.Files +import java.util.{Locale, TimeZone} + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} + +object SparkSessionHolder extends Logging { + private var spark = createSparkSession() + private var origConf = spark.conf.getAll + private var origConfKeys = origConf.keys.toSet + + private def setAllConfs(confs: Array[(String, String)]): Unit = confs.foreach { + case (key, value) if spark.conf.get(key, null) != value => + spark.conf.set(key, value) + case _ => // No need to modify it + } + + private def createSparkSession(): SparkSession = { + SparkSession.cleanupAnyExistingSession() + + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + Locale.setDefault(Locale.US) + + val builder = SparkSession.builder() + .master("local[1]") + .config("spark.sql.extensions", "com.nvidia.spark.DFUDFPlugin") + .config("spark.sql.warehouse.dir", sparkWarehouseDir.getAbsolutePath) + .appName("dataframe udf tests") + + builder.getOrCreate() + } + + private def reinitSession(): Unit = { + spark = createSparkSession() + origConf = spark.conf.getAll + origConfKeys = origConf.keys.toSet + } + + def sparkSession: SparkSession = { + if (SparkSession.getActiveSession.isEmpty) { + reinitSession() + } + spark + } + + def resetSparkSessionConf(): Unit = { + if (SparkSession.getActiveSession.isEmpty) { + reinitSession() + } else { + setAllConfs(origConf.toArray) + val currentKeys = spark.conf.getAll.keys.toSet + val toRemove = currentKeys -- origConfKeys + if (toRemove.contains("spark.shuffle.manager")) { + // cannot unset the config so need to reinitialize + reinitSession() + } else { + toRemove.foreach(spark.conf.unset) + } + } + logDebug(s"RESET CONF TO: ${spark.conf.getAll}") + } + + def withSparkSession[U](conf: SparkConf, f: SparkSession => U): U = { + resetSparkSessionConf() + logDebug(s"SETTING CONF: ${conf.getAll.toMap}") + setAllConfs(conf.getAll) + logDebug(s"RUN WITH CONF: ${spark.conf.getAll}\n") + f(spark) + } + + private lazy val sparkWarehouseDir: File = { + new File(System.getProperty("java.io.tmpdir")).mkdirs() + val path = Files.createTempDirectory("spark-warehouse") + val file = new File(path.toString) + file.deleteOnExit() + file + } +} + +/** + * Base to be able to run tests with a spark context + */ +trait SparkTestBase extends AnyFunSuite with BeforeAndAfterAll { + def withSparkSession[U](f: SparkSession => U): U = { + withSparkSession(new SparkConf, f) + } + + def withSparkSession[U](conf: SparkConf, f: SparkSession => U): U = { + SparkSessionHolder.withSparkSession(conf, f) + } + + override def afterAll(): Unit = { + super.afterAll() + SparkSession.cleanupAnyExistingSession() + } + + def assertSame(expected: Any, actual: Any, epsilon: Double = 0.0, + path: List[String] = List.empty): Unit = { + def assertDoublesAreEqualWithinPercentage(expected: Double, + actual: Double, path: List[String]): Unit = { + if (expected != actual) { + if (expected != 0) { + val v = Math.abs((expected - actual) / expected) + assert(v <= epsilon, + s"$path: ABS($expected - $actual) / ABS($actual) == $v is not <= $epsilon ") + } else { + val v = Math.abs(expected - actual) + assert(v <= epsilon, s"$path: ABS($expected - $actual) == $v is not <= $epsilon ") + } + } + } + (expected, actual) match { + case (a: Float, b: Float) if a.isNaN && b.isNaN => + case (a: Double, b: Double) if a.isNaN && b.isNaN => + case (null, null) => + case (null, other) => fail(s"$path: expected is null, but actual is $other") + case (other, null) => fail(s"$path: expected is $other, but actual is null") + case (a: Array[_], b: Array[_]) => + assert(a.length == b.length, + s"$path: expected (${a.toList}) and actual (${b.toList}) lengths don't match") + a.indices.foreach { i => + assertSame(a(i), b(i), epsilon, path :+ i.toString) + } + case (a: Map[_, _], b: Map[_, _]) => + throw new IllegalStateException(s"Maps are not supported yet for comparison $a vs $b") + case (a: Iterable[_], b: Iterable[_]) => + assert(a.size == b.size, + s"$path: expected (${a.toList}) and actual (${b.toList}) lengths don't match") + var i = 0 + a.zip(b).foreach { + case (l, r) => + assertSame(l, r, epsilon, path :+ i.toString) + i += 1 + } + case (a: Product, b: Product) => + assertSame(a.productIterator.toSeq, b.productIterator.toSeq, epsilon, path) + case (a: Row, b: Row) => + assertSame(a.toSeq, b.toSeq, epsilon, path) + // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. + case (a: Double, b: Double) if epsilon <= 0 => + java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) + case (a: Double, b: Double) if epsilon > 0 => + assertDoublesAreEqualWithinPercentage(a, b, path) + case (a: Float, b: Float) if epsilon <= 0 => + java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) + case (a: Float, b: Float) if epsilon > 0 => + assertDoublesAreEqualWithinPercentage(a, b, path) + case (a, b) => + assert(a == b, s"$path: $a != $b") + } + } +} diff --git a/pom.xml b/pom.xml index 7a4b7e56d85..bfb8a50946e 100644 --- a/pom.xml +++ b/pom.xml @@ -73,6 +73,7 @@ aggregator datagen + df_udf dist integration_tests shuffle-plugin diff --git a/scala2.13/datagen/pom.xml b/scala2.13/datagen/pom.xml index 6c01e912f94..d53ebc014c7 100644 --- a/scala2.13/datagen/pom.xml +++ b/scala2.13/datagen/pom.xml @@ -33,6 +33,7 @@ **/* package + ${project.build.outputDirectory}/datagen-version-info.properties diff --git a/scala2.13/df_udf/pom.xml b/scala2.13/df_udf/pom.xml new file mode 100644 index 00000000000..04f7a6deb28 --- /dev/null +++ b/scala2.13/df_udf/pom.xml @@ -0,0 +1,88 @@ + + + + 4.0.0 + + com.nvidia + rapids-4-spark-shim-deps-parent_2.13 + 24.12.0-SNAPSHOT + ../shim-deps/pom.xml + + df_udf_plugin_2.13 + UDFs implemented in SQL/Dataframe + UDFs for Apache Spark implemented in SQL/Dataframe + 24.12.0-SNAPSHOT + + + df_udf + + **/* + package + ${project.build.outputDirectory}/df_udf-version-info.properties + + + + + org.scala-lang + scala-library + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.test.version} + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + net.alchim31.maven + scala-maven-plugin + + + org.scalatest + scalatest-maven-plugin + + + org.apache.rat + apache-rat-plugin + + + + + + + ${project.build.directory}/extra-resources + + + + diff --git a/scala2.13/pom.xml b/scala2.13/pom.xml index f17a90f4633..e22f311561a 100644 --- a/scala2.13/pom.xml +++ b/scala2.13/pom.xml @@ -73,6 +73,7 @@ aggregator datagen + df_udf dist integration_tests shuffle-plugin