diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index bf77e6dc8d..55df4a66e1 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{pin::Pin, sync::Arc}; use common_error::DaftResult; use common_file_formats::{FileFormatConfig, ParquetSourceConfig}; @@ -8,6 +8,7 @@ use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions}; use daft_micropartition::MicroPartition; use daft_parquet::read::ParquetSchemaInferenceOptions; use daft_scan::{storage_config::StorageConfig, ChunkSpec, ScanTask}; +use daft_table::Table; use futures::{Stream, StreamExt}; use snafu::ResultExt; use tokio_stream::wrappers::ReceiverStream; @@ -288,14 +289,13 @@ async fn stream_scan_task( .map(|t| t.into()) .context(PyIOSnafu) })?; - // SQL Scan cannot be streamed at the moment, so we just return the table Box::pin(futures::stream::once(async { Ok(table) })) } #[cfg(feature = "python")] FileFormatConfig::PythonFunction => { - return Err(common_error::DaftError::TypeError( - "PythonFunction file format not implemented".to_string(), - )); + let iter = daft_micropartition::python::read_pyfunc_into_table_iter(&scan_task)?; + let stream = futures::stream::iter(iter.map(|r| r.map_err(|e| e.into()))); + Box::pin(stream) as Pin> + Send>> } }; diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 2e92ebb922..105fd561e8 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -379,105 +379,8 @@ fn materialize_scan_task( })? } FileFormatConfig::PythonFunction => { - use pyo3::{types::PyAnyMethods, PyObject}; - - let table_iterators = scan_task.sources.iter().map(|source| { - // Call Python function to create an Iterator (Grabs the GIL and then releases it) - match source { - DataSource::PythonFactoryFunction { - module, - func_name, - func_args, - .. - } => { - Python::with_gil(|py| { - let func = py.import_bound(module.as_str()) - .unwrap_or_else(|_| panic!("Cannot import factory function from module {module}")) - .getattr(func_name.as_str()) - .unwrap_or_else(|_| panic!("Cannot find function {func_name} in module {module}")); - func.call(func_args.to_pytuple(py), None) - .with_context(|_| PyIOSnafu) - .map(Into::::into) - }) - } - _ => unreachable!("PythonFunction file format must be paired with PythonFactoryFunction data file sources"), - } - }); - - let mut tables = Vec::new(); - let mut rows_seen_so_far = 0; - for iterator in table_iterators { - let iterator = iterator?; - - // Iterate on this iterator to exhaustion, or until the limit is met - while scan_task - .pushdowns - .limit - .map_or(true, |limit| rows_seen_so_far < limit) - { - // Grab the GIL to call next() on the iterator, and then release it once we have the Table - let table = match Python::with_gil(|py| { - iterator - .downcast_bound::(py) - .expect("Function must return an iterator of tables") - .clone() - .next() - .map(|result| { - result - .map(|tbl| { - tbl.extract::() - .expect("Must be a PyTable") - .table - }) - .with_context(|_| PyIOSnafu) - }) - }) { - Some(table) => table, - None => break, - }?; - - // Apply filters - let table = if let Some(filters) = scan_task.pushdowns.filters.as_ref() - { - table - .filter(&[filters.clone()]) - .with_context(|_| DaftCoreComputeSnafu)? - } else { - table - }; - - // Apply limit if necessary, and update `&mut remaining` - let table = if let Some(limit) = scan_task.pushdowns.limit { - let limited_table = if rows_seen_so_far + table.len() > limit { - table - .slice(0, limit - rows_seen_so_far) - .with_context(|_| DaftCoreComputeSnafu)? - } else { - table - }; - - // Update the rows_seen_so_far - rows_seen_so_far += limited_table.len(); - - limited_table - } else { - table - }; - - tables.push(table); - } - - // If seen enough rows, early-terminate - if scan_task - .pushdowns - .limit - .is_some_and(|limit| rows_seen_so_far >= limit) - { - break; - } - } - - tables + let tables = crate::python::read_pyfunc_into_table_iter(&scan_task)?; + tables.collect::>>()? } } } diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index eb9a43550e..2a4ac6eb18 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -10,12 +10,18 @@ use daft_dsl::python::PyExpr; use daft_io::{python::IOConfig, IOStatsContext}; use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions}; use daft_parquet::read::ParquetSchemaInferenceOptions; -use daft_scan::{python::pylib::PyScanTask, storage_config::PyStorageConfig, ScanTask}; +use daft_scan::{ + python::pylib::PyScanTask, storage_config::PyStorageConfig, DataSource, ScanTask, ScanTaskRef, +}; use daft_stats::{TableMetadata, TableStatistics}; -use daft_table::python::PyTable; +use daft_table::{python::PyTable, Table}; use pyo3::{exceptions::PyValueError, prelude::*, types::PyBytes, PyTypeInfo}; +use snafu::ResultExt; -use crate::micropartition::{MicroPartition, TableState}; +use crate::{ + micropartition::{MicroPartition, TableState}, + DaftCoreComputeSnafu, PyIOSnafu, +}; #[pyclass(module = "daft.daft", frozen)] #[derive(Clone)] @@ -900,6 +906,100 @@ pub fn read_sql_into_py_table( .extract() } +pub fn read_pyfunc_into_table_iter( + scan_task: &ScanTaskRef, +) -> crate::Result>> { + let table_iterators = scan_task.sources.iter().map(|source| { + // Call Python function to create an Iterator (Grabs the GIL and then releases it) + match source { + DataSource::PythonFactoryFunction { + module, + func_name, + func_args, + .. + } => { + Python::with_gil(|py| { + let func = py.import_bound(module.as_str()) + .unwrap_or_else(|_| panic!("Cannot import factory function from module {module}")) + .getattr(func_name.as_str()) + .unwrap_or_else(|_| panic!("Cannot find function {func_name} in module {module}")); + func.call(func_args.to_pytuple(py), None) + .with_context(|_| PyIOSnafu) + .map(Into::::into) + }) + }, + _ => unreachable!("PythonFunction file format must be paired with PythonFactoryFunction data file sources"), + } + }).collect::>>()?; + + let scan_task_limit = scan_task.pushdowns.limit; + let scan_task_filters = scan_task.pushdowns.filters.clone(); + let res = table_iterators + .into_iter() + .filter_map(|iter| { + Python::with_gil(|py| { + iter.downcast_bound::(py) + .expect("Function must return an iterator of tables") + .clone() + .next() + .map(|result| { + result + .map(|tbl| { + tbl.extract::() + .expect("Must be a PyTable") + .table + }) + .with_context(|_| PyIOSnafu) + }) + }) + }) + .scan(0, move |rows_seen_so_far, table| { + if scan_task_limit + .map(|limit| *rows_seen_so_far >= limit) + .unwrap_or(false) + { + return None; + } + match table { + Err(e) => Some(Err(e)), + Ok(table) => { + // Apply filters + let post_pushdown_table = || -> crate::Result { + let table = if let Some(filters) = scan_task_filters.as_ref() { + table + .filter(&[filters.clone()]) + .with_context(|_| DaftCoreComputeSnafu)? + } else { + table + }; + + // Apply limit if necessary, and update `&mut remaining` + if let Some(limit) = scan_task_limit { + let limited_table = if *rows_seen_so_far + table.len() > limit { + table + .slice(0, limit - *rows_seen_so_far) + .with_context(|_| DaftCoreComputeSnafu)? + } else { + table + }; + + // Update the rows_seen_so_far + *rows_seen_so_far += limited_table.len(); + + Ok(limited_table) + } else { + Ok(table) + } + }(); + + Some(post_pushdown_table) + } + } + }); + + Ok(res) +} + impl From for PyMicroPartition { fn from(value: MicroPartition) -> Self { Arc::new(value).into() diff --git a/tests/io/lancedb/test_lancedb_reads.py b/tests/io/lancedb/test_lancedb_reads.py index b1f365f8c1..ad3062ee19 100644 --- a/tests/io/lancedb/test_lancedb_reads.py +++ b/tests/io/lancedb/test_lancedb_reads.py @@ -3,12 +3,6 @@ import pytest import daft -from daft import context - -native_executor_skip = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) TABLE_NAME = "my_table" data = { @@ -18,8 +12,7 @@ } PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0) -py_arrow_skip = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="lance only supported if pyarrow >= 8.0.0") -pytestmark = [native_executor_skip, py_arrow_skip] +pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="lance only supported if pyarrow >= 8.0.0") @pytest.fixture(scope="function")