Skip to content

Commit

Permalink
[issue_165] throw runtime exception; use traversable trait (#170)
Browse files Browse the repository at this point in the history
[issue_165] throw runtime exception; handle multi-valued fields in DataFrames
  • Loading branch information
zouzias authored Apr 11, 2019
1 parent d8714b6 commit a3413dc
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 16 deletions.
35 changes: 30 additions & 5 deletions src/main/scala/org/zouzias/spark/lucenerdd/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package org.zouzias.spark

import org.apache.lucene.document._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.ArrayType
import org.zouzias.spark.lucenerdd.config.LuceneRDDConfigurable

import collection.JavaConverters._
import scala.reflect.ClassTag

package object lucenerdd extends LuceneRDDConfigurable {
Expand Down Expand Up @@ -47,6 +49,14 @@ package object lucenerdd extends LuceneRDDConfigurable {
}
}

private def listPrimitiveToDocument[T: ClassTag](doc: Document,
fieldName: String,
iter: java.util.List[T])
: Document = {
iter.asScala.foreach( item => typeToDocument(doc, fieldName, item))
doc
}

implicit def intToDocument(v: Int): Document = {
val doc = new Document
if (v != null) {
Expand Down Expand Up @@ -116,7 +126,6 @@ package object lucenerdd extends LuceneRDDConfigurable {
}

def typeToDocument[T: ClassTag](doc: Document, fieldName: String, s: T): Document = {

s match {
case x: String if x != null =>
doc.add(new Field(fieldName, x,
Expand All @@ -135,7 +144,10 @@ package object lucenerdd extends LuceneRDDConfigurable {
case x: Double if x != null =>
doc.add(new DoublePoint(fieldName, x))
doc.add(new StoredField(fieldName, x))
case _ => Unit
case null => Unit
case _ =>
throw new RuntimeException(s"Type ${s.getClass.getName} " +
s"on field ${fieldName} is not supported")
}
doc
}
Expand All @@ -146,6 +158,12 @@ package object lucenerdd extends LuceneRDDConfigurable {
doc
}

implicit def arrayPrimitiveToDocument[T: ClassTag](iter: Array[T]): Document = {
val doc = new Document
iter.foreach( item => tupleTypeToDocument(doc, 1, item))
doc
}

implicit def mapToDocument[T: ClassTag](map: Map[String, T]): Document = {
val doc = new Document
map.foreach{ case (key, value) =>
Expand Down Expand Up @@ -180,10 +198,17 @@ package object lucenerdd extends LuceneRDDConfigurable {
implicit def sparkRowToDocument(row: Row): Document = {
val doc = new Document

val fieldNames = row.schema.fieldNames
fieldNames.foreach{ case fieldName =>
row.schema.map(field => (field.name, field.dataType))
.foreach{ case (fieldName, dataType) =>
val index = row.fieldIndex(fieldName)
typeToDocument(doc, fieldName, row.get(index))

// TODO: Handle org.apache.spark.sql.types.MapType and more
if (dataType.isInstanceOf[ArrayType]) {
listPrimitiveToDocument(doc, fieldName, row.getList(index))
}
else {
typeToDocument(doc, fieldName, row.get(index))
}
}

doc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
package org.zouzias.spark.lucenerdd.testing

case class FavoriteCaseClass(name: String, age: Int, myLong: Long, myFloat: Float, email: String)

case class MultivalueFavoriteCaseClass(names: Array[String], age: Int, myLong: Long,
myFloat: Float, email: String)
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,36 @@ class LucenePrimitiveTypesSpec extends FlatSpec with Matchers
}

/**
* Do not work with facets (multi-valued issue)
"LuceneRDD" should "work with RDD[List[String]]" in {
val array = Array(List("aaa", "aaa2"), List("bbb", "bbb2"),
List("ccc", "ccc2"), List("ddd"), List("eee"))
val rdd = sc.parallelize(array)
luceneRDD = LuceneRDD(rdd)
luceneRDD.count should be (array.size)
luceneRDD.count should be (array.length)
}
*/

"LuceneRDD" should "work with RDD[Array[String]]" in {
val array = Array(Array("aaa", "aaa2"), Array("bbb", "bbb2"),
Array("ccc", "ccc2"), Array("ddd"), Array("eee"))
val rdd = sc.parallelize(array)
luceneRDD = LuceneRDD(rdd)
luceneRDD.count should be (array.length)
}

"LuceneRDD" should "work with RDD[Set[String]]" in {
val array = Array(Set("aaa", "aaa2"), Set("bbb", "bbb2"),
Set("ccc", "ccc2"), Set("ddd"), Set("eee"))
val rdd = sc.parallelize(array)
luceneRDD = LuceneRDD(rdd)
luceneRDD.count should be (array.size)
luceneRDD.count should be (array.length)
}

*/

"LuceneRDD" should "work with RDD[String]" in {
val array = Array("aaa", "bbb", "ccc", "ddd", "eee")
val rdd = sc.parallelize(array)
luceneRDD = LuceneRDD(rdd)
luceneRDD.count should be (array.size)
luceneRDD.count should be (array.length)
}

"LuceneRDD" should "work with RDD[Int]" in {
Expand Down Expand Up @@ -109,7 +114,7 @@ class LucenePrimitiveTypesSpec extends FlatSpec with Matchers
val array = Array("aaa", null, "ccc", null, "eee")
val rdd = sc.parallelize(array)
luceneRDD = LuceneRDD(rdd)
luceneRDD.count should be (array.size)
luceneRDD.count should be (array.length)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class LuceneRDDCustomCaseClassImplicitsSpec extends FlatSpec
set("spark.ui.enabled", "false").
set("spark.app.id", appID))

val elem = Array("fear", "death", "water", "fire", "house")
val elem: Array[Person] = Array("fear", "death", "water", "fire", "house")
.zipWithIndex.map{ case (str, index) => Person(str, index, s"${str}@gmail.com")}

"LuceneRDD(case class).count" should "handle nulls properly" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}
import org.zouzias.spark.lucenerdd.testing.FavoriteCaseClass
import org.zouzias.spark.lucenerdd.testing.{FavoriteCaseClass, MultivalueFavoriteCaseClass}

class LuceneRDDDataFrameImplicitsSpec extends FlatSpec
with Matchers
Expand All @@ -45,14 +45,26 @@ class LuceneRDDDataFrameImplicitsSpec extends FlatSpec
.zipWithIndex.map{ case (str, index) =>
FavoriteCaseClass(str, index, 10L, 12.3F, s"${str}@gmail.com")}

val multiValuesElems = Array("fear", "death", "water", "fire", "house")
.zipWithIndex.map{ case (str, index) =>
MultivalueFavoriteCaseClass(Array(str, str.reverse), index, 10L, 12.3F, s"${str}@gmail.com")}

"LuceneRDD(MultivalueFavoriteCaseClass).count" should "return correct number of elements" in {
val rdd = sc.parallelize(multiValuesElems)
val spark = SparkSession.builder().getOrCreate()
import spark.implicits._
val df = rdd.toDF()
luceneRDD = LuceneRDD(df)
luceneRDD.count should equal (elem.length)
}

"LuceneRDD(case class).count" should "return correct number of elements" in {
val rdd = sc.parallelize(elem)
val spark = SparkSession.builder().getOrCreate()
import spark.implicits._
val df = rdd.toDF()
luceneRDD = LuceneRDD(df)
luceneRDD.count should equal (elem.size)
luceneRDD.count should equal (elem.length)
}

"LuceneRDD(case class).fields" should "return all fields" in {
Expand All @@ -70,6 +82,21 @@ class LuceneRDDDataFrameImplicitsSpec extends FlatSpec
luceneRDD.fields().contains("email") should equal(true)
}

"LuceneRDD(MultivalueFavoriteCaseClass).fields" should "return all fields" in {
val rdd = sc.parallelize(multiValuesElems)
val spark = SparkSession.builder().getOrCreate()
import spark.implicits._
val df = rdd.toDF()
luceneRDD = LuceneRDD(df)

luceneRDD.fields().size should equal(5)
luceneRDD.fields().contains("names") should equal(true)
luceneRDD.fields().contains("age") should equal(true)
luceneRDD.fields().contains("myLong") should equal(true)
luceneRDD.fields().contains("myFloat") should equal(true)
luceneRDD.fields().contains("email") should equal(true)
}

"LuceneRDD(case class).termQuery" should "correctly search with TermQueries" in {
val rdd = sc.parallelize(elem)
val spark = SparkSession.builder().getOrCreate()
Expand All @@ -80,4 +107,15 @@ class LuceneRDDDataFrameImplicitsSpec extends FlatSpec
val results = luceneRDD.termQuery("name", "water")
results.count should equal(1)
}

"LuceneRDD(MultivalueFavoriteCaseClass).termQuery" should "correctly search with TermQueries" in {
val rdd = sc.parallelize(multiValuesElems)
val spark = SparkSession.builder().getOrCreate()
import spark.implicits._
val df = rdd.toDF()
luceneRDD = LuceneRDD(df)

val results = luceneRDD.termQuery("names", "retaw")
results.count should equal(1)
}
}

0 comments on commit a3413dc

Please sign in to comment.