Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent implementation for reduce operation #5498

Merged
merged 3 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ import com.spotify.scio.hash._
import com.spotify.scio.util._
import com.spotify.scio.util.random.{BernoulliValueSampler, PoissonValueSampler}
import com.twitter.algebird.{Aggregator, Monoid, MonoidAggregator, Semigroup}
import org.apache.beam.sdk.transforms.DoFn.{Element, OutputReceiver, ProcessElement, Timestamp}
import org.apache.beam.sdk.transforms._
import org.apache.beam.sdk.values.{KV, PCollection}
import org.joda.time.Duration
import org.joda.time.{Duration, Instant}
import org.slf4j.LoggerFactory

import java.lang.{Double => JDouble}

import scala.collection.compat._

private object PairSCollectionFunctions {
Expand Down Expand Up @@ -718,6 +721,23 @@ class PairSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
def distinctByKey: SCollection[(K, V)] =
self.distinctBy(_._1)

/**
* Convert values into pairs of (value, timestamp).
* @group transform
*/
def withTimestampedValues: SCollection[(K, (V, Instant))] =
self.parDo(new DoFn[(K, V), (K, (V, Instant))] {
@ProcessElement
private[scio] def processElement(
@Element element: (K, V),
@Timestamp timestamp: Instant,
out: OutputReceiver[(K, (V, Instant))]
): Unit = {
val (k, v) = element
out.output((k, (v, timestamp)))
}
})

/**
* Return a new SCollection of (key, value) pairs whose values satisfy the predicate.
* @group transform
Expand Down Expand Up @@ -992,16 +1012,6 @@ class PairSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
def mapValues[U: Coder](f: V => U): SCollection[(K, U)] =
self.map(kv => (kv._1, f(kv._2)))

/**
* Return the max of values for each key as defined by the implicit `Ordering[T]`.
* @return
* a new SCollection of (key, maximum value) pairs
* @group per_key
*/
// Scala lambda is simpler and more powerful than transforms.Max
def maxByKey(implicit ord: Ordering[V]): SCollection[(K, V)] =
this.reduceByKey(ord.max)

/**
* Return the min of values for each key as defined by the implicit `Ordering[T]`.
* @return
Expand All @@ -1012,6 +1022,16 @@ class PairSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
def minByKey(implicit ord: Ordering[V]): SCollection[(K, V)] =
this.reduceByKey(ord.min)

/**
* Return the max of values for each key as defined by the implicit `Ordering[T]`.
* @return
* a new SCollection of (key, maximum value) pairs
* @group per_key
*/
// Scala lambda is simpler and more powerful than transforms.Max
def maxByKey(implicit ord: Ordering[V]): SCollection[(K, V)] =
this.reduceByKey(ord.max)

/**
* Return latest of values for each key according to its event time, or null if there are no
* elements.
Expand All @@ -1022,6 +1042,30 @@ class PairSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
def latestByKey: SCollection[(K, V)] =
self.applyPerKey(Latest.perKey[K, V]())(kvToTuple)

/**
* Reduce by key with [[com.twitter.algebird.Semigroup Semigroup]]. This could be more powerful
* and better optimized than [[reduceByKey]] in some cases.
* @group per_key
*/
def sumByKey(implicit sg: Semigroup[V]): SCollection[(K, V)] = {
PairSCollectionFunctions.logger.warn(
"combineByKey/sumByKey does not support default value and may fail in some streaming " +
"scenarios. Consider aggregateByKey/foldByKey instead."
)
this.applyPerKey(Combine.perKey(Functions.reduceFn(context, sg)))(kvToTuple)
}

/**
* Return the mean of values for each key as defined by the implicit `Numeric[T]`.
* @return
* a new SCollection of (key, mean value) pairs
* @group per_key
*/
def meanByKey(implicit ev: Numeric[V]): SCollection[(K, Double)] =
self.transform { in =>
in.mapValues[JDouble](ev.toDouble).applyPerKey(Mean.perKey[K, JDouble]())(kdToTuple)
}

/**
* Merge the values for each key using an associative reduce function. This will also perform the
* merging locally on each mapper before sending results to a reducer, similarly to a "combiner"
Expand Down Expand Up @@ -1077,19 +1121,6 @@ class PairSCollectionFunctions[K, V](val self: SCollection[(K, V)]) {
}
}

/**
* Reduce by key with [[com.twitter.algebird.Semigroup Semigroup]]. This could be more powerful
* and better optimized than [[reduceByKey]] in some cases.
* @group per_key
*/
def sumByKey(implicit sg: Semigroup[V]): SCollection[(K, V)] = {
PairSCollectionFunctions.logger.warn(
"combineByKey/sumByKey does not support default value and may fail in some streaming " +
"scenarios. Consider aggregateByKey/foldByKey instead."
)
this.applyPerKey(Combine.perKey(Functions.reduceFn(context, sg)))(kvToTuple)
}

/**
* Swap the keys with the values.
* @group transform
Expand Down
67 changes: 33 additions & 34 deletions scio-core/src/main/scala/com/spotify/scio/values/SCollection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import org.apache.beam.sdk.util.{CoderUtils, SerializableUtils}
import org.apache.beam.sdk.values.WindowingStrategy.AccumulationMode
import org.apache.beam.sdk.values._
import org.apache.beam.sdk.{io => beam}
import org.joda.time.{Duration, Instant}
import org.joda.time.{Duration, Instant, ReadableInstant}
import org.slf4j.LoggerFactory

import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -773,6 +773,16 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] {
*/
def map[U: Coder](f: T => U): SCollection[U] = this.parDo(Functions.mapFn(f))

/**
* Return the min of this SCollection as defined by the implicit `Ordering[T]`.
* @return
* a new SCollection with the minimum element
* @group transform
*/
// Scala lambda is simpler and more powerful than transforms.Min
def min(implicit ord: Ordering[T]): SCollection[T] =
this.reduce(ord.min)

/**
* Return the max of this SCollection as defined by the implicit `Ordering[T]`.
* @return
Expand All @@ -784,38 +794,40 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] {
this.reduce(ord.max)

/**
* Return the mean of this SCollection as defined by the implicit `Numeric[T]`.
* Return the latest of this SCollection according to its event time.
* @return
* a new SCollection with the mean of elements
* a new SCollection with the latest element
* @group transform
*/
def mean(implicit ev: Numeric[T]): SCollection[Double] = this.transform { in =>
val e = ev // defeat closure
in.map(e.toDouble)
.asInstanceOf[SCollection[JDouble]]
.pApply(Mean.globally())
.asInstanceOf[SCollection[Double]]
}
def latest: SCollection[T] =
// widen to ReadableInstant for scala 2.12 implicit ordering
this.withTimestamp.max(Ordering.by(_._2: ReadableInstant)).keys

/**
* Return the min of this SCollection as defined by the implicit `Ordering[T]`.
* @return
* a new SCollection with the minimum element
* Reduce with [[com.twitter.algebird.Semigroup Semigroup]]. This could be more powerful and
* better optimized than [[reduce]] in some cases.
* @group transform
*/
// Scala lambda is simpler and more powerful than transforms.Min
def min(implicit ord: Ordering[T]): SCollection[T] =
this.reduce(ord.min)
def sum(implicit sg: Semigroup[T]): SCollection[T] = {
SCollection.logger.warn(
"combine/sum does not support default value and may fail in some streaming scenarios. " +
"Consider aggregate/fold instead."
)
this.pApply(Combine.globally(Functions.reduceFn(context, sg)).withoutDefaults())
}

/**
* Return the latest of this SCollection according to its event time, or null if there are no
* elements.
* Return the mean of this SCollection as defined by the implicit `Numeric[T]`.
* @return
* a new SCollection with the latest element
* a new SCollection with the mean of elements
* @group transform
*/
def latest: SCollection[T] =
this.pApply(Latest.globally())
def mean(implicit ev: Numeric[T]): SCollection[Double] = this.transform { in =>
val e = ev // defeat closure
in.map[JDouble](e.toDouble)
.pApply(Mean.globally().withoutDefaults())
.asInstanceOf[SCollection[Double]]
}

/**
* Compute the SCollection's data distribution using approximate `N`-tiles.
Expand Down Expand Up @@ -946,19 +958,6 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] {
_.map((_, ())).subtractByKey(that).keys
}

/**
* Reduce with [[com.twitter.algebird.Semigroup Semigroup]]. This could be more powerful and
* better optimized than [[reduce]] in some cases.
* @group transform
*/
def sum(implicit sg: Semigroup[T]): SCollection[T] = {
SCollection.logger.warn(
"combine/sum does not support default value and may fail in some streaming scenarios. " +
"Consider aggregate/fold instead."
)
this.pApply(Combine.globally(Functions.reduceFn(context, sg)).withoutDefaults())
}

/**
* Return a sampled subset of any `num` elements of the SCollection.
* @group transform
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ import com.spotify.scio.ScioContext
import com.spotify.scio.util.Functions
import com.spotify.scio.coders.Coder
import com.twitter.algebird.{Aggregator, Monoid, MonoidAggregator, Semigroup}
import org.apache.beam.sdk.transforms.{Combine, Latest, Mean, Reify, Top}
import org.apache.beam.sdk.transforms.{Combine, Mean, Top}
import org.apache.beam.sdk.values.PCollection
import org.joda.time.ReadableInstant

import java.lang.{Double => JDouble, Iterable => JIterable}
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -104,6 +105,24 @@ class SCollectionWithFanout[T] private[values] (coll: SCollection[T], fanout: In
Combine.globally(Functions.reduceFn(context, op)).withoutDefaults().withFanout(fanout)
)

/** [[SCollection.min]] with fan out. */
def min(implicit ord: Ordering[T]): SCollection[T] =
this.reduce(ord.min)

/** [[SCollection.max]] with fan out. */
def max(implicit ord: Ordering[T]): SCollection[T] =
this.reduce(ord.max)

/** [[SCollection.latest]] with fan out. */
def latest: SCollection[T] = {
coll.transform { in =>
// widen to ReadableInstant for scala 2.12 implicit ordering
new SCollectionWithFanout(in.withTimestamp, this.fanout)
.max(Ordering.by(_._2: ReadableInstant))
.keys
}
}

/** [[SCollection.sum]] with fan out. */
def sum(implicit sg: Semigroup[T]): SCollection[T] = {
SCollection.logger.warn(
Expand All @@ -115,14 +134,6 @@ class SCollectionWithFanout[T] private[values] (coll: SCollection[T], fanout: In
)
}

/** [[SCollection.min]] with fan out. */
def min(implicit ord: Ordering[T]): SCollection[T] =
this.reduce(ord.min)

/** [[SCollection.max]] with fan out. */
def max(implicit ord: Ordering[T]): SCollection[T] =
this.reduce(ord.max)

/** [[SCollection.mean]] with fan out. */
def mean(implicit ev: Numeric[T]): SCollection[Double] = {
val e = ev // defeat closure
Expand All @@ -133,14 +144,6 @@ class SCollectionWithFanout[T] private[values] (coll: SCollection[T], fanout: In
}
}

/** [[SCollection.latest]] with fan out. */
def latest: SCollection[T] = {
coll.transform { in =>
in.pApply("Reify Timestamps", Reify.timestamps[T]())
.pApply("Latest Value", Combine.globally(Latest.combineFn[T]()).withFanout(fanout))
}
}

/** [[SCollection.top]] with fan out. */
def top(num: Int)(implicit ord: Ordering[T]): SCollection[Iterable[T]] = {
coll.transform { in =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,8 @@ import com.spotify.scio.util.TupleFunctions._
import com.twitter.algebird.{Aggregator, Monoid, MonoidAggregator, Semigroup}
import org.apache.beam.sdk.transforms.Combine.PerKeyWithHotKeyFanout
import org.apache.beam.sdk.transforms.Top.TopCombineFn
import org.apache.beam.sdk.transforms.{
Combine,
Latest,
Mean,
PTransform,
Reify,
SerializableFunction
}
import org.apache.beam.sdk.values.{KV, PCollection}
import org.apache.beam.sdk.transforms.{Combine, Mean, SerializableFunction}
import org.joda.time.ReadableInstant

import java.lang.{Double => JDouble}

Expand Down Expand Up @@ -143,15 +136,6 @@ class SCollectionWithHotKeyFanout[K, V] private[values] (
def reduceByKey(op: (V, V) => V): SCollection[(K, V)] =
self.applyPerKey(withFanout(Combine.perKey(Functions.reduceFn(context, op))))(kvToTuple)

/** [[PairSCollectionFunctions.sumByKey]] with hot key fanout. */
def sumByKey(implicit sg: Semigroup[V]): SCollection[(K, V)] = {
SCollection.logger.warn(
"combineByKey/sumByKey does not support default value and may fail in some streaming " +
"scenarios. Consider aggregateByKey/foldByKey instead."
)
self.applyPerKey(withFanout(Combine.perKey(Functions.reduceFn(context, sg))))(kvToTuple)
}

/** [[SCollection.min]] with hot key fan out. */
def minByKey(implicit ord: Ordering[V]): SCollection[(K, V)] =
self.reduceByKey(ord.min)
Expand All @@ -160,6 +144,16 @@ class SCollectionWithHotKeyFanout[K, V] private[values] (
def maxByKey(implicit ord: Ordering[V]): SCollection[(K, V)] =
self.reduceByKey(ord.max)

/** [[SCollection.latest]] with hot key fan out. */
def latestByKey: SCollection[(K, V)] = {
self.self.transform { in =>
new SCollectionWithHotKeyFanout(in.withTimestampedValues, this.hotKeyFanout)
// widen to ReadableInstant for scala 2.12 implicit ordering
.maxByKey(Ordering.by(_._2: ReadableInstant))
.mapValues(_._1)
}
}

/** [[SCollection.mean]] with hot key fan out. */
def meanByKey(implicit ev: Numeric[V]): SCollection[(K, Double)] = {
val e = ev // defeat closure
Expand All @@ -168,16 +162,13 @@ class SCollectionWithHotKeyFanout[K, V] private[values] (
}
}

/** [[SCollection.latest]] with hot key fan out. */
def latestByKey: SCollection[(K, V)] = {
self.applyPerKey(new PTransform[PCollection[KV[K, V]], PCollection[KV[K, V]]]() {
override def expand(input: PCollection[KV[K, V]]): PCollection[KV[K, V]] = {
input
.apply("Reify Timestamps", Reify.timestampsInValue[K, V])
.apply("Latest Value", withFanout(Combine.perKey(Latest.combineFn[V]())))
.setCoder(input.getCoder)
}
})(kvToTuple)
/** [[PairSCollectionFunctions.sumByKey]] with hot key fanout. */
def sumByKey(implicit sg: Semigroup[V]): SCollection[(K, V)] = {
SCollection.logger.warn(
"combineByKey/sumByKey does not support default value and may fail in some streaming " +
"scenarios. Consider aggregateByKey/foldByKey instead."
)
self.applyPerKey(withFanout(Combine.perKey(Functions.reduceFn(context, sg))))(kvToTuple)
}

/** [[PairSCollectionFunctions.topByKey]] with hot key fanout. */
Expand Down
Loading