From 44bf25d529dc2b18445f54850c0be6f08b348510 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 7 Aug 2024 11:23:34 -0300 Subject: [PATCH 01/25] separate the bind stream --- c/driver/postgresql/bind_stream.h | 585 ++++++++++++++++++++++++++++++ c/driver/postgresql/statement.cc | 4 +- 2 files changed, 586 insertions(+), 3 deletions(-) create mode 100644 c/driver/postgresql/bind_stream.h diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h new file mode 100644 index 0000000000..7375cc2c17 --- /dev/null +++ b/c/driver/postgresql/bind_stream.h @@ -0,0 +1,585 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include + +#include "copy/writer.h" +#include "driver/common/utils.h" +#include "error.h" +#include "postgres_type.h" +#include "postgres_util.h" + +namespace adbcpq { + +/// The flag indicating to PostgreSQL that we want binary-format values. +constexpr int kPgBinaryFormat = 1; + +/// Helper to manage bind parameters with a prepared statement +struct BindStream { + Handle bind; + Handle bind_schema; + struct ArrowSchemaView bind_schema_view; + std::vector bind_schema_fields; + + // OIDs for parameter types + std::vector param_types; + std::vector param_values; + std::vector param_lengths; + std::vector param_formats; + std::vector param_values_offsets; + std::vector param_values_buffer; + // XXX: this assumes fixed-length fields only - will need more + // consideration to deal with variable-length fields + + bool has_tz_field = false; + std::string tz_setting; + + struct ArrowError na_error; + + explicit BindStream(struct ArrowArrayStream&& bind) { + this->bind.value = std::move(bind); + std::memset(&na_error, 0, sizeof(na_error)); + } + + template + AdbcStatusCode Begin(Callback&& callback, struct AdbcError* error) { + CHECK_NA(INTERNAL, bind->get_schema(&bind.value, &bind_schema.value), error); + CHECK_NA( + INTERNAL, + ArrowSchemaViewInit(&bind_schema_view, &bind_schema.value, /*error*/ nullptr), + error); + + if (bind_schema_view.type != ArrowType::NANOARROW_TYPE_STRUCT) { + SetError(error, "%s", "[libpq] Bind parameters must have type STRUCT"); + return ADBC_STATUS_INVALID_STATE; + } + + bind_schema_fields.resize(bind_schema->n_children); + for (size_t i = 0; i < bind_schema_fields.size(); i++) { + CHECK_NA(INTERNAL, + ArrowSchemaViewInit(&bind_schema_fields[i], bind_schema->children[i], + /*error*/ nullptr), + error); + } + + return std::move(callback)(); + } + + AdbcStatusCode SetParamTypes(const PostgresTypeResolver& type_resolver, + struct AdbcError* error) { + param_types.resize(bind_schema->n_children); + param_values.resize(bind_schema->n_children); + param_lengths.resize(bind_schema->n_children); + param_formats.resize(bind_schema->n_children, kPgBinaryFormat); + param_values_offsets.reserve(bind_schema->n_children); + + for (size_t i = 0; i < bind_schema_fields.size(); i++) { + PostgresTypeId type_id; + switch (bind_schema_fields[i].type) { + case ArrowType::NANOARROW_TYPE_BOOL: + type_id = PostgresTypeId::kBool; + param_lengths[i] = 1; + break; + case ArrowType::NANOARROW_TYPE_INT8: + case ArrowType::NANOARROW_TYPE_INT16: + type_id = PostgresTypeId::kInt2; + param_lengths[i] = 2; + break; + case ArrowType::NANOARROW_TYPE_INT32: + type_id = PostgresTypeId::kInt4; + param_lengths[i] = 4; + break; + case ArrowType::NANOARROW_TYPE_INT64: + type_id = PostgresTypeId::kInt8; + param_lengths[i] = 8; + break; + case ArrowType::NANOARROW_TYPE_FLOAT: + type_id = PostgresTypeId::kFloat4; + param_lengths[i] = 4; + break; + case ArrowType::NANOARROW_TYPE_DOUBLE: + type_id = PostgresTypeId::kFloat8; + param_lengths[i] = 8; + break; + case ArrowType::NANOARROW_TYPE_STRING: + case ArrowType::NANOARROW_TYPE_LARGE_STRING: + type_id = PostgresTypeId::kText; + param_lengths[i] = 0; + break; + case ArrowType::NANOARROW_TYPE_BINARY: + type_id = PostgresTypeId::kBytea; + param_lengths[i] = 0; + break; + case ArrowType::NANOARROW_TYPE_DATE32: + type_id = PostgresTypeId::kDate; + param_lengths[i] = 4; + break; + case ArrowType::NANOARROW_TYPE_TIMESTAMP: + type_id = PostgresTypeId::kTimestamp; + param_lengths[i] = 8; + break; + case ArrowType::NANOARROW_TYPE_DURATION: + case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: + type_id = PostgresTypeId::kInterval; + param_lengths[i] = 16; + break; + case ArrowType::NANOARROW_TYPE_DECIMAL128: + case ArrowType::NANOARROW_TYPE_DECIMAL256: + type_id = PostgresTypeId::kNumeric; + param_lengths[i] = 0; + break; + case ArrowType::NANOARROW_TYPE_DICTIONARY: { + struct ArrowSchemaView value_view; + CHECK_NA(INTERNAL, + ArrowSchemaViewInit(&value_view, bind_schema->children[i]->dictionary, + nullptr), + error); + switch (value_view.type) { + case NANOARROW_TYPE_BINARY: + case NANOARROW_TYPE_LARGE_BINARY: + type_id = PostgresTypeId::kBytea; + param_lengths[i] = 0; + break; + case NANOARROW_TYPE_STRING: + case NANOARROW_TYPE_LARGE_STRING: + type_id = PostgresTypeId::kText; + param_lengths[i] = 0; + break; + default: + SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", + static_cast(i + 1), " ('", + bind_schema->children[i]->name, + "') has unsupported dictionary value parameter type ", + ArrowTypeString(value_view.type)); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + break; + } + default: + SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", + static_cast(i + 1), " ('", bind_schema->children[i]->name, + "') has unsupported parameter type ", + ArrowTypeString(bind_schema_fields[i].type)); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + param_types[i] = type_resolver.GetOID(type_id); + if (param_types[i] == 0) { + SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", + static_cast(i + 1), " ('", bind_schema->children[i]->name, + "') has type with no corresponding PostgreSQL type ", + ArrowTypeString(bind_schema_fields[i].type)); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + } + + size_t param_values_length = 0; + for (int length : param_lengths) { + param_values_offsets.push_back(param_values_length); + param_values_length += length; + } + param_values_buffer.resize(param_values_length); + return ADBC_STATUS_OK; + } + + AdbcStatusCode Prepare(const PostgresConnection* conn, const std::string& query, + struct AdbcError* error, const bool autocommit) { + // tz-aware timestamps require special handling to set the timezone to UTC + // prior to sending over the binary protocol; must be reset after execute + const auto pg_conn = conn->conn(); + for (int64_t col = 0; col < bind_schema->n_children; col++) { + if ((bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) && + (strcmp("", bind_schema_fields[col].timezone))) { + has_tz_field = true; + + if (autocommit) { + PGresult* begin_result = PQexec(pg_conn, "BEGIN"); + if (PQresultStatus(begin_result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, begin_result, + "[libpq] Failed to begin transaction for timezone data: %s", + PQerrorMessage(pg_conn)); + PQclear(begin_result); + return code; + } + PQclear(begin_result); + } + + PGresult* get_tz_result = PQexec(pg_conn, "SELECT current_setting('TIMEZONE')"); + if (PQresultStatus(get_tz_result) != PGRES_TUPLES_OK) { + AdbcStatusCode code = SetError(error, get_tz_result, + "[libpq] Could not query current timezone: %s", + PQerrorMessage(pg_conn)); + PQclear(get_tz_result); + return code; + } + + tz_setting = std::string(PQgetvalue(get_tz_result, 0, 0)); + PQclear(get_tz_result); + + PGresult* set_utc_result = PQexec(pg_conn, "SET TIME ZONE 'UTC'"); + if (PQresultStatus(set_utc_result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = SetError(error, set_utc_result, + "[libpq] Failed to set time zone to UTC: %s", + PQerrorMessage(pg_conn)); + PQclear(set_utc_result); + return code; + } + PQclear(set_utc_result); + break; + } + } + + PGresult* result = PQprepare(pg_conn, /*stmtName=*/"", query.c_str(), + /*nParams=*/bind_schema->n_children, param_types.data()); + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to prepare query: %s\nQuery was:%s", + PQerrorMessage(pg_conn), query.c_str()); + PQclear(result); + return code; + } + PQclear(result); + return ADBC_STATUS_OK; + } + + AdbcStatusCode Execute(const PostgresConnection* conn, int64_t* rows_affected, + struct AdbcError* error) { + if (rows_affected) *rows_affected = 0; + PGresult* result = nullptr; + const auto pg_conn = conn->conn(); + + while (true) { + Handle array; + int res = bind->get_next(&bind.value, &array.value); + if (res != 0) { + SetError(error, + "[libpq] Failed to read next batch from stream of bind parameters: " + "(%d) %s %s", + res, std::strerror(res), bind->get_last_error(&bind.value)); + return ADBC_STATUS_IO; + } + if (!array->release) break; + + Handle array_view; + // TODO: include error messages + CHECK_NA( + INTERNAL, + ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, nullptr), + error); + CHECK_NA(INTERNAL, ArrowArrayViewSetArray(&array_view.value, &array.value, nullptr), + error); + + for (int64_t row = 0; row < array->length; row++) { + for (int64_t col = 0; col < array_view->n_children; col++) { + if (ArrowArrayViewIsNull(array_view->children[col], row)) { + param_values[col] = nullptr; + continue; + } else { + param_values[col] = param_values_buffer.data() + param_values_offsets[col]; + } + switch (bind_schema_fields[col].type) { + case ArrowType::NANOARROW_TYPE_BOOL: { + const int8_t val = ArrowBitGet( + array_view->children[col]->buffer_views[1].data.as_uint8, row); + std::memcpy(param_values[col], &val, sizeof(int8_t)); + break; + } + + case ArrowType::NANOARROW_TYPE_INT8: { + const int16_t val = + array_view->children[col]->buffer_views[1].data.as_int8[row]; + const uint16_t value = ToNetworkInt16(val); + std::memcpy(param_values[col], &value, sizeof(int16_t)); + break; + } + case ArrowType::NANOARROW_TYPE_INT16: { + const uint16_t value = ToNetworkInt16( + array_view->children[col]->buffer_views[1].data.as_int16[row]); + std::memcpy(param_values[col], &value, sizeof(int16_t)); + break; + } + case ArrowType::NANOARROW_TYPE_INT32: { + const uint32_t value = ToNetworkInt32( + array_view->children[col]->buffer_views[1].data.as_int32[row]); + std::memcpy(param_values[col], &value, sizeof(int32_t)); + break; + } + case ArrowType::NANOARROW_TYPE_INT64: { + const int64_t value = ToNetworkInt64( + array_view->children[col]->buffer_views[1].data.as_int64[row]); + std::memcpy(param_values[col], &value, sizeof(int64_t)); + break; + } + case ArrowType::NANOARROW_TYPE_FLOAT: { + const uint32_t value = ToNetworkFloat4( + array_view->children[col]->buffer_views[1].data.as_float[row]); + std::memcpy(param_values[col], &value, sizeof(uint32_t)); + break; + } + case ArrowType::NANOARROW_TYPE_DOUBLE: { + const uint64_t value = ToNetworkFloat8( + array_view->children[col]->buffer_views[1].data.as_double[row]); + std::memcpy(param_values[col], &value, sizeof(uint64_t)); + break; + } + case ArrowType::NANOARROW_TYPE_STRING: + case ArrowType::NANOARROW_TYPE_LARGE_STRING: + case ArrowType::NANOARROW_TYPE_BINARY: { + const ArrowBufferView view = + ArrowArrayViewGetBytesUnsafe(array_view->children[col], row); + // TODO: overflow check? + param_lengths[col] = static_cast(view.size_bytes); + param_values[col] = const_cast(view.data.as_char); + break; + } + case ArrowType::NANOARROW_TYPE_DATE32: { + // 2000-01-01 + constexpr int32_t kPostgresDateEpoch = 10957; + const int32_t raw_value = + array_view->children[col]->buffer_views[1].data.as_int32[row]; + if (raw_value < INT32_MIN + kPostgresDateEpoch) { + SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1, + "('", bind_schema->children[col]->name, "') Row #", row + 1, + "has value which exceeds postgres date limits"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + const uint32_t value = ToNetworkInt32(raw_value - kPostgresDateEpoch); + std::memcpy(param_values[col], &value, sizeof(int32_t)); + break; + } + case ArrowType::NANOARROW_TYPE_DURATION: + case ArrowType::NANOARROW_TYPE_TIMESTAMP: { + int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row]; + + bool overflow_safe = true; + + auto unit = bind_schema_fields[col].time_unit; + + switch (unit) { + case NANOARROW_TIME_UNIT_SECOND: + overflow_safe = + val <= kMaxSafeSecondsToMicros && val >= kMinSafeSecondsToMicros; + if (overflow_safe) { + val *= 1000000; + } + + break; + case NANOARROW_TIME_UNIT_MILLI: + overflow_safe = + val <= kMaxSafeMillisToMicros && val >= kMinSafeMillisToMicros; + if (overflow_safe) { + val *= 1000; + } + break; + case NANOARROW_TIME_UNIT_MICRO: + break; + case NANOARROW_TIME_UNIT_NANO: + val /= 1000; + break; + } + + if (!overflow_safe) { + SetError(error, + "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 + " has value '%" PRIi64 + "' which exceeds PostgreSQL timestamp limits", + col + 1, bind_schema->children[col]->name, row + 1, + array_view->children[col]->buffer_views[1].data.as_int64[row]); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + if (val < (std::numeric_limits::min)() + kPostgresTimestampEpoch) { + SetError(error, + "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 + " has value '%" PRIi64 "' which would underflow", + col + 1, bind_schema->children[col]->name, row + 1, + array_view->children[col]->buffer_views[1].data.as_int64[row]); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) { + const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch); + std::memcpy(param_values[col], &value, sizeof(int64_t)); + } else if (bind_schema_fields[col].type == + ArrowType::NANOARROW_TYPE_DURATION) { + // postgres stores an interval as a 64 bit offset in microsecond + // resolution alongside a 32 bit day and 32 bit month + // for now we just send 0 for the day / month values + const uint64_t value = ToNetworkInt64(val); + std::memcpy(param_values[col], &value, sizeof(int64_t)); + std::memset(param_values[col] + sizeof(int64_t), 0, sizeof(int64_t)); + } + break; + } + case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { + struct ArrowInterval interval; + ArrowIntervalInit(&interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO); + ArrowArrayViewGetIntervalUnsafe(array_view->children[col], row, &interval); + + const uint32_t months = ToNetworkInt32(interval.months); + const uint32_t days = ToNetworkInt32(interval.days); + const uint64_t ms = ToNetworkInt64(interval.ns / 1000); + + std::memcpy(param_values[col], &ms, sizeof(uint64_t)); + std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t)); + std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t), + &months, sizeof(uint32_t)); + break; + } + default: + SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('", + bind_schema->children[col]->name, + "') has unsupported type for ingestion ", + ArrowTypeString(bind_schema_fields[col].type)); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + } + + result = PQexecPrepared(pg_conn, /*stmtName=*/"", + /*nParams=*/bind_schema->n_children, param_values.data(), + param_lengths.data(), param_formats.data(), + /*resultFormat=*/0 /*text*/); + + ExecStatusType pg_status = PQresultStatus(result); + if (pg_status != PGRES_COMMAND_OK) { + AdbcStatusCode code = SetError( + error, result, "[libpq] Failed to execute prepared statement: %s %s", + PQresStatus(pg_status), PQerrorMessage(pg_conn)); + PQclear(result); + return code; + } + + PQclear(result); + } + if (rows_affected) *rows_affected += array->length; + + if (has_tz_field) { + std::string reset_query = "SET TIME ZONE '" + tz_setting + "'"; + PGresult* reset_tz_result = PQexec(pg_conn, reset_query.c_str()); + if (PQresultStatus(reset_tz_result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, reset_tz_result, "[libpq] Failed to reset time zone: %s", + PQerrorMessage(pg_conn)); + PQclear(reset_tz_result); + return code; + } + PQclear(reset_tz_result); + + PGresult* commit_result = PQexec(pg_conn, "COMMIT"); + if (PQresultStatus(commit_result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, commit_result, "[libpq] Failed to commit transaction: %s", + PQerrorMessage(pg_conn)); + PQclear(commit_result); + return code; + } + PQclear(commit_result); + } + } + return ADBC_STATUS_OK; + } + + AdbcStatusCode ExecuteCopy(const PostgresConnection* conn, int64_t* rows_affected, + struct AdbcError* error) { + // https://github.com/apache/arrow-adbc/issues/1921: PostgreSQL has a max + // size for a single message that we need to respect (1 GiB - 1). Since + // the buffer can be chunked up as much as we want, go for 16 MiB as our + // limit. + // https://github.com/postgres/postgres/blob/23c5a0e7d43bc925c6001538f04a458933a11fc1/src/common/stringinfo.c#L28 + constexpr int64_t kMaxCopyBufferSize = 0x1000000; + if (rows_affected) *rows_affected = 0; + const auto pg_conn = conn->conn(); + + PostgresCopyStreamWriter writer; + CHECK_NA(INTERNAL, writer.Init(&bind_schema.value), error); + CHECK_NA(INTERNAL, writer.InitFieldWriters(*conn->type_resolver(), nullptr), error); + + CHECK_NA(INTERNAL, writer.WriteHeader(nullptr), error); + + while (true) { + Handle array; + int res = bind->get_next(&bind.value, &array.value); + if (res != 0) { + SetError(error, + "[libpq] Failed to read next batch from stream of bind parameters: " + "(%d) %s %s", + res, std::strerror(res), bind->get_last_error(&bind.value)); + return ADBC_STATUS_IO; + } + if (!array->release) break; + + CHECK_NA(INTERNAL, writer.SetArray(&array.value), error); + + // build writer buffer + int write_result; + do { + write_result = writer.WriteRecord(nullptr); + } while (write_result == NANOARROW_OK); + + // check if not ENODATA at exit + if (write_result != ENODATA) { + SetError(error, "Error occurred writing COPY data: %s", PQerrorMessage(pg_conn)); + return ADBC_STATUS_IO; + } + + ArrowBuffer buffer = writer.WriteBuffer(); + { + auto* data = reinterpret_cast(buffer.data); + int64_t remaining = buffer.size_bytes; + while (remaining > 0) { + int64_t to_write = std::min(remaining, kMaxCopyBufferSize); + if (PQputCopyData(pg_conn, data, to_write) <= 0) { + SetError(error, "Error writing tuple field data: %s", + PQerrorMessage(pg_conn)); + return ADBC_STATUS_IO; + } + remaining -= to_write; + data += to_write; + } + } + + if (rows_affected) *rows_affected += array->length; + writer.Rewind(); + } + + if (PQputCopyEnd(pg_conn, NULL) <= 0) { + SetError(error, "Error message returned by PQputCopyEnd: %s", + PQerrorMessage(pg_conn)); + return ADBC_STATUS_IO; + } + + PGresult* result = PQgetResult(pg_conn); + ExecStatusType pg_status = PQresultStatus(result); + if (pg_status != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to execute COPY statement: %s %s", + PQresStatus(pg_status), PQerrorMessage(pg_conn)); + PQclear(result); + return code; + } + + PQclear(result); + return ADBC_STATUS_OK; + } +}; +} // namespace adbcpq diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index c6e012581c..e9bd25ee01 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -36,8 +36,8 @@ #include #include +#include "bind_stream.h" #include "connection.h" -#include "copy/writer.h" #include "driver/common/options.h" #include "driver/common/utils.h" #include "error.h" @@ -48,8 +48,6 @@ namespace adbcpq { namespace { -/// The flag indicating to PostgreSQL that we want binary-format values. -constexpr int kPgBinaryFormat = 1; /// One-value ArrowArrayStream used to unify the implementations of Bind struct OneValueStream { From 38637c452e82bb0efd5feca444cffeb9097eadfb Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 7 Aug 2024 11:28:44 -0300 Subject: [PATCH 02/25] use nanoarrow's single-value stream --- c/driver/postgresql/statement.cc | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index e9bd25ee01..b887af82be 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -900,13 +900,7 @@ AdbcStatusCode PostgresStatement::Bind(struct ArrowArray* values, if (bind_.release) bind_.release(&bind_); // Make a one-value stream - bind_.private_data = new OneValueStream{*schema, *values}; - bind_.get_schema = &OneValueStream::GetSchema; - bind_.get_next = &OneValueStream::GetNext; - bind_.get_last_error = &OneValueStream::GetLastError; - bind_.release = &OneValueStream::Release; - std::memset(values, 0, sizeof(*values)); - std::memset(schema, 0, sizeof(*schema)); + nanoarrow::VectorArrayStream(schema, values).ToArrayStream(&bind_); return ADBC_STATUS_OK; } From efe046af5d1ed37280e42bb0302a957730263696 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 7 Aug 2024 11:39:40 -0300 Subject: [PATCH 03/25] split result reader from result helper --- c/driver/postgresql/CMakeLists.txt | 1 + c/driver/postgresql/result_helper.cc | 154 ----------------------- c/driver/postgresql/result_helper.h | 47 ------- c/driver/postgresql/result_reader.cc | 180 +++++++++++++++++++++++++++ c/driver/postgresql/result_reader.h | 74 +++++++++++ c/driver/postgresql/statement.cc | 1 + 6 files changed, 256 insertions(+), 201 deletions(-) create mode 100644 c/driver/postgresql/result_reader.cc create mode 100644 c/driver/postgresql/result_reader.h diff --git a/c/driver/postgresql/CMakeLists.txt b/c/driver/postgresql/CMakeLists.txt index 8a004e0d09..a720696c6a 100644 --- a/c/driver/postgresql/CMakeLists.txt +++ b/c/driver/postgresql/CMakeLists.txt @@ -33,6 +33,7 @@ add_arrow_lib(adbc_driver_postgresql database.cc postgresql.cc result_helper.cc + result_reader.cc statement.cc OUTPUTS ADBC_LIBRARIES diff --git a/c/driver/postgresql/result_helper.cc b/c/driver/postgresql/result_helper.cc index df890a7c51..157b100b73 100644 --- a/c/driver/postgresql/result_helper.cc +++ b/c/driver/postgresql/result_helper.cc @@ -20,7 +20,6 @@ #include #include -#include "copy/reader.h" #include "driver/common/utils.h" #include "error.h" @@ -216,157 +215,4 @@ int64_t PqResultHelper::AffectedRows() { } } -int PqResultArrayReader::GetSchema(struct ArrowSchema* out) { - ResetErrors(); - - if (schema_->release == nullptr) { - AdbcStatusCode status = Initialize(&error_); - if (status != ADBC_STATUS_OK) { - return EINVAL; - } - } - - return ArrowSchemaDeepCopy(schema_.get(), out); -} - -int PqResultArrayReader::GetNext(struct ArrowArray* out) { - ResetErrors(); - - if (schema_->release == nullptr) { - AdbcStatusCode status = Initialize(&error_); - if (status != ADBC_STATUS_OK) { - return EINVAL; - } - } - - if (!helper_.HasResult()) { - out->release = nullptr; - return NANOARROW_OK; - } - - nanoarrow::UniqueArray tmp; - NANOARROW_RETURN_NOT_OK(ArrowArrayInitFromSchema(tmp.get(), schema_.get(), &na_error_)); - NANOARROW_RETURN_NOT_OK(ArrowArrayStartAppending(tmp.get())); - for (int i = 0; i < helper_.NumColumns(); i++) { - NANOARROW_RETURN_NOT_OK(field_readers_[i]->InitArray(tmp->children[i])); - } - - // TODO: If we get an EOVERFLOW here (e.g., big string data), we - // would need to keep track of what row number we're on and start - // from there instead of begin() on the next call. We could also - // respect the size hint here to chunk the batches. - struct ArrowBufferView item; - for (auto it = helper_.begin(); it != helper_.end(); it++) { - auto row = *it; - for (int i = 0; i < helper_.NumColumns(); i++) { - auto pg_item = row[i]; - item.data.data = pg_item.data; - - if (pg_item.is_null) { - item.size_bytes = -1; - } else { - item.size_bytes = pg_item.len; - } - - NANOARROW_RETURN_NOT_OK( - field_readers_[i]->Read(&item, item.size_bytes, tmp->children[i], &na_error_)); - } - } - - for (int i = 0; i < helper_.NumColumns(); i++) { - NANOARROW_RETURN_NOT_OK(field_readers_[i]->FinishArray(tmp->children[i], &na_error_)); - } - - tmp->length = helper_.NumRows(); - tmp->null_count = 0; - NANOARROW_RETURN_NOT_OK(ArrowArrayFinishBuildingDefault(tmp.get(), &na_error_)); - - // Ensure that the next call to GetNext() will signal the end of the stream - helper_.ClearResult(); - - // Canonically return zero-size results as an empty stream - if (tmp->length == 0) { - out->release = nullptr; - return NANOARROW_OK; - } - - ArrowArrayMove(tmp.get(), out); - return NANOARROW_OK; -} - -const char* PqResultArrayReader::GetLastError() { - if (error_.message != nullptr) { - return error_.message; - } else { - return na_error_.message; - } -} - -AdbcStatusCode PqResultArrayReader::Initialize(struct AdbcError* error) { - helper_.set_output_format(PqResultHelper::Format::kBinary); - RAISE_ADBC(helper_.Execute(error)); - - ArrowSchemaInit(schema_.get()); - CHECK_NA_DETAIL(INTERNAL, ArrowSchemaSetTypeStruct(schema_.get(), helper_.NumColumns()), - &na_error_, error); - - for (int i = 0; i < helper_.NumColumns(); i++) { - PostgresType child_type; - CHECK_NA_DETAIL(INTERNAL, - type_resolver_->Find(helper_.FieldType(i), &child_type, &na_error_), - &na_error_, error); - - CHECK_NA(INTERNAL, child_type.SetSchema(schema_->children[i]), error); - CHECK_NA(INTERNAL, ArrowSchemaSetName(schema_->children[i], helper_.FieldName(i)), - error); - - std::unique_ptr child_reader; - CHECK_NA_DETAIL( - INTERNAL, - MakeCopyFieldReader(child_type, schema_->children[i], &child_reader, &na_error_), - &na_error_, error); - - child_reader->Init(child_type); - CHECK_NA_DETAIL(INTERNAL, child_reader->InitSchema(schema_->children[i]), &na_error_, - error); - - field_readers_.push_back(std::move(child_reader)); - } - - return ADBC_STATUS_OK; -} - -AdbcStatusCode PqResultArrayReader::ToArrayStream(int64_t* affected_rows, - struct ArrowArrayStream* out, - struct AdbcError* error) { - if (out == nullptr) { - // If there is no output requested, we still need to execute and set - // affected_rows if needed. We don't need an output schema or to set - // up a copy reader, so we can skip those steps by going straight - // to Execute(). This also enables us to support queries with multiple - // statements because we can call PQexec() instead of PQexecParams(). - RAISE_ADBC(helper_.Execute(error)); - - if (affected_rows != nullptr) { - *affected_rows = helper_.AffectedRows(); - } - - return ADBC_STATUS_OK; - } - - // Execute eagerly. We need this to provide row counts for DELETE and - // CREATE TABLE queries as well as to provide more informative errors - // until this reader class is wired up to provide extended AdbcError - // information. - RAISE_ADBC(Initialize(error)); - if (affected_rows != nullptr) { - *affected_rows = helper_.AffectedRows(); - } - - nanoarrow::ArrayStreamFactory::InitArrayStream( - new PqResultArrayReader(this), out); - - return ADBC_STATUS_OK; -} - } // namespace adbcpq diff --git a/c/driver/postgresql/result_helper.h b/c/driver/postgresql/result_helper.h index 43083b8bcb..d18ee8222e 100644 --- a/c/driver/postgresql/result_helper.h +++ b/c/driver/postgresql/result_helper.h @@ -169,51 +169,4 @@ class PqResultHelper { struct AdbcError* error); }; -class PqResultArrayReader { - public: - PqResultArrayReader(PGconn* conn, std::shared_ptr type_resolver, - std::string query) - : helper_(conn, std::move(query)), type_resolver_(type_resolver) { - ArrowErrorInit(&na_error_); - error_ = ADBC_ERROR_INIT; - } - - ~PqResultArrayReader() { ResetErrors(); } - - int GetSchema(struct ArrowSchema* out); - int GetNext(struct ArrowArray* out); - const char* GetLastError(); - - AdbcStatusCode ToArrayStream(int64_t* affected_rows, struct ArrowArrayStream* out, - struct AdbcError* error); - - AdbcStatusCode Initialize(struct AdbcError* error); - - private: - PqResultHelper helper_; - std::shared_ptr type_resolver_; - std::vector> field_readers_; - nanoarrow::UniqueSchema schema_; - struct AdbcError error_; - struct ArrowError na_error_; - - explicit PqResultArrayReader(PqResultArrayReader* other) - : helper_(std::move(other->helper_)), - type_resolver_(std::move(other->type_resolver_)), - field_readers_(std::move(other->field_readers_)), - schema_(std::move(other->schema_)) { - ArrowErrorInit(&na_error_); - error_ = ADBC_ERROR_INIT; - } - - void ResetErrors() { - ArrowErrorInit(&na_error_); - - if (error_.private_data != nullptr) { - error_.release(&error_); - } - error_ = ADBC_ERROR_INIT; - } -}; - } // namespace adbcpq diff --git a/c/driver/postgresql/result_reader.cc b/c/driver/postgresql/result_reader.cc new file mode 100644 index 0000000000..9b68ef66c8 --- /dev/null +++ b/c/driver/postgresql/result_reader.cc @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "result_reader.h" + +#include "copy/reader.h" +#include "driver/common/utils.h" + +#include "error.h" + +namespace adbcpq { + +int PqResultArrayReader::GetSchema(struct ArrowSchema* out) { + ResetErrors(); + + if (schema_->release == nullptr) { + AdbcStatusCode status = Initialize(&error_); + if (status != ADBC_STATUS_OK) { + return EINVAL; + } + } + + return ArrowSchemaDeepCopy(schema_.get(), out); +} + +int PqResultArrayReader::GetNext(struct ArrowArray* out) { + ResetErrors(); + + if (schema_->release == nullptr) { + AdbcStatusCode status = Initialize(&error_); + if (status != ADBC_STATUS_OK) { + return EINVAL; + } + } + + if (!helper_.HasResult()) { + out->release = nullptr; + return NANOARROW_OK; + } + + nanoarrow::UniqueArray tmp; + NANOARROW_RETURN_NOT_OK(ArrowArrayInitFromSchema(tmp.get(), schema_.get(), &na_error_)); + NANOARROW_RETURN_NOT_OK(ArrowArrayStartAppending(tmp.get())); + for (int i = 0; i < helper_.NumColumns(); i++) { + NANOARROW_RETURN_NOT_OK(field_readers_[i]->InitArray(tmp->children[i])); + } + + // TODO: If we get an EOVERFLOW here (e.g., big string data), we + // would need to keep track of what row number we're on and start + // from there instead of begin() on the next call. We could also + // respect the size hint here to chunk the batches. + struct ArrowBufferView item; + for (auto it = helper_.begin(); it != helper_.end(); it++) { + auto row = *it; + for (int i = 0; i < helper_.NumColumns(); i++) { + auto pg_item = row[i]; + item.data.data = pg_item.data; + + if (pg_item.is_null) { + item.size_bytes = -1; + } else { + item.size_bytes = pg_item.len; + } + + NANOARROW_RETURN_NOT_OK( + field_readers_[i]->Read(&item, item.size_bytes, tmp->children[i], &na_error_)); + } + } + + for (int i = 0; i < helper_.NumColumns(); i++) { + NANOARROW_RETURN_NOT_OK(field_readers_[i]->FinishArray(tmp->children[i], &na_error_)); + } + + tmp->length = helper_.NumRows(); + tmp->null_count = 0; + NANOARROW_RETURN_NOT_OK(ArrowArrayFinishBuildingDefault(tmp.get(), &na_error_)); + + // Ensure that the next call to GetNext() will signal the end of the stream + helper_.ClearResult(); + + // Canonically return zero-size results as an empty stream + if (tmp->length == 0) { + out->release = nullptr; + return NANOARROW_OK; + } + + ArrowArrayMove(tmp.get(), out); + return NANOARROW_OK; +} + +const char* PqResultArrayReader::GetLastError() { + if (error_.message != nullptr) { + return error_.message; + } else { + return na_error_.message; + } +} + +AdbcStatusCode PqResultArrayReader::Initialize(struct AdbcError* error) { + helper_.set_output_format(PqResultHelper::Format::kBinary); + RAISE_ADBC(helper_.Execute(error)); + + ArrowSchemaInit(schema_.get()); + CHECK_NA_DETAIL(INTERNAL, ArrowSchemaSetTypeStruct(schema_.get(), helper_.NumColumns()), + &na_error_, error); + + for (int i = 0; i < helper_.NumColumns(); i++) { + PostgresType child_type; + CHECK_NA_DETAIL(INTERNAL, + type_resolver_->Find(helper_.FieldType(i), &child_type, &na_error_), + &na_error_, error); + + CHECK_NA(INTERNAL, child_type.SetSchema(schema_->children[i]), error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(schema_->children[i], helper_.FieldName(i)), + error); + + std::unique_ptr child_reader; + CHECK_NA_DETAIL( + INTERNAL, + MakeCopyFieldReader(child_type, schema_->children[i], &child_reader, &na_error_), + &na_error_, error); + + child_reader->Init(child_type); + CHECK_NA_DETAIL(INTERNAL, child_reader->InitSchema(schema_->children[i]), &na_error_, + error); + + field_readers_.push_back(std::move(child_reader)); + } + + return ADBC_STATUS_OK; +} + +AdbcStatusCode PqResultArrayReader::ToArrayStream(int64_t* affected_rows, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (out == nullptr) { + // If there is no output requested, we still need to execute and set + // affected_rows if needed. We don't need an output schema or to set + // up a copy reader, so we can skip those steps by going straight + // to Execute(). This also enables us to support queries with multiple + // statements because we can call PQexec() instead of PQexecParams(). + RAISE_ADBC(helper_.Execute(error)); + + if (affected_rows != nullptr) { + *affected_rows = helper_.AffectedRows(); + } + + return ADBC_STATUS_OK; + } + + // Execute eagerly. We need this to provide row counts for DELETE and + // CREATE TABLE queries as well as to provide more informative errors + // until this reader class is wired up to provide extended AdbcError + // information. + RAISE_ADBC(Initialize(error)); + if (affected_rows != nullptr) { + *affected_rows = helper_.AffectedRows(); + } + + nanoarrow::ArrayStreamFactory::InitArrayStream( + new PqResultArrayReader(this), out); + + return ADBC_STATUS_OK; +} + +} // namespace adbcpq diff --git a/c/driver/postgresql/result_reader.h b/c/driver/postgresql/result_reader.h new file mode 100644 index 0000000000..11429a6902 --- /dev/null +++ b/c/driver/postgresql/result_reader.h @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "copy/reader.h" +#include "result_helper.h" + +namespace adbcpq { + +class PqResultArrayReader { + public: + PqResultArrayReader(PGconn* conn, std::shared_ptr type_resolver, + std::string query) + : helper_(conn, std::move(query)), type_resolver_(type_resolver) { + ArrowErrorInit(&na_error_); + error_ = ADBC_ERROR_INIT; + } + + ~PqResultArrayReader() { ResetErrors(); } + + int GetSchema(struct ArrowSchema* out); + int GetNext(struct ArrowArray* out); + const char* GetLastError(); + + AdbcStatusCode ToArrayStream(int64_t* affected_rows, struct ArrowArrayStream* out, + struct AdbcError* error); + + AdbcStatusCode Initialize(struct AdbcError* error); + + private: + PqResultHelper helper_; + std::shared_ptr type_resolver_; + std::vector> field_readers_; + nanoarrow::UniqueSchema schema_; + struct AdbcError error_; + struct ArrowError na_error_; + + explicit PqResultArrayReader(PqResultArrayReader* other) + : helper_(std::move(other->helper_)), + type_resolver_(std::move(other->type_resolver_)), + field_readers_(std::move(other->field_readers_)), + schema_(std::move(other->schema_)) { + ArrowErrorInit(&na_error_); + error_ = ADBC_ERROR_INIT; + } + + void ResetErrors() { + ArrowErrorInit(&na_error_); + + if (error_.private_data != nullptr) { + error_.release(&error_); + } + error_ = ADBC_ERROR_INIT; + } +}; + +} // namespace adbcpq diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index b887af82be..b218a4a5c1 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -44,6 +44,7 @@ #include "postgres_type.h" #include "postgres_util.h" #include "result_helper.h" +#include "result_reader.h" namespace adbcpq { From 2024e7e5807afb91394146979c2e9cdf84a0dbc4 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 7 Aug 2024 12:48:39 -0300 Subject: [PATCH 04/25] consolidate some common actions --- c/driver/postgresql/bind_stream.h | 79 ++++++++++++++----------------- c/driver/postgresql/statement.cc | 7 +-- 2 files changed, 40 insertions(+), 46 deletions(-) diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index 7375cc2c17..fd393612cf 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -35,7 +35,10 @@ constexpr int kPgBinaryFormat = 1; /// Helper to manage bind parameters with a prepared statement struct BindStream { Handle bind; + Handle array_view; + Handle current; Handle bind_schema; + struct ArrowSchemaView bind_schema_view; std::vector bind_schema_fields; @@ -200,11 +203,10 @@ struct BindStream { return ADBC_STATUS_OK; } - AdbcStatusCode Prepare(const PostgresConnection* conn, const std::string& query, + AdbcStatusCode Prepare(PGconn* pg_conn, const std::string& query, struct AdbcError* error, const bool autocommit) { // tz-aware timestamps require special handling to set the timezone to UTC // prior to sending over the binary protocol; must be reset after execute - const auto pg_conn = conn->conn(); for (int64_t col = 0; col < bind_schema->n_children; col++) { if ((bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) && (strcmp("", bind_schema_fields[col].timezone))) { @@ -261,34 +263,33 @@ struct BindStream { return ADBC_STATUS_OK; } - AdbcStatusCode Execute(const PostgresConnection* conn, int64_t* rows_affected, + AdbcStatusCode PullNextArray(AdbcError* error) { + if (current->release != nullptr) ArrowArrayRelease(¤t.value); + + CHECK_NA_DETAIL(IO, ArrowArrayStreamGetNext(&bind.value, ¤t.value, &na_error), + &na_error, error); + + return ADBC_STATUS_OK; + } + + AdbcStatusCode Execute(PGconn* pg_conn, int64_t* rows_affected, struct AdbcError* error) { if (rows_affected) *rows_affected = 0; PGresult* result = nullptr; - const auto pg_conn = conn->conn(); + CHECK_NA_DETAIL( + INTERNAL, + ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, &na_error), + &na_error, error); while (true) { - Handle array; - int res = bind->get_next(&bind.value, &array.value); - if (res != 0) { - SetError(error, - "[libpq] Failed to read next batch from stream of bind parameters: " - "(%d) %s %s", - res, std::strerror(res), bind->get_last_error(&bind.value)); - return ADBC_STATUS_IO; - } - if (!array->release) break; - - Handle array_view; - // TODO: include error messages - CHECK_NA( - INTERNAL, - ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, nullptr), - error); - CHECK_NA(INTERNAL, ArrowArrayViewSetArray(&array_view.value, &array.value, nullptr), - error); + RAISE_ADBC(PullNextArray(error)); + if (!current->release) break; - for (int64_t row = 0; row < array->length; row++) { + CHECK_NA_DETAIL( + INTERNAL, ArrowArrayViewSetArray(&array_view.value, ¤t.value, &na_error), + &na_error, error); + + for (int64_t row = 0; row < current->length; row++) { for (int64_t col = 0; col < array_view->n_children; col++) { if (ArrowArrayViewIsNull(array_view->children[col], row)) { param_values[col] = nullptr; @@ -471,7 +472,7 @@ struct BindStream { PQclear(result); } - if (rows_affected) *rows_affected += array->length; + if (rows_affected) *rows_affected += current->length; if (has_tz_field) { std::string reset_query = "SET TIME ZONE '" + tz_setting + "'"; @@ -499,8 +500,8 @@ struct BindStream { return ADBC_STATUS_OK; } - AdbcStatusCode ExecuteCopy(const PostgresConnection* conn, int64_t* rows_affected, - struct AdbcError* error) { + AdbcStatusCode ExecuteCopy(PGconn* pg_conn, const PostgresTypeResolver& type_resolver, + int64_t* rows_affected, struct AdbcError* error) { // https://github.com/apache/arrow-adbc/issues/1921: PostgreSQL has a max // size for a single message that we need to respect (1 GiB - 1). Since // the buffer can be chunked up as much as we want, go for 16 MiB as our @@ -508,32 +509,24 @@ struct BindStream { // https://github.com/postgres/postgres/blob/23c5a0e7d43bc925c6001538f04a458933a11fc1/src/common/stringinfo.c#L28 constexpr int64_t kMaxCopyBufferSize = 0x1000000; if (rows_affected) *rows_affected = 0; - const auto pg_conn = conn->conn(); PostgresCopyStreamWriter writer; CHECK_NA(INTERNAL, writer.Init(&bind_schema.value), error); - CHECK_NA(INTERNAL, writer.InitFieldWriters(*conn->type_resolver(), nullptr), error); + CHECK_NA_DETAIL(INTERNAL, writer.InitFieldWriters(type_resolver, &na_error), + &na_error, error); - CHECK_NA(INTERNAL, writer.WriteHeader(nullptr), error); + CHECK_NA_DETAIL(INTERNAL, writer.WriteHeader(&na_error), &na_error, error); while (true) { - Handle array; - int res = bind->get_next(&bind.value, &array.value); - if (res != 0) { - SetError(error, - "[libpq] Failed to read next batch from stream of bind parameters: " - "(%d) %s %s", - res, std::strerror(res), bind->get_last_error(&bind.value)); - return ADBC_STATUS_IO; - } - if (!array->release) break; + RAISE_ADBC(PullNextArray(error)); + if (!current->release) break; - CHECK_NA(INTERNAL, writer.SetArray(&array.value), error); + CHECK_NA(INTERNAL, writer.SetArray(¤t.value), error); // build writer buffer int write_result; do { - write_result = writer.WriteRecord(nullptr); + write_result = writer.WriteRecord(&na_error); } while (write_result == NANOARROW_OK); // check if not ENODATA at exit @@ -558,7 +551,7 @@ struct BindStream { } } - if (rows_affected) *rows_affected += array->length; + if (rows_affected) *rows_affected += current->length; writer.Rewind(); } diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index b218a4a5c1..56c24ccdb3 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -1144,8 +1144,8 @@ AdbcStatusCode PostgresStatement::ExecuteBind(struct ArrowArrayStream* stream, RAISE_ADBC(bind_stream.Begin([&]() { return ADBC_STATUS_OK; }, error)); RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error)); RAISE_ADBC( - bind_stream.Prepare(connection_.get(), query_, error, connection_->autocommit())); - RAISE_ADBC(bind_stream.Execute(connection_.get(), rows_affected, error)); + bind_stream.Prepare(connection_->conn(), query_, error, connection_->autocommit())); + RAISE_ADBC(bind_stream.Execute(connection_->conn(), rows_affected, error)); return ADBC_STATUS_OK; } @@ -1321,7 +1321,8 @@ AdbcStatusCode PostgresStatement::ExecuteIngest(struct ArrowArrayStream* stream, } PQclear(result); - RAISE_ADBC(bind_stream.ExecuteCopy(connection_.get(), rows_affected, error)); + RAISE_ADBC(bind_stream.ExecuteCopy(connection_->conn(), *connection_->type_resolver(), + rows_affected, error)); return ADBC_STATUS_OK; } From 10893ac19c984e029bb92554f99c8cbd9043289c Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 7 Aug 2024 13:20:27 -0300 Subject: [PATCH 05/25] set array stream array on pull next --- c/driver/postgresql/bind_stream.h | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index fd393612cf..8fb951f9a8 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -64,11 +64,12 @@ struct BindStream { template AdbcStatusCode Begin(Callback&& callback, struct AdbcError* error) { - CHECK_NA(INTERNAL, bind->get_schema(&bind.value, &bind_schema.value), error); - CHECK_NA( - INTERNAL, - ArrowSchemaViewInit(&bind_schema_view, &bind_schema.value, /*error*/ nullptr), - error); + CHECK_NA_DETAIL(INTERNAL, + ArrowArrayStreamGetSchema(&bind.value, &bind_schema.value, &na_error), + &na_error, error); + CHECK_NA_DETAIL(INTERNAL, + ArrowSchemaViewInit(&bind_schema_view, &bind_schema.value, &na_error), + &na_error, error); if (bind_schema_view.type != ArrowType::NANOARROW_TYPE_STRUCT) { SetError(error, "%s", "[libpq] Bind parameters must have type STRUCT"); @@ -83,6 +84,11 @@ struct BindStream { error); } + CHECK_NA_DETAIL( + INTERNAL, + ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, &na_error), + &na_error, error); + return std::move(callback)(); } @@ -269,6 +275,12 @@ struct BindStream { CHECK_NA_DETAIL(IO, ArrowArrayStreamGetNext(&bind.value, ¤t.value, &na_error), &na_error, error); + if (current->release != nullptr) { + CHECK_NA_DETAIL( + INTERNAL, ArrowArrayViewSetArray(&array_view.value, ¤t.value, &na_error), + &na_error, error); + } + return ADBC_STATUS_OK; } @@ -276,19 +288,11 @@ struct BindStream { struct AdbcError* error) { if (rows_affected) *rows_affected = 0; PGresult* result = nullptr; - CHECK_NA_DETAIL( - INTERNAL, - ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, &na_error), - &na_error, error); while (true) { RAISE_ADBC(PullNextArray(error)); if (!current->release) break; - CHECK_NA_DETAIL( - INTERNAL, ArrowArrayViewSetArray(&array_view.value, ¤t.value, &na_error), - &na_error, error); - for (int64_t row = 0; row < current->length; row++) { for (int64_t col = 0; col < array_view->n_children; col++) { if (ArrowArrayViewIsNull(array_view->children[col], row)) { From 0de6afcfbbf72d92d9a2e9803d6d0c4aa20da0e1 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 7 Aug 2024 13:30:22 -0300 Subject: [PATCH 06/25] add extra file --- r/adbcpostgresql/src/Makevars.in | 1 + r/adbcpostgresql/src/Makevars.ucrt | 1 + r/adbcpostgresql/src/Makevars.win | 1 + 3 files changed, 3 insertions(+) diff --git a/r/adbcpostgresql/src/Makevars.in b/r/adbcpostgresql/src/Makevars.in index d3c62fad6a..c34b0e1604 100644 --- a/r/adbcpostgresql/src/Makevars.in +++ b/r/adbcpostgresql/src/Makevars.in @@ -27,5 +27,6 @@ OBJECTS = init.o \ c/driver/postgresql/error.o \ c/driver/postgresql/postgresql.o \ c/driver/postgresql/result_helper.o \ + c/driver/postgresql/result_reader.o \ c/driver/postgresql/statement.o \ c/vendor/nanoarrow/nanoarrow.o diff --git a/r/adbcpostgresql/src/Makevars.ucrt b/r/adbcpostgresql/src/Makevars.ucrt index 857c45b776..37b1013444 100644 --- a/r/adbcpostgresql/src/Makevars.ucrt +++ b/r/adbcpostgresql/src/Makevars.ucrt @@ -28,5 +28,6 @@ OBJECTS = init.o \ c/driver/postgresql/error.o \ c/driver/postgresql/postgresql.o \ c/driver/postgresql/result_helper.o \ + c/driver/postgresql/result_reader.o \ c/driver/postgresql/statement.o \ c/vendor/nanoarrow/nanoarrow.o diff --git a/r/adbcpostgresql/src/Makevars.win b/r/adbcpostgresql/src/Makevars.win index abd5d82aa2..fe715ef3b5 100644 --- a/r/adbcpostgresql/src/Makevars.win +++ b/r/adbcpostgresql/src/Makevars.win @@ -30,6 +30,7 @@ OBJECTS = init.o \ c/driver/postgresql/error.o \ c/driver/postgresql/postgresql.o \ c/driver/postgresql/result_helper.o \ + c/driver/postgresql/result_reader.o \ c/driver/postgresql/statement.o \ c/vendor/nanoarrow/nanoarrow.o From a91cec03e7803e333713f92b146e08888feae803 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 8 Aug 2024 22:07:40 -0300 Subject: [PATCH 07/25] pull out "next row" advancing --- c/driver/postgresql/bind_stream.h | 401 ++++++++++++++++-------------- 1 file changed, 212 insertions(+), 189 deletions(-) diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index 8fb951f9a8..9b1edfb3d9 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -38,6 +38,7 @@ struct BindStream { Handle array_view; Handle current; Handle bind_schema; + int64_t current_row = -1; struct ArrowSchemaView bind_schema_view; std::vector bind_schema_fields; @@ -284,223 +285,245 @@ struct BindStream { return ADBC_STATUS_OK; } + AdbcStatusCode EnsureNextRow(AdbcError* error) { + if (current->release != nullptr) { + current_row++; + if (current_row < current->length) { + return ADBC_STATUS_OK; + } + } + + // Pull until we have an array with at least one row or the stream is finished + do { + RAISE_ADBC(PullNextArray(error)); + if (current->release == nullptr) { + current_row = -1; + return ADBC_STATUS_OK; + } + } while (current->length == 0); + + current_row = 0; + return ADBC_STATUS_OK; + } + AdbcStatusCode Execute(PGconn* pg_conn, int64_t* rows_affected, struct AdbcError* error) { if (rows_affected) *rows_affected = 0; PGresult* result = nullptr; + int64_t row = -1; while (true) { - RAISE_ADBC(PullNextArray(error)); + RAISE_ADBC(EnsureNextRow(error)); if (!current->release) break; + row = current_row; + + for (int64_t col = 0; col < array_view->n_children; col++) { + if (ArrowArrayViewIsNull(array_view->children[col], row)) { + param_values[col] = nullptr; + continue; + } else { + param_values[col] = param_values_buffer.data() + param_values_offsets[col]; + } + switch (bind_schema_fields[col].type) { + case ArrowType::NANOARROW_TYPE_BOOL: { + const int8_t val = ArrowBitGet( + array_view->children[col]->buffer_views[1].data.as_uint8, row); + std::memcpy(param_values[col], &val, sizeof(int8_t)); + break; + } - for (int64_t row = 0; row < current->length; row++) { - for (int64_t col = 0; col < array_view->n_children; col++) { - if (ArrowArrayViewIsNull(array_view->children[col], row)) { - param_values[col] = nullptr; - continue; - } else { - param_values[col] = param_values_buffer.data() + param_values_offsets[col]; + case ArrowType::NANOARROW_TYPE_INT8: { + const int16_t val = + array_view->children[col]->buffer_views[1].data.as_int8[row]; + const uint16_t value = ToNetworkInt16(val); + std::memcpy(param_values[col], &value, sizeof(int16_t)); + break; } - switch (bind_schema_fields[col].type) { - case ArrowType::NANOARROW_TYPE_BOOL: { - const int8_t val = ArrowBitGet( - array_view->children[col]->buffer_views[1].data.as_uint8, row); - std::memcpy(param_values[col], &val, sizeof(int8_t)); - break; + case ArrowType::NANOARROW_TYPE_INT16: { + const uint16_t value = ToNetworkInt16( + array_view->children[col]->buffer_views[1].data.as_int16[row]); + std::memcpy(param_values[col], &value, sizeof(int16_t)); + break; + } + case ArrowType::NANOARROW_TYPE_INT32: { + const uint32_t value = ToNetworkInt32( + array_view->children[col]->buffer_views[1].data.as_int32[row]); + std::memcpy(param_values[col], &value, sizeof(int32_t)); + break; + } + case ArrowType::NANOARROW_TYPE_INT64: { + const int64_t value = ToNetworkInt64( + array_view->children[col]->buffer_views[1].data.as_int64[row]); + std::memcpy(param_values[col], &value, sizeof(int64_t)); + break; + } + case ArrowType::NANOARROW_TYPE_FLOAT: { + const uint32_t value = ToNetworkFloat4( + array_view->children[col]->buffer_views[1].data.as_float[row]); + std::memcpy(param_values[col], &value, sizeof(uint32_t)); + break; + } + case ArrowType::NANOARROW_TYPE_DOUBLE: { + const uint64_t value = ToNetworkFloat8( + array_view->children[col]->buffer_views[1].data.as_double[row]); + std::memcpy(param_values[col], &value, sizeof(uint64_t)); + break; + } + case ArrowType::NANOARROW_TYPE_STRING: + case ArrowType::NANOARROW_TYPE_LARGE_STRING: + case ArrowType::NANOARROW_TYPE_BINARY: { + const ArrowBufferView view = + ArrowArrayViewGetBytesUnsafe(array_view->children[col], row); + // TODO: overflow check? + param_lengths[col] = static_cast(view.size_bytes); + param_values[col] = const_cast(view.data.as_char); + break; + } + case ArrowType::NANOARROW_TYPE_DATE32: { + // 2000-01-01 + constexpr int32_t kPostgresDateEpoch = 10957; + const int32_t raw_value = + array_view->children[col]->buffer_views[1].data.as_int32[row]; + if (raw_value < INT32_MIN + kPostgresDateEpoch) { + SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1, + "('", bind_schema->children[col]->name, "') Row #", row + 1, + "has value which exceeds postgres date limits"); + return ADBC_STATUS_INVALID_ARGUMENT; } - case ArrowType::NANOARROW_TYPE_INT8: { - const int16_t val = - array_view->children[col]->buffer_views[1].data.as_int8[row]; - const uint16_t value = ToNetworkInt16(val); - std::memcpy(param_values[col], &value, sizeof(int16_t)); - break; + const uint32_t value = ToNetworkInt32(raw_value - kPostgresDateEpoch); + std::memcpy(param_values[col], &value, sizeof(int32_t)); + break; + } + case ArrowType::NANOARROW_TYPE_DURATION: + case ArrowType::NANOARROW_TYPE_TIMESTAMP: { + int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row]; + + bool overflow_safe = true; + + auto unit = bind_schema_fields[col].time_unit; + + switch (unit) { + case NANOARROW_TIME_UNIT_SECOND: + overflow_safe = + val <= kMaxSafeSecondsToMicros && val >= kMinSafeSecondsToMicros; + if (overflow_safe) { + val *= 1000000; + } + + break; + case NANOARROW_TIME_UNIT_MILLI: + overflow_safe = + val <= kMaxSafeMillisToMicros && val >= kMinSafeMillisToMicros; + if (overflow_safe) { + val *= 1000; + } + break; + case NANOARROW_TIME_UNIT_MICRO: + break; + case NANOARROW_TIME_UNIT_NANO: + val /= 1000; + break; } - case ArrowType::NANOARROW_TYPE_INT16: { - const uint16_t value = ToNetworkInt16( - array_view->children[col]->buffer_views[1].data.as_int16[row]); - std::memcpy(param_values[col], &value, sizeof(int16_t)); - break; + + if (!overflow_safe) { + SetError(error, + "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 + " has value '%" PRIi64 + "' which exceeds PostgreSQL timestamp limits", + col + 1, bind_schema->children[col]->name, row + 1, + array_view->children[col]->buffer_views[1].data.as_int64[row]); + return ADBC_STATUS_INVALID_ARGUMENT; } - case ArrowType::NANOARROW_TYPE_INT32: { - const uint32_t value = ToNetworkInt32( - array_view->children[col]->buffer_views[1].data.as_int32[row]); - std::memcpy(param_values[col], &value, sizeof(int32_t)); - break; + + if (val < (std::numeric_limits::min)() + kPostgresTimestampEpoch) { + SetError(error, + "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 + " has value '%" PRIi64 "' which would underflow", + col + 1, bind_schema->children[col]->name, row + 1, + array_view->children[col]->buffer_views[1].data.as_int64[row]); + return ADBC_STATUS_INVALID_ARGUMENT; } - case ArrowType::NANOARROW_TYPE_INT64: { - const int64_t value = ToNetworkInt64( - array_view->children[col]->buffer_views[1].data.as_int64[row]); + + if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) { + const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch); std::memcpy(param_values[col], &value, sizeof(int64_t)); - break; - } - case ArrowType::NANOARROW_TYPE_FLOAT: { - const uint32_t value = ToNetworkFloat4( - array_view->children[col]->buffer_views[1].data.as_float[row]); - std::memcpy(param_values[col], &value, sizeof(uint32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_DOUBLE: { - const uint64_t value = ToNetworkFloat8( - array_view->children[col]->buffer_views[1].data.as_double[row]); - std::memcpy(param_values[col], &value, sizeof(uint64_t)); - break; - } - case ArrowType::NANOARROW_TYPE_STRING: - case ArrowType::NANOARROW_TYPE_LARGE_STRING: - case ArrowType::NANOARROW_TYPE_BINARY: { - const ArrowBufferView view = - ArrowArrayViewGetBytesUnsafe(array_view->children[col], row); - // TODO: overflow check? - param_lengths[col] = static_cast(view.size_bytes); - param_values[col] = const_cast(view.data.as_char); - break; - } - case ArrowType::NANOARROW_TYPE_DATE32: { - // 2000-01-01 - constexpr int32_t kPostgresDateEpoch = 10957; - const int32_t raw_value = - array_view->children[col]->buffer_views[1].data.as_int32[row]; - if (raw_value < INT32_MIN + kPostgresDateEpoch) { - SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1, - "('", bind_schema->children[col]->name, "') Row #", row + 1, - "has value which exceeds postgres date limits"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - const uint32_t value = ToNetworkInt32(raw_value - kPostgresDateEpoch); - std::memcpy(param_values[col], &value, sizeof(int32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_DURATION: - case ArrowType::NANOARROW_TYPE_TIMESTAMP: { - int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row]; - - bool overflow_safe = true; - - auto unit = bind_schema_fields[col].time_unit; - - switch (unit) { - case NANOARROW_TIME_UNIT_SECOND: - overflow_safe = - val <= kMaxSafeSecondsToMicros && val >= kMinSafeSecondsToMicros; - if (overflow_safe) { - val *= 1000000; - } - - break; - case NANOARROW_TIME_UNIT_MILLI: - overflow_safe = - val <= kMaxSafeMillisToMicros && val >= kMinSafeMillisToMicros; - if (overflow_safe) { - val *= 1000; - } - break; - case NANOARROW_TIME_UNIT_MICRO: - break; - case NANOARROW_TIME_UNIT_NANO: - val /= 1000; - break; - } - - if (!overflow_safe) { - SetError(error, - "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 - " has value '%" PRIi64 - "' which exceeds PostgreSQL timestamp limits", - col + 1, bind_schema->children[col]->name, row + 1, - array_view->children[col]->buffer_views[1].data.as_int64[row]); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - if (val < (std::numeric_limits::min)() + kPostgresTimestampEpoch) { - SetError(error, - "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 - " has value '%" PRIi64 "' which would underflow", - col + 1, bind_schema->children[col]->name, row + 1, - array_view->children[col]->buffer_views[1].data.as_int64[row]); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) { - const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - } else if (bind_schema_fields[col].type == - ArrowType::NANOARROW_TYPE_DURATION) { - // postgres stores an interval as a 64 bit offset in microsecond - // resolution alongside a 32 bit day and 32 bit month - // for now we just send 0 for the day / month values - const uint64_t value = ToNetworkInt64(val); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - std::memset(param_values[col] + sizeof(int64_t), 0, sizeof(int64_t)); - } - break; - } - case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { - struct ArrowInterval interval; - ArrowIntervalInit(&interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO); - ArrowArrayViewGetIntervalUnsafe(array_view->children[col], row, &interval); - - const uint32_t months = ToNetworkInt32(interval.months); - const uint32_t days = ToNetworkInt32(interval.days); - const uint64_t ms = ToNetworkInt64(interval.ns / 1000); - - std::memcpy(param_values[col], &ms, sizeof(uint64_t)); - std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t)); - std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t), - &months, sizeof(uint32_t)); - break; + } else if (bind_schema_fields[col].type == + ArrowType::NANOARROW_TYPE_DURATION) { + // postgres stores an interval as a 64 bit offset in microsecond + // resolution alongside a 32 bit day and 32 bit month + // for now we just send 0 for the day / month values + const uint64_t value = ToNetworkInt64(val); + std::memcpy(param_values[col], &value, sizeof(int64_t)); + std::memset(param_values[col] + sizeof(int64_t), 0, sizeof(int64_t)); } - default: - SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('", - bind_schema->children[col]->name, - "') has unsupported type for ingestion ", - ArrowTypeString(bind_schema_fields[col].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; + break; + } + case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { + struct ArrowInterval interval; + ArrowIntervalInit(&interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO); + ArrowArrayViewGetIntervalUnsafe(array_view->children[col], row, &interval); + + const uint32_t months = ToNetworkInt32(interval.months); + const uint32_t days = ToNetworkInt32(interval.days); + const uint64_t ms = ToNetworkInt64(interval.ns / 1000); + + std::memcpy(param_values[col], &ms, sizeof(uint64_t)); + std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t)); + std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t), &months, + sizeof(uint32_t)); + break; } + default: + SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('", + bind_schema->children[col]->name, + "') has unsupported type for ingestion ", + ArrowTypeString(bind_schema_fields[col].type)); + return ADBC_STATUS_NOT_IMPLEMENTED; } + } - result = PQexecPrepared(pg_conn, /*stmtName=*/"", - /*nParams=*/bind_schema->n_children, param_values.data(), - param_lengths.data(), param_formats.data(), - /*resultFormat=*/0 /*text*/); - - ExecStatusType pg_status = PQresultStatus(result); - if (pg_status != PGRES_COMMAND_OK) { - AdbcStatusCode code = SetError( - error, result, "[libpq] Failed to execute prepared statement: %s %s", - PQresStatus(pg_status), PQerrorMessage(pg_conn)); - PQclear(result); - return code; - } + result = PQexecPrepared(pg_conn, /*stmtName=*/"", + /*nParams=*/bind_schema->n_children, param_values.data(), + param_lengths.data(), param_formats.data(), + /*resultFormat=*/0 /*text*/); + ExecStatusType pg_status = PQresultStatus(result); + if (pg_status != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to execute prepared statement: %s %s", + PQresStatus(pg_status), PQerrorMessage(pg_conn)); PQclear(result); + return code; } - if (rows_affected) *rows_affected += current->length; - if (has_tz_field) { - std::string reset_query = "SET TIME ZONE '" + tz_setting + "'"; - PGresult* reset_tz_result = PQexec(pg_conn, reset_query.c_str()); - if (PQresultStatus(reset_tz_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, reset_tz_result, "[libpq] Failed to reset time zone: %s", - PQerrorMessage(pg_conn)); - PQclear(reset_tz_result); - return code; - } + PQclear(result); + } + if (rows_affected) *rows_affected += current->length; + + if (has_tz_field) { + std::string reset_query = "SET TIME ZONE '" + tz_setting + "'"; + PGresult* reset_tz_result = PQexec(pg_conn, reset_query.c_str()); + if (PQresultStatus(reset_tz_result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, reset_tz_result, "[libpq] Failed to reset time zone: %s", + PQerrorMessage(pg_conn)); PQclear(reset_tz_result); + return code; + } + PQclear(reset_tz_result); - PGresult* commit_result = PQexec(pg_conn, "COMMIT"); - if (PQresultStatus(commit_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, commit_result, "[libpq] Failed to commit transaction: %s", - PQerrorMessage(pg_conn)); - PQclear(commit_result); - return code; - } + PGresult* commit_result = PQexec(pg_conn, "COMMIT"); + if (PQresultStatus(commit_result) != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, commit_result, "[libpq] Failed to commit transaction: %s", + PQerrorMessage(pg_conn)); PQclear(commit_result); + return code; } + PQclear(commit_result); } + return ADBC_STATUS_OK; } From 42fe13f7123a3afa03fbdc48ae7fafca3ba9741b Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 8 Aug 2024 22:27:59 -0300 Subject: [PATCH 08/25] factor out binding of one row --- c/driver/postgresql/bind_stream.h | 370 +++++++++++++++--------------- 1 file changed, 191 insertions(+), 179 deletions(-) diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index 9b1edfb3d9..e1eec683e3 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -306,201 +306,193 @@ struct BindStream { return ADBC_STATUS_OK; } - AdbcStatusCode Execute(PGconn* pg_conn, int64_t* rows_affected, - struct AdbcError* error) { - if (rows_affected) *rows_affected = 0; - PGresult* result = nullptr; + AdbcStatusCode BindCurrentRow(PGconn* pg_conn, PGresult** result_out, AdbcError* error) { + int64_t row = current_row; + + for (int64_t col = 0; col < array_view->n_children; col++) { + if (ArrowArrayViewIsNull(array_view->children[col], row)) { + param_values[col] = nullptr; + continue; + } else { + param_values[col] = param_values_buffer.data() + param_values_offsets[col]; + } + switch (bind_schema_fields[col].type) { + case ArrowType::NANOARROW_TYPE_BOOL: { + const int8_t val = + ArrowBitGet(array_view->children[col]->buffer_views[1].data.as_uint8, row); + std::memcpy(param_values[col], &val, sizeof(int8_t)); + break; + } - int64_t row = -1; - while (true) { - RAISE_ADBC(EnsureNextRow(error)); - if (!current->release) break; - row = current_row; - - for (int64_t col = 0; col < array_view->n_children; col++) { - if (ArrowArrayViewIsNull(array_view->children[col], row)) { - param_values[col] = nullptr; - continue; - } else { - param_values[col] = param_values_buffer.data() + param_values_offsets[col]; + case ArrowType::NANOARROW_TYPE_INT8: { + const int16_t val = + array_view->children[col]->buffer_views[1].data.as_int8[row]; + const uint16_t value = ToNetworkInt16(val); + std::memcpy(param_values[col], &value, sizeof(int16_t)); + break; + } + case ArrowType::NANOARROW_TYPE_INT16: { + const uint16_t value = ToNetworkInt16( + array_view->children[col]->buffer_views[1].data.as_int16[row]); + std::memcpy(param_values[col], &value, sizeof(int16_t)); + break; + } + case ArrowType::NANOARROW_TYPE_INT32: { + const uint32_t value = ToNetworkInt32( + array_view->children[col]->buffer_views[1].data.as_int32[row]); + std::memcpy(param_values[col], &value, sizeof(int32_t)); + break; + } + case ArrowType::NANOARROW_TYPE_INT64: { + const int64_t value = ToNetworkInt64( + array_view->children[col]->buffer_views[1].data.as_int64[row]); + std::memcpy(param_values[col], &value, sizeof(int64_t)); + break; + } + case ArrowType::NANOARROW_TYPE_FLOAT: { + const uint32_t value = ToNetworkFloat4( + array_view->children[col]->buffer_views[1].data.as_float[row]); + std::memcpy(param_values[col], &value, sizeof(uint32_t)); + break; + } + case ArrowType::NANOARROW_TYPE_DOUBLE: { + const uint64_t value = ToNetworkFloat8( + array_view->children[col]->buffer_views[1].data.as_double[row]); + std::memcpy(param_values[col], &value, sizeof(uint64_t)); + break; + } + case ArrowType::NANOARROW_TYPE_STRING: + case ArrowType::NANOARROW_TYPE_LARGE_STRING: + case ArrowType::NANOARROW_TYPE_BINARY: { + const ArrowBufferView view = + ArrowArrayViewGetBytesUnsafe(array_view->children[col], row); + // TODO: overflow check? + param_lengths[col] = static_cast(view.size_bytes); + param_values[col] = const_cast(view.data.as_char); + break; } - switch (bind_schema_fields[col].type) { - case ArrowType::NANOARROW_TYPE_BOOL: { - const int8_t val = ArrowBitGet( - array_view->children[col]->buffer_views[1].data.as_uint8, row); - std::memcpy(param_values[col], &val, sizeof(int8_t)); - break; + case ArrowType::NANOARROW_TYPE_DATE32: { + // 2000-01-01 + constexpr int32_t kPostgresDateEpoch = 10957; + const int32_t raw_value = + array_view->children[col]->buffer_views[1].data.as_int32[row]; + if (raw_value < INT32_MIN + kPostgresDateEpoch) { + SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1, + "('", bind_schema->children[col]->name, "') Row #", row + 1, + "has value which exceeds postgres date limits"); + return ADBC_STATUS_INVALID_ARGUMENT; } - case ArrowType::NANOARROW_TYPE_INT8: { - const int16_t val = - array_view->children[col]->buffer_views[1].data.as_int8[row]; - const uint16_t value = ToNetworkInt16(val); - std::memcpy(param_values[col], &value, sizeof(int16_t)); - break; + const uint32_t value = ToNetworkInt32(raw_value - kPostgresDateEpoch); + std::memcpy(param_values[col], &value, sizeof(int32_t)); + break; + } + case ArrowType::NANOARROW_TYPE_DURATION: + case ArrowType::NANOARROW_TYPE_TIMESTAMP: { + int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row]; + + bool overflow_safe = true; + + auto unit = bind_schema_fields[col].time_unit; + + switch (unit) { + case NANOARROW_TIME_UNIT_SECOND: + overflow_safe = + val <= kMaxSafeSecondsToMicros && val >= kMinSafeSecondsToMicros; + if (overflow_safe) { + val *= 1000000; + } + + break; + case NANOARROW_TIME_UNIT_MILLI: + overflow_safe = + val <= kMaxSafeMillisToMicros && val >= kMinSafeMillisToMicros; + if (overflow_safe) { + val *= 1000; + } + break; + case NANOARROW_TIME_UNIT_MICRO: + break; + case NANOARROW_TIME_UNIT_NANO: + val /= 1000; + break; } - case ArrowType::NANOARROW_TYPE_INT16: { - const uint16_t value = ToNetworkInt16( - array_view->children[col]->buffer_views[1].data.as_int16[row]); - std::memcpy(param_values[col], &value, sizeof(int16_t)); - break; + + if (!overflow_safe) { + SetError(error, + "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 + " has value '%" PRIi64 "' which exceeds PostgreSQL timestamp limits", + col + 1, bind_schema->children[col]->name, row + 1, + array_view->children[col]->buffer_views[1].data.as_int64[row]); + return ADBC_STATUS_INVALID_ARGUMENT; } - case ArrowType::NANOARROW_TYPE_INT32: { - const uint32_t value = ToNetworkInt32( - array_view->children[col]->buffer_views[1].data.as_int32[row]); - std::memcpy(param_values[col], &value, sizeof(int32_t)); - break; + + if (val < (std::numeric_limits::min)() + kPostgresTimestampEpoch) { + SetError(error, + "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 + " has value '%" PRIi64 "' which would underflow", + col + 1, bind_schema->children[col]->name, row + 1, + array_view->children[col]->buffer_views[1].data.as_int64[row]); + return ADBC_STATUS_INVALID_ARGUMENT; } - case ArrowType::NANOARROW_TYPE_INT64: { - const int64_t value = ToNetworkInt64( - array_view->children[col]->buffer_views[1].data.as_int64[row]); + + if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) { + const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch); std::memcpy(param_values[col], &value, sizeof(int64_t)); - break; - } - case ArrowType::NANOARROW_TYPE_FLOAT: { - const uint32_t value = ToNetworkFloat4( - array_view->children[col]->buffer_views[1].data.as_float[row]); - std::memcpy(param_values[col], &value, sizeof(uint32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_DOUBLE: { - const uint64_t value = ToNetworkFloat8( - array_view->children[col]->buffer_views[1].data.as_double[row]); - std::memcpy(param_values[col], &value, sizeof(uint64_t)); - break; - } - case ArrowType::NANOARROW_TYPE_STRING: - case ArrowType::NANOARROW_TYPE_LARGE_STRING: - case ArrowType::NANOARROW_TYPE_BINARY: { - const ArrowBufferView view = - ArrowArrayViewGetBytesUnsafe(array_view->children[col], row); - // TODO: overflow check? - param_lengths[col] = static_cast(view.size_bytes); - param_values[col] = const_cast(view.data.as_char); - break; - } - case ArrowType::NANOARROW_TYPE_DATE32: { - // 2000-01-01 - constexpr int32_t kPostgresDateEpoch = 10957; - const int32_t raw_value = - array_view->children[col]->buffer_views[1].data.as_int32[row]; - if (raw_value < INT32_MIN + kPostgresDateEpoch) { - SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1, - "('", bind_schema->children[col]->name, "') Row #", row + 1, - "has value which exceeds postgres date limits"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - const uint32_t value = ToNetworkInt32(raw_value - kPostgresDateEpoch); - std::memcpy(param_values[col], &value, sizeof(int32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_DURATION: - case ArrowType::NANOARROW_TYPE_TIMESTAMP: { - int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row]; - - bool overflow_safe = true; - - auto unit = bind_schema_fields[col].time_unit; - - switch (unit) { - case NANOARROW_TIME_UNIT_SECOND: - overflow_safe = - val <= kMaxSafeSecondsToMicros && val >= kMinSafeSecondsToMicros; - if (overflow_safe) { - val *= 1000000; - } - - break; - case NANOARROW_TIME_UNIT_MILLI: - overflow_safe = - val <= kMaxSafeMillisToMicros && val >= kMinSafeMillisToMicros; - if (overflow_safe) { - val *= 1000; - } - break; - case NANOARROW_TIME_UNIT_MICRO: - break; - case NANOARROW_TIME_UNIT_NANO: - val /= 1000; - break; - } - - if (!overflow_safe) { - SetError(error, - "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 - " has value '%" PRIi64 - "' which exceeds PostgreSQL timestamp limits", - col + 1, bind_schema->children[col]->name, row + 1, - array_view->children[col]->buffer_views[1].data.as_int64[row]); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - if (val < (std::numeric_limits::min)() + kPostgresTimestampEpoch) { - SetError(error, - "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 - " has value '%" PRIi64 "' which would underflow", - col + 1, bind_schema->children[col]->name, row + 1, - array_view->children[col]->buffer_views[1].data.as_int64[row]); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) { - const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - } else if (bind_schema_fields[col].type == - ArrowType::NANOARROW_TYPE_DURATION) { - // postgres stores an interval as a 64 bit offset in microsecond - // resolution alongside a 32 bit day and 32 bit month - // for now we just send 0 for the day / month values - const uint64_t value = ToNetworkInt64(val); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - std::memset(param_values[col] + sizeof(int64_t), 0, sizeof(int64_t)); - } - break; - } - case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { - struct ArrowInterval interval; - ArrowIntervalInit(&interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO); - ArrowArrayViewGetIntervalUnsafe(array_view->children[col], row, &interval); - - const uint32_t months = ToNetworkInt32(interval.months); - const uint32_t days = ToNetworkInt32(interval.days); - const uint64_t ms = ToNetworkInt64(interval.ns / 1000); - - std::memcpy(param_values[col], &ms, sizeof(uint64_t)); - std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t)); - std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t), &months, - sizeof(uint32_t)); - break; + } else if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_DURATION) { + // postgres stores an interval as a 64 bit offset in microsecond + // resolution alongside a 32 bit day and 32 bit month + // for now we just send 0 for the day / month values + const uint64_t value = ToNetworkInt64(val); + std::memcpy(param_values[col], &value, sizeof(int64_t)); + std::memset(param_values[col] + sizeof(int64_t), 0, sizeof(int64_t)); } - default: - SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('", - bind_schema->children[col]->name, - "') has unsupported type for ingestion ", - ArrowTypeString(bind_schema_fields[col].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; + break; } + case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { + struct ArrowInterval interval; + ArrowIntervalInit(&interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO); + ArrowArrayViewGetIntervalUnsafe(array_view->children[col], row, &interval); + + const uint32_t months = ToNetworkInt32(interval.months); + const uint32_t days = ToNetworkInt32(interval.days); + const uint64_t ms = ToNetworkInt64(interval.ns / 1000); + + std::memcpy(param_values[col], &ms, sizeof(uint64_t)); + std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t)); + std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t), &months, + sizeof(uint32_t)); + break; + } + default: + SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('", + bind_schema->children[col]->name, + "') has unsupported type for ingestion ", + ArrowTypeString(bind_schema_fields[col].type)); + return ADBC_STATUS_NOT_IMPLEMENTED; } + } - result = PQexecPrepared(pg_conn, /*stmtName=*/"", - /*nParams=*/bind_schema->n_children, param_values.data(), - param_lengths.data(), param_formats.data(), - /*resultFormat=*/0 /*text*/); - - ExecStatusType pg_status = PQresultStatus(result); - if (pg_status != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, result, "[libpq] Failed to execute prepared statement: %s %s", - PQresStatus(pg_status), PQerrorMessage(pg_conn)); - PQclear(result); - return code; - } + PGresult* result = + PQexecPrepared(pg_conn, /*stmtName=*/"", + /*nParams=*/bind_schema->n_children, param_values.data(), + param_lengths.data(), param_formats.data(), + /*resultFormat=*/0 /*text*/); + ExecStatusType pg_status = PQresultStatus(result); + if (pg_status != PGRES_COMMAND_OK) { + AdbcStatusCode code = + SetError(error, result, "[libpq] Failed to execute prepared statement: %s %s", + PQresStatus(pg_status), PQerrorMessage(pg_conn)); PQclear(result); + return code; } - if (rows_affected) *rows_affected += current->length; + *result_out = result; + return ADBC_STATUS_OK; + } + + AdbcStatusCode Cleanup(PGconn* pg_conn, AdbcError* error) { if (has_tz_field) { std::string reset_query = "SET TIME ZONE '" + tz_setting + "'"; PGresult* reset_tz_result = PQexec(pg_conn, reset_query.c_str()); @@ -527,6 +519,26 @@ struct BindStream { return ADBC_STATUS_OK; } + AdbcStatusCode Execute(PGconn* pg_conn, int64_t* rows_affected, + struct AdbcError* error) { + if (rows_affected) *rows_affected = 0; + PGresult* result = nullptr; + + while (true) { + RAISE_ADBC(EnsureNextRow(error)); + if (!current->release) break; + + RAISE_ADBC(BindCurrentRow(pg_conn, &result, error)); + PQclear(result); + if (rows_affected) { + (*rows_affected)++; + } + } + + RAISE_ADBC(Cleanup(pg_conn, error)); + return ADBC_STATUS_OK; + } + AdbcStatusCode ExecuteCopy(PGconn* pg_conn, const PostgresTypeResolver& type_resolver, int64_t* rows_affected, struct AdbcError* error) { // https://github.com/apache/arrow-adbc/issues/1921: PostgreSQL has a max From f3843654ca5abb0aa8b15fb545445cb90d7f0d36 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 8 Aug 2024 22:35:11 -0300 Subject: [PATCH 09/25] better name and format --- c/driver/postgresql/bind_stream.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index e1eec683e3..f2c7025090 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -306,7 +306,8 @@ struct BindStream { return ADBC_STATUS_OK; } - AdbcStatusCode BindCurrentRow(PGconn* pg_conn, PGresult** result_out, AdbcError* error) { + AdbcStatusCode BindAndExecuteCurrentRow(PGconn* pg_conn, PGresult** result_out, + AdbcError* error) { int64_t row = current_row; for (int64_t col = 0; col < array_view->n_children; col++) { @@ -528,7 +529,7 @@ struct BindStream { RAISE_ADBC(EnsureNextRow(error)); if (!current->release) break; - RAISE_ADBC(BindCurrentRow(pg_conn, &result, error)); + RAISE_ADBC(BindAndExecuteCurrentRow(pg_conn, &result, error)); PQclear(result); if (rows_affected) { (*rows_affected)++; From 64becac8c3ae28bf3e8ebe14c47c28967a149955 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 9 Aug 2024 10:45:58 -0300 Subject: [PATCH 10/25] include what you use --- c/driver/postgresql/result_reader.cc | 3 +++ c/driver/postgresql/result_reader.h | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/c/driver/postgresql/result_reader.cc b/c/driver/postgresql/result_reader.cc index 9b68ef66c8..d510e5ef3f 100644 --- a/c/driver/postgresql/result_reader.cc +++ b/c/driver/postgresql/result_reader.cc @@ -17,6 +17,9 @@ #include "result_reader.h" +#include +#include + #include "copy/reader.h" #include "driver/common/utils.h" diff --git a/c/driver/postgresql/result_reader.h b/c/driver/postgresql/result_reader.h index 11429a6902..41e2f3f208 100644 --- a/c/driver/postgresql/result_reader.h +++ b/c/driver/postgresql/result_reader.h @@ -17,6 +17,11 @@ #pragma once +#include +#include +#include +#include + #include #include "copy/reader.h" From 74882414477294b93965432e4ef1b57ad606b9d8 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 9 Aug 2024 10:51:20 -0300 Subject: [PATCH 11/25] more include what you use --- c/driver/postgresql/bind_stream.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index f2c7025090..753ccd158c 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -17,6 +17,9 @@ #pragma once +#include +#include +#include #include #include From 77562b6ed714b2e532734c8ae1d97e21e524c535 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 9 Aug 2024 12:18:00 -0300 Subject: [PATCH 12/25] tests passing with bind stream included --- c/driver/postgresql/bind_stream.h | 9 ++++- c/driver/postgresql/result_helper.h | 5 +++ c/driver/postgresql/result_reader.cc | 60 ++++++++++++++++++++++++++-- c/driver/postgresql/result_reader.h | 14 ++++++- c/driver/postgresql/statement.cc | 6 ++- 5 files changed, 84 insertions(+), 10 deletions(-) diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index 753ccd158c..63a58021fb 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -61,11 +61,16 @@ struct BindStream { struct ArrowError na_error; - explicit BindStream(struct ArrowArrayStream&& bind) { - this->bind.value = std::move(bind); + BindStream() { + this->bind->release = nullptr; std::memset(&na_error, 0, sizeof(na_error)); } + void SetBind(struct ArrowArrayStream* stream) { + this->bind.reset(); + ArrowArrayStreamMove(stream, &bind.value); + } + template AdbcStatusCode Begin(Callback&& callback, struct AdbcError* error) { CHECK_NA_DETAIL(INTERNAL, diff --git a/c/driver/postgresql/result_helper.h b/c/driver/postgresql/result_helper.h index d18ee8222e..18de7958b3 100644 --- a/c/driver/postgresql/result_helper.h +++ b/c/driver/postgresql/result_helper.h @@ -109,6 +109,11 @@ class PqResultHelper { bool HasResult() { return result_ != nullptr; } + void SetResult(PGresult* result) { + ClearResult(); + result_ = result; + } + PGresult* ReleaseResult(); void ClearResult() { diff --git a/c/driver/postgresql/result_reader.cc b/c/driver/postgresql/result_reader.cc index d510e5ef3f..8ceb8e1256 100644 --- a/c/driver/postgresql/result_reader.cc +++ b/c/driver/postgresql/result_reader.cc @@ -43,6 +43,7 @@ int PqResultArrayReader::GetSchema(struct ArrowSchema* out) { int PqResultArrayReader::GetNext(struct ArrowArray* out) { ResetErrors(); + AdbcStatusCode status; if (schema_->release == nullptr) { AdbcStatusCode status = Initialize(&error_); if (status != ADBC_STATUS_OK) { @@ -50,9 +51,27 @@ int PqResultArrayReader::GetNext(struct ArrowArray* out) { } } + // If don't already have a result, populate it by binding the next row + // in the bind stream. If there is a bind stream and this is the first + // call to GetNext(), we have already populated the result. if (!helper_.HasResult()) { - out->release = nullptr; - return NANOARROW_OK; + // Try to bind the next row. If there was no stream provided, + // we have inserted a dummy stream that contains no more arrays. + status = bind_stream_->EnsureNextRow(&error_); + if (status != ADBC_STATUS_OK) { + return EIO; + } + + // If there is no underlying current array in the bind stream, we are done. + if (bind_stream_->current->release == nullptr) { + out->release = nullptr; + return NANOARROW_OK; + } + + // Otherwise, bind and execute + PGresult* result; + RAISE_ADBC(bind_stream_->BindAndExecuteCurrentRow(conn_, &result, &error_)); + helper_.SetResult(result); } nanoarrow::UniqueArray tmp; @@ -92,7 +111,7 @@ int PqResultArrayReader::GetNext(struct ArrowArray* out) { tmp->null_count = 0; NANOARROW_RETURN_NOT_OK(ArrowArrayFinishBuildingDefault(tmp.get(), &na_error_)); - // Ensure that the next call to GetNext() will signal the end of the stream + // Signal that the next call to GetNext() will have to populate the result again helper_.ClearResult(); // Canonically return zero-size results as an empty stream @@ -115,8 +134,41 @@ const char* PqResultArrayReader::GetLastError() { AdbcStatusCode PqResultArrayReader::Initialize(struct AdbcError* error) { helper_.set_output_format(PqResultHelper::Format::kBinary); - RAISE_ADBC(helper_.Execute(error)); + helper_.set_param_format(PqResultHelper::Format::kBinary); + + // If we have to do binding, use Prepare() + DescribePrepared(), ensuring + // that the Oids of the binary we're about to send are passed on and that + // we execute something with a result + if (bind_stream_->bind->release != nullptr) { + RAISE_ADBC(bind_stream_->Begin([] { return ADBC_STATUS_OK; }, error)); + RAISE_ADBC(bind_stream_->SetParamTypes(*type_resolver_, error)); + RAISE_ADBC(helper_.Prepare(bind_stream_->param_types, error)); + + RAISE_ADBC(bind_stream_->EnsureNextRow(error)); + + // If there were no arrays in the bind stream, we can still initialize the schema + if (bind_stream_->current->release == nullptr) { + RAISE_ADBC(helper_.DescribePrepared(error)); + } else { + PGresult* result; + RAISE_ADBC(bind_stream_->BindAndExecuteCurrentRow(conn_, &result, error)); + helper_.SetResult(result); + } + } else { + RAISE_ADBC(helper_.Execute(error)); + + // It is helpful for the purposes of the GetNext() implementation to bind + // a stream with no parameters and no arrays. + nanoarrow::UniqueSchema empty_bind; + ArrowSchemaInitFromType(empty_bind.get(), NANOARROW_TYPE_STRUCT); + + nanoarrow::UniqueArrayStream empty_stream; + nanoarrow::EmptyArrayStream(empty_bind.get()).ToArrayStream(empty_stream.get()); + bind_stream_->SetBind(empty_stream.get()); + RAISE_ADBC(bind_stream_->Begin([] { return ADBC_STATUS_OK; }, error)); + } + // Build the schema we are about to build results for ArrowSchemaInit(schema_.get()); CHECK_NA_DETAIL(INTERNAL, ArrowSchemaSetTypeStruct(schema_.get(), helper_.NumColumns()), &na_error_, error); diff --git a/c/driver/postgresql/result_reader.h b/c/driver/postgresql/result_reader.h index 41e2f3f208..714ac0a0de 100644 --- a/c/driver/postgresql/result_reader.h +++ b/c/driver/postgresql/result_reader.h @@ -24,6 +24,7 @@ #include +#include "bind_stream.h" #include "copy/reader.h" #include "result_helper.h" @@ -33,13 +34,18 @@ class PqResultArrayReader { public: PqResultArrayReader(PGconn* conn, std::shared_ptr type_resolver, std::string query) - : helper_(conn, std::move(query)), type_resolver_(type_resolver) { + : conn_(conn), + helper_(conn, std::move(query)), + bind_stream_(std::make_unique()), + type_resolver_(type_resolver) { ArrowErrorInit(&na_error_); error_ = ADBC_ERROR_INIT; } ~PqResultArrayReader() { ResetErrors(); } + void SetBind(struct ArrowArrayStream* stream) { bind_stream_->SetBind(stream); } + int GetSchema(struct ArrowSchema* out); int GetNext(struct ArrowArray* out); const char* GetLastError(); @@ -50,7 +56,9 @@ class PqResultArrayReader { AdbcStatusCode Initialize(struct AdbcError* error); private: + PGconn* conn_; PqResultHelper helper_; + std::unique_ptr bind_stream_; std::shared_ptr type_resolver_; std::vector> field_readers_; nanoarrow::UniqueSchema schema_; @@ -58,7 +66,9 @@ class PqResultArrayReader { struct ArrowError na_error_; explicit PqResultArrayReader(PqResultArrayReader* other) - : helper_(std::move(other->helper_)), + : conn_(other->conn_), + helper_(std::move(other->helper_)), + bind_stream_(std::move(other->bind_stream_)), type_resolver_(std::move(other->type_resolver_)), field_readers_(std::move(other->field_readers_)), schema_(std::move(other->schema_)) { diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index 56c24ccdb3..2d18795a57 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -1138,7 +1138,8 @@ AdbcStatusCode PostgresStatement::ExecuteBind(struct ArrowArrayStream* stream, return ADBC_STATUS_NOT_IMPLEMENTED; } - BindStream bind_stream(std::move(bind_)); + BindStream bind_stream; + bind_stream.SetBind(&bind_); std::memset(&bind_, 0, sizeof(bind_)); RAISE_ADBC(bind_stream.Begin([&]() { return ADBC_STATUS_OK; }, error)); @@ -1293,7 +1294,8 @@ AdbcStatusCode PostgresStatement::ExecuteIngest(struct ArrowArrayStream* stream, current_schema = (*it)[0].data; } - BindStream bind_stream(std::move(bind_)); + BindStream bind_stream; + bind_stream.SetBind(&bind_); std::memset(&bind_, 0, sizeof(bind_)); std::string escaped_table; std::string escaped_field_list; From cd52ec3d19a9d05d98e90288679eaedd04608198 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 9 Aug 2024 13:07:54 -0300 Subject: [PATCH 13/25] mayyyybe wire in support --- c/driver/postgresql/result_reader.cc | 106 +++++++++++++++------------ c/driver/postgresql/result_reader.h | 14 ++-- c/driver/postgresql/statement.cc | 10 +-- 3 files changed, 71 insertions(+), 59 deletions(-) diff --git a/c/driver/postgresql/result_reader.cc b/c/driver/postgresql/result_reader.cc index 8ceb8e1256..420e671d6b 100644 --- a/c/driver/postgresql/result_reader.cc +++ b/c/driver/postgresql/result_reader.cc @@ -31,7 +31,7 @@ int PqResultArrayReader::GetSchema(struct ArrowSchema* out) { ResetErrors(); if (schema_->release == nullptr) { - AdbcStatusCode status = Initialize(&error_); + AdbcStatusCode status = Initialize(nullptr, &error_); if (status != ADBC_STATUS_OK) { return EINVAL; } @@ -45,33 +45,34 @@ int PqResultArrayReader::GetNext(struct ArrowArray* out) { AdbcStatusCode status; if (schema_->release == nullptr) { - AdbcStatusCode status = Initialize(&error_); + AdbcStatusCode status = Initialize(nullptr, &error_); if (status != ADBC_STATUS_OK) { return EINVAL; } } // If don't already have a result, populate it by binding the next row - // in the bind stream. If there is a bind stream and this is the first - // call to GetNext(), we have already populated the result. + // in the bind stream. If this is the first call to GetNext(), we have + // already populated the result. if (!helper_.HasResult()) { - // Try to bind the next row. If there was no stream provided, - // we have inserted a dummy stream that contains no more arrays. - status = bind_stream_->EnsureNextRow(&error_); + // If there was no bind stream provided or the existing bind stream has been + // exhausted, we are done. + if (!bind_stream_) { + out->release = nullptr; + return NANOARROW_OK; + } + + // Keep binding and executing until we have a result to return + status = BindNextAndExecute(nullptr, &error_); if (status != ADBC_STATUS_OK) { return EIO; } - // If there is no underlying current array in the bind stream, we are done. - if (bind_stream_->current->release == nullptr) { + // It's possible that there is still nothing to do here + if (!helper_.HasResult()) { out->release = nullptr; return NANOARROW_OK; } - - // Otherwise, bind and execute - PGresult* result; - RAISE_ADBC(bind_stream_->BindAndExecuteCurrentRow(conn_, &result, &error_)); - helper_.SetResult(result); } nanoarrow::UniqueArray tmp; @@ -132,43 +133,34 @@ const char* PqResultArrayReader::GetLastError() { } } -AdbcStatusCode PqResultArrayReader::Initialize(struct AdbcError* error) { +AdbcStatusCode PqResultArrayReader::Initialize(int64_t* rows_affected, + struct AdbcError* error) { helper_.set_output_format(PqResultHelper::Format::kBinary); helper_.set_param_format(PqResultHelper::Format::kBinary); - // If we have to do binding, use Prepare() + DescribePrepared(), ensuring - // that the Oids of the binary we're about to send are passed on and that - // we execute something with a result - if (bind_stream_->bind->release != nullptr) { + // If we have to do binding, set up the bind stream an execute until + // there is a result with more than zero rows to populate. + if (bind_stream_) { RAISE_ADBC(bind_stream_->Begin([] { return ADBC_STATUS_OK; }, error)); RAISE_ADBC(bind_stream_->SetParamTypes(*type_resolver_, error)); RAISE_ADBC(helper_.Prepare(bind_stream_->param_types, error)); - RAISE_ADBC(bind_stream_->EnsureNextRow(error)); + RAISE_ADBC(BindNextAndExecute(rows_affected, error)); - // If there were no arrays in the bind stream, we can still initialize the schema - if (bind_stream_->current->release == nullptr) { + // If there were no arrays in the bind stream, we still need a result + // to populate the schema. If there were any arrays in the bind stream, + // the last one will still be in helper_ even if it had zero rows. + if (!helper_.HasResult()) { RAISE_ADBC(helper_.DescribePrepared(error)); - } else { - PGresult* result; - RAISE_ADBC(bind_stream_->BindAndExecuteCurrentRow(conn_, &result, error)); - helper_.SetResult(result); } } else { RAISE_ADBC(helper_.Execute(error)); - - // It is helpful for the purposes of the GetNext() implementation to bind - // a stream with no parameters and no arrays. - nanoarrow::UniqueSchema empty_bind; - ArrowSchemaInitFromType(empty_bind.get(), NANOARROW_TYPE_STRUCT); - - nanoarrow::UniqueArrayStream empty_stream; - nanoarrow::EmptyArrayStream(empty_bind.get()).ToArrayStream(empty_stream.get()); - bind_stream_->SetBind(empty_stream.get()); - RAISE_ADBC(bind_stream_->Begin([] { return ADBC_STATUS_OK; }, error)); + if (rows_affected != nullptr) { + *rows_affected = helper_.AffectedRows(); + } } - // Build the schema we are about to build results for + // Build the schema for which we are about to build results ArrowSchemaInit(schema_.get()); CHECK_NA_DETAIL(INTERNAL, ArrowSchemaSetTypeStruct(schema_.get(), helper_.NumColumns()), &na_error_, error); @@ -202,12 +194,12 @@ AdbcStatusCode PqResultArrayReader::Initialize(struct AdbcError* error) { AdbcStatusCode PqResultArrayReader::ToArrayStream(int64_t* affected_rows, struct ArrowArrayStream* out, struct AdbcError* error) { - if (out == nullptr) { - // If there is no output requested, we still need to execute and set - // affected_rows if needed. We don't need an output schema or to set - // up a copy reader, so we can skip those steps by going straight - // to Execute(). This also enables us to support queries with multiple - // statements because we can call PQexec() instead of PQexecParams(). + if (out == nullptr && !bind_stream_) { + // If there is no output requested and nothing to bind, we still need to execute and + // set affected_rows if needed. We don't need an output schema or to set up a copy + // reader, so we can skip those steps by going straight to Execute(). This also + // enables us to support queries with multiple statements because we can call PQexec() + // instead of PQexecParams(). RAISE_ADBC(helper_.Execute(error)); if (affected_rows != nullptr) { @@ -221,10 +213,7 @@ AdbcStatusCode PqResultArrayReader::ToArrayStream(int64_t* affected_rows, // CREATE TABLE queries as well as to provide more informative errors // until this reader class is wired up to provide extended AdbcError // information. - RAISE_ADBC(Initialize(error)); - if (affected_rows != nullptr) { - *affected_rows = helper_.AffectedRows(); - } + RAISE_ADBC(Initialize(affected_rows, error)); nanoarrow::ArrayStreamFactory::InitArrayStream( new PqResultArrayReader(this), out); @@ -232,4 +221,27 @@ AdbcStatusCode PqResultArrayReader::ToArrayStream(int64_t* affected_rows, return ADBC_STATUS_OK; } +AdbcStatusCode PqResultArrayReader::BindNextAndExecute(int64_t* affected_rows, + AdbcError* error) { + // Keep pulling from the bind stream and executing as long as + // we receive results with zero rows. + do { + RAISE_ADBC(bind_stream_->EnsureNextRow(error)); + if (!bind_stream_->current->release) { + RAISE_ADBC(bind_stream_->Cleanup(conn_, error)); + bind_stream_.reset(); + return ADBC_STATUS_OK; + } + + PGresult* result; + RAISE_ADBC(bind_stream_->BindAndExecuteCurrentRow(conn_, &result, error)); + helper_.SetResult(result); + if (affected_rows) { + (*affected_rows) += helper_.AffectedRows(); + } + } while (helper_.NumRows() == 0); + + return ADBC_STATUS_OK; +} + } // namespace adbcpq diff --git a/c/driver/postgresql/result_reader.h b/c/driver/postgresql/result_reader.h index 714ac0a0de..3a23320e8d 100644 --- a/c/driver/postgresql/result_reader.h +++ b/c/driver/postgresql/result_reader.h @@ -34,17 +34,17 @@ class PqResultArrayReader { public: PqResultArrayReader(PGconn* conn, std::shared_ptr type_resolver, std::string query) - : conn_(conn), - helper_(conn, std::move(query)), - bind_stream_(std::make_unique()), - type_resolver_(type_resolver) { + : conn_(conn), helper_(conn, std::move(query)), type_resolver_(type_resolver) { ArrowErrorInit(&na_error_); error_ = ADBC_ERROR_INIT; } ~PqResultArrayReader() { ResetErrors(); } - void SetBind(struct ArrowArrayStream* stream) { bind_stream_->SetBind(stream); } + void SetBind(struct ArrowArrayStream* stream) { + bind_stream_ = std::make_unique(); + bind_stream_->SetBind(stream); + } int GetSchema(struct ArrowSchema* out); int GetNext(struct ArrowArray* out); @@ -53,7 +53,7 @@ class PqResultArrayReader { AdbcStatusCode ToArrayStream(int64_t* affected_rows, struct ArrowArrayStream* out, struct AdbcError* error); - AdbcStatusCode Initialize(struct AdbcError* error); + AdbcStatusCode Initialize(int64_t* affected_rows, struct AdbcError* error); private: PGconn* conn_; @@ -76,6 +76,8 @@ class PqResultArrayReader { error_ = ADBC_ERROR_INIT; } + AdbcStatusCode BindNextAndExecute(int64_t* affected_rows, AdbcError* error); + void ResetErrors() { ArrowErrorInit(&na_error_); diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index 2d18795a57..0ed9af69c3 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -1130,12 +1130,10 @@ AdbcStatusCode PostgresStatement::ExecuteBind(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error) { if (stream) { - // TODO: - SetError(error, "%s", - "[libpq] Prepared statements with parameters returning result sets are not " - "implemented"); - - return ADBC_STATUS_NOT_IMPLEMENTED; + PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); + reader.SetBind(&bind_); + RAISE_ADBC(reader.ToArrayStream(rows_affected, stream, error)); + return ADBC_STATUS_OK; } BindStream bind_stream; From 83924163be82c5de39a465ec91ef9c458e1442de Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 9 Aug 2024 13:22:51 -0300 Subject: [PATCH 14/25] maybe unify the excute path --- c/driver/postgresql/result_reader.cc | 40 +++++++++++++++++++--------- c/driver/postgresql/result_reader.h | 1 + c/driver/postgresql/statement.cc | 19 +++---------- 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/c/driver/postgresql/result_reader.cc b/c/driver/postgresql/result_reader.cc index 420e671d6b..3789e23301 100644 --- a/c/driver/postgresql/result_reader.cc +++ b/c/driver/postgresql/result_reader.cc @@ -194,25 +194,19 @@ AdbcStatusCode PqResultArrayReader::Initialize(int64_t* rows_affected, AdbcStatusCode PqResultArrayReader::ToArrayStream(int64_t* affected_rows, struct ArrowArrayStream* out, struct AdbcError* error) { - if (out == nullptr && !bind_stream_) { - // If there is no output requested and nothing to bind, we still need to execute and + if (out == nullptr) { + // If there is no output requested, we still need to execute and // set affected_rows if needed. We don't need an output schema or to set up a copy // reader, so we can skip those steps by going straight to Execute(). This also // enables us to support queries with multiple statements because we can call PQexec() // instead of PQexecParams(). - RAISE_ADBC(helper_.Execute(error)); - - if (affected_rows != nullptr) { - *affected_rows = helper_.AffectedRows(); - } - + RAISE_ADBC(ExecuteAll(affected_rows, error)); return ADBC_STATUS_OK; } - // Execute eagerly. We need this to provide row counts for DELETE and - // CREATE TABLE queries as well as to provide more informative errors - // until this reader class is wired up to provide extended AdbcError - // information. + // Otherwise, execute until we have a result to return. We need this to provide row + // counts for DELETE and CREATE TABLE queries as well as to provide more informative + // errors until this reader class is wired up to provide extended AdbcError information. RAISE_ADBC(Initialize(affected_rows, error)); nanoarrow::ArrayStreamFactory::InitArrayStream( @@ -244,4 +238,26 @@ AdbcStatusCode PqResultArrayReader::BindNextAndExecute(int64_t* affected_rows, return ADBC_STATUS_OK; } +AdbcStatusCode PqResultArrayReader::ExecuteAll(int64_t* affected_rows, AdbcError* error) { + // For the case where we don't need a result, we either need to exhaust the bind + // stream + if (bind_stream_) { + RAISE_ADBC(bind_stream_->Begin([] { return ADBC_STATUS_OK; }, error)); + RAISE_ADBC(bind_stream_->SetParamTypes(*type_resolver_, error)); + RAISE_ADBC(helper_.Prepare(bind_stream_->param_types, error)); + + do { + RAISE_ADBC(BindNextAndExecute(affected_rows, error)); + } while (bind_stream_); + } else { + RAISE_ADBC(helper_.Execute(error)); + + if (affected_rows != nullptr) { + *affected_rows = helper_.AffectedRows(); + } + } + + return ADBC_STATUS_OK; +} + } // namespace adbcpq diff --git a/c/driver/postgresql/result_reader.h b/c/driver/postgresql/result_reader.h index 3a23320e8d..51da6399ed 100644 --- a/c/driver/postgresql/result_reader.h +++ b/c/driver/postgresql/result_reader.h @@ -77,6 +77,7 @@ class PqResultArrayReader { } AdbcStatusCode BindNextAndExecute(int64_t* affected_rows, AdbcError* error); + AdbcStatusCode ExecuteAll(int64_t* affected_rows, AdbcError* error); void ResetErrors() { ArrowErrorInit(&na_error_); diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index 0ed9af69c3..4323ce4d00 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -1129,22 +1129,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( AdbcStatusCode PostgresStatement::ExecuteBind(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error) { - if (stream) { - PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); - reader.SetBind(&bind_); - RAISE_ADBC(reader.ToArrayStream(rows_affected, stream, error)); - return ADBC_STATUS_OK; - } - - BindStream bind_stream; - bind_stream.SetBind(&bind_); - std::memset(&bind_, 0, sizeof(bind_)); - - RAISE_ADBC(bind_stream.Begin([&]() { return ADBC_STATUS_OK; }, error)); - RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error)); - RAISE_ADBC( - bind_stream.Prepare(connection_->conn(), query_, error, connection_->autocommit())); - RAISE_ADBC(bind_stream.Execute(connection_->conn(), rows_affected, error)); + PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); + reader.SetBind(&bind_); + RAISE_ADBC(reader.ToArrayStream(rows_affected, stream, error)); return ADBC_STATUS_OK; } From 948f67844144752b12a9fe3ac84c1a567c94d68b Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 9 Aug 2024 13:24:07 -0300 Subject: [PATCH 15/25] remove previous execute --- c/driver/postgresql/bind_stream.h | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index 63a58021fb..d96cb315bc 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -528,26 +528,6 @@ struct BindStream { return ADBC_STATUS_OK; } - AdbcStatusCode Execute(PGconn* pg_conn, int64_t* rows_affected, - struct AdbcError* error) { - if (rows_affected) *rows_affected = 0; - PGresult* result = nullptr; - - while (true) { - RAISE_ADBC(EnsureNextRow(error)); - if (!current->release) break; - - RAISE_ADBC(BindAndExecuteCurrentRow(pg_conn, &result, error)); - PQclear(result); - if (rows_affected) { - (*rows_affected)++; - } - } - - RAISE_ADBC(Cleanup(pg_conn, error)); - return ADBC_STATUS_OK; - } - AdbcStatusCode ExecuteCopy(PGconn* pg_conn, const PostgresTypeResolver& type_resolver, int64_t* rows_affected, struct AdbcError* error) { // https://github.com/apache/arrow-adbc/issues/1921: PostgreSQL has a max From dfe0585fc4afd7746fd0b3d83ed7b7256b13bb20 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 9 Aug 2024 13:56:05 -0300 Subject: [PATCH 16/25] include limits in bind stream --- c/driver/postgresql/bind_stream.h | 1 + 1 file changed, 1 insertion(+) diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index d96cb315bc..5414c5aa59 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include #include From db0605edffc8442fa6c719858ca909c5a96ee2d3 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 9 Aug 2024 22:33:40 -0300 Subject: [PATCH 17/25] also accept results --- c/driver/postgresql/bind_stream.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index 5414c5aa59..fc144e99fe 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -490,7 +490,7 @@ struct BindStream { /*resultFormat=*/0 /*text*/); ExecStatusType pg_status = PQresultStatus(result); - if (pg_status != PGRES_COMMAND_OK) { + if (pg_status != PGRES_COMMAND_OK && pg_status != PGRES_TUPLES_OK) { AdbcStatusCode code = SetError(error, result, "[libpq] Failed to execute prepared statement: %s %s", PQresStatus(pg_status), PQerrorMessage(pg_conn)); From 0937125ad001e4d0776b607007247f07144da605 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 13 Aug 2024 09:37:08 -0300 Subject: [PATCH 18/25] fix merge --- c/driver/postgresql/bind_stream.h | 51 +-- c/driver/postgresql/statement.cc | 594 ------------------------------ 2 files changed, 30 insertions(+), 615 deletions(-) diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index fc144e99fe..9344d5f73e 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -531,12 +531,6 @@ struct BindStream { AdbcStatusCode ExecuteCopy(PGconn* pg_conn, const PostgresTypeResolver& type_resolver, int64_t* rows_affected, struct AdbcError* error) { - // https://github.com/apache/arrow-adbc/issues/1921: PostgreSQL has a max - // size for a single message that we need to respect (1 GiB - 1). Since - // the buffer can be chunked up as much as we want, go for 16 MiB as our - // limit. - // https://github.com/postgres/postgres/blob/23c5a0e7d43bc925c6001538f04a458933a11fc1/src/common/stringinfo.c#L28 - constexpr int64_t kMaxCopyBufferSize = 0x1000000; if (rows_affected) *rows_affected = 0; PostgresCopyStreamWriter writer; @@ -564,26 +558,15 @@ struct BindStream { return ADBC_STATUS_IO; } - ArrowBuffer buffer = writer.WriteBuffer(); - { - auto* data = reinterpret_cast(buffer.data); - int64_t remaining = buffer.size_bytes; - while (remaining > 0) { - int64_t to_write = std::min(remaining, kMaxCopyBufferSize); - if (PQputCopyData(pg_conn, data, to_write) <= 0) { - SetError(error, "Error writing tuple field data: %s", - PQerrorMessage(pg_conn)); - return ADBC_STATUS_IO; - } - remaining -= to_write; - data += to_write; - } - } + RAISE_ADBC(FlushCopyWriterToConn(pg_conn, writer, error)); if (rows_affected) *rows_affected += current->length; writer.Rewind(); } + // If there were no arrays in the stream, we haven't flushed yet + RAISE_ADBC(FlushCopyWriterToConn(pg_conn, writer, error)); + if (PQputCopyEnd(pg_conn, NULL) <= 0) { SetError(error, "Error message returned by PQputCopyEnd: %s", PQerrorMessage(pg_conn)); @@ -603,5 +586,31 @@ struct BindStream { PQclear(result); return ADBC_STATUS_OK; } + + AdbcStatusCode FlushCopyWriterToConn(PGconn* pg_conn, + const PostgresCopyStreamWriter& writer, + struct AdbcError* error) { + // https://github.com/apache/arrow-adbc/issues/1921: PostgreSQL has a max + // size for a single message that we need to respect (1 GiB - 1). Since + // the buffer can be chunked up as much as we want, go for 16 MiB as our + // limit. + // https://github.com/postgres/postgres/blob/23c5a0e7d43bc925c6001538f04a458933a11fc1/src/common/stringinfo.c#L28 + constexpr int64_t kMaxCopyBufferSize = 0x1000000; + ArrowBuffer buffer = writer.WriteBuffer(); + + auto* data = reinterpret_cast(buffer.data); + int64_t remaining = buffer.size_bytes; + while (remaining > 0) { + int64_t to_write = std::min(remaining, kMaxCopyBufferSize); + if (PQputCopyData(pg_conn, data, to_write) <= 0) { + SetError(error, "Error writing tuple field data: %s", PQerrorMessage(pg_conn)); + return ADBC_STATUS_IO; + } + remaining -= to_write; + data += to_write; + } + + return ADBC_STATUS_OK; + } }; } // namespace adbcpq diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index 4323ce4d00..224472bdc2 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -48,600 +48,6 @@ namespace adbcpq { -namespace { - -/// One-value ArrowArrayStream used to unify the implementations of Bind -struct OneValueStream { - struct ArrowSchema schema; - struct ArrowArray array; - - static int GetSchema(struct ArrowArrayStream* self, struct ArrowSchema* out) { - OneValueStream* stream = static_cast(self->private_data); - return ArrowSchemaDeepCopy(&stream->schema, out); - } - static int GetNext(struct ArrowArrayStream* self, struct ArrowArray* out) { - OneValueStream* stream = static_cast(self->private_data); - *out = stream->array; - stream->array.release = nullptr; - return 0; - } - static const char* GetLastError(struct ArrowArrayStream* self) { return NULL; } - static void Release(struct ArrowArrayStream* self) { - OneValueStream* stream = static_cast(self->private_data); - if (stream->schema.release) { - stream->schema.release(&stream->schema); - stream->schema.release = nullptr; - } - if (stream->array.release) { - stream->array.release(&stream->array); - stream->array.release = nullptr; - } - delete stream; - self->release = nullptr; - } -}; - -/// Helper to manage bind parameters with a prepared statement -struct BindStream { - Handle bind; - Handle bind_schema; - struct ArrowSchemaView bind_schema_view; - std::vector bind_schema_fields; - - // OIDs for parameter types - std::vector param_types; - std::vector param_values; - std::vector param_lengths; - std::vector param_formats; - std::vector param_values_offsets; - std::vector param_values_buffer; - // XXX: this assumes fixed-length fields only - will need more - // consideration to deal with variable-length fields - - bool has_tz_field = false; - std::string tz_setting; - - struct ArrowError na_error; - - explicit BindStream(struct ArrowArrayStream&& bind) { - this->bind.value = std::move(bind); - std::memset(&na_error, 0, sizeof(na_error)); - } - - template - AdbcStatusCode Begin(Callback&& callback, struct AdbcError* error) { - CHECK_NA(INTERNAL, bind->get_schema(&bind.value, &bind_schema.value), error); - CHECK_NA( - INTERNAL, - ArrowSchemaViewInit(&bind_schema_view, &bind_schema.value, /*error*/ nullptr), - error); - - if (bind_schema_view.type != ArrowType::NANOARROW_TYPE_STRUCT) { - SetError(error, "%s", "[libpq] Bind parameters must have type STRUCT"); - return ADBC_STATUS_INVALID_STATE; - } - - bind_schema_fields.resize(bind_schema->n_children); - for (size_t i = 0; i < bind_schema_fields.size(); i++) { - CHECK_NA(INTERNAL, - ArrowSchemaViewInit(&bind_schema_fields[i], bind_schema->children[i], - /*error*/ nullptr), - error); - } - - return std::move(callback)(); - } - - AdbcStatusCode SetParamTypes(const PostgresTypeResolver& type_resolver, - struct AdbcError* error) { - param_types.resize(bind_schema->n_children); - param_values.resize(bind_schema->n_children); - param_lengths.resize(bind_schema->n_children); - param_formats.resize(bind_schema->n_children, kPgBinaryFormat); - param_values_offsets.reserve(bind_schema->n_children); - - for (size_t i = 0; i < bind_schema_fields.size(); i++) { - PostgresTypeId type_id; - switch (bind_schema_fields[i].type) { - case ArrowType::NANOARROW_TYPE_BOOL: - type_id = PostgresTypeId::kBool; - param_lengths[i] = 1; - break; - case ArrowType::NANOARROW_TYPE_INT8: - case ArrowType::NANOARROW_TYPE_INT16: - type_id = PostgresTypeId::kInt2; - param_lengths[i] = 2; - break; - case ArrowType::NANOARROW_TYPE_INT32: - type_id = PostgresTypeId::kInt4; - param_lengths[i] = 4; - break; - case ArrowType::NANOARROW_TYPE_INT64: - type_id = PostgresTypeId::kInt8; - param_lengths[i] = 8; - break; - case ArrowType::NANOARROW_TYPE_FLOAT: - type_id = PostgresTypeId::kFloat4; - param_lengths[i] = 4; - break; - case ArrowType::NANOARROW_TYPE_DOUBLE: - type_id = PostgresTypeId::kFloat8; - param_lengths[i] = 8; - break; - case ArrowType::NANOARROW_TYPE_STRING: - case ArrowType::NANOARROW_TYPE_LARGE_STRING: - type_id = PostgresTypeId::kText; - param_lengths[i] = 0; - break; - case ArrowType::NANOARROW_TYPE_BINARY: - type_id = PostgresTypeId::kBytea; - param_lengths[i] = 0; - break; - case ArrowType::NANOARROW_TYPE_DATE32: - type_id = PostgresTypeId::kDate; - param_lengths[i] = 4; - break; - case ArrowType::NANOARROW_TYPE_TIMESTAMP: - type_id = PostgresTypeId::kTimestamp; - param_lengths[i] = 8; - break; - case ArrowType::NANOARROW_TYPE_DURATION: - case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: - type_id = PostgresTypeId::kInterval; - param_lengths[i] = 16; - break; - case ArrowType::NANOARROW_TYPE_DECIMAL128: - case ArrowType::NANOARROW_TYPE_DECIMAL256: - type_id = PostgresTypeId::kNumeric; - param_lengths[i] = 0; - break; - case ArrowType::NANOARROW_TYPE_DICTIONARY: { - struct ArrowSchemaView value_view; - CHECK_NA(INTERNAL, - ArrowSchemaViewInit(&value_view, bind_schema->children[i]->dictionary, - nullptr), - error); - switch (value_view.type) { - case NANOARROW_TYPE_BINARY: - case NANOARROW_TYPE_LARGE_BINARY: - type_id = PostgresTypeId::kBytea; - param_lengths[i] = 0; - break; - case NANOARROW_TYPE_STRING: - case NANOARROW_TYPE_LARGE_STRING: - type_id = PostgresTypeId::kText; - param_lengths[i] = 0; - break; - default: - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", - bind_schema->children[i]->name, - "') has unsupported dictionary value parameter type ", - ArrowTypeString(value_view.type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - break; - } - default: - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", bind_schema->children[i]->name, - "') has unsupported parameter type ", - ArrowTypeString(bind_schema_fields[i].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - param_types[i] = type_resolver.GetOID(type_id); - if (param_types[i] == 0) { - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", bind_schema->children[i]->name, - "') has type with no corresponding PostgreSQL type ", - ArrowTypeString(bind_schema_fields[i].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - } - - size_t param_values_length = 0; - for (int length : param_lengths) { - param_values_offsets.push_back(param_values_length); - param_values_length += length; - } - param_values_buffer.resize(param_values_length); - return ADBC_STATUS_OK; - } - - AdbcStatusCode Prepare(const PostgresConnection* conn, const std::string& query, - struct AdbcError* error, const bool autocommit) { - // tz-aware timestamps require special handling to set the timezone to UTC - // prior to sending over the binary protocol; must be reset after execute - const auto pg_conn = conn->conn(); - for (int64_t col = 0; col < bind_schema->n_children; col++) { - if ((bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) && - (strcmp("", bind_schema_fields[col].timezone))) { - has_tz_field = true; - - if (autocommit) { - PGresult* begin_result = PQexec(pg_conn, "BEGIN"); - if (PQresultStatus(begin_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, begin_result, - "[libpq] Failed to begin transaction for timezone data: %s", - PQerrorMessage(pg_conn)); - PQclear(begin_result); - return code; - } - PQclear(begin_result); - } - - PGresult* get_tz_result = PQexec(pg_conn, "SELECT current_setting('TIMEZONE')"); - if (PQresultStatus(get_tz_result) != PGRES_TUPLES_OK) { - AdbcStatusCode code = SetError(error, get_tz_result, - "[libpq] Could not query current timezone: %s", - PQerrorMessage(pg_conn)); - PQclear(get_tz_result); - return code; - } - - tz_setting = std::string(PQgetvalue(get_tz_result, 0, 0)); - PQclear(get_tz_result); - - PGresult* set_utc_result = PQexec(pg_conn, "SET TIME ZONE 'UTC'"); - if (PQresultStatus(set_utc_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = SetError(error, set_utc_result, - "[libpq] Failed to set time zone to UTC: %s", - PQerrorMessage(pg_conn)); - PQclear(set_utc_result); - return code; - } - PQclear(set_utc_result); - break; - } - } - - PGresult* result = PQprepare(pg_conn, /*stmtName=*/"", query.c_str(), - /*nParams=*/bind_schema->n_children, param_types.data()); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, result, "[libpq] Failed to prepare query: %s\nQuery was:%s", - PQerrorMessage(pg_conn), query.c_str()); - PQclear(result); - return code; - } - PQclear(result); - return ADBC_STATUS_OK; - } - - AdbcStatusCode Execute(const PostgresConnection* conn, int64_t* rows_affected, - struct AdbcError* error) { - if (rows_affected) *rows_affected = 0; - PGresult* result = nullptr; - const auto pg_conn = conn->conn(); - - while (true) { - Handle array; - int res = bind->get_next(&bind.value, &array.value); - if (res != 0) { - SetError(error, - "[libpq] Failed to read next batch from stream of bind parameters: " - "(%d) %s %s", - res, std::strerror(res), bind->get_last_error(&bind.value)); - return ADBC_STATUS_IO; - } - if (!array->release) break; - - Handle array_view; - // TODO: include error messages - CHECK_NA( - INTERNAL, - ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, nullptr), - error); - CHECK_NA(INTERNAL, ArrowArrayViewSetArray(&array_view.value, &array.value, nullptr), - error); - - for (int64_t row = 0; row < array->length; row++) { - for (int64_t col = 0; col < array_view->n_children; col++) { - if (ArrowArrayViewIsNull(array_view->children[col], row)) { - param_values[col] = nullptr; - continue; - } else { - param_values[col] = param_values_buffer.data() + param_values_offsets[col]; - } - switch (bind_schema_fields[col].type) { - case ArrowType::NANOARROW_TYPE_BOOL: { - const int8_t val = ArrowBitGet( - array_view->children[col]->buffer_views[1].data.as_uint8, row); - std::memcpy(param_values[col], &val, sizeof(int8_t)); - break; - } - - case ArrowType::NANOARROW_TYPE_INT8: { - const int16_t val = - array_view->children[col]->buffer_views[1].data.as_int8[row]; - const uint16_t value = ToNetworkInt16(val); - std::memcpy(param_values[col], &value, sizeof(int16_t)); - break; - } - case ArrowType::NANOARROW_TYPE_INT16: { - const uint16_t value = ToNetworkInt16( - array_view->children[col]->buffer_views[1].data.as_int16[row]); - std::memcpy(param_values[col], &value, sizeof(int16_t)); - break; - } - case ArrowType::NANOARROW_TYPE_INT32: { - const uint32_t value = ToNetworkInt32( - array_view->children[col]->buffer_views[1].data.as_int32[row]); - std::memcpy(param_values[col], &value, sizeof(int32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_INT64: { - const int64_t value = ToNetworkInt64( - array_view->children[col]->buffer_views[1].data.as_int64[row]); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - break; - } - case ArrowType::NANOARROW_TYPE_FLOAT: { - const uint32_t value = ToNetworkFloat4( - array_view->children[col]->buffer_views[1].data.as_float[row]); - std::memcpy(param_values[col], &value, sizeof(uint32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_DOUBLE: { - const uint64_t value = ToNetworkFloat8( - array_view->children[col]->buffer_views[1].data.as_double[row]); - std::memcpy(param_values[col], &value, sizeof(uint64_t)); - break; - } - case ArrowType::NANOARROW_TYPE_STRING: - case ArrowType::NANOARROW_TYPE_LARGE_STRING: - case ArrowType::NANOARROW_TYPE_BINARY: { - const ArrowBufferView view = - ArrowArrayViewGetBytesUnsafe(array_view->children[col], row); - // TODO: overflow check? - param_lengths[col] = static_cast(view.size_bytes); - param_values[col] = const_cast(view.data.as_char); - break; - } - case ArrowType::NANOARROW_TYPE_DATE32: { - // 2000-01-01 - constexpr int32_t kPostgresDateEpoch = 10957; - const int32_t raw_value = - array_view->children[col]->buffer_views[1].data.as_int32[row]; - if (raw_value < INT32_MIN + kPostgresDateEpoch) { - SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1, - "('", bind_schema->children[col]->name, "') Row #", row + 1, - "has value which exceeds postgres date limits"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - const uint32_t value = ToNetworkInt32(raw_value - kPostgresDateEpoch); - std::memcpy(param_values[col], &value, sizeof(int32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_DURATION: - case ArrowType::NANOARROW_TYPE_TIMESTAMP: { - int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row]; - - bool overflow_safe = true; - - auto unit = bind_schema_fields[col].time_unit; - - switch (unit) { - case NANOARROW_TIME_UNIT_SECOND: - overflow_safe = - val <= kMaxSafeSecondsToMicros && val >= kMinSafeSecondsToMicros; - if (overflow_safe) { - val *= 1000000; - } - - break; - case NANOARROW_TIME_UNIT_MILLI: - overflow_safe = - val <= kMaxSafeMillisToMicros && val >= kMinSafeMillisToMicros; - if (overflow_safe) { - val *= 1000; - } - break; - case NANOARROW_TIME_UNIT_MICRO: - break; - case NANOARROW_TIME_UNIT_NANO: - val /= 1000; - break; - } - - if (!overflow_safe) { - SetError(error, - "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 - " has value '%" PRIi64 - "' which exceeds PostgreSQL timestamp limits", - col + 1, bind_schema->children[col]->name, row + 1, - array_view->children[col]->buffer_views[1].data.as_int64[row]); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - if (val < (std::numeric_limits::min)() + kPostgresTimestampEpoch) { - SetError(error, - "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 - " has value '%" PRIi64 "' which would underflow", - col + 1, bind_schema->children[col]->name, row + 1, - array_view->children[col]->buffer_views[1].data.as_int64[row]); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) { - const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - } else if (bind_schema_fields[col].type == - ArrowType::NANOARROW_TYPE_DURATION) { - // postgres stores an interval as a 64 bit offset in microsecond - // resolution alongside a 32 bit day and 32 bit month - // for now we just send 0 for the day / month values - const uint64_t value = ToNetworkInt64(val); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - std::memset(param_values[col] + sizeof(int64_t), 0, sizeof(int64_t)); - } - break; - } - case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { - struct ArrowInterval interval; - ArrowIntervalInit(&interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO); - ArrowArrayViewGetIntervalUnsafe(array_view->children[col], row, &interval); - - const uint32_t months = ToNetworkInt32(interval.months); - const uint32_t days = ToNetworkInt32(interval.days); - const uint64_t ms = ToNetworkInt64(interval.ns / 1000); - - std::memcpy(param_values[col], &ms, sizeof(uint64_t)); - std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t)); - std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t), - &months, sizeof(uint32_t)); - break; - } - default: - SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('", - bind_schema->children[col]->name, - "') has unsupported type for ingestion ", - ArrowTypeString(bind_schema_fields[col].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - } - - result = PQexecPrepared(pg_conn, /*stmtName=*/"", - /*nParams=*/bind_schema->n_children, param_values.data(), - param_lengths.data(), param_formats.data(), - /*resultFormat=*/0 /*text*/); - - ExecStatusType pg_status = PQresultStatus(result); - if (pg_status != PGRES_COMMAND_OK) { - AdbcStatusCode code = SetError( - error, result, "[libpq] Failed to execute prepared statement: %s %s", - PQresStatus(pg_status), PQerrorMessage(pg_conn)); - PQclear(result); - return code; - } - - PQclear(result); - } - if (rows_affected) *rows_affected += array->length; - - if (has_tz_field) { - std::string reset_query = "SET TIME ZONE '" + tz_setting + "'"; - PGresult* reset_tz_result = PQexec(pg_conn, reset_query.c_str()); - if (PQresultStatus(reset_tz_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, reset_tz_result, "[libpq] Failed to reset time zone: %s", - PQerrorMessage(pg_conn)); - PQclear(reset_tz_result); - return code; - } - PQclear(reset_tz_result); - - PGresult* commit_result = PQexec(pg_conn, "COMMIT"); - if (PQresultStatus(commit_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, commit_result, "[libpq] Failed to commit transaction: %s", - PQerrorMessage(pg_conn)); - PQclear(commit_result); - return code; - } - PQclear(commit_result); - } - } - return ADBC_STATUS_OK; - } - - AdbcStatusCode ExecuteCopy(const PostgresConnection* conn, int64_t* rows_affected, - struct AdbcError* error) { - if (rows_affected) *rows_affected = 0; - const auto pg_conn = conn->conn(); - - PostgresCopyStreamWriter writer; - CHECK_NA(INTERNAL, writer.Init(&bind_schema.value), error); - CHECK_NA(INTERNAL, writer.InitFieldWriters(*conn->type_resolver(), nullptr), error); - - CHECK_NA(INTERNAL, writer.WriteHeader(nullptr), error); - - while (true) { - Handle array; - int res = bind->get_next(&bind.value, &array.value); - if (res != 0) { - SetError(error, - "[libpq] Failed to read next batch from stream of bind parameters: " - "(%d) %s %s", - res, std::strerror(res), bind->get_last_error(&bind.value)); - return ADBC_STATUS_IO; - } - if (!array->release) break; - - CHECK_NA(INTERNAL, writer.SetArray(&array.value), error); - - // build writer buffer - int write_result; - do { - write_result = writer.WriteRecord(nullptr); - } while (write_result == NANOARROW_OK); - - // check if not ENODATA at exit - if (write_result != ENODATA) { - SetError(error, "Error occurred writing COPY data: %s", PQerrorMessage(pg_conn)); - return ADBC_STATUS_IO; - } - - RAISE_ADBC(FlushCopyWriterToConn(pg_conn, writer, error)); - - if (rows_affected) *rows_affected += array->length; - writer.Rewind(); - } - - // If there were no arrays in the stream, we haven't flushed yet - RAISE_ADBC(FlushCopyWriterToConn(pg_conn, writer, error)); - - if (PQputCopyEnd(pg_conn, NULL) <= 0) { - SetError(error, "Error message returned by PQputCopyEnd: %s", - PQerrorMessage(pg_conn)); - return ADBC_STATUS_IO; - } - - PGresult* result = PQgetResult(pg_conn); - ExecStatusType pg_status = PQresultStatus(result); - if (pg_status != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, result, "[libpq] Failed to execute COPY statement: %s %s", - PQresStatus(pg_status), PQerrorMessage(pg_conn)); - PQclear(result); - return code; - } - - PQclear(result); - return ADBC_STATUS_OK; - } - - AdbcStatusCode FlushCopyWriterToConn(PGconn* pg_conn, - const PostgresCopyStreamWriter& writer, - struct AdbcError* error) { - // https://github.com/apache/arrow-adbc/issues/1921: PostgreSQL has a max - // size for a single message that we need to respect (1 GiB - 1). Since - // the buffer can be chunked up as much as we want, go for 16 MiB as our - // limit. - // https://github.com/postgres/postgres/blob/23c5a0e7d43bc925c6001538f04a458933a11fc1/src/common/stringinfo.c#L28 - constexpr int64_t kMaxCopyBufferSize = 0x1000000; - ArrowBuffer buffer = writer.WriteBuffer(); - - auto* data = reinterpret_cast(buffer.data); - int64_t remaining = buffer.size_bytes; - while (remaining > 0) { - int64_t to_write = std::min(remaining, kMaxCopyBufferSize); - if (PQputCopyData(pg_conn, data, to_write) <= 0) { - SetError(error, "Error writing tuple field data: %s", PQerrorMessage(pg_conn)); - return ADBC_STATUS_IO; - } - remaining -= to_write; - data += to_write; - } - - return ADBC_STATUS_OK; - } -}; -} // namespace - int TupleReader::GetSchema(struct ArrowSchema* out) { assert(copy_reader_ != nullptr); From a3f2988dfed37474dcac391de97770336d9a0f47 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 13 Aug 2024 11:04:59 -0300 Subject: [PATCH 19/25] test parameterized queries --- c/driver/postgresql/bind_stream.h | 5 +-- c/driver/postgresql/postgresql_test.cc | 53 ++++++++++++++++++++++++++ c/driver/postgresql/result_reader.cc | 12 +++++- 3 files changed, 65 insertions(+), 5 deletions(-) diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index 9344d5f73e..3e440e8005 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -316,7 +316,7 @@ struct BindStream { } AdbcStatusCode BindAndExecuteCurrentRow(PGconn* pg_conn, PGresult** result_out, - AdbcError* error) { + int result_format, AdbcError* error) { int64_t row = current_row; for (int64_t col = 0; col < array_view->n_children; col++) { @@ -486,8 +486,7 @@ struct BindStream { PGresult* result = PQexecPrepared(pg_conn, /*stmtName=*/"", /*nParams=*/bind_schema->n_children, param_values.data(), - param_lengths.data(), param_formats.data(), - /*resultFormat=*/0 /*text*/); + param_lengths.data(), param_formats.data(), result_format); ExecStatusType pg_status = PQresultStatus(result); if (pg_status != PGRES_COMMAND_OK && pg_status != PGRES_TUPLES_OK) { diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index ff3dc0b70b..ef5cca7969 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -1302,6 +1302,59 @@ TEST_F(PostgresStatementTest, ExecuteSchemaParameterizedQuery) { ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); } +TEST_F(PostgresStatementTest, ExecuteParameterizedQueryWithResult) { + nanoarrow::UniqueSchema schema_bind; + ArrowSchemaInit(schema_bind.get()); + ASSERT_THAT(ArrowSchemaSetTypeStruct(schema_bind.get(), 1), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetType(schema_bind->children[0], NANOARROW_TYPE_INT32), + adbc_validation::IsOkErrno()); + + nanoarrow::UniqueArray bind; + ASSERT_THAT(ArrowArrayInitFromSchema(bind.get(), schema_bind.get(), nullptr), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayStartAppending(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayAppendInt(bind->children[0], 123), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishElement(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayAppendInt(bind->children[0], 456), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishElement(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayAppendNull(bind->children[0], 1), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishElement(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishBuildingDefault(bind.get(), nullptr), + adbc_validation::IsOkErrno()); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT $1", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, bind.get(), schema_bind.get(), &error), + IsOkStatus()); + + { + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_EQ(reader.rows_affected, -1); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_EQ(reader.schema->n_children, 1); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->length, 1); + ASSERT_EQ(reader.array_view->children[0]->buffer_views[1].data.as_int32[0], 123); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->length, 1); + ASSERT_EQ(reader.array_view->children[0]->buffer_views[1].data.as_int32[0], 456); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->length, 1); + ASSERT_EQ(reader.array->children[0]->null_count, 1); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->release, nullptr); + } +} + TEST_F(PostgresStatementTest, BatchSizeHint) { ASSERT_THAT(quirks()->EnsureSampleTable(&connection, "batch_size_hint_test", &error), IsOkStatus(&error)); diff --git a/c/driver/postgresql/result_reader.cc b/c/driver/postgresql/result_reader.cc index 3789e23301..a565e40209 100644 --- a/c/driver/postgresql/result_reader.cc +++ b/c/driver/postgresql/result_reader.cc @@ -145,7 +145,7 @@ AdbcStatusCode PqResultArrayReader::Initialize(int64_t* rows_affected, RAISE_ADBC(bind_stream_->SetParamTypes(*type_resolver_, error)); RAISE_ADBC(helper_.Prepare(bind_stream_->param_types, error)); - RAISE_ADBC(BindNextAndExecute(rows_affected, error)); + RAISE_ADBC(BindNextAndExecute(nullptr, error)); // If there were no arrays in the bind stream, we still need a result // to populate the schema. If there were any arrays in the bind stream, @@ -153,6 +153,13 @@ AdbcStatusCode PqResultArrayReader::Initialize(int64_t* rows_affected, if (!helper_.HasResult()) { RAISE_ADBC(helper_.DescribePrepared(error)); } + + // We can't provide affected row counts if there is a bind stream and + // an output because we don't know how many future bind arrays/rows there + // might be. + if (rows_affected != nullptr) { + *rows_affected = -1; + } } else { RAISE_ADBC(helper_.Execute(error)); if (rows_affected != nullptr) { @@ -228,7 +235,8 @@ AdbcStatusCode PqResultArrayReader::BindNextAndExecute(int64_t* affected_rows, } PGresult* result; - RAISE_ADBC(bind_stream_->BindAndExecuteCurrentRow(conn_, &result, error)); + RAISE_ADBC(bind_stream_->BindAndExecuteCurrentRow( + conn_, &result, /*result_format*/ kPgBinaryFormat, error)); helper_.SetResult(result); if (affected_rows) { (*affected_rows) += helper_.AffectedRows(); From 64774d4af7a5549bde857dfc9d2319f513da947a Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 13 Aug 2024 11:16:46 -0300 Subject: [PATCH 20/25] test multiple elements in the bind stream with rows affected. --- c/driver/postgresql/postgresql_test.cc | 82 ++++++++++++++++++++++++++ c/driver/postgresql/result_reader.cc | 5 ++ 2 files changed, 87 insertions(+) diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index ef5cca7969..5fe7a81611 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -1355,6 +1355,88 @@ TEST_F(PostgresStatementTest, ExecuteParameterizedQueryWithResult) { } } +TEST_F(PostgresStatementTest, ExecuteParameterizedQueryWithRowsAffected) { + // Check that when executing one or more parameterized queries that the corresponding + // affected row count is added. + ASSERT_THAT(quirks()->DropTable(&connection, "adbc_test", &error), IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + { + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, "CREATE TABLE adbc_test (ints INT)", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_EQ(reader.rows_affected, -1); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->release, nullptr); + } + + { + // Use INSERT INTO + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement, "INSERT INTO adbc_test (ints) VALUES (123), (456)", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_EQ(reader.rows_affected, 2); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->release, nullptr); + } + + nanoarrow::UniqueSchema schema_bind; + ArrowSchemaInit(schema_bind.get()); + ASSERT_THAT(ArrowSchemaSetTypeStruct(schema_bind.get(), 1), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetType(schema_bind->children[0], NANOARROW_TYPE_INT32), + adbc_validation::IsOkErrno()); + + nanoarrow::UniqueArray bind; + ASSERT_THAT(ArrowArrayInitFromSchema(bind.get(), schema_bind.get(), nullptr), + adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayStartAppending(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayAppendInt(bind->children[0], 123), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishElement(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayAppendInt(bind->children[0], 456), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishElement(bind.get()), adbc_validation::IsOkErrno()); + ASSERT_THAT(ArrowArrayFinishBuildingDefault(bind.get(), nullptr), + adbc_validation::IsOkErrno()); + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, + "DELETE FROM adbc_test WHERE ints = $1", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, bind.get(), schema_bind.get(), &error), + IsOkStatus()); + + { + int64_t rows_affected = -2; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_EQ(rows_affected, 2); + } + + { + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * from adbc_test", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->release, nullptr); + } +} + TEST_F(PostgresStatementTest, BatchSizeHint) { ASSERT_THAT(quirks()->EnsureSampleTable(&connection, "batch_size_hint_test", &error), IsOkStatus(&error)); diff --git a/c/driver/postgresql/result_reader.cc b/c/driver/postgresql/result_reader.cc index a565e40209..447d1151f0 100644 --- a/c/driver/postgresql/result_reader.cc +++ b/c/driver/postgresql/result_reader.cc @@ -254,6 +254,11 @@ AdbcStatusCode PqResultArrayReader::ExecuteAll(int64_t* affected_rows, AdbcError RAISE_ADBC(bind_stream_->SetParamTypes(*type_resolver_, error)); RAISE_ADBC(helper_.Prepare(bind_stream_->param_types, error)); + // Reset affected rows to zero before binding and executing any + if (affected_rows) { + (*affected_rows) = 0; + } + do { RAISE_ADBC(BindNextAndExecute(affected_rows, error)); } while (bind_stream_); From 19c66e2bf57c44c60bf55817dcce24bab07fb326 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 13 Aug 2024 11:28:04 -0300 Subject: [PATCH 21/25] release statement --- c/driver/postgresql/postgresql_test.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index 5fe7a81611..054128973b 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -1353,6 +1353,8 @@ TEST_F(PostgresStatementTest, ExecuteParameterizedQueryWithResult) { ASSERT_NO_FATAL_FAILURE(reader.Next()); ASSERT_EQ(reader.array->release, nullptr); } + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); } TEST_F(PostgresStatementTest, ExecuteParameterizedQueryWithRowsAffected) { @@ -1435,6 +1437,8 @@ TEST_F(PostgresStatementTest, ExecuteParameterizedQueryWithRowsAffected) { ASSERT_NO_FATAL_FAILURE(reader.Next()); ASSERT_EQ(reader.array->release, nullptr); } + + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); } TEST_F(PostgresStatementTest, BatchSizeHint) { From 3975e2a68984c334cde0fb8d53fe5b461c294796 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 13 Aug 2024 11:38:47 -0300 Subject: [PATCH 22/25] leak is not that --- c/driver/postgresql/postgresql_test.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index 054128973b..5fe7a81611 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -1353,8 +1353,6 @@ TEST_F(PostgresStatementTest, ExecuteParameterizedQueryWithResult) { ASSERT_NO_FATAL_FAILURE(reader.Next()); ASSERT_EQ(reader.array->release, nullptr); } - - ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); } TEST_F(PostgresStatementTest, ExecuteParameterizedQueryWithRowsAffected) { @@ -1437,8 +1435,6 @@ TEST_F(PostgresStatementTest, ExecuteParameterizedQueryWithRowsAffected) { ASSERT_NO_FATAL_FAILURE(reader.Next()); ASSERT_EQ(reader.array->release, nullptr); } - - ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); } TEST_F(PostgresStatementTest, BatchSizeHint) { From 9c7c0b0d549c560a8d369bdc816352011cd27ffb Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 13 Aug 2024 12:11:49 -0300 Subject: [PATCH 23/25] don't recreate statement --- c/driver/postgresql/postgresql_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index 5fe7a81611..c45168ef68 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -1410,7 +1410,6 @@ TEST_F(PostgresStatementTest, ExecuteParameterizedQueryWithRowsAffected) { ASSERT_THAT(ArrowArrayFinishBuildingDefault(bind.get(), nullptr), adbc_validation::IsOkErrno()); - ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "DELETE FROM adbc_test WHERE ints = $1", &error), IsOkStatus(&error)); From bde9c9a3c2aa7561a489ae6cc96c719d5d503ac0 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 14 Aug 2024 15:32:58 -0300 Subject: [PATCH 24/25] fix meson build, comment --- c/driver/postgresql/meson.build | 1 + c/driver/postgresql/result_reader.cc | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/c/driver/postgresql/meson.build b/c/driver/postgresql/meson.build index 179cae59b0..d2d1c28119 100644 --- a/c/driver/postgresql/meson.build +++ b/c/driver/postgresql/meson.build @@ -25,6 +25,7 @@ adbc_postgres_driver_lib = library( 'database.cc', 'postgresql.cc', 'result_helper.cc', + 'result_builder.cc', 'statement.cc', ], include_directories: [include_dir, c_dir], diff --git a/c/driver/postgresql/result_reader.cc b/c/driver/postgresql/result_reader.cc index 447d1151f0..21bc2bdbc4 100644 --- a/c/driver/postgresql/result_reader.cc +++ b/c/driver/postgresql/result_reader.cc @@ -248,7 +248,7 @@ AdbcStatusCode PqResultArrayReader::BindNextAndExecute(int64_t* affected_rows, AdbcStatusCode PqResultArrayReader::ExecuteAll(int64_t* affected_rows, AdbcError* error) { // For the case where we don't need a result, we either need to exhaust the bind - // stream + // stream (if there is one) or execute the query without binding. if (bind_stream_) { RAISE_ADBC(bind_stream_->Begin([] { return ADBC_STATUS_OK; }, error)); RAISE_ADBC(bind_stream_->SetParamTypes(*type_resolver_, error)); From 26981f2bfcd628e4a10767aaa7f282c716568609 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 14 Aug 2024 16:21:11 -0300 Subject: [PATCH 25/25] fix filename --- c/driver/postgresql/meson.build | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c/driver/postgresql/meson.build b/c/driver/postgresql/meson.build index d2d1c28119..ac075417f5 100644 --- a/c/driver/postgresql/meson.build +++ b/c/driver/postgresql/meson.build @@ -25,7 +25,7 @@ adbc_postgres_driver_lib = library( 'database.cc', 'postgresql.cc', 'result_helper.cc', - 'result_builder.cc', + 'result_reader.cc', 'statement.cc', ], include_directories: [include_dir, c_dir],