Skip to content

Commit

Permalink
add scala.collection.immutable.SortedMap coder (#5443)
Browse files Browse the repository at this point in the history
Co-authored-by: Michel Davit <[email protected]>
  • Loading branch information
pgoggijr and RustedBones authored Aug 12, 2024
1 parent 43e110d commit 25c6987
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 119 deletions.
19 changes: 19 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,25 @@ ThisBuild / mimaBinaryIssueFilters ++= Seq(
// added new Cache.get method
ProblemFilters.exclude[ReversedMissingMethodProblem](
"com.spotify.scio.util.Cache.get"
),
// added SortedMapCoder
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.coders.instances.MutableMapCoder.encode"
),
ProblemFilters.exclude[DirectAbstractMethodProblem](
"org.apache.beam.sdk.coders.Coder.verifyDeterministic"
),
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.coders.instances.MutableMapCoder.structuralValue"
),
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.coders.instances.MutableMapCoder.isRegisterByteSizeObserverCheap"
),
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.coders.instances.MutableMapCoder.registerByteSizeObserver"
),
ProblemFilters.exclude[DirectAbstractMethodProblem](
"org.apache.beam.sdk.coders.Coder.getCoderArguments"
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package com.spotify.scio.coders.instances

import com.spotify.scio.coders.{Coder, CoderDerivation, CoderGrammar}
import org.apache.beam.sdk.coders.Coder.NonDeterministicException
import org.apache.beam.sdk.coders.{Coder => BCoder, IterableCoder => BIterableCoder, _}
import org.apache.beam.sdk.coders.{Coder => BCoder, _}
import org.apache.beam.sdk.util.common.ElementByteSizeObserver
import org.apache.beam.sdk.util.{BufferedElementCountingOutputStream, CoderUtils, VarInt}

import java.io.{InputStream, OutputStream}
import java.lang.{Iterable => JIterable}
import java.util.{Collections, List => JList}
import scala.collection.compat._
import scala.collection.immutable.SortedMap
import scala.collection.{mutable => m, AbstractIterable, BitSet, SortedSet}
import scala.jdk.CollectionConverters._
import scala.reflect.{classTag, ClassTag}
Expand All @@ -39,6 +39,7 @@ private[coders] object UnitCoder extends AtomicCoder[Unit] {
override def consistentWithEquals(): Boolean = true
override def isRegisterByteSizeObserverCheap(value: Unit): Boolean = true
override def getEncodedElementByteSize(value: Unit): Long = 0
override def toString: String = "UnitCoder"
}

private object NothingCoder extends AtomicCoder[Nothing] {
Expand All @@ -49,6 +50,7 @@ private object NothingCoder extends AtomicCoder[Nothing] {
override def consistentWithEquals(): Boolean = true
override def isRegisterByteSizeObserverCheap(value: Nothing): Boolean = true
override def getEncodedElementByteSize(value: Nothing): Long = 0
override def toString: String = "NothingCoder"
}

abstract private[coders] class BaseSeqLikeCoder[M[_], T](val elemCoder: BCoder[T])
Expand All @@ -59,23 +61,13 @@ abstract private[coders] class BaseSeqLikeCoder[M[_], T](val elemCoder: BCoder[T
// delegate methods for determinism and equality checks
override def verifyDeterministic(): Unit = elemCoder.verifyDeterministic()
override def consistentWithEquals(): Boolean = elemCoder.consistentWithEquals()

// delegate methods for byte size estimation
override def isRegisterByteSizeObserverCheap(value: M[T]): Boolean = false
override def registerByteSizeObserver(value: M[T], observer: ElementByteSizeObserver): Unit =
value match {
case JavaCollectionWrappers.JIterableWrapper(underlying) =>
BIterableCoder
.of(elemCoder)
.registerByteSizeObserver(underlying.asInstanceOf[JIterable[T]], observer)
case _ =>
super.registerByteSizeObserver(value, observer)
}
override def toString: String = s"${getClass.getSimpleName}($elemCoder)"
}

abstract private[coders] class SeqLikeCoder[M[_], T](bc: BCoder[T])(implicit
ev: M[T] => IterableOnce[T]
) extends BaseSeqLikeCoder[M, T](bc) {

override def encode(value: M[T], outStream: OutputStream): Unit = {
val traversable = ev(value)
VarInt.encode(traversable.iterator.size, outStream)
Expand Down Expand Up @@ -104,13 +96,16 @@ abstract private[coders] class SeqLikeCoder[M[_], T](bc: BCoder[T])(implicit
b.result()
}

override def toString: String = s"SeqLikeCoder($bc)"
override def registerByteSizeObserver(value: M[T], observer: ElementByteSizeObserver): Unit = {
val traversable = ev(value)
observer.update(VarInt.getLength(traversable.iterator.size))
traversable.iterator.foreach(v => elemCoder.registerByteSizeObserver(v, observer))
}
}

abstract private class BufferedSeqLikeCoder[M[_], T](bc: BCoder[T])(implicit
ev: M[T] => IterableOnce[T]
) extends BaseSeqLikeCoder[M, T](bc) {

override def encode(value: M[T], outStream: OutputStream): Unit = {
val buff = new BufferedElementCountingOutputStream(outStream)
ev(value).iterator.foreach { elem =>
Expand Down Expand Up @@ -138,7 +133,19 @@ abstract private class BufferedSeqLikeCoder[M[_], T](bc: BCoder[T])(implicit
b.result()
}

override def toString: String = s"BufferedSeqLikeCoder($bc)"
override def registerByteSizeObserver(value: M[T], observer: ElementByteSizeObserver): Unit = {
val traversable = ev(value)
val size = traversable.iterator.foldLeft(0) { (count, v) =>
elemCoder.registerByteSizeObserver(v, observer)
count + 1
}
if (size > 0) {
// this is an approximation, but it's better than nothing
// BufferedElementCountingOutputStream can split the iterator in many chunks
observer.update(VarInt.getLength(size))
}
observer.update(1L) // terminator byte
}
}

// keep this for binary compatibility
Expand Down Expand Up @@ -249,155 +256,126 @@ private class MutablePriorityQueueCoder[T: Ordering](bc: BCoder[T])
}

private[coders] class BitSetCoder extends AtomicCoder[BitSet] {
private[this] val lc = VarIntCoder.of()

def decode(in: InputStream): BitSet = {
val l = lc.decode(in)
val l = VarInt.decodeInt(in)
val builder = BitSet.newBuilder
builder.sizeHint(l)
(1 to l).foreach(_ => builder += lc.decode(in))
(1 to l).foreach(_ => builder += VarInt.decodeInt(in))

builder.result()
}

def encode(ts: BitSet, out: OutputStream): Unit = {
lc.encode(ts.size, out)
ts.foreach(v => lc.encode(v, out))
VarInt.encode(ts.size, out)
ts.foreach(v => VarInt.encode(v, out))
}

override def consistentWithEquals(): Boolean = lc.consistentWithEquals()
override def consistentWithEquals(): Boolean = true

override def toString: String = "BitSetCoder"
}

private[coders] class MapCoder[K, V](val kc: BCoder[K], val vc: BCoder[V])
extends StructuredCoder[Map[K, V]] {
private[this] val lc = VarIntCoder.of()

override def encode(value: Map[K, V], os: OutputStream): Unit = {
lc.encode(value.size, os)
val it = value.iterator
while (it.hasNext) {
val (k, v) = it.next()
kc.encode(k, os)
vc.encode(v, os)
abstract private[coders] class MapLikeCoder[K, V, M[_, _]](
val keyCoder: BCoder[K],
val valueCoder: BCoder[V]
)(implicit
ev: M[K, V] => IterableOnce[(K, V)]
) extends StructuredCoder[M[K, V]] {
override def getCoderArguments: JList[_ <: BCoder[_]] = List(keyCoder, valueCoder).asJava

override def encode(value: M[K, V], os: OutputStream): Unit = {
val traversable = ev(value)
VarInt.encode(traversable.iterator.size, os)
traversable.iterator.foreach { case (k, v) =>
keyCoder.encode(k, os)
valueCoder.encode(v, os)
}
}

override def decode(is: InputStream): Map[K, V] = {
val l = lc.decode(is)
val builder = Map.newBuilder[K, V]
builder.sizeHint(l)
def decode(is: InputStream, builder: m.Builder[(K, V), M[K, V]]): M[K, V] = {
val size = VarInt.decodeInt(is)
builder.sizeHint(size)
var i = 0
while (i < l) {
val k = kc.decode(is)
val v = vc.decode(is)
while (i < size) {
val k = keyCoder.decode(is)
val v = valueCoder.decode(is)
builder += (k -> v)
i = i + 1
}
builder.result()
}

// delegate methods for determinism and equality checks
override def verifyDeterministic(): Unit =
throw new NonDeterministicException(
this,
"Ordering of entries in a Map may be non-deterministic."
)

override def consistentWithEquals(): Boolean =
kc.consistentWithEquals() && vc.consistentWithEquals()
override def structuralValue(value: Map[K, V]): AnyRef =
keyCoder.consistentWithEquals() && valueCoder.consistentWithEquals()

override def structuralValue(value: M[K, V]): AnyRef =
if (consistentWithEquals()) {
value
value.asInstanceOf[AnyRef]
} else {
val b = Map.newBuilder[Any, Any]
b.sizeHint(value.size)
value.foreach { case (k, v) =>
b += kc.structuralValue(k) -> vc.structuralValue(v)
val traversable = ev(value)
b.sizeHint(traversable.iterator.size)
traversable.iterator.foreach { case (k, v) =>
b += keyCoder.structuralValue(k) -> valueCoder.structuralValue(v)
}
b.result()
}

// delegate methods for byte size estimation
override def isRegisterByteSizeObserverCheap(value: Map[K, V]): Boolean = false
override def isRegisterByteSizeObserverCheap(value: M[K, V]): Boolean = false
override def registerByteSizeObserver(
value: Map[K, V],
value: M[K, V],
observer: ElementByteSizeObserver
): Unit = {
lc.registerByteSizeObserver(value.size, observer)
value.foreach { case (k, v) =>
kc.registerByteSizeObserver(k, observer)
vc.registerByteSizeObserver(v, observer)
val traversable = ev(value)
observer.update(VarInt.getLength(traversable.iterator.size))
traversable.iterator.foreach { case (k, v) =>
keyCoder.registerByteSizeObserver(k, observer)
valueCoder.registerByteSizeObserver(v, observer)
}
}

override def toString: String =
s"MapCoder($kc, $vc)"
override def toString: String = s"${getClass.getSimpleName}($keyCoder, $valueCoder)"
}

override def getCoderArguments: JList[_ <: BCoder[_]] = List(kc, vc).asJava
private[coders] class MapCoder[K, V](kc: BCoder[K], vc: BCoder[V])
extends MapLikeCoder[K, V, Map](kc, vc) {
override def decode(is: InputStream): Map[K, V] =
decode(is, Map.newBuilder[K, V])
}

private class MutableMapCoder[K, V](kc: BCoder[K], vc: BCoder[V])
extends StructuredCoder[m.Map[K, V]] {
private[this] val lc = VarIntCoder.of()

override def encode(value: m.Map[K, V], os: OutputStream): Unit = {
lc.encode(value.size, os)
value.foreach { case (k, v) =>
kc.encode(k, os)
vc.encode(v, os)
}
}
extends MapLikeCoder[K, V, m.Map](kc, vc) {

override def decode(is: InputStream): m.Map[K, V] = {
val l = lc.decode(is)
val builder = m.Map.newBuilder[K, V]
builder.sizeHint(l)
var i = 0
while (i < l) {
val k = kc.decode(is)
val v = vc.decode(is)
builder += (k -> v)
i = i + 1
}
builder.result()
}
override def decode(inStream: InputStream): m.Map[K, V] =
decode(inStream, m.Map.newBuilder[K, V])
}

// delegate methods for determinism and equality checks
override def verifyDeterministic(): Unit =
throw new NonDeterministicException(
this,
"Ordering of entries in a Map may be non-deterministic."
)
override def consistentWithEquals(): Boolean =
kc.consistentWithEquals() && vc.consistentWithEquals()
override def structuralValue(value: m.Map[K, V]): AnyRef =
if (consistentWithEquals()) {
value
} else {
val b = m.Map.newBuilder[Any, Any]
b.sizeHint(value.size)
value.foreach { case (k, v) =>
b += kc.structuralValue(k) -> vc.structuralValue(v)
}
b.result()
}
private[coders] class SortedMapCoder[K: Ordering, V](kc: BCoder[K], vc: BCoder[V])
extends MapLikeCoder[K, V, SortedMap](kc, vc) {

// delegate methods for byte size estimation
override def isRegisterByteSizeObserverCheap(value: m.Map[K, V]): Boolean = false
override def registerByteSizeObserver(
value: m.Map[K, V],
observer: ElementByteSizeObserver
): Unit = {
lc.registerByteSizeObserver(value.size, observer)
value.foreach { case (k, v) =>
kc.registerByteSizeObserver(k, observer)
vc.registerByteSizeObserver(v, observer)
}
override def encode(value: SortedMap[K, V], os: OutputStream): Unit = {
require(
value.ordering == Ordering[K],
"SortedMap ordering does not match SortedMapCoder ordering"
)
super.encode(value, os)
}

override def toString: String =
s"MutableMapCoder($kc, $vc)"
override def decode(is: InputStream): SortedMap[K, V] =
decode(is, SortedMap.newBuilder[K, V])

override def getCoderArguments: JList[_ <: BCoder[_]] = List(kc, vc).asJava
override def verifyDeterministic(): Unit = {
keyCoder.verifyDeterministic()
valueCoder.verifyDeterministic()
}
}

private[coders] object SFloatCoder extends BCoder[Float] {
Expand Down Expand Up @@ -509,14 +487,19 @@ trait ScalaCoders extends CoderGrammar with CoderDerivation {
): Coder[m.WrappedArray[T]] =
xmap(Coder[Array[T]])(wrap, _.toArray)

implicit def mapCoder[K: Coder, V: Coder]: Coder[Map[K, V]] =
transform(Coder[K]) { kc =>
transform(Coder[V])(vc => beam(new MapCoder[K, V](kc, vc)))
}

implicit def mutableMapCoder[K: Coder, V: Coder]: Coder[m.Map[K, V]] =
transform(Coder[K]) { kc =>
transform(Coder[V])(vc => beam(new MutableMapCoder[K, V](kc, vc)))
}

implicit def mapCoder[K: Coder, V: Coder]: Coder[Map[K, V]] =
implicit def sortedMapCoder[K: Coder: Ordering, V: Coder]: Coder[SortedMap[K, V]] =
transform(Coder[K]) { kc =>
transform(Coder[V])(vc => beam(new MapCoder[K, V](kc, vc)))
transform(Coder[V])(vc => beam(new SortedMapCoder[K, V](kc, vc)))
}

implicit def sortedSetCoder[T: Coder: Ordering]: Coder[SortedSet[T]] =
Expand Down
Loading

0 comments on commit 25c6987

Please sign in to comment.