Skip to content

Commit

Permalink
feat!: add batch_size option to merge_columns (#2896)
Browse files Browse the repository at this point in the history
There's a few other places where updater is used (e.g. merge_insert /
add_columns) and we may want to review those paths as well.

This is technically a breaking change on the rust API and so I will mark
it as such.

BREAKING CHANGE: `Fragment::updater` now accepts a new `batch_size`
argument
  • Loading branch information
westonpace authored Sep 17, 2024
1 parent 9ea8d5a commit 7afbdcf
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 36 deletions.
3 changes: 2 additions & 1 deletion python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def merge_columns(
self,
value_func: Callable[[pa.RecordBatch], pa.RecordBatch],
columns: Optional[list[str]] = None,
batch_size: Optional[int] = None,
) -> Tuple[FragmentMetadata, LanceSchema]:
"""Add columns to this Fragment.
Expand All @@ -390,7 +391,7 @@ def merge_columns(
Tuple[FragmentMetadata, LanceSchema]
A new fragment with the added column(s) and the final schema.
"""
updater = self._fragment.updater(columns)
updater = self._fragment.updater(columns, batch_size)

while True:
batch = updater.next()
Expand Down
27 changes: 27 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,33 @@ def test_merge_with_commit(tmp_path: Path):
assert tbl == expected


def test_merge_batch_size(tmp_path: Path):
# Create dataset with 10 fragments with 100 rows each
table = pa.table({"a": range(1000)})
for batch_size in [1, 10, 100, 1000]:
ds_path = str(tmp_path / str(batch_size))
dataset = lance.write_dataset(table, ds_path, max_rows_per_file=100)
fragments = []

def mutate(batch):
assert batch.num_rows <= batch_size
return pa.RecordBatch.from_pydict({"b": batch.column("a")})

for frag in dataset.get_fragments():
merged, schema = frag.merge_columns(mutate, batch_size=batch_size)
fragments.append(merged)

merge = lance.LanceOperation.Merge(fragments, schema)
dataset = lance.LanceDataset.commit(
ds_path, merge, read_version=dataset.version
)

dataset.validate()
tbl = dataset.to_table()
expected = pa.table({"a": range(1000), "b": range(1000)})
assert tbl == expected


def test_merge_with_schema_holes(tmp_path: Path):
# Create table with 3 cols
table = pa.table({"a": range(10)})
Expand Down
6 changes: 4 additions & 2 deletions python/src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,12 @@ impl FileFragment {
Ok(Scanner::new(scn))
}

