Skip to content

Commit

Permalink
Update FixAvroCoder to match transformed Avro SCollections (#5351)
Browse files Browse the repository at this point in the history
  • Loading branch information
clairemcginty authored May 2, 2024
1 parent 8ff8bba commit da05566
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
rule = FixAvroCoder
*/
package fix.v0_14_0

import com.spotify.scio.values.SCollection

object FixAvroCoder16 {

def someMethod(data: SCollection[A]): SCollection[(String, A)] = {
data.map(r => ("foo", r))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package fix.v0_14_0

import com.spotify.scio.values.SCollection
import com.spotify.scio.avro._

object FixAvroCoder16 {

def someMethod(data: SCollection[A]): SCollection[(String, A)] = {
data.map(r => ("foo", r))
}
}
23 changes: 21 additions & 2 deletions scalafix/rules/src/main/scala/fix/v0_14_0/FixAvroCoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ object FixAvroCoder {
val ParallelizeMatcher = SymbolMatcher.normalized(
"com/spotify/scio/ScioContext#parallelize()."
)
val SCollectionMatcher = SymbolMatcher.normalized(
"com/spotify/scio/values/SCollection#"
)

/** @return true if `sym` is a class whose parents include a type matching `parentMatcher` */
def hasParentClass(sym: Symbol, parentMatcher: SymbolMatcher)(implicit
Expand Down Expand Up @@ -162,9 +165,23 @@ object FixAvroCoder {
coderT
}
.flatMap(_.info.map(_.signature).toList)
.exists { case TypeSignature(_, _, TypeRef(_, maybeAvroType, _)) =>
AvroMatcher.matches(maybeAvroType)
.exists {
case TypeSignature(_, _, TypeRef(_, maybeAvroType, _)) =>
AvroMatcher.matches(maybeAvroType)
case _ => false
}

def methodReturnsAvroSCollection(
returnType: Option[Type]
)(implicit doc: SemanticDocument): Boolean = returnType match {
case Some(t"$tpe[..$tpesnel]") if SCollectionMatcher.matches(tpe) =>
tpesnel.exists {
case tpe if isAvroType(tpe.symbol) => true
case t"(..$tupleTypes)" if tupleTypes.exists(tt => isAvroType(tt.symbol)) => true
case _ => false
}
case _ => false
}
}

class FixAvroCoder extends SemanticRule("FixAvroCoder") {
Expand Down Expand Up @@ -209,6 +226,8 @@ class FixAvroCoder extends SemanticRule("FixAvroCoder") {
case _ => false
}
case q"$fn(..$args)" if methodHasAvroCoderTypeBound(fn) => true
case q"..$mods def $ename(...$params): $tpe = $expr" if methodReturnsAvroSCollection(tpe) =>
true
}
.foldLeft(false)(_ || _)
val avroValuePatch = if (usesAvroCoders) Patch.addGlobalImport(avroImport) else Patch.empty
Expand Down

0 comments on commit da05566

Please sign in to comment.