Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHORE] Enable lancedb reads for native executor #2925

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
10 changes: 5 additions & 5 deletions src/daft-local-execution/src/sources/scan_task.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{pin::Pin, sync::Arc};

use common_error::DaftResult;
use common_file_formats::{FileFormatConfig, ParquetSourceConfig};
Expand All @@ -8,6 +8,7 @@ use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions};
use daft_micropartition::MicroPartition;
use daft_parquet::read::ParquetSchemaInferenceOptions;
use daft_scan::{storage_config::StorageConfig, ChunkSpec, ScanTask};
use daft_table::Table;
use futures::{Stream, StreamExt};
use snafu::ResultExt;
use tokio_stream::wrappers::ReceiverStream;
Expand Down Expand Up @@ -288,14 +289,13 @@ async fn stream_scan_task(
.map(|t| t.into())
.context(PyIOSnafu)
})?;
// SQL Scan cannot be streamed at the moment, so we just return the table
Box::pin(futures::stream::once(async { Ok(table) }))
}
#[cfg(feature = "python")]
FileFormatConfig::PythonFunction => {
return Err(common_error::DaftError::TypeError(
"PythonFunction file format not implemented".to_string(),
));
let iter = daft_micropartition::python::read_pyfunc_into_table_iter(&scan_task)?;
let stream = futures::stream::iter(iter.map(|r| r.map_err(|e| e.into())));
Box::pin(stream) as Pin<Box<dyn Stream<Item = DaftResult<Table>> + Send>>
}
};

Expand Down
101 changes: 2 additions & 99 deletions src/daft-micropartition/src/micropartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,105 +379,8 @@ fn materialize_scan_task(
})?
}
FileFormatConfig::PythonFunction => {
use pyo3::{types::PyAnyMethods, PyObject};

let table_iterators = scan_task.sources.iter().map(|source| {
// Call Python function to create an Iterator (Grabs the GIL and then releases it)
match source {
DataSource::PythonFactoryFunction {
module,
func_name,
func_args,
..
} => {
Python::with_gil(|py| {
let func = py.import_bound(module.as_str())
.unwrap_or_else(|_| panic!("Cannot import factory function from module {module}"))
.getattr(func_name.as_str())
.unwrap_or_else(|_| panic!("Cannot find function {func_name} in module {module}"));
func.call(func_args.to_pytuple(py), None)
.with_context(|_| PyIOSnafu)
.map(Into::<PyObject>::into)
})
}
_ => unreachable!("PythonFunction file format must be paired with PythonFactoryFunction data file sources"),
}
});

let mut tables = Vec::new();
let mut rows_seen_so_far = 0;
for iterator in table_iterators {
let iterator = iterator?;

// Iterate on this iterator to exhaustion, or until the limit is met
while scan_task
.pushdowns
.limit
.map_or(true, |limit| rows_seen_so_far < limit)
{
// Grab the GIL to call next() on the iterator, and then release it once we have the Table
let table = match Python::with_gil(|py| {
iterator
.downcast_bound::<pyo3::types::PyIterator>(py)
.expect("Function must return an iterator of tables")
.clone()
.next()
.map(|result| {
result
.map(|tbl| {
tbl.extract::<daft_table::python::PyTable>()
.expect("Must be a PyTable")
.table
})
.with_context(|_| PyIOSnafu)
})
}) {
Some(table) => table,
None => break,
}?;

// Apply filters
let table = if let Some(filters) = scan_task.pushdowns.filters.as_ref()
{
table
.filter(&[filters.clone()])
.with_context(|_| DaftCoreComputeSnafu)?
} else {
table
};

// Apply limit if necessary, and update `&mut remaining`
let table = if let Some(limit) = scan_task.pushdowns.limit {
let limited_table = if rows_seen_so_far + table.len() > limit {
table
.slice(0, limit - rows_seen_so_far)
.with_context(|_| DaftCoreComputeSnafu)?
} else {
table
};

// Update the rows_seen_so_far
rows_seen_so_far += limited_table.len();

limited_table
} else {
table
};

tables.push(table);
}

// If seen enough rows, early-terminate
if scan_task
.pushdowns
.limit
.is_some_and(|limit| rows_seen_so_far >= limit)
{
break;
}
}

tables
let tables = crate::python::read_pyfunc_into_table_iter(&scan_task)?;
tables.collect::<crate::Result<Vec<_>>>()?
}
}
}
Expand Down
106 changes: 103 additions & 3 deletions src/daft-micropartition/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@
use daft_io::{python::IOConfig, IOStatsContext};
use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions};
use daft_parquet::read::ParquetSchemaInferenceOptions;
use daft_scan::{python::pylib::PyScanTask, storage_config::PyStorageConfig, ScanTask};
use daft_scan::{
python::pylib::PyScanTask, storage_config::PyStorageConfig, DataSource, ScanTask, ScanTaskRef,
};
use daft_stats::{TableMetadata, TableStatistics};
use daft_table::python::PyTable;
use daft_table::{python::PyTable, Table};
use pyo3::{exceptions::PyValueError, prelude::*, types::PyBytes, PyTypeInfo};
use snafu::ResultExt;