fn updater(&self, columns: Option<Vec<String>>) -> PyResult<Updater> {
fn updater(&self, columns: Option<Vec<String>>, batch_size: Option<u32>) -> PyResult<Updater> {
let cols = columns.as_deref();
let inner = RT
.block_on(None, async { self.fragment.updater(cols, None).await })?
.block_on(None, async {
self.fragment.updater(cols, None, batch_size).await
})?
.map_err(|err| PyIOError::new_err(err.to_string()))?;
Ok(Updater::new(inner))
}
Expand Down
9 changes: 5 additions & 4 deletions rust/lance/src/dataset/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,7 @@ impl FileFragment {
&self,
columns: Option<&[T]>,
schemas: Option<(Schema, Schema)>,
batch_size: Option<u32>,
) -> Result<Updater> {
let mut schema = self.dataset.schema().clone();

Expand Down Expand Up @@ -1160,11 +1161,11 @@ impl FileFragment {
let reader = reader?;
let deletion_vector = deletion_vector?.unwrap_or_default();

Updater::try_new(self.clone(), reader, deletion_vector, schemas)
Updater::try_new(self.clone(), reader, deletion_vector, schemas, batch_size)
}

pub(crate) async fn merge(mut self, join_column: &str, joiner: &HashJoiner) -> Result<Self> {
let mut updater = self.updater(Some(&[join_column]), None).await?;
let mut updater = self.updater(Some(&[join_column]), None, None).await?;

while let Some(batch) = updater.next().await? {
let batch = joiner.collect(batch[join_column].clone()).await?;
Expand Down Expand Up @@ -2433,7 +2434,7 @@ mod tests {
}

let fragment = &mut dataset.get_fragment(0).unwrap();
let mut updater = fragment.updater(Some(&["i"]), None).await.unwrap();
let mut updater = fragment.updater(Some(&["i"]), None, None).await.unwrap();
let new_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"double_i",
DataType::Int32,
Expand Down Expand Up @@ -2677,7 +2678,7 @@ mod tests {
let fragment = dataset.get_fragments().pop().unwrap();

// Write batch_s using add_columns
let mut updater = fragment.updater(Some(&["i"]), None).await?;
let mut updater = fragment.updater(Some(&["i"]), None, None).await?;
updater.next().await?;
updater.update(batch_s.clone()).await?;
let frag = updater.finish().await?;
Expand Down
32 changes: 15 additions & 17 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ use snafu::{location, Location};
#[cfg(feature = "substrait")]
use lance_datafusion::substrait::parse_substrait;

pub const DEFAULT_BATCH_SIZE: usize = 8192;
const BATCH_SIZE_FALLBACK: usize = 8192;
// For backwards compatibility / historical reasons we re-calculate the default batch size
// on each call
pub fn get_default_batch_size() -> Option<usize> {
std::env::var("LANCE_DEFAULT_BATCH_SIZE")
.map(|val| Some(val.parse().unwrap()))
.unwrap_or(None)
}

pub const LEGACY_DEFAULT_FRAGMENT_READAHEAD: usize = 4;
lazy_static::lazy_static! {
Expand Down Expand Up @@ -262,23 +269,14 @@ impl Scanner {
// 64KB, this is 16K rows. For local file systems, the default block size
// is just 4K, which would mean only 1K rows, which might be a little small.
// So we use a default minimum of 8K rows.
std::env::var("LANCE_DEFAULT_BATCH_SIZE")
.map(|bs| {
bs.parse().unwrap_or_else(|_| {
panic!(
"The value of LANCE_DEFAULT_BATCH_SIZE ({}) is not a valid batch size",
bs
)
})
})
.unwrap_or_else(|_| {
self.batch_size.unwrap_or_else(|| {
std::cmp::max(
self.dataset.object_store().block_size() / 4,
DEFAULT_BATCH_SIZE,
)
})
get_default_batch_size().unwrap_or_else(|| {
self.batch_size.unwrap_or_else(|| {
std::cmp::max(
self.dataset.object_store().block_size() / 4,
BATCH_SIZE_FALLBACK,
)
})
})
}

fn ensure_not_fragment_scan(&self) -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/dataset/schema_evolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ async fn add_columns_impl(
}

let mut updater = fragment
.updater(read_columns_ref, schemas_ref.clone())
.updater(read_columns_ref, schemas_ref.clone(), None)
.await?;

let mut batch_index = 0;
Expand Down
34 changes: 23 additions & 11 deletions rust/lance/src/dataset/updater.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use lance_table::utils::stream::ReadBatchFutStream;
use snafu::{location, Location};

use super::fragment::FragmentReader;
use super::scanner::get_default_batch_size;
use super::write::{open_writer, GenericWriter};
use super::Dataset;
use crate::dataset::FileFragment;
Expand Down Expand Up @@ -59,16 +60,26 @@ impl Updater {
reader: FragmentReader,
deletion_vector: DeletionVector,
schemas: Option<(Schema, Schema)>,
batch_size: Option<u32>,
) -> Result<Self> {
let (write_schema, final_schema) = if let Some((write_schema, final_schema)) = schemas {
(Some(write_schema), Some(final_schema))
} else {
(None, None)
};

let batch_size = reader.legacy_num_rows_in_batch(0);
let legacy_batch_size = reader.legacy_num_rows_in_batch(0);

let input_stream = reader.read_all(1024)?;
let batch_size = match (&legacy_batch_size, batch_size) {
// If this is a v1 dataset we must use the row group size of the file
(Some(num_rows), _) => *num_rows,
// If this is a v2 dataset, let the user pick the batch size
(None, Some(legacy_batch_size)) => legacy_batch_size,
// Otherwise, default to 1024 if the user didn't specify anything
(None, None) => get_default_batch_size().unwrap_or(1024) as u32,
};

let input_stream = reader.read_all(batch_size)?;

Ok(Self {
fragment,
Expand All @@ -78,7 +89,7 @@ impl Updater {
write_schema,
final_schema,
finished: false,
deletion_restorer: DeletionRestorer::new(deletion_vector, batch_size),
deletion_restorer: DeletionRestorer::new(deletion_vector, legacy_batch_size),
})
}

Expand Down Expand Up @@ -226,18 +237,18 @@ struct DeletionRestorer {
current_row_id: u32,

/// Number of rows in each batch, only used in legacy files for validation
batch_size: Option<u32>,
legacy_batch_size: Option<u32>,

deletion_vector_iter: Option<Box<dyn Iterator<Item = u32> + Send>>,

last_deleted_row_id: Option<u32>,
}

impl DeletionRestorer {
fn new(deletion_vector: DeletionVector, batch_size: Option<u32>) -> Self {
fn new(deletion_vector: DeletionVector, legacy_batch_size: Option<u32>) -> Self {
Self {
current_row_id: 0,
batch_size,
legacy_batch_size,
deletion_vector_iter: Some(deletion_vector.into_sorted_iter()),
last_deleted_row_id: None,
}
Expand All @@ -248,12 +259,12 @@ impl DeletionRestorer {
}

fn is_full(batch_size: Option<u32>, num_rows: u32) -> bool {
if let Some(batch_size) = batch_size {
if let Some(legacy_batch_size) = batch_size {
// We should never encounter the case that `batch_size < num_rows` because
// that would mean we have a v1 writer and it generated a batch with more rows
// than expected
debug_assert!(batch_size >= num_rows);
batch_size == num_rows
debug_assert!(legacy_batch_size >= num_rows);
legacy_batch_size == num_rows
} else {
false
}
Expand Down Expand Up @@ -295,7 +306,8 @@ impl DeletionRestorer {
loop {
if let Some(next_deleted_id) = next_deleted_id {
if next_deleted_id > last_row_id
|| (next_deleted_id == last_row_id && Self::is_full(self.batch_size, num_rows))
|| (next_deleted_id == last_row_id
&& Self::is_full(self.legacy_batch_size, num_rows))
{
// Either the next deleted id is out of range or it is the next row but
// we are full. Either way, stash it and return
Expand All @@ -322,7 +334,7 @@ impl DeletionRestorer {
let deleted_batch_offsets = self.deleted_batch_offsets_in_range(batch.num_rows() as u32);
let batch = add_blanks(batch, &deleted_batch_offsets)?;

if let Some(batch_size) = self.batch_size {
if let Some(batch_size) = self.legacy_batch_size {
// validation just in case, when the input has a fixed batch size then the
// output should have the same fixed batch size (except the last batch)
let is_last = self.is_exhausted();
Expand Down
1 change: 1 addition & 0 deletions rust/lance/src/dataset/write/merge_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ impl MergeInsertJob {
.updater(
Some(&read_columns),
Some((write_schema, dataset.schema().clone())),
None,
)
.await?;

Expand Down

0 comments on commit 7afbdcf

Please sign in to comment.