Skip to content

Commit

Permalink
use java map (#21)
Browse files Browse the repository at this point in the history
* use java map

* sync with main

* remove unused imports

* update filter selection

* fix exception message
  • Loading branch information
elliVM authored Dec 7, 2023
1 parent 8a38613 commit 9e62f53
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ import org.apache.spark.util.sketch.BloomFilter

import scala.collection.mutable
import scala.reflect.ClassTag
import scala.collection.JavaConverters._

class BloomFilterAggregator(final val columnName: String, final val estimateName: String, sizeMap: mutable.SortedMap[Long, Double]) extends Aggregator[Row, BloomFilter, Array[Byte]]
class BloomFilterAggregator(final val columnName: String, final val estimateName: String, sortedSizeMap: java.util.SortedMap[java.lang.Long, java.lang.Double]) extends Aggregator[Row, BloomFilter, Array[Byte]]
with Serializable {

var tokenizer: Option[Tokenizer] = None
Expand Down Expand Up @@ -113,14 +114,19 @@ class BloomFilterAggregator(final val columnName: String, final val estimateName
implicit def customKryoEncoder[A](implicit ct: ClassTag[A]): Encoder[A] = Encoders.kryo[A](ct)

private def selectFilterFromMap(estimate: Long): BloomFilter = {
var filter = BloomFilter.create(sizeMap.last._1, sizeMap.last._2)
val sortedScalaMap = sortedSizeMap.asScala

for (entry <- sizeMap) {
if (estimate <= entry._1) {
filter = BloomFilter.create(entry._1, entry._2)
// default to largest
var size = sortedScalaMap.last._1

for (entry <- sortedScalaMap) {
if (entry._1 >= estimate && entry._1 < size) {
size = entry._1
}
}
val fpp = sortedScalaMap.getOrElse(size,
throw new RuntimeException("sortedScalaMap did not contain value for key size: " + size))

filter
BloomFilter.create(size, fpp)
}
}
8 changes: 5 additions & 3 deletions src/test/scala/BloomFilterAggregatorTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ import org.apache.spark.util.sketch.BloomFilter
import java.io.ByteArrayInputStream
import java.sql.Timestamp
import java.time.{Instant, LocalDateTime, ZoneOffset}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

class BloomFilterAggregatorTest {
Expand Down Expand Up @@ -91,7 +90,10 @@ class BloomFilterAggregatorTest {
val rowMemoryStream = new MemoryStream[Row](1,sqlContext)(encoder)

var rowDataset = rowMemoryStream.toDF
val sizeMap: mutable.TreeMap[Long, Double] = mutable.TreeMap(1000L -> 0.01, 10000L -> 0.01)
val javaMap = new java.util.TreeMap[java.lang.Long, java.lang.Double]() {
put(1000L, 0.01)
put(10000L, 0.01)
}


// create Scala udf
Expand All @@ -103,7 +105,7 @@ class BloomFilterAggregatorTest {
rowDataset = rowDataset.withColumn("tokens", tokenizerUDF.apply(functions.col("_raw")))

// run bloomfilter on the column
val tokenAggregator = new BloomFilterAggregator("tokens", "estimate(tokens)", sizeMap)
val tokenAggregator = new BloomFilterAggregator("tokens", "estimate(tokens)", javaMap)
val tokenAggregatorColumn = tokenAggregator.toColumn

val aggregatedDataset = rowDataset
Expand Down
9 changes: 6 additions & 3 deletions src/test/scala/BloomFilterBufferTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

import com.teragrep.functions.dpf_03.BloomFilterAggregator
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{ArrayType, ByteType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, ByteType, StructField, StructType}
import org.apache.spark.util.sketch.BloomFilter
import org.junit.jupiter.api.Disabled

Expand All @@ -61,7 +61,10 @@ class BloomFilterBufferTest {
def testNoDuplicateKeys(): Unit = {

// TODO test other sizes / size categorization
val sizeMap: mutable.TreeMap[Long, Double] = mutable.TreeMap(1000L -> 0.01, 10000L -> 0.01)
val javaMap = new java.util.TreeMap[java.lang.Long, java.lang.Double]() {
put(1000L, 0.01)
put(10000L, 0.01)
}

// single token, converted to WrappedArray
val input: String = "one,one"
Expand All @@ -79,7 +82,7 @@ class BloomFilterBufferTest {
val schema = StructType(Seq(StructField(columnName, ArrayType(ArrayType(ByteType)))))
val row = new GenericRowWithSchema(columns, schema)

val bfAgg : BloomFilterAggregator = new BloomFilterAggregator(columnName, "estimate(tokens)", sizeMap)
val bfAgg : BloomFilterAggregator = new BloomFilterAggregator(columnName, "estimate(tokens)", javaMap)

val bfAggBuf = bfAgg.zero()
bfAgg.reduce(bfAggBuf, row)
Expand Down

0 comments on commit 9e62f53

Please sign in to comment.