Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparkc 577 round three #1264

Open
wants to merge 4 commits into
base: b3.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,49 +1,70 @@
package com.datastax.spark.connector

import scala.language.implicitConversions
import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata

import com.datastax.spark.connector.cql.TableDef
import scala.language.implicitConversions
import com.datastax.spark.connector.cql.{ColumnDef, StructDef, TableDef}

sealed trait ColumnSelector {
def aliases: Map[String, String]
def selectFrom(table: TableDef): IndexedSeq[ColumnRef]
def selectFrom(table: TableMetadata): IndexedSeq[ColumnRef]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instances of this trait are used widely throughout the code base so in order to avoid broad breakage here I just overloaded selectFrom()

}

case object AllColumns extends ColumnSelector {
override def aliases: Map[String, String] = Map.empty.withDefault(x => x)
override def selectFrom(table: TableDef) =
table.columns.map(_.ref)
override def selectFrom(table: TableMetadata) =
TableDef.columns(table).map(ColumnDef.toRef(_))
}

case object PartitionKeyColumns extends ColumnSelector {
override def aliases: Map[String, String] = Map.empty.withDefault(x => x)
override def selectFrom(table: TableDef) =
table.partitionKey.map(_.ref).toIndexedSeq
override def selectFrom(table: TableMetadata) =
TableDef.partitionKey(table).map(ColumnDef.toRef(_)).toIndexedSeq
}

case object PrimaryKeyColumns extends ColumnSelector {
override def aliases: Map[String, String] = Map.empty.withDefault(x => x)
override def selectFrom(table: TableDef) =
table.primaryKey.map(_.ref)
override def selectFrom(table: TableMetadata) =
TableDef.primaryKey(table).map(ColumnDef.toRef(_))
}

