From da05566664c622da0114df9984b6b340b9535d0d Mon Sep 17 00:00:00 2001 From: Claire McGinty Date: Thu, 2 May 2024 13:48:23 +0100 Subject: [PATCH] Update FixAvroCoder to match transformed Avro SCollections (#5351) --- .../scala/fix/v0_14_0/FixAvroCoder16.scala | 13 +++++++++++ .../scala/fix/v0_14_0/FixAvroCoder16.scala | 11 +++++++++ .../main/scala/fix/v0_14_0/FixAvroCoder.scala | 23 +++++++++++++++++-- 3 files changed, 45 insertions(+), 2 deletions(-) create mode 100644 scalafix/input-0_14/src/main/scala/fix/v0_14_0/FixAvroCoder16.scala create mode 100644 scalafix/output-0_14/src/main/scala/fix/v0_14_0/FixAvroCoder16.scala diff --git a/scalafix/input-0_14/src/main/scala/fix/v0_14_0/FixAvroCoder16.scala b/scalafix/input-0_14/src/main/scala/fix/v0_14_0/FixAvroCoder16.scala new file mode 100644 index 0000000000..e50ca9fac9 --- /dev/null +++ b/scalafix/input-0_14/src/main/scala/fix/v0_14_0/FixAvroCoder16.scala @@ -0,0 +1,13 @@ +/* +rule = FixAvroCoder + */ +package fix.v0_14_0 + +import com.spotify.scio.values.SCollection + +object FixAvroCoder16 { + + def someMethod(data: SCollection[A]): SCollection[(String, A)] = { + data.map(r => ("foo", r)) + } +} diff --git a/scalafix/output-0_14/src/main/scala/fix/v0_14_0/FixAvroCoder16.scala b/scalafix/output-0_14/src/main/scala/fix/v0_14_0/FixAvroCoder16.scala new file mode 100644 index 0000000000..f1ff5813bf --- /dev/null +++ b/scalafix/output-0_14/src/main/scala/fix/v0_14_0/FixAvroCoder16.scala @@ -0,0 +1,11 @@ +package fix.v0_14_0 + +import com.spotify.scio.values.SCollection +import com.spotify.scio.avro._ + +object FixAvroCoder16 { + + def someMethod(data: SCollection[A]): SCollection[(String, A)] = { + data.map(r => ("foo", r)) + } +} diff --git a/scalafix/rules/src/main/scala/fix/v0_14_0/FixAvroCoder.scala b/scalafix/rules/src/main/scala/fix/v0_14_0/FixAvroCoder.scala index f009b87c22..2881b0b4e2 100644 --- a/scalafix/rules/src/main/scala/fix/v0_14_0/FixAvroCoder.scala +++ b/scalafix/rules/src/main/scala/fix/v0_14_0/FixAvroCoder.scala @@ -42,6 +42,9 @@ object FixAvroCoder { val ParallelizeMatcher = SymbolMatcher.normalized( "com/spotify/scio/ScioContext#parallelize()." ) + val SCollectionMatcher = SymbolMatcher.normalized( + "com/spotify/scio/values/SCollection#" + ) /** @return true if `sym` is a class whose parents include a type matching `parentMatcher` */ def hasParentClass(sym: Symbol, parentMatcher: SymbolMatcher)(implicit @@ -162,9 +165,23 @@ object FixAvroCoder { coderT } .flatMap(_.info.map(_.signature).toList) - .exists { case TypeSignature(_, _, TypeRef(_, maybeAvroType, _)) => - AvroMatcher.matches(maybeAvroType) + .exists { + case TypeSignature(_, _, TypeRef(_, maybeAvroType, _)) => + AvroMatcher.matches(maybeAvroType) + case _ => false } + + def methodReturnsAvroSCollection( + returnType: Option[Type] + )(implicit doc: SemanticDocument): Boolean = returnType match { + case Some(t"$tpe[..$tpesnel]") if SCollectionMatcher.matches(tpe) => + tpesnel.exists { + case tpe if isAvroType(tpe.symbol) => true + case t"(..$tupleTypes)" if tupleTypes.exists(tt => isAvroType(tt.symbol)) => true + case _ => false + } + case _ => false + } } class FixAvroCoder extends SemanticRule("FixAvroCoder") { @@ -209,6 +226,8 @@ class FixAvroCoder extends SemanticRule("FixAvroCoder") { case _ => false } case q"$fn(..$args)" if methodHasAvroCoderTypeBound(fn) => true + case q"..$mods def $ename(...$params): $tpe = $expr" if methodReturnsAvroSCollection(tpe) => + true } .foldLeft(false)(_ || _) val avroValuePatch = if (usesAvroCoders) Patch.addGlobalImport(avroImport) else Patch.empty