Skip to content

Commit

Permalink
Add support for Zstd coders (#5321)
Browse files Browse the repository at this point in the history
  • Loading branch information
kellen authored May 20, 2024
1 parent e1687eb commit fe93831
Show file tree
Hide file tree
Showing 13 changed files with 636 additions and 54 deletions.
8 changes: 8 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ val jacksonVersion = "2.14.1"
val jodaTimeVersion = "2.10.10"
val nettyTcNativeVersion = "2.0.52.Final"
val slf4jVersion = "1.7.30"
val zstdJniVersion = "1.5.2-5"
// dependent versions
val googleApiServicesBigQueryVersion = s"v2-rev20240229-$googleClientsVersion"
val googleApiServicesDataflowVersion = s"v1b3-rev20240218-$googleClientsVersion"
Expand Down Expand Up @@ -395,6 +396,12 @@ ThisBuild / mimaBinaryIssueFilters ++= Seq(
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.testing.TransformOverride.ofSource"
),
ProblemFilters.exclude[ReversedMissingMethodProblem](
"com.spotify.scio.options.ScioOptions.setZstdDictionary"
),
ProblemFilters.exclude[ReversedMissingMethodProblem](
"com.spotify.scio.options.ScioOptions.getZstdDictionary"
),
// removal of private classes
ProblemFilters.exclude[MissingClassProblem](
"com.spotify.scio.coders.instances.kryo.GaxApiExceptionSerializer"
Expand Down Expand Up @@ -655,6 +662,7 @@ lazy val `scio-core` = project
"com.fasterxml.jackson.core" % "jackson-annotations" % jacksonVersion,
"com.fasterxml.jackson.core" % "jackson-databind" % jacksonVersion,
"com.fasterxml.jackson.module" %% "jackson-module-scala" % jacksonVersion,
"com.github.luben" % "zstd-jni" % zstdJniVersion,
"com.google.api" % "gax" % gaxVersion,
"com.google.api-client" % "google-api-client" % googleApiClientVersion,
"com.google.auto.service" % "auto-service-annotations" % autoServiceVersion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package com.spotify.scio.options;

import com.fasterxml.jackson.annotation.JsonIgnore;
import java.util.List;
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PipelineOptions;
Expand All @@ -34,7 +35,7 @@ public interface ScioOptions extends PipelineOptions, KryoOptions {

void setScalaVersion(String version);

@Description("Filename to save metrics to.")
@Description("Filename to save metrics to")
String getMetricsLocation();

void setMetricsLocation(String metricsLocation);
Expand Down Expand Up @@ -73,9 +74,14 @@ enum CheckEnabled {
ERROR
}

@Description("Should scio use NullableCoder to serialize data.")
@Description("Should scio use NullableCoder to serialize data")
@Default.Boolean(false)
boolean getNullableCoders();

void setNullableCoders(boolean value);

@Description("Colon-separated mapping of fully-qualified class name to location of Zstd dictionary for that class com.MyClass:gs://bucket/file.bin")
List<String> getZstdDictionary();

void setZstdDictionary(List<String> value);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.spotify.scio.coders

import com.spotify.scio.coders.CoderMaterializer.CoderOptions
import com.spotify.scio.values.SCollection
import org.apache.beam.sdk.coders.{Coder => BCoder, NullableCoder}
import org.apache.beam.sdk.coders.{Coder => BCoder, NullableCoder, ZstdCoder => BZstdCoder}
import org.apache.beam.sdk.values.PCollection

import scala.annotation.tailrec
Expand All @@ -35,6 +35,14 @@ private[scio] object BeamCoders {
case _ => coder
}

private def unwrapZstd[T](options: CoderOptions, coder: BCoder[T]): BCoder[T] =
coder match {
case c: BZstdCoder[T] =>
val underlying = c.getCoderArguments.get(0).asInstanceOf[BCoder[T]]
unwrap(options, underlying)
case _ => coder
}

/** Get coder from an `PCollection[T]`. */
def getCoder[T](coll: PCollection[T]): Coder[T] = {
val options = CoderOptions(coll.getPipeline.getOptions)
Expand All @@ -49,6 +57,7 @@ private[scio] object BeamCoders {
val options = CoderOptions(coll.context.options)
val coder = coll.internal.getCoder
Some(unwrap(options, coder))
.map(unwrapZstd(options, _))
.map(_.getCoderArguments.asScala.toList)
.collect {
case (c1: BCoder[K @unchecked]) ::
Expand All @@ -75,6 +84,7 @@ private[scio] object BeamCoders {
val options = CoderOptions(coll.context.options)
val coder = coll.internal.getCoder
Some(unwrap(options, coder))
.map(unwrapZstd(options, _))
.map(_.getCoderArguments.asScala.toList)
.collect {
case (c1: BCoder[A @unchecked]) ::
Expand All @@ -99,6 +109,7 @@ private[scio] object BeamCoders {
val options = CoderOptions(coll.context.options)
val coder = coll.internal.getCoder
Some(unwrap(options, coder))
.map(unwrapZstd(options, _))
.map(_.getCoderArguments.asScala.toList)
.collect {
case (c1: BCoder[A @unchecked]) ::
Expand Down
24 changes: 17 additions & 7 deletions scio-core/src/main/scala/com/spotify/scio/coders/Coder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,20 @@ Cannot find an implicit Coder instance for type:
)
sealed trait Coder[T] extends Serializable

sealed private[scio] trait TypeName {
def typeName: String
}

final private[scio] case class Singleton[T] private (typeName: String, supply: () => T)
extends Coder[T] {
extends Coder[T]
with TypeName {
override def toString: String = s"Singleton[$typeName]"
}

// This should not be a case class. equality must be reference equality to detect recursive coders
final private[scio] class Ref[T] private (val typeName: String, c: => Coder[T]) extends Coder[T] {
final private[scio] class Ref[T] private (val typeName: String, c: => Coder[T])
extends Coder[T]
with TypeName {
def value: Coder[T] = c
override def toString: String = s"Ref[$typeName]"
}
Expand All @@ -82,15 +89,17 @@ final case class CoderTransform[T, U] private (
typeName: String,
c: Coder[U],
f: BCoder[U] => Coder[T]
) extends Coder[T] {
) extends Coder[T]
with TypeName {
override def toString: String = s"CoderTransform[$typeName]($c)"
}
final case class Transform[T, U] private (
typeName: String,
c: Coder[U],
t: T => U,
f: U => T
) extends Coder[T] {
) extends Coder[T]
with TypeName {
override def toString: String = s"Transform[$typeName]($c)"
}

Expand All @@ -99,7 +108,8 @@ final case class Disjunction[T, Id] private (
idCoder: Coder[Id],
coder: Map[Id, Coder[T]],
id: T => Id
) extends Coder[T] {
) extends Coder[T]
with TypeName {
override def toString: String = {
val body = coder.map { case (id, v) => s"$id -> $v" }.mkString(", ")
s"Disjunction[$typeName]($body)"
Expand All @@ -111,7 +121,8 @@ final case class Record[T] private (
cs: Array[(String, Coder[Any])],
construct: Seq[Any] => T,
destruct: T => IndexedSeq[Any]
) extends Coder[T] {
) extends Coder[T]
with TypeName {
override def toString: String = {
val body = cs.map { case (k, v) => s"($k, $v)" }.mkString(", ")
s"Record[$typeName]($body)"
Expand Down Expand Up @@ -221,7 +232,6 @@ object Coder
with CoderDerivation
with LowPriorityCoders {
@inline final def apply[T](implicit c: Coder[T]): Coder[T] = c

}

trait LowPriorityCoders extends LowPriorityCoders1 { self: CoderDerivation with JavaBeanCoders =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,102 @@

package com.spotify.scio.coders

import org.apache.beam.sdk.coders.{Coder => BCoder, IterableCoder, KvCoder, NullableCoder}
import com.spotify.scio.util.RemoteFileUtil
import org.apache.beam.sdk.coders.{
Coder => BCoder,
IterableCoder,
KvCoder,
NullableCoder,
ZstdCoder
}
import org.apache.beam.sdk.options.{PipelineOptions, PipelineOptionsFactory}
import org.apache.commons.lang3.ObjectUtils

import java.net.URI
import java.nio.file.Files
import java.util.concurrent.ConcurrentHashMap
import scala.collection.compat._
import scala.collection.concurrent.TrieMap
import scala.jdk.CollectionConverters._
import scala.util.Try
import scala.util.chaining._

object CoderMaterializer {
import com.spotify.scio.ScioContext

private[scio] case class CoderOptions(nullableCoders: Boolean, kryo: KryoOptions)
private[scio] case class CoderOptions(
nullableCoders: Boolean,
kryo: KryoOptions,
zstdDictMapping: Map[String, Array[Byte]]
)
private[scio] object CoderOptions {
private val cache: ConcurrentHashMap[String, CoderOptions] = new ConcurrentHashMap()
private val ZstdArgRegex = "([^:]+):(.*)".r
private val ZstdPackageBlacklist =
List("scala.", "java.", "com.spotify.scio.", "org.apache.beam.")

final def apply(o: PipelineOptions): CoderOptions = {
val nullableCoder = o.as(classOf[com.spotify.scio.options.ScioOptions]).getNullableCoders
new CoderOptions(nullableCoder, KryoOptions(o))
cache.computeIfAbsent(
ObjectUtils.identityToString(o),
{ _ =>
val scioOpts = o.as(classOf[com.spotify.scio.options.ScioOptions])
val nullableCoder = scioOpts.getNullableCoders

val (errors, classPathMapping) = Option(scioOpts.getZstdDictionary)
.map(_.asScala.toList)
.getOrElse(List.empty)
.partitionMap {
case s @ ZstdArgRegex(className, path) =>
Option
.when(ZstdPackageBlacklist.exists(className.startsWith))(
s"zstdDictionary command-line arguments may not be used for class $className. " +
s"Provide Zstd coders manually instead."
)
.orElse {
Try(Class.forName(className)).failed.toOption
.map(_ => s"Class for zstdDictionary argument ${s} not found.")
}
.toLeft(className.replaceAll("\\$", ".") -> path)
case s =>
Left(
"zstdDictionary arguments must be in a colon-separated format. " +
s"Example: `com.spotify.ClassName:gs://path`. Found: $s"
)
}

if (errors.nonEmpty) {
throw new IllegalArgumentException(
errors.mkString("Bad zstdDictionary arguments:\n\t", "\n\t", "\n")
)
}

val zstdDictPaths = classPathMapping
.groupBy(_._1)
.map { case (className, values) => className -> values.map(_._2).toSet }

val dupes = zstdDictPaths
.collect {
case (className, values) if values.size > 1 =>
s"Class $className -> [${values.mkString(", ")}]"
}
if (dupes.size > 1) {
throw new IllegalArgumentException(
dupes.mkString("Found multiple Zstd dictionaries for:\n\t", "\n\t", "\n")
)
}

val zstdDictMapping = zstdDictPaths.map { case (clazz, dictUriSet) =>
// dictUriSet always contains exactly 1 item
val dictUri = dictUriSet.toList.head
val dictPath = RemoteFileUtil.create(o).download(new URI(dictUri))
val dictBytes = Files.readAllBytes(dictPath)
Files.delete(dictPath)
clazz -> dictBytes
}

new CoderOptions(nullableCoder, KryoOptions(o), zstdDictMapping)
}
)
}
}

Expand Down Expand Up @@ -119,6 +201,13 @@ object CoderMaterializer {

bCoder
.pipe(bc => if (isNullableCoder(o, coder)) NullableCoder.of(bc) else bc)
.pipe { bc =>
Option(coder)
.collect { case x: TypeName => x.typeName }
.flatMap(o.zstdDictMapping.get)
.map(ZstdCoder.of(bc, _))
.getOrElse(bc)
}
.pipe(bc => if (isWrappableCoder(topLevel, coder)) new MaterializedCoder(bc) else bc)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2024 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.coders.instances

import com.spotify.scio.coders.Coder
import org.apache.beam.sdk.coders.{ZstdCoder => BZstdCoder}

import scala.reflect.ClassTag

object ZstdCoder {
def apply[T: Coder: ClassTag](dict: Array[Byte]): Coder[T] =
Coder.transform(Coder[T])(tCoder => Coder.beam(BZstdCoder.of(tCoder, dict)))

def tuple2[K: Coder, V: Coder](
keyDict: Array[Byte] = null,
valueDict: Array[Byte] = null
): Coder[(K, V)] =
Coder.transform(Coder[K]) { kCoder =>
val bKCoder = Option(keyDict).map(BZstdCoder.of(kCoder, _)).getOrElse(kCoder)
Coder.transform(Coder[V]) { vCoder =>
val bVCoder = Option(valueDict).map(BZstdCoder.of(vCoder, _)).getOrElse(vCoder)
Coder.beam(
new Tuple2Coder[K, V](bKCoder, bVCoder)
)
}
}
}
Loading

0 comments on commit fe93831

Please sign in to comment.