From a3cf469cd3dfd18bffff92bcddba804a8c29bc58 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Thu, 19 Sep 2024 10:27:35 +0200 Subject: [PATCH] Add missing hot-key SCollection API --- .../spotify/scio/util/TupleFunctions.scala | 3 ++ .../scio/values/SCollectionWithFanout.scala | 31 ++++++++++-- .../values/SCollectionWithHotKeyFanout.scala | 40 +++++++++++++++- .../values/SCollectionWithFanoutTest.scala | 48 ++++++++++++++++++- .../SCollectionWithHotKeyFanoutTest.scala | 45 ++++++++++++++++- 5 files changed, 161 insertions(+), 6 deletions(-) diff --git a/scio-core/src/main/scala/com/spotify/scio/util/TupleFunctions.scala b/scio-core/src/main/scala/com/spotify/scio/util/TupleFunctions.scala index 8db5569199..674ed5c7b4 100644 --- a/scio-core/src/main/scala/com/spotify/scio/util/TupleFunctions.scala +++ b/scio-core/src/main/scala/com/spotify/scio/util/TupleFunctions.scala @@ -31,6 +31,9 @@ private[scio] object TupleFunctions { def klToTuple[K](kv: KV[K, java.lang.Long]): (K, Long) = (kv.getKey, kv.getValue) + def kdToTuple[K](kv: KV[K, java.lang.Double]): (K, Double) = + (kv.getKey, kv.getValue) + def kvIterableToTuple[K, V](kv: KV[K, JIterable[V]]): (K, Iterable[V]) = (kv.getKey, kv.getValue.asScala) diff --git a/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithFanout.scala b/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithFanout.scala index 1a185f7519..0cb50b031f 100644 --- a/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithFanout.scala +++ b/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithFanout.scala @@ -21,11 +21,10 @@ import com.spotify.scio.ScioContext import com.spotify.scio.util.Functions import com.spotify.scio.coders.Coder import com.twitter.algebird.{Aggregator, Monoid, MonoidAggregator, Semigroup} -import org.apache.beam.sdk.transforms.{Combine, Top} +import org.apache.beam.sdk.transforms.{Combine, Latest, Mean, Reify, Top} import org.apache.beam.sdk.values.PCollection -import java.lang.{Iterable => JIterable} - +import java.lang.{Double => JDouble, Iterable => JIterable} import scala.jdk.CollectionConverters._ /** @@ -116,6 +115,32 @@ class SCollectionWithFanout[T] private[values] (coll: SCollection[T], fanout: In ) } + /** [[SCollection.min]] with fan out. */ + def min(implicit ord: Ordering[T]): SCollection[T] = + this.reduce(ord.min) + + /** [[SCollection.max]] with fan out. */ + def max(implicit ord: Ordering[T]): SCollection[T] = + this.reduce(ord.max) + + /** [[SCollection.mean]] with fan out. */ + def mean(implicit ev: Numeric[T]): SCollection[Double] = { + val e = ev // defeat closure + coll.transform { in => + in.map[JDouble](e.toDouble) + .pApply(Mean.globally().withFanout(fanout)) + .asInstanceOf[SCollection[Double]] + } + } + + /** [[SCollection.latest]] with fan out. */ + def latest: SCollection[T] = { + coll.transform { in => + in.pApply("Reify Timestamps", Reify.timestamps[T]()) + .pApply("Latest Value", Combine.globally(Latest.combineFn[T]()).withFanout(fanout)) + } + } + /** [[SCollection.top]] with fan out. */ def top(num: Int)(implicit ord: Ordering[T]): SCollection[Iterable[T]] = { coll.transform { in => diff --git a/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithHotKeyFanout.scala b/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithHotKeyFanout.scala index fa45a3c4ac..20c0679ea3 100644 --- a/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithHotKeyFanout.scala +++ b/scio-core/src/main/scala/com/spotify/scio/values/SCollectionWithHotKeyFanout.scala @@ -24,7 +24,17 @@ import com.spotify.scio.util.TupleFunctions._ import com.twitter.algebird.{Aggregator, Monoid, MonoidAggregator, Semigroup} import org.apache.beam.sdk.transforms.Combine.PerKeyWithHotKeyFanout import org.apache.beam.sdk.transforms.Top.TopCombineFn -import org.apache.beam.sdk.transforms.{Combine, SerializableFunction} +import org.apache.beam.sdk.transforms.{ + Combine, + Latest, + Mean, + PTransform, + Reify, + SerializableFunction +} +import org.apache.beam.sdk.values.{KV, PCollection} + +import java.lang.{Double => JDouble} /** * An enhanced SCollection that uses an intermediate node to combine "hot" keys partially before @@ -142,6 +152,34 @@ class SCollectionWithHotKeyFanout[K, V] private[values] ( self.applyPerKey(withFanout(Combine.perKey(Functions.reduceFn(context, sg))))(kvToTuple) } + /** [[SCollection.min]] with hot key fan out. */ + def minByKey(implicit ord: Ordering[V]): SCollection[(K, V)] = + self.reduceByKey(ord.min) + + /** [[SCollection.max]] with hot key fan out. */ + def maxByKey(implicit ord: Ordering[V]): SCollection[(K, V)] = + self.reduceByKey(ord.max) + + /** [[SCollection.mean]] with hot key fan out. */ + def meanByKey(implicit ev: Numeric[V]): SCollection[(K, Double)] = { + val e = ev // defeat closure + self.self.transform { in => + in.mapValues[JDouble](e.toDouble).applyPerKey(Mean.perKey[K, JDouble]())(kdToTuple) + } + } + + /** [[SCollection.latest]] with hot key fan out. */ + def latestByKey: SCollection[(K, V)] = { + self.applyPerKey(new PTransform[PCollection[KV[K, V]], PCollection[KV[K, V]]]() { + override def expand(input: PCollection[KV[K, V]]): PCollection[KV[K, V]] = { + input + .apply("Reify Timestamps", Reify.timestampsInValue[K, V]) + .apply("Latest Value", withFanout(Combine.perKey(Latest.combineFn[V]()))) + .setCoder(input.getCoder) + } + })(kvToTuple) + } + /** [[PairSCollectionFunctions.topByKey]] with hot key fanout. */ def topByKey(num: Int)(implicit ord: Ordering[V]): SCollection[(K, Iterable[V])] = self.applyPerKey(withFanout(Combine.perKey(new TopCombineFn[V, Ordering[V]](num, ord))))( diff --git a/scio-core/src/test/scala/com/spotify/scio/values/SCollectionWithFanoutTest.scala b/scio-core/src/test/scala/com/spotify/scio/values/SCollectionWithFanoutTest.scala index e890298408..862ca33cfb 100644 --- a/scio-core/src/test/scala/com/spotify/scio/values/SCollectionWithFanoutTest.scala +++ b/scio-core/src/test/scala/com/spotify/scio/values/SCollectionWithFanoutTest.scala @@ -19,6 +19,7 @@ package com.spotify.scio.values import com.twitter.algebird.{Aggregator, Semigroup} import com.spotify.scio.coders.Coder +import org.joda.time.Instant class SCollectionWithFanoutTest extends NamedTransformSpec { "SCollectionWithFanout" should "support aggregate()" in { @@ -60,7 +61,7 @@ class SCollectionWithFanoutTest extends NamedTransformSpec { } } - it should "support sum()" in { + it should "support sum" in { runWithContext { sc => def sum[T: Coder: Semigroup](elems: T*): SCollection[T] = sc.parallelize(elems).withFanout(10).sum @@ -72,6 +73,51 @@ class SCollectionWithFanoutTest extends NamedTransformSpec { } } + it should "support min" in { + runWithContext { sc => + def min[T: Coder: Ordering](elems: T*): SCollection[T] = + sc.parallelize(elems).withFanout(10).min + min(1, 2, 3) should containSingleValue(1) + min(1L, 2L, 3L) should containSingleValue(1L) + min(1f, 2f, 3f) should containSingleValue(1f) + min(1.0, 2.0, 3.0) should containSingleValue(1.0) + min(1 to 100: _*) should containSingleValue(1) + } + } + + it should "support max" in { + runWithContext { sc => + def max[T: Coder: Ordering](elems: T*): SCollection[T] = + sc.parallelize(elems).withFanout(10).max + max(1, 2, 3) should containSingleValue(3) + max(1L, 2L, 3L) should containSingleValue(3L) + max(1f, 2f, 3f) should containSingleValue(3f) + max(1.0, 2.0, 3.0) should containSingleValue(3.0) + max(1 to 100: _*) should containSingleValue(100) + } + } + + it should "support mean" in { + runWithContext { sc => + def mean[T: Coder: Numeric](elems: T*): SCollection[Double] = + sc.parallelize(elems).withFanout(10).mean + mean(1, 2, 3) should containSingleValue(2.0) + mean(1L, 2L, 3L) should containSingleValue(2.0) + mean(1f, 2f, 3f) should containSingleValue(2.0) + mean(1.0, 2.0, 3.0) should containSingleValue(2.0) + mean(0 to 100: _*) should containSingleValue(50.0) + } + } + + it should "support latest" in { + runWithContext { sc => + def latest(elems: Long*): SCollection[Long] = + sc.parallelize(elems).timestampBy(Instant.ofEpochMilli).withFanout(10).latest + latest(1L, 2L, 3L) should containSingleValue(3L) + latest(1L to 100L: _*) should containSingleValue(100L) + } + } + it should "support top()" in { runWithContext { sc => def top3[T: Ordering: Coder](elems: T*): SCollection[Iterable[T]] = diff --git a/scio-core/src/test/scala/com/spotify/scio/values/SCollectionWithHotKeyFanoutTest.scala b/scio-core/src/test/scala/com/spotify/scio/values/SCollectionWithHotKeyFanoutTest.scala index 865a922f5c..f253939dd6 100644 --- a/scio-core/src/test/scala/com/spotify/scio/values/SCollectionWithHotKeyFanoutTest.scala +++ b/scio-core/src/test/scala/com/spotify/scio/values/SCollectionWithHotKeyFanoutTest.scala @@ -18,6 +18,7 @@ package com.spotify.scio.values import com.twitter.algebird.Aggregator +import org.joda.time.Instant class SCollectionWithHotKeyFanoutTest extends NamedTransformSpec { "SCollectionWithHotKeyFanout" should "support aggregateByKey()" in { @@ -83,7 +84,7 @@ class SCollectionWithHotKeyFanoutTest extends NamedTransformSpec { } } - it should "support sumByKey()" in { + it should "support sumByKey" in { runWithContext { sc => val p = sc.parallelize(List(("a", 1), ("b", 2), ("b", 2)) ++ (1 to 100).map(("c", _))) val r1 = p.withHotKeyFanout(10).sumByKey @@ -93,6 +94,48 @@ class SCollectionWithHotKeyFanoutTest extends NamedTransformSpec { } } + it should "support minByKey" in { + runWithContext { sc => + val p = sc.parallelize(List(("a", 1), ("b", 2), ("b", 3)) ++ (1 to 100).map(("c", _))) + val r1 = p.withHotKeyFanout(10).minByKey + val r2 = p.withHotKeyFanout(_.hashCode).minByKey + r1 should containInAnyOrder(Seq(("a", 1), ("b", 2), ("c", 1))) + r2 should containInAnyOrder(Seq(("a", 1), ("b", 2), ("c", 1))) + } + } + + it should "support maxByKey" in { + runWithContext { sc => + val p = sc.parallelize(List(("a", 1), ("b", 2), ("b", 3)) ++ (1 to 100).map(("c", _))) + val r1 = p.withHotKeyFanout(10).maxByKey + val r2 = p.withHotKeyFanout(_.hashCode).maxByKey + r1 should containInAnyOrder(Seq(("a", 1), ("b", 3), ("c", 100))) + r2 should containInAnyOrder(Seq(("a", 1), ("b", 3), ("c", 100))) + } + } + + it should "support meanByKey" in { + runWithContext { sc => + val p = sc.parallelize(List(("a", 1), ("b", 2), ("b", 3)) ++ (0 to 100).map(("c", _))) + val r1 = p.withHotKeyFanout(10).meanByKey + val r2 = p.withHotKeyFanout(_.hashCode).meanByKey + r1 should containInAnyOrder(Seq(("a", 1.0), ("b", 2.5), ("c", 50.0))) + r2 should containInAnyOrder(Seq(("a", 1.0), ("b", 2.5), ("c", 50.0))) + } + } + + it should "support latestByKey" in { + runWithContext { sc => + val p = sc + .parallelize(List(("a", 1L), ("b", 2L), ("b", 3L)) ++ (1L to 100L).map(("c", _))) + .timestampBy { case (_, v) => Instant.ofEpochMilli(v) } + val r1 = p.withHotKeyFanout(10).latestByKey + val r2 = p.withHotKeyFanout(_.hashCode).latestByKey + r1 should containInAnyOrder(Seq(("a", 1L), ("b", 3L), ("c", 100L))) + r2 should containInAnyOrder(Seq(("a", 1L), ("b", 3L), ("c", 100L))) + } + } + it should "support topByKey()" in { runWithContext { sc => val p = sc.parallelize(