case class SomeColumns(columns: ColumnRef*) extends ColumnSelector {

private def columnsToCheck():Seq[ColumnRef] = columns flatMap {
case f: FunctionCallRef => f.requiredColumns //Replaces function calls by their required columns
case RowCountRef => Seq.empty //Filters RowCountRef from the column list
case other => Seq(other)
}

/** Compute the columns that are not present in the structure. */
private def missingColumns(table:StructDef, columnsToCheck: Seq[ColumnRef]): Seq[ColumnRef] =
for (c <- columnsToCheck if !table.columnByName.contains(c.columnName)) yield c

private def missingColumns(table:TableMetadata, columnsToCheck: Seq[ColumnRef]): Seq[ColumnRef] =
for (c <- columnsToCheck if !TableDef.containsColumn(c.columnName)(table)) yield c

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formerly in StructDef. Pulled out because it's only really used here so it seemed better to localize the functionality where it's needed. If this becomes more broadly used it can always be moved back to StructDef

override def aliases: Map[String, String] = columns.map {
case ref => (ref.selectedAs, ref.cqlValueName)
}.toMap

override def selectFrom(table: TableDef): IndexedSeq[ColumnRef] = {
val missing = table.missingColumns {
columns flatMap {
case f: FunctionCallRef => f.requiredColumns //Replaces function calls by their required columns
case RowCountRef => Seq.empty //Filters RowCountRef from the column list
case other => Seq(other)
}
}
val missing = missingColumns(table, columnsToCheck)
if (missing.nonEmpty) throw new NoSuchElementException(
s"Columns not found in table ${table.name}: ${missing.mkString(", ")}")
columns.toIndexedSeq
}

override def selectFrom(table: TableMetadata): IndexedSeq[ColumnRef] = {
val missing = missingColumns(table, columnsToCheck)
if (missing.nonEmpty) throw new NoSuchElementException(
s"Columns not found in table ${TableDef.tableName(table)}: ${missing.mkString(", ")}")
columns.toIndexedSeq
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package com.datastax.spark.connector.datasource

import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata
import com.datastax.oss.driver.internal.core.cql.ResultSets
import com.datastax.spark.connector.cql.{CassandraConnector, TableDef}
import com.datastax.spark.connector.cql.{CassandraConnector}
import com.datastax.spark.connector.rdd.ReadConf
import com.datastax.spark.connector.rdd.reader.PrefetchingResultSetIterator
import com.datastax.spark.connector.util.Logging
Expand All @@ -18,7 +19,7 @@ import scala.util.{Failure, Success}

case class CassandraInJoinReaderFactory(
connector: CassandraConnector,
tableDef: TableDef,
tableMetadata: TableMetadata,
inClauses: Seq[In],
readConf: ReadConf,
schema: StructType,
Expand All @@ -27,33 +28,32 @@ case class CassandraInJoinReaderFactory(

override def createReader(partition: InputPartition): PartitionReader[InternalRow] =
if (cqlQueryParts.selectedColumnRefs.contains(RowCountRef)) {
CassandraInJoinCountReader(connector, tableDef, inClauses, readConf, schema, cqlQueryParts, partition)
CassandraInJoinCountReader(connector, tableMetadata, inClauses, readConf, schema, cqlQueryParts, partition)
} else {
CassandraInJoinReader(connector, tableDef, inClauses, readConf, schema, cqlQueryParts, partition)
CassandraInJoinReader(connector, tableMetadata, inClauses, readConf, schema, cqlQueryParts, partition)
}
}


abstract class CassandraBaseInJoinReader(
connector: CassandraConnector,
tableDef: TableDef,
inClauses: Seq[In],
readConf: ReadConf,
schema: StructType,
cqlQueryParts: ScanHelper.CqlQueryParts,
partition: InputPartition)
abstract class CassandraBaseInJoinReader(connector: CassandraConnector,
tableMetadata: TableMetadata,
inClauses: Seq[In],
readConf: ReadConf,
schema: StructType,
cqlQueryParts: ScanHelper.CqlQueryParts,
partition: InputPartition)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IntelliJ kept insisting on reformatting arg lists in this way and nothing I did could convince it to stop. This is a relatively minor thing that can be cleaned up later but my apologies for the noise while reviewing.

extends PartitionReader[InternalRow]
with Logging {

protected val numberedInputPartition = partition.asInstanceOf[NumberedInputPartition]
protected val joinColumnNames = inClauses.map(in => ColumnName(in.attribute)).toIndexedSeq
protected val session = connector.openSession()
protected val rowWriter = CassandraRowWriter.Factory.rowWriter(tableDef, joinColumnNames)
protected val rowReader = new UnsafeRowReaderFactory(schema).rowReader(tableDef, cqlQueryParts.selectedColumnRefs)
protected val rowWriter = CassandraRowWriter.Factory.rowWriter(tableMetadata, joinColumnNames)
protected val rowReader = new UnsafeRowReaderFactory(schema).rowReader(tableMetadata, cqlQueryParts.selectedColumnRefs)

protected val keyIterator: Iterator[CassandraRow] = InClauseKeyGenerator.getIterator(numberedInputPartition.index, numberedInputPartition.total, inClauses) //Generate Iterators for this partition here

protected val stmt = JoinHelper.getJoinQueryString(tableDef, joinColumnNames, cqlQueryParts)
protected val stmt = JoinHelper.getJoinQueryString(tableMetadata, joinColumnNames, cqlQueryParts)
protected val preparedStatement = JoinHelper.getJoinPreparedStatement(session, stmt, readConf.consistencyLevel)
protected val bsb = JoinHelper.getKeyBuilderStatementBuilder(session, rowWriter, preparedStatement, cqlQueryParts.whereClause)
protected val rowMetadata = JoinHelper.getCassandraRowMetadata(session, preparedStatement, cqlQueryParts.selectedColumnRefs)
Expand Down Expand Up @@ -109,28 +109,27 @@ abstract class CassandraBaseInJoinReader(

case class CassandraInJoinReader(
connector: CassandraConnector,
tableDef: TableDef,
tableMetadata: TableMetadata,
inClauses: Seq[In],
readConf: ReadConf,
schema: StructType,
cqlQueryParts: ScanHelper.CqlQueryParts,
partition: InputPartition)
extends CassandraBaseInJoinReader(connector, tableDef, inClauses, readConf, schema, cqlQueryParts, partition)
extends CassandraBaseInJoinReader(connector, tableMetadata, inClauses, readConf, schema, cqlQueryParts, partition)

case class CassandraInJoinCountReader(
connector: CassandraConnector,
tableDef: TableDef,
inClauses: Seq[In],
readConf: ReadConf,
schema: StructType,
cqlQueryParts: ScanHelper.CqlQueryParts,
partition: InputPartition)
extends CassandraBaseInJoinReader(connector, tableDef, inClauses, readConf, schema, cqlQueryParts, partition) {
case class CassandraInJoinCountReader(connector: CassandraConnector,
tableMetadata: TableMetadata,
inClauses: Seq[In],
readConf: ReadConf,
schema: StructType,
cqlQueryParts: ScanHelper.CqlQueryParts,
partition: InputPartition)
extends CassandraBaseInJoinReader(connector, tableMetadata, inClauses, readConf, schema, cqlQueryParts, partition) {

//Our read is not based on the structure of the table we are reading from
override val rowReader =
new UnsafeRowReaderFactory(StructType(Seq(StructField("count", LongType, false))))
.rowReader(tableDef, cqlQueryParts.selectedColumnRefs)
.rowReader(tableMetadata, cqlQueryParts.selectedColumnRefs)

/*
Casting issue here for extremely large C* partitions,
Expand Down
Loading