Skip to content

Commit

Permalink
test copy path
Browse files Browse the repository at this point in the history
  • Loading branch information
paleolimbot committed Jul 26, 2024
1 parent 181ebc1 commit ae107e0
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 23 deletions.
60 changes: 60 additions & 0 deletions c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,66 @@ TEST_F(PostgresStatementTest, MultipleStatementsSingleQuery) {
ASSERT_EQ(reader.array->length, 3);
}

TEST_F(PostgresStatementTest, SetUseCopyFalse) {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));

const char* query = R"(DROP TABLE IF EXISTS test_query_set_copy_false;
CREATE TABLE test_query_set_copy_false (ints INT);
INSERT INTO test_query_set_copy_false VALUES((1));
INSERT INTO test_query_set_copy_false VALUES((NULL));
INSERT INTO test_query_set_copy_false VALUES((3));)";
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, query, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));

// Check option setting/getting
ASSERT_EQ(
adbc_validation::StatementGetOption(&statement, "adbc.postgresql.use_copy", &error),
"true");

ASSERT_THAT(AdbcStatementSetOption(&statement, "adbc.postgresql.use_copy",
"not true or false", &error),
IsStatus(ADBC_STATUS_INVALID_ARGUMENT));

ASSERT_THAT(AdbcStatementSetOption(&statement, "adbc.postgresql.use_copy",
ADBC_OPTION_VALUE_ENABLED, &error),
IsOkStatus(&error));
ASSERT_EQ(
adbc_validation::StatementGetOption(&statement, "adbc.postgresql.use_copy", &error),
"true");

ASSERT_THAT(AdbcStatementSetOption(&statement, "adbc.postgresql.use_copy",
ADBC_OPTION_VALUE_DISABLED, &error),
IsOkStatus(&error));
ASSERT_EQ(
adbc_validation::StatementGetOption(&statement, "adbc.postgresql.use_copy", &error),
"false");

ASSERT_THAT(AdbcStatementSetSqlQuery(&statement,
"SELECT * FROM test_query_set_copy_false", &error),
IsOkStatus(&error));

adbc_validation::StreamReader reader;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));

ASSERT_EQ(reader.rows_affected, 3);

reader.GetSchema();
ASSERT_EQ(reader.schema->n_children, 1);
ASSERT_STREQ(reader.schema->children[0]->format, "i");
ASSERT_STREQ(reader.schema->children[0]->name, "ints");

ASSERT_THAT(reader.MaybeNext(), adbc_validation::IsOkErrno());
ASSERT_EQ(reader.array->length, 3);
ASSERT_EQ(reader.array->n_children, 1);
ASSERT_EQ(reader.array->children[0]->null_count, 1);

ASSERT_THAT(reader.MaybeNext(), adbc_validation::IsOkErrno());
ASSERT_EQ(reader.array->release, nullptr);
}

