diff --git a/r/adbcdrivermanager/NAMESPACE b/r/adbcdrivermanager/NAMESPACE index 63b2552224..a37824acb9 100644 --- a/r/adbcdrivermanager/NAMESPACE +++ b/r/adbcdrivermanager/NAMESPACE @@ -8,11 +8,13 @@ S3method("[[",adbc_async_task) S3method("[[",adbc_error) S3method("[[",adbc_xptr) S3method("[[<-",adbc_xptr) -S3method(adbc_async_task_cancel,adbc_async_execute_query) +S3method(adbc_async_task_cancel,adbc_async_statement_cancellable) S3method(adbc_async_task_cancel,default) S3method(adbc_async_task_result,adbc_async_execute_query) +S3method(adbc_async_task_result,adbc_async_prepare) S3method(adbc_async_task_result,adbc_async_sleep) S3method(adbc_async_task_result,adbc_async_statement_stream_get_next) +S3method(adbc_async_task_result,adbc_async_statement_stream_schema) S3method(adbc_connection_init,adbc_database_log) S3method(adbc_connection_init,adbc_database_monkey) S3method(adbc_connection_init,default) diff --git a/r/adbcdrivermanager/R/async.R b/r/adbcdrivermanager/R/async.R index 1f39607ff0..78bbe60eb5 100644 --- a/r/adbcdrivermanager/R/async.R +++ b/r/adbcdrivermanager/R/async.R @@ -134,8 +134,37 @@ adbc_async_task_result.adbc_async_sleep <- function(task) { task$user_data$duration_ms } +#' @export +adbc_async_task_cancel.adbc_async_statement_cancellable <- function(task) { + adbc_statement_cancel(task$user_data$statement) + TRUE +} + +adbc_statement_prepare_async <- function(statement) { + task <- adbc_async_task( + c("adbc_async_prepare", "adbc_async_statement_cancellable") + ) + + user_data <- task$user_data + user_data$statement <- statement + .Call(RAdbcAsyncTaskLaunchPrepare, task, statement) + + task +} + +#' @export +adbc_async_task_result.adbc_async_prepare <- function(task) { + if (!identical(task$return_code, 0L)) { + stop_for_error(task$return_code, task$error_xptr) + } + + task$user_data$statement +} + adbc_statement_execute_query_async <- function(statement, stream = NULL) { - task <- adbc_async_task("adbc_async_execute_query") + task <- adbc_async_task( + c("adbc_async_execute_query", "adbc_async_statement_cancellable") + ) user_data <- task$user_data user_data$statement <- statement @@ -151,12 +180,6 @@ adbc_statement_execute_query_async <- function(statement, stream = NULL) { task } -#' @export -adbc_async_task_cancel.adbc_async_execute_query <- function(task) { - adbc_statement_cancel(task$user_data$statement) - TRUE -} - #' @export adbc_async_task_result.adbc_async_execute_query <- function(task) { if (!identical(task$return_code, 0L)) { @@ -170,8 +193,44 @@ adbc_async_task_result.adbc_async_execute_query <- function(task) { ) } +adbc_statement_stream_get_schema_async <- function(statement, stream) { + task <- adbc_async_task( + c("adbc_async_statement_stream_get_next", "adbc_async_statement_cancellable") + ) + + user_data <- task$user_data + user_data$statement <- statement + user_data$stream <- stream + user_data$schema <- nanoarrow::nanoarrow_allocate_schema() + + user_data$rows_affected <- .Call( + RAdbcAsyncTaskLaunchStreamGetSchema, + task, + stream, + user_data$schema + ) + + task +} + + +#' @export +adbc_async_task_result.adbc_async_statement_stream_schema <- function(task) { + if (!identical(task$return_code, 0L)) { + adbc_statement_release(task$user_data$statement) + stop(task$user_data$stream$get_last_error()) + } + + list( + statement = task$user_data$statement, + array = task$user_data$schema + ) +} + adbc_statement_stream_get_next_async <- function(statement, stream) { - task <- adbc_async_task("adbc_async_statement_stream_get_next") + task <- adbc_async_task( + c("adbc_async_statement_stream_get_next", "adbc_async_statement_cancellable") + ) user_data <- task$user_data user_data$statement <- statement diff --git a/r/adbcdrivermanager/src/async.cc b/r/adbcdrivermanager/src/async.cc index bf436a6d69..6ef130192b 100644 --- a/r/adbcdrivermanager/src/async.cc +++ b/r/adbcdrivermanager/src/async.cc @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -42,6 +43,17 @@ static inline void later_ensure_initialized() { static void later_task_callback_wrapper(void* data); +struct ArrowArrayCustomDeleter { + void operator()(ArrowArray* array) const { + if (array->release != nullptr) { + array->release(array); + } + delete array; + } +}; + +using UniqueArrowArrayPtr = std::unique_ptr; + enum class RAdbcAsyncTaskStatus { NOT_STARTED, STARTED, READY }; struct RAdbcAsyncTask { @@ -217,6 +229,22 @@ extern "C" SEXP RAdbcAsyncTaskLaunchSleep(SEXP task_xptr, SEXP duration_ms_sexp) return R_NilValue; } +extern "C" SEXP RAdbcAsyncTaskLaunchPrepare(SEXP task_xptr, SEXP statement_xptr) { + auto task = adbc_from_xptr(task_xptr); + error_for_started_task(task); + + auto statement = adbc_from_xptr(statement_xptr); + + task->result = std::async(std::launch::async, [task, statement] { + *(task->return_code) = AdbcStatementPrepare(statement, task->return_error); + task->ScheduleCallbackIfSet(); + }); + + task->status = RAdbcAsyncTaskStatus::STARTED; + UNPROTECT(1); + return R_NilValue; +} + extern "C" SEXP RAdbcAsyncTaskLaunchExecuteQuery(SEXP task_xptr, SEXP statement_xptr, SEXP stream_xptr) { auto task = adbc_from_xptr(task_xptr); @@ -245,6 +273,23 @@ extern "C" SEXP RAdbcAsyncTaskLaunchExecuteQuery(SEXP task_xptr, SEXP statement_ return rows_affected_sexp; } +extern "C" SEXP RAdbcAsyncTaskLaunchStreamGetSchema(SEXP task_xptr, SEXP stream_xptr, + SEXP schema_xptr) { + auto task = adbc_from_xptr(task_xptr); + error_for_started_task(task); + + auto stream = adbc_from_xptr(stream_xptr); + auto schema = adbc_from_xptr(schema_xptr); + + task->result = std::async(std::launch::async, [task, stream, schema] { + *(task->return_code) = stream->get_schema(stream, schema); + task->ScheduleCallbackIfSet(); + }); + + task->status = RAdbcAsyncTaskStatus::STARTED; + return R_NilValue; +} + extern "C" SEXP RAdbcAsyncTaskLaunchStreamGetNext(SEXP task_xptr, SEXP stream_xptr, SEXP array_xptr) { auto task = adbc_from_xptr(task_xptr); diff --git a/r/adbcdrivermanager/src/init.c b/r/adbcdrivermanager/src/init.c index 743e5bb194..2ea4c7dde3 100644 --- a/r/adbcdrivermanager/src/init.c +++ b/r/adbcdrivermanager/src/init.c @@ -27,8 +27,11 @@ SEXP RAdbcAsyncTaskData(SEXP task_xptr); SEXP RAdbcAsyncTaskWaitFor(SEXP task_xptr, SEXP duration_ms_sexp); SEXP RAdbcAsyncTaskWait(SEXP task_xptr, SEXP resolution_ms_sexp); SEXP RAdbcAsyncTaskLaunchSleep(SEXP task_xptr, SEXP duration_ms_sexp); +SEXP RAdbcAsyncTaskLaunchPrepare(SEXP task_xptr, SEXP statement_xptr); SEXP RAdbcAsyncTaskLaunchExecuteQuery(SEXP task_xptr, SEXP statement_xptr, SEXP stream_xptr); +SEXP RAdbcAsyncTaskLaunchStreamGetSchema(SEXP task_xptr, SEXP stream_xptr, + SEXP schema_xptr); SEXP RAdbcAsyncTaskLaunchStreamGetNext(SEXP task_xptr, SEXP stream_xptr, SEXP array_xptr); SEXP RAdbcVoidDriverInitFunc(void); SEXP RAdbcMonkeyDriverInitFunc(void); @@ -118,7 +121,10 @@ static const R_CallMethodDef CallEntries[] = { {"RAdbcAsyncTaskWaitFor", (DL_FUNC)&RAdbcAsyncTaskWaitFor, 2}, {"RAdbcAsyncTaskWait", (DL_FUNC)&RAdbcAsyncTaskWait, 2}, {"RAdbcAsyncTaskLaunchSleep", (DL_FUNC)&RAdbcAsyncTaskLaunchSleep, 2}, + {"RAdbcAsyncTaskLaunchPrepare", (DL_FUNC)&RAdbcAsyncTaskLaunchPrepare, 2}, {"RAdbcAsyncTaskLaunchExecuteQuery", (DL_FUNC)&RAdbcAsyncTaskLaunchExecuteQuery, 3}, + {"RAdbcAsyncTaskLaunchStreamGetSchema", (DL_FUNC)&RAdbcAsyncTaskLaunchStreamGetSchema, + 3}, {"RAdbcAsyncTaskLaunchStreamGetNext", (DL_FUNC)&RAdbcAsyncTaskLaunchStreamGetNext, 3}, {"RAdbcVoidDriverInitFunc", (DL_FUNC)&RAdbcVoidDriverInitFunc, 0}, {"RAdbcMonkeyDriverInitFunc", (DL_FUNC)&RAdbcMonkeyDriverInitFunc, 0},