Skip to content

Commit

Permalink
Refactor BQ to expose all beam's configurations
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Aug 19, 2024
1 parent 27fd3ca commit d64e382
Show file tree
Hide file tree
Showing 40 changed files with 832 additions and 1,141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class BigQueryClientIT extends AnyFlatSpec with Matchers {

"TableService.getRows" should "work" in {
val rows =
bq.tables.rows(Table.Spec("bigquery-public-data:samples.shakespeare")).take(10).toList
bq.tables.rows(Table("bigquery-public-data:samples.shakespeare")).take(10).toList
val columns = Set("word", "word_count", "corpus", "corpus_date")
all(rows.map(_.keySet().asScala)) shouldBe columns
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class BigQueryIOIT extends PipelineSpec {

"Select" should "read typed values from a SQL query" in
runWithRealContext(options) { sc =>
val scoll = sc.read(BigQueryTyped[ShakespeareFromQuery])
val scoll = sc.typedBigQueryStorage[ShakespeareFromQuery]()
scoll should haveSize(10)
scoll should satisfy[ShakespeareFromQuery] {
_.forall(_.getClass == classOf[ShakespeareFromQuery])
Expand All @@ -54,7 +54,7 @@ class BigQueryIOIT extends PipelineSpec {

"TableRef" should "read typed values from table" in
runWithRealContext(options) { sc =>
val scoll = sc.read(BigQueryTyped[ShakespeareFromTable])
val scoll = sc.typedBigQueryStorage[ShakespeareFromTable]()
scoll.take(10) should haveSize(10)
scoll should satisfy[ShakespeareFromTable] {
_.forall(_.getClass == classOf[ShakespeareFromTable])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package com.spotify.scio.bigquery
import com.google.protobuf.ByteString
import com.spotify.scio._
import com.spotify.scio.avro._
import com.spotify.scio.bigquery.BigQueryTypedTable.Format
import com.spotify.scio.bigquery.client.BigQuery
import com.spotify.scio.testing._
import magnolify.scalacheck.auto._
Expand Down Expand Up @@ -69,7 +68,7 @@ object TypedBigQueryIT {
val now = Instant.now().toString(TIME_FORMATTER)
val spec =
s"data-integration-test:bigquery_avro_it.$name${now}_${Random.nextInt(Int.MaxValue)}"
Table.Spec(spec)
Table(spec)
}
private val tableRowTable = table("records_tablerow")
private val avroTable = table("records_avro")
Expand Down Expand Up @@ -101,37 +100,25 @@ class TypedBigQueryIT extends PipelineSpec with BeforeAndAfterAll {
BigQuery.defaultInstance().tables.delete(avroLogicalTypeTable.ref)
}

"TypedBigQuery" should "read records" in {
"typedBigQuery" should "read records" in {
val sc = ScioContext(options)
sc.typedBigQuery[Record](tableRowTable) should containInAnyOrder(records)
sc.run()
}

it should "convert to avro format" in {
"bigQueryTableFormat" should "read TableRow records" in {
val sc = ScioContext(options)
implicit val coder = avroGenericRecordCoder(Record.avroSchema)
sc.typedBigQuery[Record](tableRowTable)
.map(Record.toAvro)
.map(Record.fromAvro) should containInAnyOrder(
records
)
val format = BigQueryIO.Format.Default(BigQueryType[Record])
val data = sc.bigQueryTableFormat(tableRowTable, format)
data should containInAnyOrder(records)
sc.run()
}

"BigQueryTypedTable" should "read TableRow records" in {
it should "read GenericRecord records" in {
val sc = ScioContext(options)
sc
.bigQueryTable(tableRowTable)
.map(Record.fromTableRow) should containInAnyOrder(records)
sc.run()
}

it should "read GenericRecord recors" in {
val sc = ScioContext(options)
implicit val coder = avroGenericRecordCoder(Record.avroSchema)
sc
.bigQueryTable(tableRowTable, Format.GenericRecord)
.map(Record.fromAvro) should containInAnyOrder(records)
val format = BigQueryIO.Format.Avro(BigQueryType[Record])
val data = sc.bigQueryTableFormat(tableRowTable, format)
data should containInAnyOrder(records)
sc.run()
}

Expand All @@ -157,7 +144,7 @@ class TypedBigQueryIT extends PipelineSpec with BeforeAndAfterAll {
|}
""".stripMargin)
val tap = sc
.bigQueryTable(tableRowTable, Format.GenericRecord)
.bigQueryTableFormat(tableRowTable, BigQueryIO.Format.Avro())
.saveAsBigQueryTable(avroTable, schema = schema, createDisposition = CREATE_IF_NEEDED)

val result = sc.run().waitUntilDone()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ class BigQueryStorageIT extends AnyFlatSpec with Matchers {
val (sc, _) = ContextAndArgs(
Array("--project=data-integration-test", "--tempLocation=gs://data-integration-test-eu/temp")
)
val bqt = BigQueryType[NestedWithRestriction]
val source = Table(bqt.table.get, "required.int < 3")
val p = sc
.typedBigQueryStorage[NestedWithRestriction](rowRestriction = "required.int < 3")
.typedBigQueryStorage[NestedWithRestriction](source)
.map { r =>
val (req, opt, rep) = (r.required, r.optional.get, r.repeated.head)
(req.int, req.string, opt.int, opt.string, rep.int, rep.string)
Expand All @@ -172,7 +174,7 @@ class BigQueryStorageIT extends AnyFlatSpec with Matchers {
Array("--project=data-integration-test", "--tempLocation=gs://data-integration-test-eu/temp")
)
val p = sc
.typedBigQuery[NestedWithAll](Table.Spec(NestedWithAll.table.format("nested")))
.typedBigQuery[NestedWithAll](Table(NestedWithAll.table.format("nested")))
.map(r => (r.required.int, r.required.string, r.optional.get.int))
.internal
PAssert.that(p).containsInAnyOrder(expected)
Expand Down Expand Up @@ -243,7 +245,7 @@ class BigQueryStorageIT extends AnyFlatSpec with Matchers {
Array("--project=data-integration-test", "--tempLocation=gs://data-integration-test-eu/temp")
)
val p = sc
.typedBigQueryStorage[ToTableRequired](Table.Spec("data-integration-test:storage.required"))
.typedBigQueryStorage[ToTableRequired](Table("data-integration-test:storage.required"))
.internal
PAssert.that(p).containsInAnyOrder(expected)
sc.run()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,9 @@ class BigQueryTypeIT extends AnyFlatSpec with Matchers {
tableReference.setProjectId("data-integration-test")
tableReference.setDatasetId("partition_a")
tableReference.setTableId("table_$LATEST")
Table.Ref(tableReference).latest().ref.getTableId shouldBe "table_20170302"
Table(tableReference).latest().ref.getTableId shouldBe "table_20170302"

Table
.Spec("data-integration-test:partition_a.table_$LATEST")
Table("data-integration-test:partition_a.table_$LATEST")
.latest()
.ref
.getTableId shouldBe "table_20170302"
Expand All @@ -210,7 +209,7 @@ class BigQueryTypeIT extends AnyFlatSpec with Matchers {
val bqt = BigQueryType[FromTableT]
bqt.isQuery shouldBe false
bqt.isTable shouldBe true
bqt.query shouldBe None
bqt.queryRaw shouldBe None
bqt.table shouldBe Some("bigquery-public-data:samples.shakespeare")
val fields = bqt.schema.getFields.asScala
fields.size shouldBe 4
Expand Down
23 changes: 23 additions & 0 deletions scio-core/src/main/scala/com/spotify/scio/ScioContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ import scala.reflect.ClassTag
import scala.util.control.NoStackTrace
import scala.util.{Failure, Success, Try}
import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions
import org.apache.beam.sdk.transforms.errorhandling.BadRecord
import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.BadRecordErrorHandler

/** Runner specific context. */
trait RunnerContext {
Expand Down Expand Up @@ -851,6 +853,27 @@ class ScioContext private[scio] (
this.applyTransform(Create.timestamped(v.asJava).withCoder(coder))
}

// =======================================================================
// Error handler
// =======================================================================
def registerBadRecordErrorHandler[O <: POutput](
sinkTransform: PTransform[PCollection[BadRecord], O]
): BadRecordErrorHandler[O] =
pipeline.registerBadRecordErrorHandler(sinkTransform)

def badRecordErrorHandler(): (BadRecordErrorHandler[PCollectionTuple], SCollection[BadRecord]) = {
val tag = new TupleTag[BadRecord]()
val sideOutput = PCollectionTuple.empty(pipeline)
val sinkTransform = new PTransform[PCollection[BadRecord], PCollectionTuple] {
override def expand(input: PCollection[BadRecord]): PCollectionTuple =
sideOutput.and(tag, input)
}

val handler = pipeline.registerBadRecordErrorHandler(sinkTransform)
val errorOutput = wrap(sideOutput.get(tag))
(handler, errorOutput)
}

// =======================================================================
// Metrics
// =======================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package com.spotify.scio.coders.instances
import com.google.api.client.json.GenericJson
import com.google.api.client.json.JsonObjectParser
import com.google.api.client.json.gson.GsonFactory
import com.spotify.scio.ScioContext
import com.spotify.scio.coders.{Coder, CoderGrammar}
import com.spotify.scio.util.ScioUtil

Expand All @@ -29,6 +30,7 @@ import org.apache.beam.sdk.io.FileIO.ReadableFile
import org.apache.beam.sdk.io.fs.{MatchResult, MetadataCoderV2, ResourceId, ResourceIdCoder}
import org.apache.beam.sdk.io.ReadableFileCoder
import org.apache.beam.sdk.schemas.{Schema => BSchema}
import org.apache.beam.sdk.transforms.errorhandling.BadRecord
import org.apache.beam.sdk.transforms.windowing.{
BoundedWindow,
GlobalWindow,
Expand Down Expand Up @@ -66,6 +68,8 @@ trait BeamTypeCoders extends CoderGrammar {
str => DefaultJsonObjectParser.parseAndClose(new StringReader(str), ScioUtil.classOf[T]),
DefaultJsonObjectParser.getJsonFactory().toString(_)
)

def badRecordCoder(sc: ScioContext): Coder[BadRecord] = beam(BadRecord.getCoder(sc.pipeline))
}

private[coders] object BeamTypeCoders extends BeamTypeCoders {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ object AutoComplete {
if (outputToBigqueryTable) {
tags
.map(kv => Record(kv._1, kv._2.map(p => Tag(p._1, p._2)).toList))
.saveAsTypedBigQueryTable(Table.Spec(args("output")))
.saveAsTypedBigQueryTable(Table(args("output")))
}
if (outputToDatastore) {
val kind = args.getOrElse("kind", "autocomplete-demo")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ object StreamingWordExtract {
.flatMap(_.split("[^a-zA-Z']+").filter(_.nonEmpty))
.map(_.toUpperCase)
.map(s => TableRow("string_field" -> s))
.saveAsBigQueryTable(Table.Spec(args("output")), schema)
.saveAsBigQueryTable(Table(args("output")), schema)

val result = sc.run()
exampleUtils.waitToFinish(result.pipelineResult)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ object TrafficMaxLaneFlow {
ts
)
}
.saveAsTypedBigQueryTable(Table.Spec(args("output")))
.saveAsTypedBigQueryTable(Table(args("output")))

val result = sc.run()
exampleUtils.waitToFinish(result.pipelineResult)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ object TrafficRoutes {
.map { case (r, ts) =>
Record(r.route, r.avgSpeed, r.slowdownEvent, ts)
}
.saveAsTypedBigQueryTable(Table.Spec(args("output")))
.saveAsTypedBigQueryTable(Table(args("output")))

val result = sc.run()
exampleUtils.waitToFinish(result.pipelineResult)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ object GameStats {
// Done using windowing information, convert back to regular `SCollection`
.toSCollection
// Save to the BigQuery table defined by "output" in the arguments passed in + "_team" suffix
.saveAsTypedBigQueryTable(Table.Spec(args("output") + "_team"))
.saveAsTypedBigQueryTable(Table(args("output") + "_team"))

userEvents
// Window over a variable length of time - sessions end after sessionGap minutes no activity
Expand Down Expand Up @@ -141,7 +141,7 @@ object GameStats {
AvgSessionLength(mean, fmt.print(w.start()))
}
// Save to the BigQuery table defined by "output" + "_sessions" suffix
.saveAsTypedBigQueryTable(Table.Spec(args("output") + "_sessions"))
.saveAsTypedBigQueryTable(Table(args("output") + "_sessions"))

// Execute the pipeline
val result = sc.run()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ object HourlyTeamScore {
TeamScoreSums(team, score, start)
}
// Save to the BigQuery table defined by "output" in the arguments passed in
.saveAsTypedBigQueryTable(Table.Spec(args("output")))
.saveAsTypedBigQueryTable(Table(args("output")))

// Execute the pipeline
sc.run()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ object LeaderBoard {
// Done with windowing information, convert back to regular `SCollection`
.toSCollection
// Save to the BigQuery table defined by "output" in the arguments passed in + "_team" suffix
.saveAsTypedBigQueryTable(Table.Spec(args("output") + "_team"))
.saveAsTypedBigQueryTable(Table(args("output") + "_team"))

gameEvents
// Use a global window for unbounded data, which updates calculation every 10 minutes,
Expand Down Expand Up @@ -126,7 +126,7 @@ object LeaderBoard {
// Map summed results from tuples into `UserScoreSums` case class, so we can save to BQ
.map(kv => UserScoreSums(kv._1, kv._2, fmt.print(Instant.now())))
// Save to the BigQuery table defined by "output" in the arguments passed in + "_user" suffix
.saveAsTypedBigQueryTable(Table.Spec(args("output") + "_user"))
.saveAsTypedBigQueryTable(Table(args("output") + "_user"))

// Execute the pipeline
val result = sc.run()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ object UserScore {
// Map summed results from tuples into `UserScoreSums` case class, so we can save to BQ
.map(UserScoreSums.tupled)
// Save to the BigQuery table defined by "output" in the arguments passed in
.saveAsTypedBigQueryTable(Table.Spec(args("output")))
.saveAsTypedBigQueryTable(Table(args("output")))

// Execute the pipeline
sc.run()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object BigQueryTornadoes {
)

// Open a BigQuery table as a `SCollection[TableRow]`
val table = Table.Spec(args.getOrElse("input", ExampleData.WEATHER_SAMPLES_TABLE))
val table = Table(args.getOrElse("input", ExampleData.WEATHER_SAMPLES_TABLE))
val resultTap = sc
.bigQueryTable(table)
// Extract months with tornadoes
Expand All @@ -55,7 +55,7 @@ object BigQueryTornadoes {
// Map `(Long, Long)` tuples into result `TableRow`s
.map(kv => TableRow("month" -> kv._1, "tornado_count" -> kv._2))
// Save result as a BigQuery table
.saveAsBigQueryTable(Table.Spec(args("output")), schema, WRITE_TRUNCATE, CREATE_IF_NEEDED)
.saveAsBigQueryTable(Table(args("output")), schema, WRITE_TRUNCATE, CREATE_IF_NEEDED)

// Access the loaded tables
resultTap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object CombinePerKeyExamples {
)

// Open a BigQuery table as a `SCollection[TableRow]`
val table = Table.Spec(args.getOrElse("input", ExampleData.SHAKESPEARE_TABLE))
val table = Table(args.getOrElse("input", ExampleData.SHAKESPEARE_TABLE))
sc.bigQueryTable(table)
// Extract words and corresponding play names
.flatMap { row =>
Expand All @@ -64,7 +64,7 @@ object CombinePerKeyExamples {
// Map `(String, String)` tuples into result `TableRow`s
.map(kv => TableRow("word" -> kv._1, "all_plays" -> kv._2))
// Save result as a BigQuery table
.saveAsBigQueryTable(Table.Spec(args("output")), schema, WRITE_TRUNCATE, CREATE_IF_NEEDED)
.saveAsBigQueryTable(Table(args("output")), schema, WRITE_TRUNCATE, CREATE_IF_NEEDED)

// Execute the pipeline
sc.run()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ object DistinctByKeyExample {
)

// Open a BigQuery table as a `SCollection[TableRow]`
val table = Table.Spec(args.getOrElse("input", ExampleData.SHAKESPEARE_TABLE))
val table = Table(args.getOrElse("input", ExampleData.SHAKESPEARE_TABLE))
sc.bigQueryTable(table)
// Extract words and corresponding play names
.flatMap { row =>
Expand All @@ -59,7 +59,7 @@ object DistinctByKeyExample {
// Map `(String, String)` tuples into result `TableRow`s
.map(kv => TableRow("word" -> kv._1, "reference_play" -> kv._2))
// Save result as a BigQuery table
.saveAsBigQueryTable(Table.Spec(args("output")), schema, WRITE_TRUNCATE, CREATE_IF_NEEDED)
.saveAsBigQueryTable(Table(args("output")), schema, WRITE_TRUNCATE, CREATE_IF_NEEDED)

// Execute the pipeline
sc.run()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ object FilterExamples {
val monthFilter = args.int("monthFilter", 7)

// Open BigQuery table as a `SCollection[TableRow]`
val table = Table.Spec(args.getOrElse("input", ExampleData.WEATHER_SAMPLES_TABLE))
val table = Table(args.getOrElse("input", ExampleData.WEATHER_SAMPLES_TABLE))
val pipe = sc
.bigQueryTable(table)
// Map `TableRow`s into `Record`s
Expand Down Expand Up @@ -81,7 +81,7 @@ object FilterExamples {
TableRow("year" -> r.year, "month" -> r.month, "day" -> r.day, "mean_temp" -> r.meanTemp)
}
// Save result as a BigQuery table
.saveAsBigQueryTable(Table.Spec(args("output")), schema, WRITE_TRUNCATE, CREATE_IF_NEEDED)
.saveAsBigQueryTable(Table(args("output")), schema, WRITE_TRUNCATE, CREATE_IF_NEEDED)

// Execute the pipeline
sc.run()
Expand Down
Loading

0 comments on commit d64e382

Please sign in to comment.