diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index 35ddbd786a..1af87f56a3 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -1429,6 +1429,66 @@ TEST_F(PostgresStatementTest, MultipleStatementsSingleQuery) { ASSERT_EQ(reader.array->length, 3); } +TEST_F(PostgresStatementTest, SetUseCopyFalse) { + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + const char* query = R"(DROP TABLE IF EXISTS test_query_set_copy_false; + CREATE TABLE test_query_set_copy_false (ints INT); + INSERT INTO test_query_set_copy_false VALUES((1)); + INSERT INTO test_query_set_copy_false VALUES((NULL)); + INSERT INTO test_query_set_copy_false VALUES((3));)"; + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + + // Check option setting/getting + ASSERT_EQ( + adbc_validation::StatementGetOption(&statement, "adbc.postgresql.use_copy", &error), + "true"); + + ASSERT_THAT(AdbcStatementSetOption(&statement, "adbc.postgresql.use_copy", + "not true or false", &error), + IsStatus(ADBC_STATUS_INVALID_ARGUMENT)); + + ASSERT_THAT(AdbcStatementSetOption(&statement, "adbc.postgresql.use_copy", + ADBC_OPTION_VALUE_ENABLED, &error), + IsOkStatus(&error)); + ASSERT_EQ( + adbc_validation::StatementGetOption(&statement, "adbc.postgresql.use_copy", &error), + "true"); + + ASSERT_THAT(AdbcStatementSetOption(&statement, "adbc.postgresql.use_copy", + ADBC_OPTION_VALUE_DISABLED, &error), + IsOkStatus(&error)); + ASSERT_EQ( + adbc_validation::StatementGetOption(&statement, "adbc.postgresql.use_copy", &error), + "false"); + + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, + "SELECT * FROM test_query_set_copy_false", &error), + IsOkStatus(&error)); + + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + + ASSERT_EQ(reader.rows_affected, 3); + + reader.GetSchema(); + ASSERT_EQ(reader.schema->n_children, 1); + ASSERT_STREQ(reader.schema->children[0]->format, "i"); + ASSERT_STREQ(reader.schema->children[0]->name, "ints"); + + ASSERT_THAT(reader.MaybeNext(), adbc_validation::IsOkErrno()); + ASSERT_EQ(reader.array->length, 3); + ASSERT_EQ(reader.array->n_children, 1); + ASSERT_EQ(reader.array->children[0]->null_count, 1); + + ASSERT_THAT(reader.MaybeNext(), adbc_validation::IsOkErrno()); + ASSERT_EQ(reader.array->release, nullptr); +} + struct TypeTestCase { std::string name; std::string sql_type; diff --git a/c/driver/postgresql/result_helper.cc b/c/driver/postgresql/result_helper.cc index 170d58291d..df890a7c51 100644 --- a/c/driver/postgresql/result_helper.cc +++ b/c/driver/postgresql/result_helper.cc @@ -253,7 +253,8 @@ int PqResultArrayReader::GetNext(struct ArrowArray* out) { // TODO: If we get an EOVERFLOW here (e.g., big string data), we // would need to keep track of what row number we're on and start - // from there instead of begin() on the next call + // from there instead of begin() on the next call. We could also + // respect the size hint here to chunk the batches. struct ArrowBufferView item; for (auto it = helper_.begin(); it != helper_.end(); it++) { auto row = *it; diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index 8299848d20..9c3824485d 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -1169,7 +1169,7 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, // If we have been requested to avoid COPY or there is no output requested, // execute using the PqResultArrayReader. - if (!stream || !UseCopyIfPossible()) { + if (!stream || !use_copy_) { PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); RAISE_ADBC(reader.ToArrayStream(rows_affected, stream, error)); return ADBC_STATUS_OK; @@ -1347,6 +1347,12 @@ AdbcStatusCode PostgresStatement::GetOption(const char* key, char* value, size_t } } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { result = std::to_string(reader_.batch_size_hint_bytes_); + } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0) { + if (use_copy_) { + result = "true"; + } else { + result = "false"; + } } else { SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_FOUND; @@ -1466,6 +1472,15 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value, } this->reader_.batch_size_hint_bytes_ = int_value; + } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0) { + if (std::strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { + use_copy_ = true; + } else if (std::strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) { + use_copy_ = false; + } else { + SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key); + return ADBC_STATUS_INVALID_ARGUMENT; + } } else { SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_IMPLEMENTED; @@ -1505,23 +1520,4 @@ void PostgresStatement::ClearResult() { reader_.Release(); } -bool PostgresStatement::UseCopyIfPossible() { - // Check if we have been explicitly requested to avoid COPY - size_t length = 0; - AdbcStatusCode status = - GetOption(ADBC_POSTGRESQL_OPTION_USE_COPY, nullptr, &length, nullptr); - if (status != ADBC_STATUS_OK) { - return true; - } - - std::string out; - out.resize(length); - status = GetOption(ADBC_POSTGRESQL_OPTION_USE_COPY, out.data(), &length, nullptr); - if (status != ADBC_STATUS_OK) { - return true; - } - - return out == "true"; -} - } // namespace adbcpq diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h index 8837a0ad35..d29f383873 100644 --- a/c/driver/postgresql/statement.h +++ b/c/driver/postgresql/statement.h @@ -92,7 +92,11 @@ class TupleReader final { class PostgresStatement { public: PostgresStatement() - : connection_(nullptr), query_(), prepared_(false), reader_(nullptr) { + : connection_(nullptr), + query_(), + prepared_(false), + use_copy_(true), + reader_(nullptr) { std::memset(&bind_, 0, sizeof(bind_)); } @@ -136,7 +140,6 @@ class PostgresStatement { struct AdbcError* error); AdbcStatusCode ExecuteBind(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error); - bool UseCopyIfPossible(); private: std::shared_ptr type_resolver_; @@ -155,6 +158,9 @@ class PostgresStatement { kCreateAppend, }; + // Options + bool use_copy_; + struct { std::string db_schema; std::string target; diff --git a/c/validation/adbc_validation_util.cc b/c/validation/adbc_validation_util.cc index 24310aba3d..6d018c89a6 100644 --- a/c/validation/adbc_validation_util.cc +++ b/c/validation/adbc_validation_util.cc @@ -36,6 +36,20 @@ std::optional ConnectionGetOption(struct AdbcConnection* connection return std::string(buffer, buffer_size - 1); } +std::optional StatementGetOption(struct AdbcStatement* statement, + std::string_view option, + struct AdbcError* error) { + char buffer[128]; + size_t buffer_size = sizeof(buffer); + AdbcStatusCode status = + AdbcStatementGetOption(statement, option.data(), buffer, &buffer_size, error); + EXPECT_THAT(status, IsOkStatus(error)); + if (status != ADBC_STATUS_OK) return std::nullopt; + EXPECT_GT(buffer_size, 0); + if (buffer_size == 0) return std::nullopt; + return std::string(buffer, buffer_size - 1); +} + std::string StatusCodeToString(AdbcStatusCode code) { #define CASE(CONSTANT) \ case ADBC_STATUS_##CONSTANT: \ diff --git a/c/validation/adbc_validation_util.h b/c/validation/adbc_validation_util.h index 08401f2b46..0027cd5f74 100644 --- a/c/validation/adbc_validation_util.h +++ b/c/validation/adbc_validation_util.h @@ -43,6 +43,10 @@ std::optional ConnectionGetOption(struct AdbcConnection* connection std::string_view option, struct AdbcError* error); +std::optional StatementGetOption(struct AdbcStatement* statement, + std::string_view option, + struct AdbcError* error); + // ------------------------------------------------------------ // Helpers to print values