use crate::micropartition::{MicroPartition, TableState};
use crate::{
micropartition::{MicroPartition, TableState},
DaftCoreComputeSnafu, PyIOSnafu,
};

#[pyclass(module = "daft.daft", frozen)]
#[derive(Clone)]
Expand Down Expand Up @@ -900,6 +906,100 @@
.extract()
}

pub fn read_pyfunc_into_table_iter(
scan_task: &ScanTaskRef,
) -> crate::Result<impl Iterator<Item = crate::Result<Table>>> {
let table_iterators = scan_task.sources.iter().map(|source| {
// Call Python function to create an Iterator (Grabs the GIL and then releases it)
match source {
DataSource::PythonFactoryFunction {
module,
func_name,
func_args,
..
} => {
Python::with_gil(|py| {
let func = py.import_bound(module.as_str())
.unwrap_or_else(|_| panic!("Cannot import factory function from module {module}"))
.getattr(func_name.as_str())
.unwrap_or_else(|_| panic!("Cannot find function {func_name} in module {module}"));
func.call(func_args.to_pytuple(py), None)
.with_context(|_| PyIOSnafu)
.map(Into::<PyObject>::into)
})
},
_ => unreachable!("PythonFunction file format must be paired with PythonFactoryFunction data file sources"),

Check warning on line 931 in src/daft-micropartition/src/python.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-micropartition/src/python.rs#L931

Added line #L931 was not covered by tests
}
}).collect::<crate::Result<Vec<_>>>()?;

let scan_task_limit = scan_task.pushdowns.limit;
let scan_task_filters = scan_task.pushdowns.filters.clone();
let res = table_iterators
.into_iter()
.filter_map(|iter| {
Python::with_gil(|py| {
iter.downcast_bound::<pyo3::types::PyIterator>(py)
.expect("Function must return an iterator of tables")
.clone()
.next()
.map(|result| {
result
.map(|tbl| {
tbl.extract::<daft_table::python::PyTable>()
.expect("Must be a PyTable")
.table
})
.with_context(|_| PyIOSnafu)
})
})
})
.scan(0, move |rows_seen_so_far, table| {
if scan_task_limit
.map(|limit| *rows_seen_so_far >= limit)
.unwrap_or(false)
{
return None;

Check warning on line 961 in src/daft-micropartition/src/python.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-micropartition/src/python.rs#L961

Added line #L961 was not covered by tests
}
match table {
Err(e) => Some(Err(e)),

Check warning on line 964 in src/daft-micropartition/src/python.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-micropartition/src/python.rs#L964

Added line #L964 was not covered by tests
Ok(table) => {
// Apply filters
let post_pushdown_table = || -> crate::Result<Table> {
let table = if let Some(filters) = scan_task_filters.as_ref() {
table
.filter(&[filters.clone()])
.with_context(|_| DaftCoreComputeSnafu)?
} else {
table
};

// Apply limit if necessary, and update `&mut remaining`
if let Some(limit) = scan_task_limit {
let limited_table = if *rows_seen_so_far + table.len() > limit {
table
.slice(0, limit - *rows_seen_so_far)
.with_context(|_| DaftCoreComputeSnafu)?
} else {
table

Check warning on line 983 in src/daft-micropartition/src/python.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-micropartition/src/python.rs#L983

Added line #L983 was not covered by tests
};

// Update the rows_seen_so_far
*rows_seen_so_far += limited_table.len();

Ok(limited_table)
} else {
Ok(table)
}
}();

Some(post_pushdown_table)
}
}
});

Ok(res)
}

impl From<MicroPartition> for PyMicroPartition {
fn from(value: MicroPartition) -> Self {
Arc::new(value).into()
Expand Down
9 changes: 1 addition & 8 deletions tests/io/lancedb/test_lancedb_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
import pytest

import daft
from daft import context

native_executor_skip = pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor fails for these tests",
)

TABLE_NAME = "my_table"
data = {
Expand All @@ -18,8 +12,7 @@
}

PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0)
py_arrow_skip = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="lance only supported if pyarrow >= 8.0.0")
pytestmark = [native_executor_skip, py_arrow_skip]
pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="lance only supported if pyarrow >= 8.0.0")


@pytest.fixture(scope="function")
Expand Down
Loading