Skip to content

Commit

Permalink
Add in a basic plugin for dataframe UDF support in Apache Spark (#11561)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Oct 8, 2024
1 parent cd46572 commit 6897713
Show file tree
Hide file tree
Showing 15 changed files with 1,607 additions and 0 deletions.
1 change: 1 addition & 0 deletions datagen/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
<target.classifier/>
<rapids.default.jar.excludePattern>**/*</rapids.default.jar.excludePattern>
<rapids.shim.jar.phase>package</rapids.shim.jar.phase>
<build.info.path>${project.build.outputDirectory}/datagen-version-info.properties</build.info.path>
</properties>
<dependencies>
<dependency>
Expand Down
90 changes: 90 additions & 0 deletions df_udf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Scala / Java UDFS implemented using data frame

User Defined Functions (UDFs) are used for a number of reasons in Apache Spark. Much of the time it is to implement
logic that is either very difficult or impossible to implement using existing SQL/Dataframe APIs directly. But they
are also used as a way to standardize processing logic across an organization or for code reused.

But UDFs come with some downsides. The biggest one is visibility into the processing being done. SQL is a language that
can be highly optimized. But a UDF in most cases is a black box, that the SQL optimizer cannot do anything about.
This can result in less than ideal query planning. Additionally, accelerated execution environments, like the
RAPIDS Accelerator for Apache Spark have no easy way to replace UDFs with accelerated versions, which can result in
slow performance.

This attempts to add visibility to the code reuse use case by providing a way to implement a UDF in terms of dataframe
commands.

## Setup

To do this include com.nvidia:df_udf_plugin as a dependency for your project and also include it on the
classpath for your Apache Spark environment. Then include `com.nvidia.spark.DFUDFPlugin` in the config
`spark.sql.extensions`. Now you can implement a UDF in terms of Dataframe operations.

## Usage

```scala
import com.nvidia.spark.functions._

import org.apache.spark.sql.Column
import org.apache.spark.sql.functions._

val sum_array = df_udf((longArray: Column) =>
aggregate(longArray,
lit(0L),
(a, b) => coalesce(a, lit(0L)) + coalesce(b, lit(0L)),
a => a))
spark.udf.register("sum_array", sum_array)
```

You can then use `sum_array` however you would have used any other UDF. This allows you to provide a drop in replacement
implementation of an existing UDF.

```scala
Seq(Array(1L, 2L, 3L)).toDF("data").selectExpr("sum_array(data) as result").show()

+------+
|result|
+------+
| 6|
+------+
```

## Type Checks

DataFrame APIs do not provide type safety when writing the code and that is the same here. There are no builtin type
checks for inputs yet. Also, because of how types are resolved in Spark there is no way to adjust the query based on
the types passed in. Type checks are handled by the SQL planner/optimizer after the UDF has been replaced. This means
that the final SQL will not violate any type safety, but it also means that the errors might be confusing. For example,
if I passed in an `ARRAY<DOUBLE>` to `sum_array` instead of an `ARRAY<LONG>` I would get an error like

```scala
Seq(Array(1.0, 2.0, 3.0)).toDF("data").selectExpr("sum_array(data) as result").show()
org.apache.spark.sql.AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "aggregate(data, 0, lambdafunction((coalesce(namedlambdavariable(), 0) + coalesce(namedlambdavariable(), 0)), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))" due to data type mismatch: Parameter 3 requires the "BIGINT" type, however "lambdafunction((coalesce(namedlambdavariable(), 0) + coalesce(namedlambdavariable(), 0)), namedlambdavariable(), namedlambdavariable())" has the type "DOUBLE".; line 1 pos 0;
Project [aggregate(data#46, 0, lambdafunction((cast(coalesce(lambda x_9#49L, 0) as double) + coalesce(lambda y_10#50, cast(0 as double))), lambda x_9#49L, lambda y_10#50, false), lambdafunction(lambda x_11#51L, lambda x_11#51L, false)) AS result#48L]
+- Project [value#43 AS data#46]
+- LocalRelation [value#43]

at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.dataTypeMismatch(package.scala:73)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5(CheckAnalysis.scala:269)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5$adapted(CheckAnalysis.scala:256)
```

Which is not as simple to understand as a normal UDF.

```scala
val sum_array = udf((a: Array[Long]) => a.sum)

spark.udf.register("sum_array", sum_array)

Seq(Array(1.0, 2.0, 3.0)).toDF("data").selectExpr("sum_array(data) as result").show()
org.apache.spark.sql.AnalysisException: [CANNOT_UP_CAST_DATATYPE] Cannot up cast array element from "DOUBLE" to "BIGINT".
The type path of the target object is:
- array element class: "long"
- root class: "[J"
You can either add an explicit cast to the input data or choose a higher precision type of the field in the target object
at org.apache.spark.sql.errors.QueryCompilationErrors$.upCastFailureError(QueryCompilationErrors.scala:285)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$.org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveUpCast$$fail(Analyzer.scala:3646)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$$anonfun$apply$57$$anonfun$applyOrElse$234.applyOrElse(Analyzer.scala:3677)
at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$$anonfun$apply$57$$anonfun$applyOrElse$234.applyOrElse(Analyzer.scala:3654)
```

We hope to add optional type checks in the future.
88 changes: 88 additions & 0 deletions df_udf/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Copyright (c) 2024 NVIDIA CORPORATION.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>com.nvidia</groupId>
<artifactId>rapids-4-spark-shim-deps-parent_2.12</artifactId>
<version>24.12.0-SNAPSHOT</version>
<relativePath>../shim-deps/pom.xml</relativePath>
</parent>
<artifactId>df_udf_plugin_2.12</artifactId>
<name>UDFs implemented in SQL/Dataframe</name>
<description>UDFs for Apache Spark implemented in SQL/Dataframe</description>
<version>24.12.0-SNAPSHOT</version>

<properties>
<rapids.module>df_udf</rapids.module>
<target.classifier/>
<rapids.default.jar.excludePattern>**/*</rapids.default.jar.excludePattern>
<rapids.shim.jar.phase>package</rapids.shim.jar.phase>
<build.info.path>${project.build.outputDirectory}/df_udf-version-info.properties</build.info.path>
</properties>

<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${spark.test.version}</version>
</dependency>
</dependencies>

<build>
<plugins>
<!-- disable surefire as we are using scalatest only -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<skipTests>true</skipTests>
</configuration>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.apache.rat</groupId>
<artifactId>apache-rat-plugin</artifactId>
</plugin>
</plugins>

<resources>
<resource>
<!-- Include the properties file to provide the build information. -->
<directory>${project.build.directory}/extra-resources</directory>
</resource>
</resources>
</build>
</project>
31 changes: 31 additions & 0 deletions df_udf/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark

import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule

class DFUDFPlugin extends (SparkSessionExtensions => Unit) {
override def apply(extensions: SparkSessionExtensions): Unit = {
extensions.injectResolutionRule(logicalPlanRules)
}

def logicalPlanRules(sparkSession: SparkSession): Rule[LogicalPlan] = {
org.apache.spark.sql.nvidia.LogicalPlanRules()
}
}
Loading

0 comments on commit 6897713

Please sign in to comment.