Skip to content

Commit

Permalink
Add more test cases. Add guard to ensure stream doesn't provide too m…
Browse files Browse the repository at this point in the history
…any rows.
  • Loading branch information
westonpace committed Oct 16, 2024
1 parent 4d7dd8e commit ddef402
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
69 changes: 69 additions & 0 deletions python/python/tests/test_schema_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,75 @@ def triple_a(batch):
assert expected == dataset.to_table()


def test_add_columns_from_rbr(tmp_path):
tab = pa.table({"a": range(100), "b": range(100)})
dataset = lance.write_dataset(tab, tmp_path / "dataset", max_rows_per_file=25)

# New data in smaller chunks than old data
def gen_data():
for i in range(34):
num_rows = 3
if i == 33:
num_rows = 1
yield pa.record_batch(
[pa.array(range(num_rows)), pa.array(range(num_rows))], ["c", "d"]
)

dataset.add_columns(
gen_data(),
reader_schema=pa.schema([pa.field("c", pa.int64()), pa.field("d", pa.int64())]),
)

expected = tab.append_column(
"c", pa.array([i % 3 for i in range(100)])
).append_column("d", pa.array([i % 3 for i in range(100)]))

assert expected == dataset.to_table()

# New data in larger chunks than old data
def gen_data():
for i in range(3):
num_rows = 40
if i == 2:
num_rows = 20
yield pa.record_batch([pa.array(range(num_rows))], ["e"])

dataset.add_columns(
gen_data(),
reader_schema=pa.schema([pa.field("e", pa.int64())]),
)

expected = expected.append_column("e", pa.array([i % 40 for i in range(100)]))

assert expected == dataset.to_table()

# Insufficient number of rows

def gen_data():
yield pa.record_batch([pa.array(range(50))], ["f"])

with pytest.raises(
OSError, match="Stream ended before producing values for all rows in dataset"
):
dataset.add_columns(
gen_data(),
reader_schema=pa.schema([pa.field("f", pa.int64())]),
)

# Too many rows

def gen_data():
yield pa.record_batch([pa.array(range(101))], ["f"])

with pytest.raises(
OSError, match="Stream produced more values than expected for dataset"
):
dataset.add_columns(
gen_data(),
reader_schema=pa.schema([pa.field("f", pa.int64())]),
)


def test_add_columns_from_file(tmp_path):
tab = pa.table({"a": range(100), "b": range(100)})
dataset = lance.write_dataset(tab, tmp_path / "dataset", max_rows_per_file=25)
Expand Down
9 changes: 9 additions & 0 deletions rust/lance/src/dataset/schema_evolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,15 @@ async fn add_columns_from_stream(
}
new_fragments.push(updater.finish().await?);
}

// Ensure the stream is fully consumed
if last_seen_batch.is_some() || stream.next().await.is_some() {
return Err(Error::invalid_input(
format!("Stream produced more values than expected for dataset"),
location!(),
));
}

Ok(new_fragments)
}

Expand Down

0 comments on commit ddef402

Please sign in to comment.