struct TypeTestCase {
std::string name;
std::string sql_type;
Expand Down
3 changes: 2 additions & 1 deletion c/driver/postgresql/result_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ int PqResultArrayReader::GetNext(struct ArrowArray* out) {

// TODO: If we get an EOVERFLOW here (e.g., big string data), we
// would need to keep track of what row number we're on and start
// from there instead of begin() on the next call
// from there instead of begin() on the next call. We could also
// respect the size hint here to chunk the batches.
struct ArrowBufferView item;
for (auto it = helper_.begin(); it != helper_.end(); it++) {
auto row = *it;
Expand Down
36 changes: 16 additions & 20 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream,

// If we have been requested to avoid COPY or there is no output requested,
// execute using the PqResultArrayReader.
if (!stream || !UseCopyIfPossible()) {
if (!stream || !use_copy_) {
PqResultArrayReader reader(connection_->conn(), type_resolver_, query_);
RAISE_ADBC(reader.ToArrayStream(rows_affected, stream, error));
return ADBC_STATUS_OK;
Expand Down Expand Up @@ -1347,6 +1347,12 @@ AdbcStatusCode PostgresStatement::GetOption(const char* key, char* value, size_t
}
} else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) {
result = std::to_string(reader_.batch_size_hint_bytes_);
} else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0) {
if (use_copy_) {
result = "true";
} else {
result = "false";
}
} else {
SetError(error, "[libpq] Unknown statement option '%s'", key);
return ADBC_STATUS_NOT_FOUND;
Expand Down Expand Up @@ -1466,6 +1472,15 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value,
}

this->reader_.batch_size_hint_bytes_ = int_value;
} else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0) {
if (std::strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) {
use_copy_ = true;
} else if (std::strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) {
use_copy_ = false;
} else {
SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key);
return ADBC_STATUS_INVALID_ARGUMENT;
}
} else {
SetError(error, "[libpq] Unknown statement option '%s'", key);
return ADBC_STATUS_NOT_IMPLEMENTED;
Expand Down Expand Up @@ -1505,23 +1520,4 @@ void PostgresStatement::ClearResult() {
reader_.Release();
}

bool PostgresStatement::UseCopyIfPossible() {
// Check if we have been explicitly requested to avoid COPY
size_t length = 0;
AdbcStatusCode status =
GetOption(ADBC_POSTGRESQL_OPTION_USE_COPY, nullptr, &length, nullptr);
if (status != ADBC_STATUS_OK) {
return true;
}

std::string out;
out.resize(length);
status = GetOption(ADBC_POSTGRESQL_OPTION_USE_COPY, out.data(), &length, nullptr);
if (status != ADBC_STATUS_OK) {
return true;
}

return out == "true";
}

} // namespace adbcpq
10 changes: 8 additions & 2 deletions c/driver/postgresql/statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ class TupleReader final {
class PostgresStatement {
public:
PostgresStatement()
: connection_(nullptr), query_(), prepared_(false), reader_(nullptr) {
: connection_(nullptr),
query_(),
prepared_(false),
use_copy_(true),
reader_(nullptr) {
std::memset(&bind_, 0, sizeof(bind_));
}

Expand Down Expand Up @@ -136,7 +140,6 @@ class PostgresStatement {
struct AdbcError* error);
AdbcStatusCode ExecuteBind(struct ArrowArrayStream* stream, int64_t* rows_affected,
struct AdbcError* error);
bool UseCopyIfPossible();

private:
std::shared_ptr<PostgresTypeResolver> type_resolver_;
Expand All @@ -155,6 +158,9 @@ class PostgresStatement {
kCreateAppend,
};

// Options
bool use_copy_;

struct {
std::string db_schema;
std::string target;
Expand Down
14 changes: 14 additions & 0 deletions c/validation/adbc_validation_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ std::optional<std::string> ConnectionGetOption(struct AdbcConnection* connection
return std::string(buffer, buffer_size - 1);
}

std::optional<std::string> StatementGetOption(struct AdbcStatement* statement,
std::string_view option,
struct AdbcError* error) {
char buffer[128];
size_t buffer_size = sizeof(buffer);
AdbcStatusCode status =
AdbcStatementGetOption(statement, option.data(), buffer, &buffer_size, error);
EXPECT_THAT(status, IsOkStatus(error));
if (status != ADBC_STATUS_OK) return std::nullopt;
EXPECT_GT(buffer_size, 0);
if (buffer_size == 0) return std::nullopt;
return std::string(buffer, buffer_size - 1);
}

std::string StatusCodeToString(AdbcStatusCode code) {
#define CASE(CONSTANT) \
case ADBC_STATUS_##CONSTANT: \
Expand Down
4 changes: 4 additions & 0 deletions c/validation/adbc_validation_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ std::optional<std::string> ConnectionGetOption(struct AdbcConnection* connection
std::string_view option,
struct AdbcError* error);

std::optional<std::string> StatementGetOption(struct AdbcStatement* statement,
std::string_view option,
struct AdbcError* error);

// ------------------------------------------------------------
// Helpers to print values

Expand Down

0 comments on commit ae107e0

Please sign in to comment.