Skip to content

Commit

Permalink
Use estimate spark column to select size for bloom filter (#19)
Browse files Browse the repository at this point in the history
* use estimate column to select filter size

* use sorted map to ensure smallest matching filter is used

* test aggregator with multiple partitions

* refactoring
  • Loading branch information
elliVM authored Dec 5, 2023
1 parent f8bb6cd commit 0edd189
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,36 @@ import org.apache.spark.util.sketch.BloomFilter
import scala.collection.mutable
import scala.reflect.ClassTag

class BloomFilterAggregator(final val columnName: String, final val bloomfilterExpectedItems: Long, final val bloomfilterFfp: Double ) extends Aggregator[Row, BloomFilter, Array[Byte]]
class BloomFilterAggregator(final val columnName: String, final val estimateName: String, sizeMap: mutable.SortedMap[Long, Double]) extends Aggregator[Row, BloomFilter, Array[Byte]]
with Serializable {

var tokenizer: Option[Tokenizer] = None

override def zero(): BloomFilter = {
BloomFilter.create(bloomfilterExpectedItems, bloomfilterFfp)
BloomFilter.create(1, 0.01)
}

override def reduce(buffer: BloomFilter, row: Row): BloomFilter = {
var newBuffer = buffer
val tokens : mutable.WrappedArray[mutable.WrappedArray[Byte]] = row.getAs[mutable.WrappedArray[mutable.WrappedArray[Byte]]](columnName)
val estimate: Long = row.getAs[Long](estimateName)

if (newBuffer.bitSize() == 64) { // zero() will have 64 bitSize
newBuffer = selectFilterFromMap(estimate)
}

for (token : mutable.WrappedArray[Byte] <- tokens) {
val tokenByteArray: Array[Byte] = token.toArray
buffer.putBinary(tokenByteArray)
newBuffer.putBinary(tokenByteArray)
}

buffer
newBuffer
}

override def merge(ours: BloomFilter, their: BloomFilter): BloomFilter = {
// ignore merge with zero buffer
if (!ours.isCompatible(their)) return their

ours.mergeInPlace(their)
}

Expand All @@ -96,4 +105,21 @@ class BloomFilterAggregator(final val columnName: String, final val bloomfilterE
override def outputEncoder: Encoder[Array[Byte]] = ExpressionEncoder[Array[Byte]]

implicit def customKryoEncoder[A](implicit ct: ClassTag[A]): Encoder[A] = Encoders.kryo[A](ct)

private def selectFilterFromMap(estimate: Long): BloomFilter = {

var backupExpected = 0L
var backupFpp = 0.01

for (entry <- sizeMap) {
backupExpected = entry._1
backupFpp = entry._2

if (estimate <= entry._1) {
return BloomFilter.create(entry._1, entry._2)
}
}

BloomFilter.create(backupExpected, backupFpp)
}
}
36 changes: 20 additions & 16 deletions src/test/scala/BloomFilterAggregatorTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import org.apache.spark.util.sketch.BloomFilter
import java.io.ByteArrayInputStream
import java.sql.Timestamp
import java.time.{Instant, LocalDateTime, ZoneOffset}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

class BloomFilterAggregatorTest {
Expand All @@ -68,16 +69,18 @@ class BloomFilterAggregatorTest {
val amount: Long = 10

val testSchema: StructType = new StructType(
Array[StructField]
(StructField("_time", DataTypes.TimestampType, nullable = false, new MetadataBuilder().build),
StructField("_raw", DataTypes.StringType, nullable = false, new MetadataBuilder().build),
StructField("index", DataTypes.StringType, nullable = false, new MetadataBuilder().build),
StructField("sourcetype", DataTypes.StringType, nullable = false, new MetadataBuilder().build),
StructField("host", DataTypes.StringType, nullable = false, new MetadataBuilder().build),
StructField("source", DataTypes.StringType, nullable = false, new MetadataBuilder().build),
StructField("partition", DataTypes.StringType, nullable = false, new MetadataBuilder().build),
// Offset set as string instead of Long.
StructField("offset", DataTypes.StringType, nullable = false, new MetadataBuilder().build)))
Array[StructField](
StructField("_time", TimestampType, nullable = false, new MetadataBuilder().build),
StructField("_raw", StringType, nullable = false, new MetadataBuilder().build),
StructField("index", StringType, nullable = false, new MetadataBuilder().build),
StructField("sourcetype", StringType, nullable = false, new MetadataBuilder().build),
StructField("host", StringType, nullable = false, new MetadataBuilder().build),
StructField("source", StringType, nullable = false, new MetadataBuilder().build),
StructField("partition", StringType, nullable = false, new MetadataBuilder().build),
StructField("offset", LongType, nullable = false, new MetadataBuilder().build),
StructField("estimate(tokens)", LongType, nullable = false, new MetadataBuilder().build)
)
)

@org.junit.jupiter.api.Test
def testTokenization(): Unit = {
Expand All @@ -88,7 +91,7 @@ class BloomFilterAggregatorTest {
val rowMemoryStream = new MemoryStream[Row](1,sqlContext)(encoder)

var rowDataset = rowMemoryStream.toDF

val sizeMap: mutable.TreeMap[Long, Double] = mutable.TreeMap(1000L -> 0.01, 10000L -> 0.01)


// create Scala udf
Expand All @@ -99,9 +102,8 @@ class BloomFilterAggregatorTest {
// apply udf to column
rowDataset = rowDataset.withColumn("tokens", tokenizerUDF.apply(functions.col("_raw")))


// run bloomfilter on the column
val tokenAggregator = new BloomFilterAggregator("tokens", 50000L, 0.01)
val tokenAggregator = new BloomFilterAggregator("tokens", "estimate(tokens)", sizeMap)
val tokenAggregatorColumn = tokenAggregator.toColumn

val aggregatedDataset = rowDataset
Expand Down Expand Up @@ -147,14 +149,16 @@ class BloomFilterAggregatorTest {
val rowData = generateRawData()

for (i <- 0 until amount.toInt) {
val row = RowFactory.create(time,
val row = Row(
time,
exampleString,
"topic",
"stream",
"host",
"input",
partition,
"0L")
i.toString,
0L,
exampleString.length.toLong)

rowList += row
}
Expand Down
5 changes: 2 additions & 3 deletions src/test/scala/BloomFilterBufferTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ class BloomFilterBufferTest {
def testNoDuplicateKeys(): Unit = {

// TODO test other sizes / size categorization
val bloomfilterExpectedItems = 50000L
val bloomfilterFpp = 0.01D
val sizeMap: mutable.TreeMap[Long, Double] = mutable.TreeMap(1000L -> 0.01, 10000L -> 0.01)

// single token, converted to WrappedArray
val input: String = "one,one"
Expand All @@ -80,7 +79,7 @@ class BloomFilterBufferTest {
val schema = StructType(Seq(StructField(columnName, ArrayType(ArrayType(ByteType)))))
val row = new GenericRowWithSchema(columns, schema)

val bfAgg : BloomFilterAggregator = new BloomFilterAggregator(columnName, bloomfilterExpectedItems, bloomfilterFpp)
val bfAgg : BloomFilterAggregator = new BloomFilterAggregator(columnName, "estimate(tokens)", sizeMap)

val bfAggBuf = bfAgg.zero()
bfAgg.reduce(bfAggBuf, row)
Expand Down
36 changes: 17 additions & 19 deletions src/test/scala/TokenizerTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,13 @@
* a licensee so wish it.
*/

import com.teragrep.functions.dpf_03.{BloomFilterAggregator, ByteArrayListAsStringListUDF, TokenizerUDF}
import com.teragrep.functions.dpf_03.{ByteArrayListAsStringListUDF, TokenizerUDF}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming.{StreamingQuery, Trigger}
import org.apache.spark.sql.types._
import org.apache.spark.util.sketch.BloomFilter

import java.io.ByteArrayInputStream
import java.sql.Timestamp
import java.time.{Instant, LocalDateTime, ZoneOffset}
import scala.collection.mutable
Expand All @@ -70,15 +68,17 @@ class TokenizerTest {

val testSchema: StructType = new StructType(
Array[StructField]
(StructField("_time", DataTypes.TimestampType, nullable = false, new MetadataBuilder().build),
StructField("_raw", DataTypes.StringType, nullable = false, new MetadataBuilder().build),
StructField("index", DataTypes.StringType, nullable = false, new MetadataBuilder().build),
StructField("sourcetype", DataTypes.StringType, nullable = false, new MetadataBuilder().build),
StructField("host", DataTypes.StringType, nullable = false, new MetadataBuilder().build),
StructField("source", DataTypes.StringType, nullable = false, new MetadataBuilder().build),
StructField("partition", DataTypes.StringType, nullable = false, new MetadataBuilder().build),
// Offset set as string instead of Long.
StructField("offset", DataTypes.StringType, nullable = false, new MetadataBuilder().build)))
(StructField("_time", TimestampType, nullable = false, new MetadataBuilder().build),
StructField("_raw", StringType, nullable = false, new MetadataBuilder().build),
StructField("index", StringType, nullable = false, new MetadataBuilder().build),
StructField("sourcetype", StringType, nullable = false, new MetadataBuilder().build),
StructField("host", StringType, nullable = false, new MetadataBuilder().build),
StructField("source", StringType, nullable = false, new MetadataBuilder().build),
StructField("partition", StringType, nullable = false, new MetadataBuilder().build),
StructField("offset", LongType, nullable = false, new MetadataBuilder().build),
StructField("estimate(tokens)", LongType, nullable = false, new MetadataBuilder().build)
)
)

@org.junit.jupiter.api.Test
def testTokenization(): Unit = {
Expand All @@ -90,8 +90,6 @@ class TokenizerTest {

var rowDataset = rowMemoryStream.toDF



// create Scala udf for tokenizer
val tokenizerUDF = functions.udf(new TokenizerUDF, DataTypes.createArrayType(DataTypes.createArrayType(ByteType, false), false))
// register tokenizer udf
Expand Down Expand Up @@ -138,19 +136,19 @@ class TokenizerTest {
private def makeRows(time: Timestamp, partition: String): Seq[Row] = {

val rowList: ArrayBuffer[Row] = new ArrayBuffer[Row]
val rowData = generateRawData()

for (i <- 0 until amount.toInt) {
val row = RowFactory.create(time,
rowList += Row(
time,
exampleString,
"topic",
"stream",
"host",
"input",
partition,
"0L")

rowList += row
0L,
exampleString.length.toLong
)
}
rowList
}
Expand Down

0 comments on commit 0edd189

Please sign in to comment.