Skip to content

Commit

Permalink
Change future return type for BigtableDoFn
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Aug 14, 2024
1 parent 9617337 commit f7140cf
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package com.spotify.scio.transforms;

import com.google.api.core.ApiFuture;
import com.google.api.core.ApiFutureCallback;
import com.google.api.core.ApiFutures;
import com.google.api.core.SettableApiFuture;
import com.google.common.util.concurrent.*;
import java.util.concurrent.*;
import java.util.function.Function;
Expand Down Expand Up @@ -137,4 +141,78 @@ default CompletableFuture<V> addCallback(
});
}
}

/**
* A {@link Base} implementation for Google API {@link ApiFuture}. Similar to Guava's
* ListenableFuture, but redeclared so that Guava could be shaded.
*/
public interface GoogleApi<V> extends Base<ApiFuture<V>, V> {
/**
* Executor used for callbacks. Default is {@link ForkJoinPool#commonPool()}. Consider
* overriding this method if callbacks are blocking.
*
* @return Executor for callbacks.
*/
default Executor getCallbackExecutor() {
return ForkJoinPool.commonPool();
}

@Override
default void waitForFutures(Iterable<ApiFuture<V>> futures)
throws InterruptedException, ExecutionException {
// use Future#successfulAsList instead of Futures#allAsList which only works if all
// futures succeed
ApiFutures.successfulAsList(futures).get();
}

@Override
default ApiFuture<V> addCallback(
ApiFuture<V> future, Function<V, Void> onSuccess, Function<Throwable, Void> onFailure) {
// Futures#transform doesn't allow onFailure callback while Futures#addCallback doesn't
// guarantee that callbacks are called before ListenableFuture#get() unblocks
SettableApiFuture<V> f = SettableApiFuture.create();
// if executor rejects the callback, we have to fail the future
Executor rejectPropagationExecutor =
command -> {
try {
getCallbackExecutor().execute(command);
} catch (RejectedExecutionException e) {
f.setException(e);
}
};
ApiFutures.addCallback(
future,
new ApiFutureCallback<V>() {
@Override
public void onSuccess(@Nullable V result) {
try {
onSuccess.apply(result);
f.set(result);
} catch (Throwable e) {
f.setException(e);
}
}

@Override
public void onFailure(Throwable t) {
Throwable callbackException = null;
try {
onFailure.apply(t);
} catch (Throwable e) {
// do not fail executing thread if callback fails
// record exception and propagate as suppressed
callbackException = e;
} finally {
if (callbackException != null) {
t.addSuppressed(callbackException);
}
f.setException(t);
}
}
},
rejectPropagationExecutor);

return f;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package com.spotify.scio.bigtable;

import com.google.api.core.ApiFuture;
import com.google.cloud.bigtable.data.v2.BigtableDataClient;
import com.google.cloud.bigtable.data.v2.BigtableDataSettings;
import com.google.common.util.concurrent.ListenableFuture;
import com.spotify.scio.transforms.BaseAsyncLookupDoFn;
import com.spotify.scio.transforms.GuavaAsyncLookupDoFn;
import com.spotify.scio.transforms.FutureHandlers;
import java.io.IOException;
import java.util.function.Supplier;
import org.apache.beam.sdk.transforms.DoFn;
Expand All @@ -32,12 +32,14 @@
* @param <A> input element type.
* @param <B> Bigtable lookup value type.
*/
public abstract class BigtableDoFn<A, B> extends GuavaAsyncLookupDoFn<A, B, BigtableDataClient> {
public abstract class BigtableDoFn<A, B>
extends BaseAsyncLookupDoFn<A, B, BigtableDataClient, ApiFuture<B>, BaseAsyncLookupDoFn.Try<B>>
implements FutureHandlers.GoogleApi<B> {

private final Supplier<BigtableDataSettings> settingsSupplier;

/** Perform asynchronous Bigtable lookup. */
public abstract ListenableFuture<B> asyncLookup(BigtableDataClient client, A input);
public abstract ApiFuture<B> asyncLookup(BigtableDataClient client, A input);

/**
* Create a {@link BigtableDoFn} instance.
Expand Down Expand Up @@ -107,4 +109,14 @@ protected BigtableDataClient newClient() {
throw new RuntimeException(e);
}
}

@Override
public BaseAsyncLookupDoFn.Try<B> success(B output) {
return new Try<>(output);
}

@Override
public BaseAsyncLookupDoFn.Try<B> failure(Throwable throwable) {
return new Try<>(throwable);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

package com.spotify.scio.bigtable

import com.google.api.core.{ApiFuture, ApiFutures}
import com.google.cloud.bigtable.data.v2.BigtableDataClient
import com.google.common.cache.{Cache, CacheBuilder}
import com.google.common.util.concurrent.{Futures, ListenableFuture}
import com.spotify.scio.testing._
import com.spotify.scio.transforms.BaseAsyncLookupDoFn.CacheSupplier
import com.spotify.scio.transforms.JavaAsyncConverters._
Expand Down Expand Up @@ -69,25 +69,25 @@ object BigtableDoFnTest {

class TestBigtableDoFn extends BigtableDoFn[Int, String](null) {
override def newClient(): BigtableDataClient = null
override def asyncLookup(client: BigtableDataClient, input: Int): ListenableFuture[String] =
Futures.immediateFuture(input.toString)
override def asyncLookup(client: BigtableDataClient, input: Int): ApiFuture[String] =
ApiFutures.immediateFuture(input.toString)
}

class TestCachingBigtableDoFn extends BigtableDoFn[Int, String](null, 100, new TestCacheSupplier) {
override def newClient(): BigtableDataClient = null
override def asyncLookup(client: BigtableDataClient, input: Int): ListenableFuture[String] = {
override def asyncLookup(client: BigtableDataClient, input: Int): ApiFuture[String] = {
BigtableDoFnTest.queue.add(input)
Futures.immediateFuture(input.toString)
ApiFutures.immediateFuture(input.toString)
}
}

class TestFailingBigtableDoFn extends BigtableDoFn[Int, String](null) {
override def newClient(): BigtableDataClient = null
override def asyncLookup(client: BigtableDataClient, input: Int): ListenableFuture[String] =
override def asyncLookup(client: BigtableDataClient, input: Int): ApiFuture[String] =
if (input % 2 == 0) {
Futures.immediateFuture("success" + input)
ApiFutures.immediateFuture("success" + input)
} else {
Futures.immediateFailedFuture(new RuntimeException("failure" + input))
ApiFutures.immediateFailedFuture(new RuntimeException("failure" + input))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ package com.spotify.scio.coders.instance.kryo
import com.google.api.gax.grpc.GrpcStatusCode
import com.google.api.gax.rpc.InternalException
import com.google.cloud.bigtable.data.v2.models.MutateRowsException
import com.google.cloud.bigtable.grpc.scanner.BigtableRetriesExhaustedException
import com.spotify.scio.coders.instances.kryo.GrpcSerializerTest._
import io.grpc.Status.Code
import io.grpc.{Metadata, Status, StatusRuntimeException}
import io.grpc.{Status, StatusRuntimeException}
import org.scalactic.Equality
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
Expand All @@ -31,13 +30,6 @@ import scala.jdk.CollectionConverters._

object GcpSerializerTest {

implicit val eqBigtableRetriesExhaustedException: Equality[BigtableRetriesExhaustedException] = {
case (a: BigtableRetriesExhaustedException, b: BigtableRetriesExhaustedException) =>
a.getMessage == b.getMessage &&
eqCause.areEqual(a.getCause, b.getCause)
case _ => false
}

implicit val eqMutateRowsException: Equality[MutateRowsException] = {
case (a: MutateRowsException, b: MutateRowsException) =>
eqCause.areEqual(a.getCause, b.getCause) &&
Expand All @@ -59,16 +51,6 @@ class GcpSerializerTest extends AnyFlatSpec with Matchers {
import GcpSerializerTest._
import com.spotify.scio.testing.CoderAssertions._

"BigtableRetriesExhaustedException" should "roundtrip" in {
val metadata = new Metadata()
metadata.put(Metadata.Key.of[String]("k", Metadata.ASCII_STRING_MARSHALLER), "v")
val cause = new StatusRuntimeException(
Status.OK.withCause(new RuntimeException("bar")).withDescription("bar"),
metadata
)
new BigtableRetriesExhaustedException("Error", cause) coderShould roundtrip()
}

"MutateRowsExceptionSerializer" should "roundtrip" in {
val cause = new StatusRuntimeException(Status.OK)
val code = GrpcStatusCode.of(Status.OK.getCode)
Expand Down

0 comments on commit f7140cf

Please sign in to comment.