diff --git a/.gitignore b/.gitignore index 96608ef..5ffe2a3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ arrow -duckdb # build is out build diff --git a/CMakeLists.txt b/CMakeLists.txt index 171835f..04ad769 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -152,6 +152,7 @@ set_target_properties(sqlfliteserver PROPERTIES PUBLIC_HEADER ${HEADER_FILES}) target_include_directories(sqlfliteserver PRIVATE src/sqlite src/duckdb + src/library/include ${SQLITE_INCLUDE_DIR} ${DUCKDB_INCLUDE_DIR} ${JWT_CPP_INCLUDE_DIR} diff --git a/src/duckdb/duckdb_server.cpp b/src/duckdb/duckdb_server.cpp index aef8dad..86eb106 100644 --- a/src/duckdb/duckdb_server.cpp +++ b/src/duckdb/duckdb_server.cpp @@ -27,7 +27,9 @@ #include #include +#include #include +#include #include "duckdb_sql_info.h" #include "duckdb_statement.h" @@ -35,15 +37,17 @@ #include "duckdb_tables_schema_batch_reader.h" #include "duckdb/main/prepared_statement.hpp" #include "duckdb/main/prepared_statement_data.hpp" +#include "flight_sql_fwd.h" -namespace arrow { -namespace flight { -namespace sql { -namespace duckdbflight { +using arrow::Result; +using arrow::Status; +namespace sql = flight::sql; + +namespace sqlflite::ddb { namespace { -std::string PrepareQueryForGetTables(const GetTables &command) { +std::string PrepareQueryForGetTables(const sql::GetTables &command) { std::stringstream table_query; table_query @@ -82,10 +86,10 @@ std::string PrepareQueryForGetTables(const GetTables &command) { } Status SetParametersOnDuckDBStatement(std::shared_ptr stmt, - FlightMessageReader *reader) { + flight::FlightMessageReader *reader) { while (true) { - ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, reader->Next()) - const std::shared_ptr &record_batch = chunk.data; + ARROW_ASSIGN_OR_RAISE(flight::FlightStreamChunk chunk, reader->Next()) + const std::shared_ptr &record_batch = chunk.data; if (record_batch == nullptr) break; const int64_t num_rows = record_batch->num_rows(); @@ -93,8 +97,8 @@ Status SetParametersOnDuckDBStatement(std::shared_ptr stmt, for (int row_index = 0; row_index < num_rows; ++row_index) { for (int column_index = 0; column_index < num_columns; ++column_index) { - const std::shared_ptr &column = record_batch->column(column_index); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, + const std::shared_ptr &column = record_batch->column(column_index); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, column->GetScalar(row_index)) stmt->bind_parameters.push_back(scalar->ToString()); @@ -105,9 +109,9 @@ Status SetParametersOnDuckDBStatement(std::shared_ptr stmt, return Status::OK(); } -arrow::Result> DoGetDuckDBQuery( +Result> DoGetDuckDBQuery( std::shared_ptr db, const std::string &query, - const std::shared_ptr &schema) { + const std::shared_ptr &schema) { std::shared_ptr statement; ARROW_ASSIGN_OR_RAISE(statement, DuckDBStatement::Create(db, query)) @@ -115,16 +119,18 @@ arrow::Result> DoGetDuckDBQuery( std::shared_ptr reader; ARROW_ASSIGN_OR_RAISE(reader, DuckDBStatementBatchReader::Create(statement, schema)) - return std::make_unique(reader); + return std::make_unique(reader); } -arrow::Result> GetFlightInfoForCommand( - const FlightDescriptor &descriptor, const std::shared_ptr &schema) { - std::vector endpoints{FlightEndpoint{{descriptor.cmd}, {}}}; +Result> GetFlightInfoForCommand( + const flight::FlightDescriptor &descriptor, + const std::shared_ptr &schema) { + std::vector endpoints{ + flight::FlightEndpoint{{descriptor.cmd}, {}}}; ARROW_ASSIGN_OR_RAISE(auto result, - FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)) + flight::FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)) - return std::make_unique(result); + return std::make_unique(result); } std::string PrepareQueryForGetImportedOrExportedKeys(const std::string &filter) { @@ -169,7 +175,7 @@ class DuckDBFlightSqlServer::Impl { std::default_random_engine gen_; std::mutex mutex_; - arrow::Result> GetStatementByHandle( + Result> GetStatementByHandle( const std::string &handle) { std::scoped_lock guard(mutex_); auto search = prepared_statements_.find(handle); @@ -179,7 +185,7 @@ class DuckDBFlightSqlServer::Impl { return search->second; } - arrow::Result> GetConnection( + Result> GetConnection( const std::string &transaction_id) { if (transaction_id.empty()) return db_conn_; @@ -192,17 +198,17 @@ class DuckDBFlightSqlServer::Impl { } // Create a Ticket that combines a query and a transaction ID. - arrow::Result EncodeTransactionQuery(const std::string &query, - const std::string &transaction_id) { + Result EncodeTransactionQuery(const std::string &query, + const std::string &transaction_id) { std::string transaction_query = transaction_id; transaction_query += ':'; transaction_query += query; ARROW_ASSIGN_OR_RAISE(auto ticket_string, - CreateStatementQueryTicket(transaction_query)) - return Ticket{std::move(ticket_string)}; + sql::CreateStatementQueryTicket(transaction_query)); + return flight::Ticket{std::move(ticket_string)}; } - arrow::Result> DecodeTransactionQuery( + Result> DecodeTransactionQuery( const std::string &ticket) { auto divider = ticket.find(':'); if (divider == std::string::npos) { @@ -242,23 +248,25 @@ class DuckDBFlightSqlServer::Impl { return randomString; } - arrow::Result> GetFlightInfoStatement( - const ServerCallContext &context, const StatementQuery &command, - const FlightDescriptor &descriptor) { + Result> GetFlightInfoStatement( + const flight::ServerCallContext &context, const sql::StatementQuery &command, + const flight::FlightDescriptor &descriptor) { const std::string &query = command.query; ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(command.transaction_id)) ARROW_ASSIGN_OR_RAISE(auto statement, DuckDBStatement::Create(db, query)) ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()) ARROW_ASSIGN_OR_RAISE(auto ticket, EncodeTransactionQuery(query, command.transaction_id)) - std::vector endpoints{FlightEndpoint{std::move(ticket), {}}}; - ARROW_ASSIGN_OR_RAISE(auto result, - FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)) - return std::make_unique(result); + std::vector endpoints{ + flight::FlightEndpoint{std::move(ticket), {}}}; + ARROW_ASSIGN_OR_RAISE( + auto result, flight::FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)) + return std::make_unique(result); } - arrow::Result> DoGetStatement( - const ServerCallContext &context, const StatementQueryTicket &command) { + Result> DoGetStatement( + const flight::ServerCallContext &context, + const sql::StatementQueryTicket &command) { ARROW_ASSIGN_OR_RAISE(auto pair, DecodeTransactionQuery(command.statement_handle)) const std::string &sql = pair.first; const std::string transaction_id = pair.second; @@ -266,31 +274,32 @@ class DuckDBFlightSqlServer::Impl { ARROW_ASSIGN_OR_RAISE(auto statement, DuckDBStatement::Create(db, sql)) ARROW_ASSIGN_OR_RAISE(auto reader, DuckDBStatementBatchReader::Create(statement)) - return std::make_unique(reader); + return std::make_unique(reader); } - arrow::Result> GetFlightInfoCatalogs( - const ServerCallContext &context, const FlightDescriptor &descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetCatalogsSchema()); + Result> GetFlightInfoCatalogs( + const flight::ServerCallContext &context, + const flight::FlightDescriptor &descriptor) { + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetCatalogsSchema()); } - arrow::Result> DoGetCatalogs( - const ServerCallContext &context) { + Result> DoGetCatalogs( + const flight::ServerCallContext &context) { std::string query = "SELECT DISTINCT catalog_name FROM information_schema.schemata ORDER BY " "catalog_name"; - return DoGetDuckDBQuery(db_conn_, query, SqlSchema::GetCatalogsSchema()); + return DoGetDuckDBQuery(db_conn_, query, sql::SqlSchema::GetCatalogsSchema()); } - arrow::Result> GetFlightInfoSchemas( - const ServerCallContext &context, const GetDbSchemas &command, - const FlightDescriptor &descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetDbSchemasSchema()); + Result> GetFlightInfoSchemas( + const flight::ServerCallContext &context, const sql::GetDbSchemas &command, + const flight::FlightDescriptor &descriptor) { + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetDbSchemasSchema()); } - arrow::Result> DoGetDbSchemas( - const ServerCallContext &context, const GetDbSchemas &command) { + Result> DoGetDbSchemas( + const flight::ServerCallContext &context, const sql::GetDbSchemas &command) { std::stringstream query; query << "SELECT catalog_name, schema_name AS db_schema_name FROM " "information_schema.schemata WHERE 1 = 1"; @@ -304,12 +313,12 @@ class DuckDBFlightSqlServer::Impl { } query << " ORDER BY catalog_name, db_schema_name"; - return DoGetDuckDBQuery(db_conn_, query.str(), SqlSchema::GetDbSchemasSchema()); + return DoGetDuckDBQuery(db_conn_, query.str(), sql::SqlSchema::GetDbSchemasSchema()); } - arrow::Result CreatePreparedStatement( - const ServerCallContext &context, - const ActionCreatePreparedStatementRequest &request) { + Result CreatePreparedStatement( + const flight::ServerCallContext &context, + const sql::ActionCreatePreparedStatementRequest &request) { std::scoped_lock guard(mutex_); std::shared_ptr statement; ARROW_ASSIGN_OR_RAISE(statement, DuckDBStatement::Create(db_conn_, request.query)) @@ -320,7 +329,7 @@ class DuckDBFlightSqlServer::Impl { std::shared_ptr stmt = statement->GetDuckDBStmt(); const id_t parameter_count = stmt->n_param; - FieldVector parameter_fields; + arrow::FieldVector parameter_fields; parameter_fields.reserve(parameter_count); duckdb::shared_ptr parameter_data = stmt->data; @@ -334,11 +343,12 @@ class DuckDBFlightSqlServer::Impl { parameter_fields.push_back(field(parameter_name, parameter_arrow_type)); } - const std::shared_ptr ¶meter_schema = arrow::schema(parameter_fields); + const std::shared_ptr ¶meter_schema = + arrow::schema(parameter_fields); - ActionCreatePreparedStatementResult result{.dataset_schema = dataset_schema, - .parameter_schema = parameter_schema, - .prepared_statement_handle = handle}; + sql::ActionCreatePreparedStatementResult result{.dataset_schema = dataset_schema, + .parameter_schema = parameter_schema, + .prepared_statement_handle = handle}; if (print_queries_) { std::cout << "Client running SQL command: \n" @@ -349,8 +359,8 @@ class DuckDBFlightSqlServer::Impl { return result; } - Status ClosePreparedStatement(const ServerCallContext &context, - const ActionClosePreparedStatementRequest &request) { + Status ClosePreparedStatement(const flight::ServerCallContext &context, + const sql::ActionClosePreparedStatementRequest &request) { std::scoped_lock guard(mutex_); const std::string &prepared_statement_handle = request.prepared_statement_handle; @@ -364,9 +374,10 @@ class DuckDBFlightSqlServer::Impl { return Status::OK(); } - arrow::Result> GetFlightInfoPreparedStatement( - const ServerCallContext &context, const PreparedStatementQuery &command, - const FlightDescriptor &descriptor) { + Result> GetFlightInfoPreparedStatement( + const flight::ServerCallContext &context, + const sql::PreparedStatementQuery &command, + const flight::FlightDescriptor &descriptor) { std::scoped_lock guard(mutex_); const std::string &prepared_statement_handle = command.prepared_statement_handle; @@ -382,8 +393,9 @@ class DuckDBFlightSqlServer::Impl { return GetFlightInfoForCommand(descriptor, schema); } - arrow::Result> DoGetPreparedStatement( - const ServerCallContext &context, const PreparedStatementQuery &command) { + Result> DoGetPreparedStatement( + const flight::ServerCallContext &context, + const sql::PreparedStatementQuery &command) { std::scoped_lock guard(mutex_); const std::string &prepared_statement_handle = command.prepared_statement_handle; @@ -396,13 +408,13 @@ class DuckDBFlightSqlServer::Impl { ARROW_ASSIGN_OR_RAISE(auto reader, DuckDBStatementBatchReader::Create(statement)) - return std::make_unique(reader); + return std::make_unique(reader); } - Status DoPutPreparedStatementQuery(const ServerCallContext &context, - const PreparedStatementQuery &command, - FlightMessageReader *reader, - FlightMetadataWriter *writer) { + Status DoPutPreparedStatementQuery(const flight::ServerCallContext &context, + const sql::PreparedStatementQuery &command, + flight::FlightMessageReader *reader, + flight::FlightMetadataWriter *writer) { const std::string &prepared_statement_handle = command.prepared_statement_handle; ARROW_ASSIGN_OR_RAISE(auto statement, GetStatementByHandle(prepared_statement_handle)) @@ -411,9 +423,9 @@ class DuckDBFlightSqlServer::Impl { return Status::OK(); } - arrow::Result DoPutPreparedStatementUpdate( - const ServerCallContext &context, const PreparedStatementUpdate &command, - FlightMessageReader *reader) { + Result DoPutPreparedStatementUpdate( + const flight::ServerCallContext &context, + const sql::PreparedStatementUpdate &command, flight::FlightMessageReader *reader) { const std::string &prepared_statement_handle = command.prepared_statement_handle; ARROW_ASSIGN_OR_RAISE(auto statement, GetStatementByHandle(prepared_statement_handle)) @@ -422,68 +434,71 @@ class DuckDBFlightSqlServer::Impl { return statement->ExecuteUpdate(); } - arrow::Result> DoGetTables( - const ServerCallContext &context, const GetTables &command) { + Result> DoGetTables( + const flight::ServerCallContext &context, const sql::GetTables &command) { std::string query = PrepareQueryForGetTables(command); std::shared_ptr statement; ARROW_ASSIGN_OR_RAISE(statement, DuckDBStatement::Create(db_conn_, query)) ARROW_ASSIGN_OR_RAISE(auto reader, DuckDBStatementBatchReader::Create( - statement, SqlSchema::GetTablesSchema())) + statement, sql::SqlSchema::GetTablesSchema())) if (command.include_schema) { auto table_schema_reader = std::make_shared(reader, query, db_conn_); - return std::make_unique(table_schema_reader); + return std::make_unique(table_schema_reader); } else { - return std::make_unique(reader); + return std::make_unique(reader); } } - arrow::Result DoPutCommandStatementUpdate(const ServerCallContext &context, - const StatementUpdate &command) { + Result DoPutCommandStatementUpdate(const flight::ServerCallContext &context, + const sql::StatementUpdate &command) { const std::string &sql = command.query; ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(command.transaction_id)) ARROW_ASSIGN_OR_RAISE(auto statement, DuckDBStatement::Create(db, sql)) return statement->ExecuteUpdate(); } - arrow::Result> GetFlightInfoTables( - const ServerCallContext &context, const GetTables &command, - const FlightDescriptor &descriptor) { - std::vector endpoints{FlightEndpoint{{descriptor.cmd}, {}}}; + Result> GetFlightInfoTables( + const flight::ServerCallContext &context, const sql::GetTables &command, + const flight::FlightDescriptor &descriptor) { + std::vector endpoints{ + flight::FlightEndpoint{{descriptor.cmd}, {}}}; bool include_schema = command.include_schema; ARROW_ASSIGN_OR_RAISE( auto result, - FlightInfo::Make(include_schema ? *SqlSchema::GetTablesSchemaWithIncludedSchema() - : *SqlSchema::GetTablesSchema(), - descriptor, endpoints, -1, -1)) + flight::FlightInfo::Make( + include_schema ? *sql::SqlSchema::GetTablesSchemaWithIncludedSchema() + : *sql::SqlSchema::GetTablesSchema(), + descriptor, endpoints, -1, -1)) - return std::make_unique(result); + return std::make_unique(result); } - arrow::Result> GetFlightInfoTableTypes( - const ServerCallContext &context, const FlightDescriptor &descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetTableTypesSchema()); + Result> GetFlightInfoTableTypes( + const flight::ServerCallContext &context, + const flight::FlightDescriptor &descriptor) { + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetTableTypesSchema()); } - arrow::Result> DoGetTableTypes( - const ServerCallContext &context) { + Result> DoGetTableTypes( + const flight::ServerCallContext &context) { std::string query = "SELECT DISTINCT table_type FROM information_schema.tables"; - return DoGetDuckDBQuery(db_conn_, query, SqlSchema::GetTableTypesSchema()); + return DoGetDuckDBQuery(db_conn_, query, sql::SqlSchema::GetTableTypesSchema()); } - arrow::Result> GetFlightInfoPrimaryKeys( - const ServerCallContext &context, const GetPrimaryKeys &command, - const FlightDescriptor &descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetPrimaryKeysSchema()); + Result> GetFlightInfoPrimaryKeys( + const flight::ServerCallContext &context, const sql::GetPrimaryKeys &command, + const flight::FlightDescriptor &descriptor) { + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetPrimaryKeysSchema()); } - arrow::Result> DoGetPrimaryKeys( - const ServerCallContext &context, const GetPrimaryKeys &command) { + Result> DoGetPrimaryKeys( + const flight::ServerCallContext &context, const sql::GetPrimaryKeys &command) { std::stringstream table_query; // The field key_name can not be recovered by the sqlite, so it is being set @@ -502,7 +517,7 @@ class DuckDBFlightSqlServer::Impl { " WHERE constraint_type = 'PRIMARY KEY'\n" " ) WHERE 1 = 1"; - const TableRef &table_ref = command.table_ref; + const sql::TableRef &table_ref = command.table_ref; table_query << " AND catalog_name = " << (table_ref.catalog.has_value() ? "'" + table_ref.catalog.value() + "'" : "CURRENT_DATABASE()"); @@ -514,18 +529,18 @@ class DuckDBFlightSqlServer::Impl { table_query << " and table_name LIKE '" << table_ref.table << "'"; return DoGetDuckDBQuery(db_conn_, table_query.str(), - SqlSchema::GetPrimaryKeysSchema()); + sql::SqlSchema::GetPrimaryKeysSchema()); } - arrow::Result> GetFlightInfoImportedKeys( - const ServerCallContext &context, const GetImportedKeys &command, - const FlightDescriptor &descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetImportedKeysSchema()); + Result> GetFlightInfoImportedKeys( + const flight::ServerCallContext &context, const sql::GetImportedKeys &command, + const flight::FlightDescriptor &descriptor) { + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetImportedKeysSchema()); } - arrow::Result> DoGetImportedKeys( - const ServerCallContext &context, const GetImportedKeys &command) { - const TableRef &table_ref = command.table_ref; + Result> DoGetImportedKeys( + const flight::ServerCallContext &context, const sql::GetImportedKeys &command) { + const sql::TableRef &table_ref = command.table_ref; std::string filter = "fk_table_name = '" + table_ref.table + "'"; filter += " AND fk_catalog_name = " + (table_ref.catalog.has_value() @@ -536,18 +551,18 @@ class DuckDBFlightSqlServer::Impl { } std::string query = PrepareQueryForGetImportedOrExportedKeys(filter); - return DoGetDuckDBQuery(db_conn_, query, SqlSchema::GetImportedKeysSchema()); + return DoGetDuckDBQuery(db_conn_, query, sql::SqlSchema::GetImportedKeysSchema()); } - arrow::Result> GetFlightInfoExportedKeys( - const ServerCallContext &context, const GetExportedKeys &command, - const FlightDescriptor &descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetExportedKeysSchema()); + Result> GetFlightInfoExportedKeys( + const flight::ServerCallContext &context, const sql::GetExportedKeys &command, + const flight::FlightDescriptor &descriptor) { + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetExportedKeysSchema()); } - arrow::Result> DoGetExportedKeys( - const ServerCallContext &context, const GetExportedKeys &command) { - const TableRef &table_ref = command.table_ref; + Result> DoGetExportedKeys( + const flight::ServerCallContext &context, const sql::GetExportedKeys &command) { + const sql::TableRef &table_ref = command.table_ref; std::string filter = "pk_table_name = '" + table_ref.table + "'"; filter += " AND pk_catalog_name = " + (table_ref.catalog.has_value() ? "'" + table_ref.catalog.value() + "'" @@ -557,18 +572,18 @@ class DuckDBFlightSqlServer::Impl { } std::string query = PrepareQueryForGetImportedOrExportedKeys(filter); - return DoGetDuckDBQuery(db_conn_, query, SqlSchema::GetExportedKeysSchema()); + return DoGetDuckDBQuery(db_conn_, query, sql::SqlSchema::GetExportedKeysSchema()); } - arrow::Result> GetFlightInfoCrossReference( - const ServerCallContext &context, const GetCrossReference &command, - const FlightDescriptor &descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetCrossReferenceSchema()); + Result> GetFlightInfoCrossReference( + const flight::ServerCallContext &context, const sql::GetCrossReference &command, + const flight::FlightDescriptor &descriptor) { + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetCrossReferenceSchema()); } - arrow::Result> DoGetCrossReference( - const ServerCallContext &context, const GetCrossReference &command) { - const TableRef &pk_table_ref = command.pk_table_ref; + Result> DoGetCrossReference( + const flight::ServerCallContext &context, const sql::GetCrossReference &command) { + const sql::TableRef &pk_table_ref = command.pk_table_ref; std::string filter = "pk_table_name = '" + pk_table_ref.table + "'"; filter += " AND pk_catalog_name = " + (pk_table_ref.catalog.has_value() ? "'" + pk_table_ref.catalog.value() + "'" @@ -577,7 +592,7 @@ class DuckDBFlightSqlServer::Impl { filter += " AND pk_schema_name = '" + pk_table_ref.db_schema.value() + "'"; } - const TableRef &fk_table_ref = command.fk_table_ref; + const sql::TableRef &fk_table_ref = command.fk_table_ref; filter += " AND fk_table_name = '" + fk_table_ref.table + "'"; filter += " AND fk_catalog_name = " + (fk_table_ref.catalog.has_value() ? "'" + fk_table_ref.catalog.value() + "'" @@ -587,22 +602,23 @@ class DuckDBFlightSqlServer::Impl { } std::string query = PrepareQueryForGetImportedOrExportedKeys(filter); - return DoGetDuckDBQuery(db_conn_, query, SqlSchema::GetCrossReferenceSchema()); + return DoGetDuckDBQuery(db_conn_, query, sql::SqlSchema::GetCrossReferenceSchema()); } - arrow::Result BeginTransaction( - const ServerCallContext &context, const ActionBeginTransactionRequest &request) { + Result BeginTransaction( + const flight::ServerCallContext &context, + const sql::ActionBeginTransactionRequest &request) { std::string handle = GenerateRandomString(); auto new_db = std::make_shared(*db_instance_); ARROW_RETURN_NOT_OK(ExecuteSql(new_db, "BEGIN TRANSACTION")); std::scoped_lock guard(mutex_); open_transactions_[handle] = new_db; - return ActionBeginTransactionResult{std::move(handle)}; + return sql::ActionBeginTransactionResult{std::move(handle)}; } - Status EndTransaction(const ServerCallContext &context, - const ActionEndTransactionRequest &request) { + Status EndTransaction(const flight::ServerCallContext &context, + const sql::ActionEndTransactionRequest &request) { Status status; std::shared_ptr transaction = nullptr; { @@ -612,7 +628,7 @@ class DuckDBFlightSqlServer::Impl { return Status::KeyError("Unknown transaction ID: ", request.transaction_id); } - if (request.action == ActionEndTransactionRequest::kCommit) { + if (request.action == sql::ActionEndTransactionRequest::kCommit) { status = ExecuteSql(it->second, "COMMIT"); } else { status = ExecuteSql(it->second, "ROLLBACK"); @@ -638,7 +654,7 @@ class DuckDBFlightSqlServer::Impl { DuckDBFlightSqlServer::DuckDBFlightSqlServer(std::shared_ptr impl) : impl_(std::move(impl)) {} -arrow::Result> DuckDBFlightSqlServer::Create( +Result> DuckDBFlightSqlServer::Create( const std::string &path, const duckdb::DBConfig &config, const bool &print_queries) { std::cout << "DuckDB version: " << duckdb_library_version() << std::endl; @@ -659,166 +675,171 @@ arrow::Result> DuckDBFlightSqlServer::Cre DuckDBFlightSqlServer::~DuckDBFlightSqlServer() = default; -arrow::Status DuckDBFlightSqlServer::ExecuteSql(const std::string &sql) { +Status DuckDBFlightSqlServer::ExecuteSql(const std::string &sql) { return impl_->ExecuteSql(sql); } -arrow::Result> DuckDBFlightSqlServer::GetFlightInfoStatement( - const ServerCallContext &context, const StatementQuery &command, - const FlightDescriptor &descriptor) { +Result> DuckDBFlightSqlServer::GetFlightInfoStatement( + const flight::ServerCallContext &context, const sql::StatementQuery &command, + const flight::FlightDescriptor &descriptor) { return impl_->GetFlightInfoStatement(context, command, descriptor); } -arrow::Result> DuckDBFlightSqlServer::DoGetStatement( - const ServerCallContext &context, const StatementQueryTicket &command) { +Result> DuckDBFlightSqlServer::DoGetStatement( + const flight::ServerCallContext &context, const sql::StatementQueryTicket &command) { return impl_->DoGetStatement(context, command); } -arrow::Result> DuckDBFlightSqlServer::GetFlightInfoCatalogs( - const ServerCallContext &context, const FlightDescriptor &descriptor) { +Result> DuckDBFlightSqlServer::GetFlightInfoCatalogs( + const flight::ServerCallContext &context, + const flight::FlightDescriptor &descriptor) { return impl_->GetFlightInfoCatalogs(context, descriptor); } -arrow::Result> DuckDBFlightSqlServer::DoGetCatalogs( - const ServerCallContext &context) { +Result> DuckDBFlightSqlServer::DoGetCatalogs( + const flight::ServerCallContext &context) { return impl_->DoGetCatalogs(context); } -arrow::Result> DuckDBFlightSqlServer::GetFlightInfoSchemas( - const ServerCallContext &context, const GetDbSchemas &command, - const FlightDescriptor &descriptor) { +Result> DuckDBFlightSqlServer::GetFlightInfoSchemas( + const flight::ServerCallContext &context, const sql::GetDbSchemas &command, + const flight::FlightDescriptor &descriptor) { return impl_->GetFlightInfoSchemas(context, command, descriptor); } -arrow::Result> DuckDBFlightSqlServer::DoGetDbSchemas( - const ServerCallContext &context, const GetDbSchemas &command) { +Result> DuckDBFlightSqlServer::DoGetDbSchemas( + const flight::ServerCallContext &context, const sql::GetDbSchemas &command) { return impl_->DoGetDbSchemas(context, command); } -arrow::Result> DuckDBFlightSqlServer::GetFlightInfoTables( - const ServerCallContext &context, const GetTables &command, - const FlightDescriptor &descriptor) { +Result> DuckDBFlightSqlServer::GetFlightInfoTables( + const flight::ServerCallContext &context, const sql::GetTables &command, + const flight::FlightDescriptor &descriptor) { return impl_->GetFlightInfoTables(context, command, descriptor); } -arrow::Result> DuckDBFlightSqlServer::DoGetTables( - const ServerCallContext &context, const GetTables &command) { +Result> DuckDBFlightSqlServer::DoGetTables( + const flight::ServerCallContext &context, const sql::GetTables &command) { return impl_->DoGetTables(context, command); } -arrow::Result> DuckDBFlightSqlServer::GetFlightInfoTableTypes( - const ServerCallContext &context, const FlightDescriptor &descriptor) { +Result> +DuckDBFlightSqlServer::GetFlightInfoTableTypes( + const flight::ServerCallContext &context, + const flight::FlightDescriptor &descriptor) { return impl_->GetFlightInfoTableTypes(context, descriptor); } -arrow::Result> DuckDBFlightSqlServer::DoGetTableTypes( - const ServerCallContext &context) { +Result> DuckDBFlightSqlServer::DoGetTableTypes( + const flight::ServerCallContext &context) { return impl_->DoGetTableTypes(context); } -arrow::Result DuckDBFlightSqlServer::DoPutCommandStatementUpdate( - const ServerCallContext &context, const StatementUpdate &command) { +Result DuckDBFlightSqlServer::DoPutCommandStatementUpdate( + const flight::ServerCallContext &context, const sql::StatementUpdate &command) { return impl_->DoPutCommandStatementUpdate(context, command); } -arrow::Result +Result DuckDBFlightSqlServer::CreatePreparedStatement( - const ServerCallContext &context, - const ActionCreatePreparedStatementRequest &request) { + const flight::ServerCallContext &context, + const sql::ActionCreatePreparedStatementRequest &request) { return impl_->CreatePreparedStatement(context, request); } Status DuckDBFlightSqlServer::ClosePreparedStatement( - const ServerCallContext &context, - const ActionClosePreparedStatementRequest &request) { + const flight::ServerCallContext &context, + const sql::ActionClosePreparedStatementRequest &request) { return impl_->ClosePreparedStatement(context, request); } -arrow::Result> +Result> DuckDBFlightSqlServer::GetFlightInfoPreparedStatement( - const ServerCallContext &context, const PreparedStatementQuery &command, - const FlightDescriptor &descriptor) { + const flight::ServerCallContext &context, const sql::PreparedStatementQuery &command, + const flight::FlightDescriptor &descriptor) { return impl_->GetFlightInfoPreparedStatement(context, command, descriptor); } -arrow::Result> -DuckDBFlightSqlServer::DoGetPreparedStatement(const ServerCallContext &context, - const PreparedStatementQuery &command) { +Result> +DuckDBFlightSqlServer::DoGetPreparedStatement( + const flight::ServerCallContext &context, + const sql::PreparedStatementQuery &command) { return impl_->DoGetPreparedStatement(context, command); } Status DuckDBFlightSqlServer::DoPutPreparedStatementQuery( - const ServerCallContext &context, const PreparedStatementQuery &command, - FlightMessageReader *reader, FlightMetadataWriter *writer) { + const flight::ServerCallContext &context, const sql::PreparedStatementQuery &command, + flight::FlightMessageReader *reader, flight::FlightMetadataWriter *writer) { return impl_->DoPutPreparedStatementQuery(context, command, reader, writer); } -arrow::Result DuckDBFlightSqlServer::DoPutPreparedStatementUpdate( - const ServerCallContext &context, const PreparedStatementUpdate &command, - FlightMessageReader *reader) { +Result DuckDBFlightSqlServer::DoPutPreparedStatementUpdate( + const flight::ServerCallContext &context, const sql::PreparedStatementUpdate &command, + flight::FlightMessageReader *reader) { return impl_->DoPutPreparedStatementUpdate(context, command, reader); } -arrow::Result> -DuckDBFlightSqlServer::GetFlightInfoPrimaryKeys(const ServerCallContext &context, - const GetPrimaryKeys &command, - const FlightDescriptor &descriptor) { +Result> +DuckDBFlightSqlServer::GetFlightInfoPrimaryKeys( + const flight::ServerCallContext &context, const sql::GetPrimaryKeys &command, + const flight::FlightDescriptor &descriptor) { return impl_->GetFlightInfoPrimaryKeys(context, command, descriptor); } -arrow::Result> DuckDBFlightSqlServer::DoGetPrimaryKeys( - const ServerCallContext &context, const GetPrimaryKeys &command) { +Result> DuckDBFlightSqlServer::DoGetPrimaryKeys( + const flight::ServerCallContext &context, const sql::GetPrimaryKeys &command) { return impl_->DoGetPrimaryKeys(context, command); } -arrow::Result> -DuckDBFlightSqlServer::GetFlightInfoImportedKeys(const ServerCallContext &context, - const GetImportedKeys &command, - const FlightDescriptor &descriptor) { +Result> +DuckDBFlightSqlServer::GetFlightInfoImportedKeys( + const flight::ServerCallContext &context, const sql::GetImportedKeys &command, + const flight::FlightDescriptor &descriptor) { return impl_->GetFlightInfoImportedKeys(context, command, descriptor); } -arrow::Result> DuckDBFlightSqlServer::DoGetImportedKeys( - const ServerCallContext &context, const GetImportedKeys &command) { +Result> +DuckDBFlightSqlServer::DoGetImportedKeys(const flight::ServerCallContext &context, + const sql::GetImportedKeys &command) { return impl_->DoGetImportedKeys(context, command); } -arrow::Result> -DuckDBFlightSqlServer::GetFlightInfoExportedKeys(const ServerCallContext &context, - const GetExportedKeys &command, - const FlightDescriptor &descriptor) { +Result> +DuckDBFlightSqlServer::GetFlightInfoExportedKeys( + const flight::ServerCallContext &context, const sql::GetExportedKeys &command, + const flight::FlightDescriptor &descriptor) { return impl_->GetFlightInfoExportedKeys(context, command, descriptor); } -arrow::Result> DuckDBFlightSqlServer::DoGetExportedKeys( - const ServerCallContext &context, const GetExportedKeys &command) { +Result> +DuckDBFlightSqlServer::DoGetExportedKeys(const flight::ServerCallContext &context, + const sql::GetExportedKeys &command) { return impl_->DoGetExportedKeys(context, command); } -arrow::Result> -DuckDBFlightSqlServer::GetFlightInfoCrossReference(const ServerCallContext &context, - const GetCrossReference &command, - const FlightDescriptor &descriptor) { +Result> +DuckDBFlightSqlServer::GetFlightInfoCrossReference( + const flight::ServerCallContext &context, const sql::GetCrossReference &command, + const flight::FlightDescriptor &descriptor) { return impl_->GetFlightInfoCrossReference(context, command, descriptor); } -arrow::Result> -DuckDBFlightSqlServer::DoGetCrossReference(const ServerCallContext &context, - const GetCrossReference &command) { +Result> +DuckDBFlightSqlServer::DoGetCrossReference(const flight::ServerCallContext &context, + const sql::GetCrossReference &command) { return impl_->DoGetCrossReference(context, command); } -arrow::Result DuckDBFlightSqlServer::BeginTransaction( - const ServerCallContext &context, const ActionBeginTransactionRequest &request) { +Result DuckDBFlightSqlServer::BeginTransaction( + const flight::ServerCallContext &context, + const sql::ActionBeginTransactionRequest &request) { return impl_->BeginTransaction(context, request); } -Status DuckDBFlightSqlServer::EndTransaction(const ServerCallContext &context, - const ActionEndTransactionRequest &request) { +Status DuckDBFlightSqlServer::EndTransaction( + const flight::ServerCallContext &context, + const sql::ActionEndTransactionRequest &request) { return impl_->EndTransaction(context, request); } -} // namespace duckdbflight -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::ddb diff --git a/src/duckdb/duckdb_server.h b/src/duckdb/duckdb_server.h index 0cf0ab0..6416c39 100644 --- a/src/duckdb/duckdb_server.h +++ b/src/duckdb/duckdb_server.h @@ -24,20 +24,18 @@ #include #include +#include "flight_sql_fwd.h" -namespace arrow { -namespace flight { -namespace sql { -namespace duckdbflight { +namespace sqlflite::ddb { /// \brief Convert a column type to a ArrowType. /// \param duckdb_type the duckdb type. /// \return The equivalent ArrowType. -std::shared_ptr GetArrowType(const char *duckdb_type); +std::shared_ptr GetArrowType(const char *duckdb_type); /// \brief Example implementation of FlightSqlServerBase backed by an in-memory DuckDB /// database. -class DuckDBFlightSqlServer : public FlightSqlServerBase { +class DuckDBFlightSqlServer : public flight::sql::FlightSqlServerBase { public: ~DuckDBFlightSqlServer() override; @@ -46,102 +44,121 @@ class DuckDBFlightSqlServer : public FlightSqlServerBase { /// \brief Auxiliary method used to execute an arbitrary SQL statement on the underlying /// SQLite database. - Status ExecuteSql(const std::string &sql); + arrow::Status ExecuteSql(const std::string &sql); - arrow::Result> GetFlightInfoStatement( - const ServerCallContext &context, const StatementQuery &command, - const FlightDescriptor &descriptor) override; + arrow::Result> GetFlightInfoStatement( + const flight::ServerCallContext &context, + const flight::sql::StatementQuery &command, + const flight::FlightDescriptor &descriptor) override; - arrow::Result> DoGetStatement( - const ServerCallContext &context, const StatementQueryTicket &command) override; + arrow::Result> DoGetStatement( + const flight::ServerCallContext &context, + const flight::sql::StatementQueryTicket &command) override; - arrow::Result> GetFlightInfoCatalogs( - const ServerCallContext &context, const FlightDescriptor &descriptor) override; + arrow::Result> GetFlightInfoCatalogs( + const flight::ServerCallContext &context, + const flight::FlightDescriptor &descriptor) override; - arrow::Result> DoGetCatalogs( - const ServerCallContext &context) override; + arrow::Result> DoGetCatalogs( + const flight::ServerCallContext &context) override; - arrow::Result> GetFlightInfoSchemas( - const ServerCallContext &context, const GetDbSchemas &command, - const FlightDescriptor &descriptor) override; + arrow::Result> GetFlightInfoSchemas( + const flight::ServerCallContext &context, const flight::sql::GetDbSchemas &command, + const flight::FlightDescriptor &descriptor) override; - arrow::Result> DoGetDbSchemas( - const ServerCallContext &context, const GetDbSchemas &command) override; + arrow::Result> DoGetDbSchemas( + const flight::ServerCallContext &context, + const flight::sql::GetDbSchemas &command) override; arrow::Result DoPutCommandStatementUpdate( - const ServerCallContext &context, const StatementUpdate &update) override; + const flight::ServerCallContext &context, + const flight::sql::StatementUpdate &update) override; - arrow::Result CreatePreparedStatement( - const ServerCallContext &context, - const ActionCreatePreparedStatementRequest &request) override; + arrow::Result CreatePreparedStatement( + const flight::ServerCallContext &context, + const flight::sql::ActionCreatePreparedStatementRequest &request) override; - Status ClosePreparedStatement( - const ServerCallContext &context, - const ActionClosePreparedStatementRequest &request) override; + arrow::Status ClosePreparedStatement( + const flight::ServerCallContext &context, + const flight::sql::ActionClosePreparedStatementRequest &request) override; - arrow::Result> GetFlightInfoPreparedStatement( - const ServerCallContext &context, const PreparedStatementQuery &command, - const FlightDescriptor &descriptor) override; + arrow::Result> GetFlightInfoPreparedStatement( + const flight::ServerCallContext &context, + const flight::sql::PreparedStatementQuery &command, + const flight::FlightDescriptor &descriptor) override; - arrow::Result> DoGetPreparedStatement( - const ServerCallContext &context, const PreparedStatementQuery &command) override; + arrow::Result> DoGetPreparedStatement( + const flight::ServerCallContext &context, + const flight::sql::PreparedStatementQuery &command) override; - Status DoPutPreparedStatementQuery(const ServerCallContext &context, - const PreparedStatementQuery &command, - FlightMessageReader *reader, - FlightMetadataWriter *writer) override; + arrow::Status DoPutPreparedStatementQuery( + const flight::ServerCallContext &context, + const flight::sql::PreparedStatementQuery &command, + flight::FlightMessageReader *reader, flight::FlightMetadataWriter *writer) override; arrow::Result DoPutPreparedStatementUpdate( - const ServerCallContext &context, const PreparedStatementUpdate &command, - FlightMessageReader *reader) override; + const flight::ServerCallContext &context, + const flight::sql::PreparedStatementUpdate &command, + flight::FlightMessageReader *reader) override; - arrow::Result> GetFlightInfoTables( - const ServerCallContext &context, const GetTables &command, - const FlightDescriptor &descriptor) override; + arrow::Result> GetFlightInfoTables( + const flight::ServerCallContext &context, const flight::sql::GetTables &command, + const flight::FlightDescriptor &descriptor) override; - arrow::Result> DoGetTables( - const ServerCallContext &context, const GetTables &command) override; + arrow::Result> DoGetTables( + const flight::ServerCallContext &context, + const flight::sql::GetTables &command) override; - arrow::Result> GetFlightInfoTableTypes( - const ServerCallContext &context, const FlightDescriptor &descriptor) override; + arrow::Result> GetFlightInfoTableTypes( + const flight::ServerCallContext &context, + const flight::FlightDescriptor &descriptor) override; - arrow::Result> DoGetTableTypes( - const ServerCallContext &context) override; + arrow::Result> DoGetTableTypes( + const flight::ServerCallContext &context) override; - arrow::Result> GetFlightInfoImportedKeys( - const ServerCallContext &context, const GetImportedKeys &command, - const FlightDescriptor &descriptor) override; + arrow::Result> GetFlightInfoImportedKeys( + const flight::ServerCallContext &context, + const flight::sql::GetImportedKeys &command, + const flight::FlightDescriptor &descriptor) override; - arrow::Result> DoGetImportedKeys( - const ServerCallContext &context, const GetImportedKeys &command) override; + arrow::Result> DoGetImportedKeys( + const flight::ServerCallContext &context, + const flight::sql::GetImportedKeys &command) override; - arrow::Result> GetFlightInfoExportedKeys( - const ServerCallContext &context, const GetExportedKeys &command, - const FlightDescriptor &descriptor) override; + arrow::Result> GetFlightInfoExportedKeys( + const flight::ServerCallContext &context, + const flight::sql::GetExportedKeys &command, + const flight::FlightDescriptor &descriptor) override; - arrow::Result> DoGetExportedKeys( - const ServerCallContext &context, const GetExportedKeys &command) override; + arrow::Result> DoGetExportedKeys( + const flight::ServerCallContext &context, + const flight::sql::GetExportedKeys &command) override; - arrow::Result> GetFlightInfoCrossReference( - const ServerCallContext &context, const GetCrossReference &command, - const FlightDescriptor &descriptor) override; + arrow::Result> GetFlightInfoCrossReference( + const flight::ServerCallContext &context, + const flight::sql::GetCrossReference &command, + const flight::FlightDescriptor &descriptor) override; - arrow::Result> DoGetCrossReference( - const ServerCallContext &context, const GetCrossReference &command) override; + arrow::Result> DoGetCrossReference( + const flight::ServerCallContext &context, + const flight::sql::GetCrossReference &command) override; - arrow::Result> GetFlightInfoPrimaryKeys( - const ServerCallContext &context, const GetPrimaryKeys &command, - const FlightDescriptor &descriptor) override; + arrow::Result> GetFlightInfoPrimaryKeys( + const flight::ServerCallContext &context, + const flight::sql::GetPrimaryKeys &command, + const flight::FlightDescriptor &descriptor) override; - arrow::Result> DoGetPrimaryKeys( - const ServerCallContext &context, const GetPrimaryKeys &command) override; + arrow::Result> DoGetPrimaryKeys( + const flight::ServerCallContext &context, + const flight::sql::GetPrimaryKeys &command) override; - arrow::Result BeginTransaction( - const ServerCallContext &context, - const ActionBeginTransactionRequest &request) override; + arrow::Result BeginTransaction( + const flight::ServerCallContext &context, + const flight::sql::ActionBeginTransactionRequest &request) override; - Status EndTransaction(const ServerCallContext &context, - const ActionEndTransactionRequest &request) override; + arrow::Status EndTransaction( + const flight::ServerCallContext &context, + const flight::sql::ActionEndTransactionRequest &request) override; private: class Impl; @@ -151,7 +168,4 @@ class DuckDBFlightSqlServer : public FlightSqlServerBase { explicit DuckDBFlightSqlServer(std::shared_ptr impl); }; -} // namespace duckdbflight -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::ddb diff --git a/src/duckdb/duckdb_sql_info.cpp b/src/duckdb/duckdb_sql_info.cpp index a5d5e7a..790e4e7 100644 --- a/src/duckdb/duckdb_sql_info.cpp +++ b/src/duckdb/duckdb_sql_info.cpp @@ -15,21 +15,27 @@ // specific language governing permissions and limitations // under the License. -#include "duckdb_sql_info.h" +#include #include #include #include "arrow/util/config.h" -namespace arrow { -namespace flight { -namespace sql { -namespace duckdbflight { +#include "duckdb_sql_info.h" +#include "flight_sql_fwd.h" + +namespace sql = flight::sql; + +namespace sqlflite::ddb { // clang-format off /// \brief Gets the mapping from SQL info ids to SqlInfoResult instances. /// \return the cache. -SqlInfoResultMap GetSqlInfoResultMap() { +sql::SqlInfoResultMap GetSqlInfoResultMap() { + using SqlInfo = sql::SqlInfoOptions::SqlInfo; + using SqlInfoOptions = sql::SqlInfoOptions; + using SqlInfoResult = sql::SqlInfoResult; + return { {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, SqlInfoResult(std::string("db_name"))}, @@ -50,17 +56,17 @@ SqlInfoResultMap GetSqlInfoResultMap() { SqlInfoResult(true)}, {SqlInfoOptions::SqlInfo::SQL_DDL_TABLE, SqlInfoResult(true)}, {SqlInfoOptions::SqlInfo::SQL_IDENTIFIER_CASE, - SqlInfoResult(int64_t(SqlInfoOptions::SqlSupportedCaseSensitivity:: + SqlInfoResult(static_cast(SqlInfoOptions::SqlSupportedCaseSensitivity:: SQL_CASE_SENSITIVITY_CASE_INSENSITIVE))}, {SqlInfoOptions::SqlInfo::SQL_IDENTIFIER_QUOTE_CHAR, SqlInfoResult(std::string("\""))}, {SqlInfoOptions::SqlInfo::SQL_QUOTED_IDENTIFIER_CASE, - SqlInfoResult(int64_t(SqlInfoOptions::SqlSupportedCaseSensitivity:: + SqlInfoResult(static_cast(SqlInfoOptions::SqlSupportedCaseSensitivity:: SQL_CASE_SENSITIVITY_CASE_INSENSITIVE))}, {SqlInfoOptions::SqlInfo::SQL_ALL_TABLES_ARE_SELECTABLE, SqlInfoResult(true)}, {SqlInfoOptions::SqlInfo::SQL_NULL_ORDERING, SqlInfoResult( - int64_t(SqlInfoOptions::SqlNullOrdering::SQL_NULLS_SORTED_AT_END))}, + static_cast(SqlInfoOptions::SqlNullOrdering::SQL_NULLS_SORTED_AT_END))}, {SqlInfoOptions::SqlInfo::SQL_KEYWORDS, SqlInfoResult(std::vector({"ABORT_P" "ABSOLUTE_P" @@ -654,7 +660,4 @@ SqlInfoResultMap GetSqlInfoResultMap() { } // clang-format on -} // namespace duckdbflight -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::ddb diff --git a/src/duckdb/duckdb_sql_info.h b/src/duckdb/duckdb_sql_info.h index abe903b..1b904ab 100644 --- a/src/duckdb/duckdb_sql_info.h +++ b/src/duckdb/duckdb_sql_info.h @@ -18,17 +18,12 @@ #pragma once #include +#include "flight_sql_fwd.h" -namespace arrow { -namespace flight { -namespace sql { -namespace duckdbflight { +namespace sqlflite::ddb { /// \brief Gets the mapping from SQL info ids to SqlInfoResult instances. /// \return the cache. -SqlInfoResultMap GetSqlInfoResultMap(); +flight::sql::SqlInfoResultMap GetSqlInfoResultMap(); -} // namespace duckdbflight -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::ddb diff --git a/src/duckdb/duckdb_statement.cpp b/src/duckdb/duckdb_statement.cpp index 6899a05..fcf7ba4 100644 --- a/src/duckdb/duckdb_statement.cpp +++ b/src/duckdb/duckdb_statement.cpp @@ -27,44 +27,42 @@ #include #include "duckdb_server.h" +using arrow::Status; using duckdb::QueryResult; -namespace arrow { -namespace flight { -namespace sql { -namespace duckdbflight { +namespace sqlflite::ddb { -std::shared_ptr GetDataTypeFromDuckDbType( +std::shared_ptr GetDataTypeFromDuckDbType( const duckdb::LogicalType duckdb_type) { const duckdb::LogicalTypeId column_type_id = duckdb_type.id(); switch (column_type_id) { case duckdb::LogicalTypeId::INTEGER: - return int32(); + return arrow::int32(); case duckdb::LogicalTypeId::DECIMAL: { uint8_t width = 0; uint8_t scale = 0; bool dec_properties = duckdb_type.GetDecimalProperties(width, scale); - return decimal(scale, width); + return arrow::decimal(scale, width); } case duckdb::LogicalTypeId::FLOAT: - return float32(); + return arrow::float32(); case duckdb::LogicalTypeId::DOUBLE: - return float64(); + return arrow::float64(); case duckdb::LogicalTypeId::CHAR: case duckdb::LogicalTypeId::VARCHAR: - return utf8(); + return arrow::utf8(); case duckdb::LogicalTypeId::BLOB: - return binary(); + return arrow::binary(); case duckdb::LogicalTypeId::TINYINT: - return int8(); + return arrow::int8(); case duckdb::LogicalTypeId::SMALLINT: - return int16(); + return arrow::int16(); case duckdb::LogicalTypeId::BIGINT: - return int64(); + return arrow::int64(); case duckdb::LogicalTypeId::BOOLEAN: - return boolean(); + return arrow::boolean(); case duckdb::LogicalTypeId::DATE: - return date32(); + return arrow::date32(); case duckdb::LogicalTypeId::TIME: case duckdb::LogicalTypeId::TIMESTAMP_MS: return timestamp(arrow::TimeUnit::MILLI); @@ -78,13 +76,13 @@ std::shared_ptr GetDataTypeFromDuckDbType( return duration( arrow::TimeUnit::MICRO); // ASSUMING MICRO AS DUCKDB's DOCS DOES NOT SPECIFY case duckdb::LogicalTypeId::UTINYINT: - return uint8(); + return arrow::uint8(); case duckdb::LogicalTypeId::USMALLINT: - return uint16(); + return arrow::uint16(); case duckdb::LogicalTypeId::UINTEGER: - return uint32(); + return arrow::uint32(); case duckdb::LogicalTypeId::UBIGINT: - return int64(); + return arrow::int64(); case duckdb::LogicalTypeId::INVALID: case duckdb::LogicalTypeId::SQLNULL: case duckdb::LogicalTypeId::UNKNOWN: @@ -93,7 +91,7 @@ std::shared_ptr GetDataTypeFromDuckDbType( case duckdb::LogicalTypeId::TIMESTAMP_TZ: case duckdb::LogicalTypeId::TIME_TZ: case duckdb::LogicalTypeId::HUGEINT: - return decimal128(38, 0); + return arrow::decimal128(38, 0); case duckdb::LogicalTypeId::POINTER: case duckdb::LogicalTypeId::VALIDITY: case duckdb::LogicalTypeId::UUID: @@ -103,7 +101,7 @@ std::shared_ptr GetDataTypeFromDuckDbType( case duckdb::LogicalTypeId::TABLE: case duckdb::LogicalTypeId::ENUM: default: - return null(); + return arrow::null(); } } @@ -130,8 +128,8 @@ arrow::Result DuckDBStatement::Execute() { return 0; } -arrow::Result> DuckDBStatement::FetchResult() { - std::shared_ptr record_batch; +arrow::Result> DuckDBStatement::FetchResult() { + std::shared_ptr record_batch; ArrowArray res_arr; ArrowSchema res_schema; duckdb::ClientProperties res_options; @@ -165,7 +163,7 @@ arrow::Result DuckDBStatement::ExecuteUpdate() { return result->get()->num_rows(); } -arrow::Result> DuckDBStatement::GetSchema() const { +arrow::Result> DuckDBStatement::GetSchema() const { // get the names and types of the result schema auto names = stmt_->GetNames(); auto types = stmt_->GetTypes(); @@ -181,7 +179,4 @@ arrow::Result> DuckDBStatement::GetSchema() const { return return_value; } -} // namespace duckdbflight -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::ddb diff --git a/src/duckdb/duckdb_statement.h b/src/duckdb/duckdb_statement.h index 8eeedcb..fb8f974 100644 --- a/src/duckdb/duckdb_statement.h +++ b/src/duckdb/duckdb_statement.h @@ -25,22 +25,19 @@ #include #include -// clang-format off -namespace arrow { -namespace flight { -namespace sql { -namespace duckdbflight { +#include "flight_sql_fwd.h" -std::shared_ptr GetDataTypeFromDuckDbType( - const duckdb::LogicalType duckdb_type -); +namespace sqlflite::ddb { + +std::shared_ptr GetDataTypeFromDuckDbType( + const duckdb::LogicalType duckdb_type); /// \brief Create an object ColumnMetadata using the column type and /// table name. /// \param column_type The DuckDB type. /// \param table The table name. /// \return A Column Metadata object. -ColumnMetadata GetColumnMetadata(int column_type, const char* table); +flight::sql::ColumnMetadata GetColumnMetadata(int column_type, const char* table); class DuckDBStatement { public: @@ -48,17 +45,17 @@ class DuckDBStatement { /// \param[in] db duckdb database instance. /// \param[in] sql SQL statement. /// \return A DuckDBStatement object. - static arrow::Result> Create(std::shared_ptr con, - const std::string& sql); + static arrow::Result> Create( + std::shared_ptr con, const std::string& sql); ~DuckDBStatement(); /// \brief Creates an Arrow Schema based on the results of this statement. /// \return The resulting Schema. - arrow::Result> GetSchema() const; + arrow::Result> GetSchema() const; arrow::Result Execute(); - arrow::Result> FetchResult(); + arrow::Result> FetchResult(); // arrow::Result> GetArrowSchema(); std::shared_ptr GetDuckDBStmt() const; @@ -69,20 +66,16 @@ class DuckDBStatement { duckdb::vector bind_parameters; -private: + private: std::shared_ptr con_; std::shared_ptr stmt_; duckdb::unique_ptr query_result_; - DuckDBStatement( - std::shared_ptr con, - std::shared_ptr stmt) { + DuckDBStatement(std::shared_ptr con, + std::shared_ptr stmt) { con_ = con; stmt_ = stmt; } }; -} // namespace duckdbflight -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::ddb diff --git a/src/duckdb/duckdb_statement_batch_reader.cpp b/src/duckdb/duckdb_statement_batch_reader.cpp index 8186edf..2c9108d 100644 --- a/src/duckdb/duckdb_statement_batch_reader.cpp +++ b/src/duckdb/duckdb_statement_batch_reader.cpp @@ -26,19 +26,17 @@ #include "duckdb_statement.h" -// clang-format off -namespace arrow { -namespace flight { -namespace sql { -namespace duckdbflight { +namespace sqlflite::ddb { // Batch size for SQLite statement results static constexpr int kMaxBatchSize = 1024; -std::shared_ptr DuckDBStatementBatchReader::schema() const { return schema_; } +std::shared_ptr DuckDBStatementBatchReader::schema() const { + return schema_; +} DuckDBStatementBatchReader::DuckDBStatementBatchReader( - std::shared_ptr statement, std::shared_ptr schema) + std::shared_ptr statement, std::shared_ptr schema) : statement_(std::move(statement)), schema_(std::move(schema)), rc_(DuckDBSuccess), @@ -57,25 +55,22 @@ DuckDBStatementBatchReader::Create(const std::shared_ptr& state arrow::Result> DuckDBStatementBatchReader::Create(const std::shared_ptr& statement, - const std::shared_ptr& schema) { + const std::shared_ptr& schema) { std::shared_ptr result( new DuckDBStatementBatchReader(statement, schema)); return result; } -Status DuckDBStatementBatchReader::ReadNext(std::shared_ptr* out) { - +arrow::Status DuckDBStatementBatchReader::ReadNext( + std::shared_ptr* out) { if (!already_executed_) { ARROW_ASSIGN_OR_RAISE(rc_, statement_->Execute()); already_executed_ = true; } ARROW_ASSIGN_OR_RAISE(*out, statement_->FetchResult()); - return Status::OK(); + return arrow::Status::OK(); } -} // namespace duckdbflight -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::ddb diff --git a/src/duckdb/duckdb_statement_batch_reader.h b/src/duckdb/duckdb_statement_batch_reader.h index e0a4902..fa42388 100644 --- a/src/duckdb/duckdb_statement_batch_reader.h +++ b/src/duckdb/duckdb_statement_batch_reader.h @@ -21,14 +21,11 @@ #include #include #include "duckdb_statement.h" +#include "flight_sql_fwd.h" -// clang-format off -namespace arrow { -namespace flight { -namespace sql { -namespace duckdbflight { +namespace sqlflite::ddb { -class DuckDBStatementBatchReader : public RecordBatchReader { +class DuckDBStatementBatchReader : public arrow::RecordBatchReader { public: /// \brief Creates a RecordBatchReader backed by a duckdb statement. /// \param[in] statement duckdb statement to be read. @@ -42,24 +39,21 @@ class DuckDBStatementBatchReader : public RecordBatchReader { /// \return A DuckDBStatementBatchReader.. static arrow::Result> Create( const std::shared_ptr& statement, - const std::shared_ptr& schema); + const std::shared_ptr& schema); - std::shared_ptr schema() const override; + std::shared_ptr schema() const override; - Status ReadNext(std::shared_ptr* out) override; + arrow::Status ReadNext(std::shared_ptr* out) override; private: std::shared_ptr statement_; - std::shared_ptr schema_; + std::shared_ptr schema_; int rc_; bool already_executed_; bool results_read_; DuckDBStatementBatchReader(std::shared_ptr statement, - std::shared_ptr schema); + std::shared_ptr schema); }; -} // namespace duckdbflight -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::ddb diff --git a/src/duckdb/duckdb_tables_schema_batch_reader.cpp b/src/duckdb/duckdb_tables_schema_batch_reader.cpp index e7da28f..6dd0c12 100644 --- a/src/duckdb/duckdb_tables_schema_batch_reader.cpp +++ b/src/duckdb/duckdb_tables_schema_batch_reader.cpp @@ -27,73 +27,73 @@ #include "arrow/ipc/writer.h" #include "arrow/record_batch.h" -// clang-format off -namespace arrow { -namespace flight { -namespace sql { -namespace duckdbflight { - -std::shared_ptr DuckDBTablesWithSchemaBatchReader::schema() const { - return SqlSchema::GetTablesSchemaWithIncludedSchema(); +#include "flight_sql_fwd.h" + +using arrow::Status; + +namespace sqlflite::ddb { + +std::shared_ptr DuckDBTablesWithSchemaBatchReader::schema() const { + return flight::sql::SqlSchema::GetTablesSchemaWithIncludedSchema(); } -Status DuckDBTablesWithSchemaBatchReader::ReadNext(std::shared_ptr* batch) { +Status DuckDBTablesWithSchemaBatchReader::ReadNext( + std::shared_ptr *batch) { if (already_executed_) { - *batch = NULLPTR; - return Status::OK(); - } - else { - std::shared_ptr schema_statement; - ARROW_ASSIGN_OR_RAISE(schema_statement, - DuckDBStatement::Create(db_conn_, main_query_)); + *batch = NULLPTR; + return Status::OK(); + } else { + std::shared_ptr schema_statement; + ARROW_ASSIGN_OR_RAISE(schema_statement, + DuckDBStatement::Create(db_conn_, main_query_)); - std::shared_ptr first_batch; + std::shared_ptr first_batch; - ARROW_RETURN_NOT_OK(reader_->ReadNext(&first_batch)); + ARROW_RETURN_NOT_OK(reader_->ReadNext(&first_batch)); - if (!first_batch) { - *batch = NULLPTR; - return Status::OK(); - } + if (!first_batch) { + *batch = NULLPTR; + return Status::OK(); + } - const std::shared_ptr table_name_array = - first_batch->GetColumnByName("table_name"); + const std::shared_ptr table_name_array = + first_batch->GetColumnByName("table_name"); - BinaryBuilder schema_builder; + arrow::BinaryBuilder schema_builder; - auto *string_array = reinterpret_cast(table_name_array.get()); + auto *string_array = reinterpret_cast(table_name_array.get()); - for (int table_name_index = 0; table_name_index < table_name_array->length(); table_name_index++) { - const std::string &table_name = string_array->GetString(table_name_index); + for (int table_name_index = 0; table_name_index < table_name_array->length(); + table_name_index++) { + const std::string &table_name = string_array->GetString(table_name_index); - // Just get the schema from a prepared statement - std::shared_ptr table_schema_statement; - ARROW_ASSIGN_OR_RAISE(table_schema_statement, - DuckDBStatement::Create(db_conn_, "SELECT * FROM " + table_name + " WHERE 1 = 0")); + // Just get the schema from a prepared statement + std::shared_ptr table_schema_statement; + ARROW_ASSIGN_OR_RAISE( + table_schema_statement, + DuckDBStatement::Create(db_conn_, + "SELECT * FROM " + table_name + " WHERE 1 = 0")); - ARROW_ASSIGN_OR_RAISE(auto table_schema, table_schema_statement->GetSchema()); + ARROW_ASSIGN_OR_RAISE(auto table_schema, table_schema_statement->GetSchema()); - const arrow::Result> &value = - ipc::SerializeSchema(*table_schema); + const arrow::Result> &value = + arrow::ipc::SerializeSchema(*table_schema); - std::shared_ptr schema_buffer; - ARROW_ASSIGN_OR_RAISE(schema_buffer, value); + std::shared_ptr schema_buffer; + ARROW_ASSIGN_OR_RAISE(schema_buffer, value); - ARROW_RETURN_NOT_OK(schema_builder.Append(::std::string_view(*schema_buffer))); - } + ARROW_RETURN_NOT_OK(schema_builder.Append(::std::string_view(*schema_buffer))); + } - std::shared_ptr schema_array; - ARROW_RETURN_NOT_OK(schema_builder.Finish(&schema_array)); + std::shared_ptr schema_array; + ARROW_RETURN_NOT_OK(schema_builder.Finish(&schema_array)); - ARROW_ASSIGN_OR_RAISE(*batch, first_batch->AddColumn(4, "table_schema", schema_array)); - already_executed_ = true; + ARROW_ASSIGN_OR_RAISE(*batch, + first_batch->AddColumn(4, "table_schema", schema_array)); + already_executed_ = true; - return Status::OK(); + return Status::OK(); } - } -} // namespace sqlite -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::ddb diff --git a/src/duckdb/duckdb_tables_schema_batch_reader.h b/src/duckdb/duckdb_tables_schema_batch_reader.h index f3d2d83..69c382b 100644 --- a/src/duckdb/duckdb_tables_schema_batch_reader.h +++ b/src/duckdb/duckdb_tables_schema_batch_reader.h @@ -26,13 +26,9 @@ #include "duckdb_statement_batch_reader.h" #include "arrow/record_batch.h" -// clang-format off -namespace arrow { -namespace flight { -namespace sql { -namespace duckdbflight { +namespace sqlflite::ddb { -class DuckDBTablesWithSchemaBatchReader : public RecordBatchReader { +class DuckDBTablesWithSchemaBatchReader : public arrow::RecordBatchReader { private: std::shared_ptr reader_; std::string main_query_; @@ -44,17 +40,17 @@ class DuckDBTablesWithSchemaBatchReader : public RecordBatchReader { /// \param reader an shared_ptr from a DuckDBStatementBatchReader. /// \param main_query SQL query that originated reader's data. /// \param db a pointer to the sqlite3 db. - DuckDBTablesWithSchemaBatchReader( - std::shared_ptr reader, std::string main_query, - std::shared_ptr db_conn) - : reader_(std::move(reader)), main_query_(std::move(main_query)), db_conn_(db_conn), already_executed_(false) {} + DuckDBTablesWithSchemaBatchReader(std::shared_ptr reader, + std::string main_query, + std::shared_ptr db_conn) + : reader_(std::move(reader)), + main_query_(std::move(main_query)), + db_conn_(db_conn), + already_executed_(false) {} - std::shared_ptr schema() const override; + std::shared_ptr schema() const override; - Status ReadNext(std::shared_ptr* batch) override; + arrow::Status ReadNext(std::shared_ptr* batch) override; }; -} // namespace duckdb -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::ddb diff --git a/src/library/include/flight_sql_fwd.h b/src/library/include/flight_sql_fwd.h new file mode 100644 index 0000000..4e5557e --- /dev/null +++ b/src/library/include/flight_sql_fwd.h @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +namespace arrow::flight {} + +namespace flight = arrow::flight; diff --git a/src/library/include/sqlflite_library.h b/src/library/include/sqlflite_library.h index b7c911e..e189902 100644 --- a/src/library/include/sqlflite_library.h +++ b/src/library/include/sqlflite_library.h @@ -1,4 +1,20 @@ -// sqlflite_library.h +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + #pragma once #include @@ -12,8 +28,6 @@ const int DEFAULT_FLIGHT_PORT = 31337; enum class BackendType { duckdb, sqlite }; -namespace fs = std::filesystem; - /** * @brief Run a SQLFlite Server with the specified configuration. * @@ -41,13 +55,14 @@ namespace fs = std::filesystem; */ extern "C" { -int RunFlightSQLServer(const BackendType backend, fs::path &database_filename, - std::string hostname = "", const int &port = DEFAULT_FLIGHT_PORT, - std::string username = "", std::string password = "", - std::string secret_key = "", fs::path tls_cert_path = fs::path(), - fs::path tls_key_path = fs::path(), - fs::path mtls_ca_cert_path = fs::path(), - std::string init_sql_commands = "", - fs::path init_sql_commands_file = fs::path(), - const bool &print_queries = false); +int RunFlightSQLServer( + const BackendType backend, std::filesystem::path &database_filename, + std::string hostname = "", const int &port = DEFAULT_FLIGHT_PORT, + std::string username = "", std::string password = "", std::string secret_key = "", + std::filesystem::path tls_cert_path = std::filesystem::path(), + std::filesystem::path tls_key_path = std::filesystem::path(), + std::filesystem::path mtls_ca_cert_path = std::filesystem::path(), + std::string init_sql_commands = "", + std::filesystem::path init_sql_commands_file = std::filesystem::path(), + const bool &print_queries = false); } diff --git a/src/library/include/sqlflite_security.h b/src/library/include/sqlflite_security.h index d4d70e2..f6612c4 100644 --- a/src/library/include/sqlflite_security.h +++ b/src/library/include/sqlflite_security.h @@ -1,6 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at // -// Created by Philip Moore on 11/14/22. +// http://www.apache.org/licenses/LICENSE-2.0 // +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + #include #include #include @@ -14,38 +28,37 @@ #include #include #include +#include "flight_sql_fwd.h" -namespace fs = std::filesystem; - -namespace arrow { -namespace flight { +namespace sqlflite { class SecurityUtilities { public: - static Status FlightServerTlsCertificates(const fs::path &cert_path, - const fs::path &key_path, - std::vector *out); + static arrow::Status FlightServerTlsCertificates(const std::filesystem::path &cert_path, + const std::filesystem::path &key_path, + std::vector *out); - static Status FlightServerMtlsCACertificate(const std::string &cert_path, - std::string *out); + static arrow::Status FlightServerMtlsCACertificate(const std::string &cert_path, + std::string *out); - static std::string FindKeyValPrefixInCallHeaders(const CallHeaders &incoming_headers, - const std::string &key, - const std::string &prefix); + static std::string FindKeyValPrefixInCallHeaders( + const flight::CallHeaders &incoming_headers, const std::string &key, + const std::string &prefix); - static Status GetAuthHeaderType(const CallHeaders &incoming_headers, std::string *out); + static arrow::Status GetAuthHeaderType(const flight::CallHeaders &incoming_headers, + std::string *out); - static void ParseBasicHeader(const CallHeaders &incoming_headers, std::string &username, - std::string &password); + static void ParseBasicHeader(const flight::CallHeaders &incoming_headers, + std::string &username, std::string &password); }; -class HeaderAuthServerMiddleware : public ServerMiddleware { +class HeaderAuthServerMiddleware : public flight::ServerMiddleware { public: HeaderAuthServerMiddleware(const std::string &username, const std::string &secret_key); - void SendingHeaders(AddCallHeaders *outgoing_headers) override; + void SendingHeaders(flight::AddCallHeaders *outgoing_headers) override; - void CallCompleted(const Status &status) override; + void CallCompleted(const arrow::Status &status) override; std::string name() const override; @@ -56,14 +69,15 @@ class HeaderAuthServerMiddleware : public ServerMiddleware { std::string CreateJWTToken() const; }; -class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory { +class HeaderAuthServerMiddlewareFactory : public flight::ServerMiddlewareFactory { public: HeaderAuthServerMiddlewareFactory(const std::string &username, const std::string &password, const std::string &secret_key); - Status StartCall(const CallInfo &info, const CallHeaders &incoming_headers, - std::shared_ptr *middleware) override; + arrow::Status StartCall(const flight::CallInfo &info, + const flight::CallHeaders &incoming_headers, + std::shared_ptr *middleware) override; private: std::string username_; @@ -71,32 +85,33 @@ class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory { std::string secret_key_; }; -class BearerAuthServerMiddleware : public ServerMiddleware { +class BearerAuthServerMiddleware : public flight::ServerMiddleware { public: explicit BearerAuthServerMiddleware(const std::string &secret_key, - const CallHeaders &incoming_headers, + const flight::CallHeaders &incoming_headers, std::optional *isValid); - void SendingHeaders(AddCallHeaders *outgoing_headers) override; + void SendingHeaders(flight::AddCallHeaders *outgoing_headers) override; - void CallCompleted(const Status &status) override; + void CallCompleted(const arrow::Status &status) override; std::string name() const override; private: std::string secret_key_; - CallHeaders incoming_headers_; + flight::CallHeaders incoming_headers_; std::optional *isValid_; bool VerifyToken(const std::string &token) const; }; -class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory { +class BearerAuthServerMiddlewareFactory : public flight::ServerMiddlewareFactory { public: explicit BearerAuthServerMiddlewareFactory(const std::string &secret_key); - Status StartCall(const CallInfo &info, const CallHeaders &incoming_headers, - std::shared_ptr *middleware) override; + arrow::Status StartCall(const flight::CallInfo &info, + const flight::CallHeaders &incoming_headers, + std::shared_ptr *middleware) override; std::optional GetIsValid(); @@ -105,5 +120,4 @@ class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory { std::string secret_key_; }; -} // namespace flight -} // namespace arrow +} // namespace sqlflite diff --git a/src/library/sqlflite_library.cpp b/src/library/sqlflite_library.cpp index d9dbc4c..355ad25 100644 --- a/src/library/sqlflite_library.cpp +++ b/src/library/sqlflite_library.cpp @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + #include "include/sqlflite_library.h" #include @@ -16,10 +33,12 @@ #include "sqlite_server.h" #include "duckdb_server.h" +#include "include/flight_sql_fwd.h" #include "include/sqlflite_security.h" -namespace flight = arrow::flight; -namespace flightsql = arrow::flight::sql; +namespace fs = std::filesystem; + +namespace sqlflite { const int port = 31337; @@ -41,24 +60,24 @@ const int port = 31337; } \ } while (false) -arrow::Result> -FlightSQLServerBuilder(const BackendType backend, const fs::path &database_filename, - const std::string &hostname, const int &port, - const std::string &username, const std::string &password, - const std::string &secret_key, const fs::path &tls_cert_path, - const fs::path &tls_key_path, const fs::path &mtls_ca_cert_path, - const std::string &init_sql_commands, const bool &print_queries) { +arrow::Result> FlightSQLServerBuilder( + const BackendType backend, const fs::path &database_filename, + const std::string &hostname, const int &port, const std::string &username, + const std::string &password, const std::string &secret_key, + const fs::path &tls_cert_path, const fs::path &tls_key_path, + const fs::path &mtls_ca_cert_path, const std::string &init_sql_commands, + const bool &print_queries) { ARROW_ASSIGN_OR_RAISE(auto location, (!tls_cert_path.empty()) - ? arrow::flight::Location::ForGrpcTls(hostname, port) - : arrow::flight::Location::ForGrpcTcp(hostname, port)); + ? flight::Location::ForGrpcTls(hostname, port) + : flight::Location::ForGrpcTcp(hostname, port)); std::cout << "Apache Arrow version: " << ARROW_VERSION_STRING << std::endl; - arrow::flight::FlightServerOptions options(location); + flight::FlightServerOptions options(location); if (!tls_cert_path.empty() && !tls_key_path.empty()) { - ARROW_CHECK_OK(arrow::flight::SecurityUtilities::FlightServerTlsCertificates( + ARROW_CHECK_OK(sqlflite::SecurityUtilities::FlightServerTlsCertificates( tls_cert_path, tls_key_path, &options.tls_certificates)); } else { std::cout << "WARNING - TLS is disabled for the SQLFlite server - this is insecure." @@ -66,43 +85,38 @@ FlightSQLServerBuilder(const BackendType backend, const fs::path &database_filen } // Setup authentication middleware (using the same TLS certificate keypair) - auto header_middleware = - std::make_shared( - username, password, secret_key); + auto header_middleware = std::make_shared( + username, password, secret_key); auto bearer_middleware = - std::make_shared(secret_key); + std::make_shared(secret_key); - options.auth_handler = std::make_unique(); + options.auth_handler = std::make_unique(); options.middleware.push_back({"header-auth-server", header_middleware}); options.middleware.push_back({"bearer-auth-server", bearer_middleware}); if (!mtls_ca_cert_path.empty()) { std::cout << "Using mTLS CA certificate: " << mtls_ca_cert_path << std::endl; - ARROW_CHECK_OK(arrow::flight::SecurityUtilities::FlightServerMtlsCACertificate( + ARROW_CHECK_OK(sqlflite::SecurityUtilities::FlightServerMtlsCACertificate( mtls_ca_cert_path, &options.root_certificates)); options.verify_client = true; } - std::shared_ptr server = nullptr; + std::shared_ptr server = nullptr; std::string db_type = ""; if (backend == BackendType::sqlite) { db_type = "SQLite"; - std::shared_ptr sqlite_server = - nullptr; - ARROW_ASSIGN_OR_RAISE( - sqlite_server, - arrow::flight::sql::sqlite::SQLiteFlightSqlServer::Create(database_filename)) + std::shared_ptr sqlite_server = nullptr; + ARROW_ASSIGN_OR_RAISE(sqlite_server, sqlflite::sqlite::SQLiteFlightSqlServer::Create( + database_filename)); RUN_INIT_COMMANDS(sqlite_server, init_sql_commands); server = sqlite_server; } else if (backend == BackendType::duckdb) { db_type = "DuckDB"; - std::shared_ptr - duckdb_server = nullptr; + std::shared_ptr duckdb_server = nullptr; duckdb::DBConfig config; - ARROW_ASSIGN_OR_RAISE(duckdb_server, - arrow::flight::sql::duckdbflight::DuckDBFlightSqlServer::Create( - database_filename, config, print_queries)) + ARROW_ASSIGN_OR_RAISE(duckdb_server, sqlflite::ddb::DuckDBFlightSqlServer::Create( + database_filename, config, print_queries)) // Run additional commands (first) for the DuckDB back-end... auto duckdb_init_sql_commands = "SET autoinstall_known_extensions = true; SET autoload_known_extensions = true;" + @@ -142,13 +156,12 @@ std::string SafeGetEnvVarValue(const std::string &env_var_name) { } } -arrow::Result> -CreateFlightSQLServer(const BackendType backend, fs::path &database_filename, - std::string hostname, const int &port, std::string username, - std::string password, std::string secret_key, - fs::path tls_cert_path, fs::path tls_key_path, - fs::path mtls_ca_cert_path, std::string init_sql_commands, - fs::path init_sql_commands_file, const bool &print_queries) { +arrow::Result> CreateFlightSQLServer( + const BackendType backend, fs::path &database_filename, std::string hostname, + const int &port, std::string username, std::string password, std::string secret_key, + fs::path tls_cert_path, fs::path tls_key_path, fs::path mtls_ca_cert_path, + std::string init_sql_commands, fs::path init_sql_commands_file, + const bool &print_queries) { // Validate and default the arguments to env var values where applicable if (database_filename.empty()) { return arrow::Status::Invalid("The database filename was not provided!"); @@ -242,17 +255,21 @@ CreateFlightSQLServer(const BackendType backend, fs::path &database_filename, } arrow::Status StartFlightSQLServer( - std::shared_ptr server) { + std::shared_ptr server) { return arrow::Status::OK(); } +} // namespace sqlflite + +extern "C" { + int RunFlightSQLServer(const BackendType backend, fs::path &database_filename, std::string hostname, const int &port, std::string username, std::string password, std::string secret_key, fs::path tls_cert_path, fs::path tls_key_path, fs::path mtls_ca_cert_path, std::string init_sql_commands, fs::path init_sql_commands_file, const bool &print_queries) { - auto create_server_result = CreateFlightSQLServer( + auto create_server_result = sqlflite::CreateFlightSQLServer( backend, database_filename, hostname, port, username, password, secret_key, tls_cert_path, tls_key_path, mtls_ca_cert_path, init_sql_commands, init_sql_commands_file, print_queries); @@ -268,3 +285,4 @@ int RunFlightSQLServer(const BackendType backend, fs::path &database_filename, return EXIT_FAILURE; } } +} diff --git a/src/library/sqlflite_security.cpp b/src/library/sqlflite_security.cpp index a03c376..c8760dc 100644 --- a/src/library/sqlflite_security.cpp +++ b/src/library/sqlflite_security.cpp @@ -1,12 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at // -// Created by Philip Moore on 11/14/22. +// http://www.apache.org/licenses/LICENSE-2.0 // +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + #include "include/sqlflite_security.h" namespace fs = std::filesystem; -namespace arrow { -namespace flight { +using arrow::Status; + +namespace sqlflite { const std::string kJWTIssuer = "sqlflite"; const int kJWTExpiration = 24 * 3600; @@ -16,13 +31,13 @@ const std::string kBearerPrefix = "Bearer "; const std::string kAuthHeader = "authorization"; // ---------------------------------------- -Status SecurityUtilities::FlightServerTlsCertificates(const fs::path &cert_path, - const fs::path &key_path, - std::vector *out) { +Status SecurityUtilities::FlightServerTlsCertificates( + const fs::path &cert_path, const fs::path &key_path, + std::vector *out) { std::cout << "Using TLS Cert file: " << cert_path << std::endl; std::cout << "Using TLS Key file: " << key_path << std::endl; - *out = std::vector(); + *out = std::vector(); try { std::ifstream cert_file(cert_path); if (!cert_file) { @@ -38,7 +53,7 @@ Status SecurityUtilities::FlightServerTlsCertificates(const fs::path &cert_path, std::stringstream key; key << key_file.rdbuf(); - out->push_back(CertKeyPair{cert.str(), key.str()}); + out->push_back(flight::CertKeyPair{cert.str(), key.str()}); } catch (const std::ifstream::failure &e) { return Status::IOError(e.what()); } @@ -65,7 +80,7 @@ Status SecurityUtilities::FlightServerMtlsCACertificate(const std::string &cert_ // Function to look in CallHeaders for a key that has a value starting with prefix and // return the rest of the value after the prefix. std::string SecurityUtilities::FindKeyValPrefixInCallHeaders( - const CallHeaders &incoming_headers, const std::string &key, + const flight::CallHeaders &incoming_headers, const std::string &key, const std::string &prefix) { // Lambda function to compare characters without case sensitivity. auto char_compare = [](const char &char1, const char &char2) { @@ -86,7 +101,7 @@ std::string SecurityUtilities::FindKeyValPrefixInCallHeaders( return ""; } -Status SecurityUtilities::GetAuthHeaderType(const CallHeaders &incoming_headers, +Status SecurityUtilities::GetAuthHeaderType(const flight::CallHeaders &incoming_headers, std::string *out) { if (!FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix) .empty()) { @@ -100,7 +115,7 @@ Status SecurityUtilities::GetAuthHeaderType(const CallHeaders &incoming_headers, return Status::OK(); } -void SecurityUtilities::ParseBasicHeader(const CallHeaders &incoming_headers, +void SecurityUtilities::ParseBasicHeader(const flight::CallHeaders &incoming_headers, std::string &username, std::string &password) { std::string encoded_credentials = FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix); @@ -114,7 +129,8 @@ HeaderAuthServerMiddleware::HeaderAuthServerMiddleware(const std::string &userna const std::string &secret_key) : username_(username), secret_key_(secret_key) {} -void HeaderAuthServerMiddleware::SendingHeaders(AddCallHeaders *outgoing_headers) { +void HeaderAuthServerMiddleware::SendingHeaders( + flight::AddCallHeaders *outgoing_headers) { auto token = CreateJWTToken(); outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + token); } @@ -147,8 +163,8 @@ HeaderAuthServerMiddlewareFactory::HeaderAuthServerMiddlewareFactory( : username_(username), password_(password), secret_key_(secret_key) {} Status HeaderAuthServerMiddlewareFactory::StartCall( - const CallInfo &info, const CallHeaders &incoming_headers, - std::shared_ptr *middleware) { + const flight::CallInfo &info, const flight::CallHeaders &incoming_headers, + std::shared_ptr *middleware) { std::string auth_header_type; ARROW_RETURN_NOT_OK( SecurityUtilities::GetAuthHeaderType(incoming_headers, &auth_header_type)); @@ -161,7 +177,8 @@ Status HeaderAuthServerMiddlewareFactory::StartCall( if ((username == username_) && (password == password_)) { *middleware = std::make_shared(username, secret_key_); } else { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid credentials"); + return MakeFlightError(flight::FlightStatusCode::Unauthenticated, + "Invalid credentials"); } } return Status::OK(); @@ -169,11 +186,12 @@ Status HeaderAuthServerMiddlewareFactory::StartCall( // ---------------------------------------- BearerAuthServerMiddleware::BearerAuthServerMiddleware( - const std::string &secret_key, const CallHeaders &incoming_headers, + const std::string &secret_key, const flight::CallHeaders &incoming_headers, std::optional *isValid) : secret_key_(secret_key), incoming_headers_(incoming_headers), isValid_(isValid) {} -void BearerAuthServerMiddleware::SendingHeaders(AddCallHeaders *outgoing_headers) { +void BearerAuthServerMiddleware::SendingHeaders( + flight::AddCallHeaders *outgoing_headers) { std::string bearer_token = SecurityUtilities::FindKeyValPrefixInCallHeaders( incoming_headers_, kAuthHeader, kBearerPrefix); *isValid_ = (VerifyToken(bearer_token)); @@ -211,10 +229,11 @@ BearerAuthServerMiddlewareFactory::BearerAuthServerMiddlewareFactory( : secret_key_(secret_key) {} Status BearerAuthServerMiddlewareFactory::StartCall( - const CallInfo &info, const CallHeaders &incoming_headers, - std::shared_ptr *middleware) { - if (const std::pair - &iter_pair = incoming_headers.equal_range(kAuthHeader); + const flight::CallInfo &info, const flight::CallHeaders &incoming_headers, + std::shared_ptr *middleware) { + if (const std::pair &iter_pair = + incoming_headers.equal_range(kAuthHeader); iter_pair.first != iter_pair.second) { std::string auth_header_type; ARROW_RETURN_NOT_OK( @@ -227,7 +246,7 @@ Status BearerAuthServerMiddlewareFactory::StartCall( if (isValid_.has_value() && !*isValid_) { isValid_.reset(); - return MakeFlightError(FlightStatusCode::Unauthenticated, + return MakeFlightError(flight::FlightStatusCode::Unauthenticated, "Invalid bearer token provided"); } @@ -236,5 +255,4 @@ Status BearerAuthServerMiddlewareFactory::StartCall( std::optional BearerAuthServerMiddlewareFactory::GetIsValid() { return isValid_; } -} // namespace flight -} // namespace arrow +} // namespace sqlflite diff --git a/src/sqlflite_client.cpp b/src/sqlflite_client.cpp index b97363d..5cbf83d 100644 --- a/src/sqlflite_client.cpp +++ b/src/sqlflite_client.cpp @@ -33,21 +33,11 @@ #include "arrow/status.h" #include "arrow/table.h" -using arrow::Result; -using arrow::Schema; +#include "library/include/flight_sql_fwd.h" + using arrow::Status; -using arrow::flight::ClientAuthHandler; -using arrow::flight::FlightCallOptions; -using arrow::flight::FlightClient; -using arrow::flight::FlightDescriptor; -using arrow::flight::FlightEndpoint; -using arrow::flight::FlightInfo; -using arrow::flight::FlightStreamChunk; -using arrow::flight::FlightStreamReader; -using arrow::flight::Location; -using arrow::flight::Ticket; -using arrow::flight::sql::FlightSqlClient; -using arrow::flight::sql::TableRef; + +namespace sqlflite { DEFINE_string(host, "localhost", "Host to connect to"); DEFINE_int32(port, 31337, "Port to connect to"); @@ -69,12 +59,12 @@ DEFINE_string(catalog, "", "Catalog"); DEFINE_string(schema, "", "Schema"); DEFINE_string(table, "", "Table"); -Status PrintResultsForEndpoint(FlightSqlClient &client, - const FlightCallOptions &call_options, - const FlightEndpoint &endpoint) { +Status PrintResultsForEndpoint(flight::sql::FlightSqlClient &client, + const flight::FlightCallOptions &call_options, + const flight::FlightEndpoint &endpoint) { ARROW_ASSIGN_OR_RAISE(auto stream, client.DoGet(call_options, endpoint.ticket)); - const arrow::Result> &schema = stream->GetSchema(); + const arrow::Result> &schema = stream->GetSchema(); ARROW_RETURN_NOT_OK(schema); std::cout << "Schema:" << std::endl; @@ -85,7 +75,7 @@ Status PrintResultsForEndpoint(FlightSqlClient &client, int64_t num_rows = 0; while (true) { - ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, stream->Next()); + ARROW_ASSIGN_OR_RAISE(flight::FlightStreamChunk chunk, stream->Next()); if (chunk.data == nullptr) { break; } @@ -98,9 +88,10 @@ Status PrintResultsForEndpoint(FlightSqlClient &client, return Status::OK(); } -Status PrintResults(FlightSqlClient &client, const FlightCallOptions &call_options, - const std::unique_ptr &info) { - const std::vector &endpoints = info->endpoints(); +Status PrintResults(flight::sql::FlightSqlClient &client, + const flight::FlightCallOptions &call_options, + const std::unique_ptr &info) { + const std::vector &endpoints = info->endpoints(); for (size_t i = 0; i < endpoints.size(); i++) { std::cout << "Results from endpoint " << i + 1 << " of " << endpoints.size() @@ -127,11 +118,12 @@ Status getPEMCertFileContents(const std::string &cert_file_path, Status RunMain() { ARROW_ASSIGN_OR_RAISE(auto location, - (FLAGS_use_tls) ? Location::ForGrpcTls(FLAGS_host, FLAGS_port) - : Location::ForGrpcTcp(FLAGS_host, FLAGS_port)); + (FLAGS_use_tls) + ? flight::Location::ForGrpcTls(FLAGS_host, FLAGS_port) + : flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port)); // Setup our options - arrow::flight::FlightClientOptions options; + flight::FlightClientOptions options; if (!FLAGS_tls_roots.empty()) { ARROW_RETURN_NOT_OK(getPEMCertFileContents(FLAGS_tls_roots, options.tls_root_certs)); @@ -152,19 +144,19 @@ Status RunMain() { } } - ARROW_ASSIGN_OR_RAISE(auto client, FlightClient::Connect(location, options)); + ARROW_ASSIGN_OR_RAISE(auto client, flight::FlightClient::Connect(location, options)); - FlightCallOptions call_options; + flight::FlightCallOptions call_options; if (!FLAGS_username.empty() || !FLAGS_password.empty()) { - Result> bearer_result = + arrow::Result> bearer_result = client->AuthenticateBasicToken({}, FLAGS_username, FLAGS_password); ARROW_RETURN_NOT_OK(bearer_result); call_options.headers.push_back(bearer_result.ValueOrDie()); } - FlightSqlClient sql_client(std::move(client)); + flight::sql::FlightSqlClient sql_client(std::move(client)); if (FLAGS_command == "ExecuteUpdate") { ARROW_ASSIGN_OR_RAISE(auto rows, sql_client.ExecuteUpdate(call_options, FLAGS_query)); @@ -174,7 +166,7 @@ Status RunMain() { return Status::OK(); } - std::unique_ptr info; + std::unique_ptr info; std::shared_ptr prepared_statement; @@ -211,16 +203,16 @@ Status RunMain() { info, sql_client.GetTables(call_options, &FLAGS_catalog, &FLAGS_schema, &FLAGS_table, false, nullptr)); } else if (FLAGS_command == "GetExportedKeys") { - TableRef table_ref = {std::make_optional(FLAGS_catalog), - std::make_optional(FLAGS_schema), FLAGS_table}; + flight::sql::TableRef table_ref = {std::make_optional(FLAGS_catalog), + std::make_optional(FLAGS_schema), FLAGS_table}; ARROW_ASSIGN_OR_RAISE(info, sql_client.GetExportedKeys(call_options, table_ref)); } else if (FLAGS_command == "GetImportedKeys") { - TableRef table_ref = {std::make_optional(FLAGS_catalog), - std::make_optional(FLAGS_schema), FLAGS_table}; + flight::sql::TableRef table_ref = {std::make_optional(FLAGS_catalog), + std::make_optional(FLAGS_schema), FLAGS_table}; ARROW_ASSIGN_OR_RAISE(info, sql_client.GetImportedKeys(call_options, table_ref)); } else if (FLAGS_command == "GetPrimaryKeys") { - TableRef table_ref = {std::make_optional(FLAGS_catalog), - std::make_optional(FLAGS_schema), FLAGS_table}; + flight::sql::TableRef table_ref = {std::make_optional(FLAGS_catalog), + std::make_optional(FLAGS_schema), FLAGS_table}; ARROW_ASSIGN_OR_RAISE(info, sql_client.GetPrimaryKeys(call_options, table_ref)); } else if (FLAGS_command == "GetSqlInfo") { ARROW_ASSIGN_OR_RAISE(info, sql_client.GetSqlInfo(call_options, {})); @@ -238,10 +230,12 @@ Status RunMain() { return print_status; } +} // namespace sqlflite + int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); - Status st = RunMain(); + Status st = sqlflite::RunMain(); if (!st.ok()) { std::cerr << st << std::endl; return 1; diff --git a/src/sqlflite_server.cpp b/src/sqlflite_server.cpp index 5809b74..04337c3 100644 --- a/src/sqlflite_server.cpp +++ b/src/sqlflite_server.cpp @@ -1,8 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + #include "library/include/sqlflite_library.h" #include #include namespace po = boost::program_options; +namespace fs = std::filesystem; int main(int argc, char **argv) { std::vector tls_token_values; diff --git a/src/sqlite/sqlite_server.cc b/src/sqlite/sqlite_server.cc index ed4f57b..dfccd4f 100644 --- a/src/sqlite/sqlite_server.cc +++ b/src/sqlite/sqlite_server.cc @@ -36,19 +36,17 @@ #include "sqlite_type_info.h" #include "arrow/flight/sql/server.h" #include "arrow/scalar.h" -#include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" -namespace arrow { -namespace flight { -namespace sql { -namespace sqlite { +#include "flight_sql_fwd.h" -using arrow::internal::checked_cast; +using arrow::Status; + +namespace sqlflite::sqlite { namespace { -std::string PrepareQueryForGetTables(const GetTables& command) { +std::string PrepareQueryForGetTables(const flight::sql::GetTables& command) { std::stringstream table_query; table_query << "SELECT 'main' as catalog_name, null as schema_name, name as " @@ -87,10 +85,11 @@ std::string PrepareQueryForGetTables(const GetTables& command) { template Status SetParametersOnSQLiteStatement(SqliteStatement* statement, - FlightMessageReader* reader, Callback callback) { + flight::FlightMessageReader* reader, + Callback callback) { sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); while (true) { - ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, reader->Next()); + ARROW_ASSIGN_OR_RAISE(flight::FlightStreamChunk chunk, reader->Next()); if (chunk.data == nullptr) break; const int64_t num_rows = chunk.data->num_rows(); @@ -112,8 +111,8 @@ Status SetParametersOnSQLiteStatement(SqliteStatement* statement, return Status::OK(); } -arrow::Result> DoGetSQLiteQuery( - sqlite3* db, const std::string& query, const std::shared_ptr& schema) { +arrow::Result> DoGetSQLiteQuery( + sqlite3* db, const std::string& query, const std::shared_ptr& schema) { std::shared_ptr statement; ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db, query)); @@ -121,17 +120,18 @@ arrow::Result> DoGetSQLiteQuery( std::shared_ptr reader; ARROW_ASSIGN_OR_RAISE(reader, SqliteStatementBatchReader::Create(statement, schema)); - return std::make_unique(reader); + return std::make_unique(reader); } -arrow::Result> GetFlightInfoForCommand( - const FlightDescriptor& descriptor, const std::shared_ptr& schema) { - std::vector endpoints{ - FlightEndpoint{{descriptor.cmd}, {}, std::nullopt, ""}}; - ARROW_ASSIGN_OR_RAISE(auto result, - FlightInfo::Make(*schema, descriptor, endpoints, -1, -1, false)) +arrow::Result> GetFlightInfoForCommand( + const flight::FlightDescriptor& descriptor, + const std::shared_ptr& schema) { + std::vector endpoints{ + flight::FlightEndpoint{{descriptor.cmd}, {}, std::nullopt, ""}}; + ARROW_ASSIGN_OR_RAISE(auto result, flight::FlightInfo::Make(*schema, descriptor, + endpoints, -1, -1, false)) - return std::make_unique(result); + return std::make_unique(result); } std::string PrepareQueryForGetImportedOrExportedKeys(const std::string& filter) { @@ -169,22 +169,22 @@ std::string PrepareQueryForGetImportedOrExportedKeys(const std::string& filter) } // namespace -arrow::Result> GetArrowType(const char* sqlite_type) { +arrow::Result> GetArrowType(const char* sqlite_type) { if (sqlite_type == nullptr || std::strlen(sqlite_type) == 0) { // SQLite may not know the column type yet. - return null(); + return arrow::null(); } if (boost::iequals(sqlite_type, "int") || boost::iequals(sqlite_type, "integer")) { - return int64(); + return arrow::int64(); } else if (boost::iequals(sqlite_type, "REAL")) { - return float64(); + return arrow::float64(); } else if (boost::iequals(sqlite_type, "BLOB")) { - return binary(); + return arrow::binary(); } else if (boost::iequals(sqlite_type, "TEXT") || boost::iequals(sqlite_type, "DATE") || boost::istarts_with(sqlite_type, "char") || boost::istarts_with(sqlite_type, "varchar")) { - return utf8(); + return arrow::utf8(); } return Status::Invalid("Invalid SQLite type: ", sqlite_type); } @@ -245,14 +245,14 @@ class SQLiteFlightSqlServer::Impl { } // Create a Ticket that combines a query and a transaction ID. - arrow::Result EncodeTransactionQuery(const std::string& query, - const std::string& transaction_id) { + arrow::Result EncodeTransactionQuery( + const std::string& query, const std::string& transaction_id) { std::string transaction_query = transaction_id; transaction_query += ':'; transaction_query += query; ARROW_ASSIGN_OR_RAISE(auto ticket_string, - CreateStatementQueryTicket(transaction_query)); - return Ticket{std::move(ticket_string)}; + flight::sql::CreateStatementQueryTicket(transaction_query)); + return flight::Ticket{std::move(ticket_string)}; } arrow::Result> DecodeTransactionQuery( @@ -296,28 +296,31 @@ class SQLiteFlightSqlServer::Impl { return ret; } - arrow::Result> GetFlightInfoStatement( - const ServerCallContext& context, const StatementQuery& command, - const FlightDescriptor& descriptor) { + arrow::Result> GetFlightInfoStatement( + const flight::ServerCallContext& context, + const flight::sql::StatementQuery& command, + const flight::FlightDescriptor& descriptor) { const std::string& query = command.query; ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(command.transaction_id)); ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db, query)); ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); ARROW_ASSIGN_OR_RAISE(auto ticket, EncodeTransactionQuery(query, command.transaction_id)); - std::vector endpoints{ - FlightEndpoint{std::move(ticket), {}, std::nullopt, ""}}; + std::vector endpoints{ + flight::FlightEndpoint{std::move(ticket), {}, std::nullopt, ""}}; // TODO: Set true only when "ORDER BY" is used in a main "SELECT" // in the given query. const bool ordered = false; ARROW_ASSIGN_OR_RAISE( - auto result, FlightInfo::Make(*schema, descriptor, endpoints, -1, -1, ordered)); + auto result, + flight::FlightInfo::Make(*schema, descriptor, endpoints, -1, -1, ordered)); - return std::make_unique(result); + return std::make_unique(result); } - arrow::Result> DoGetStatement( - const ServerCallContext& context, const StatementQueryTicket& command) { + arrow::Result> DoGetStatement( + const flight::ServerCallContext& context, + const flight::sql::StatementQueryTicket& command) { ARROW_ASSIGN_OR_RAISE(auto pair, DecodeTransactionQuery(command.statement_handle)); const std::string& sql = pair.first; const std::string transaction_id = pair.second; @@ -329,45 +332,51 @@ class SQLiteFlightSqlServer::Impl { std::shared_ptr reader; ARROW_ASSIGN_OR_RAISE(reader, SqliteStatementBatchReader::Create(statement)); - return std::make_unique(reader); + return std::make_unique(reader); } - arrow::Result> GetFlightInfoCatalogs( - const ServerCallContext& context, const FlightDescriptor& descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetCatalogsSchema()); + arrow::Result> GetFlightInfoCatalogs( + const flight::ServerCallContext& context, + const flight::FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, + flight::sql::SqlSchema::GetCatalogsSchema()); } - arrow::Result> DoGetCatalogs( - const ServerCallContext& context) { + arrow::Result> DoGetCatalogs( + const flight::ServerCallContext& context) { // https://www.sqlite.org/cli.html // > The ".databases" command shows a list of all databases open // > in the current connection. There will always be at least // > 2. The first one is "main", the original database opened. The // > second is "temp", the database used for temporary tables. // For our purposes, return only "main" and ignore other databases. - const std::shared_ptr& schema = SqlSchema::GetCatalogsSchema(); - StringBuilder catalog_name_builder; + const std::shared_ptr& schema = + flight::sql::SqlSchema::GetCatalogsSchema(); + arrow::StringBuilder catalog_name_builder; ARROW_RETURN_NOT_OK(catalog_name_builder.Append("main")); ARROW_ASSIGN_OR_RAISE(auto catalog_name, catalog_name_builder.Finish()); - std::shared_ptr batch = - RecordBatch::Make(schema, 1, {std::move(catalog_name)}); - ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch})); - return std::make_unique(reader); + std::shared_ptr batch = + arrow::RecordBatch::Make(schema, 1, {std::move(catalog_name)}); + ARROW_ASSIGN_OR_RAISE(auto reader, arrow::RecordBatchReader::Make({batch})); + return std::make_unique(reader); } - arrow::Result> GetFlightInfoSchemas( - const ServerCallContext& context, const GetDbSchemas& command, - const FlightDescriptor& descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetDbSchemasSchema()); + arrow::Result> GetFlightInfoSchemas( + const flight::ServerCallContext& context, const flight::sql::GetDbSchemas& command, + const flight::FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, + flight::sql::SqlSchema::GetDbSchemasSchema()); } - arrow::Result> DoGetDbSchemas( - const ServerCallContext& context, const GetDbSchemas& command) { + arrow::Result> DoGetDbSchemas( + const flight::ServerCallContext& context, + const flight::sql::GetDbSchemas& command) { // SQLite doesn't support schemas, so pretend we have a single // unnamed schema. - const std::shared_ptr& schema = SqlSchema::GetDbSchemasSchema(); - StringBuilder catalog_name_builder; - StringBuilder schema_name_builder; + const std::shared_ptr& schema = + flight::sql::SqlSchema::GetDbSchemasSchema(); + arrow::StringBuilder catalog_name_builder; + arrow::StringBuilder schema_name_builder; int64_t length = 0; // XXX: we don't really implement the full pattern match here @@ -380,32 +389,33 @@ class SQLiteFlightSqlServer::Impl { ARROW_ASSIGN_OR_RAISE(auto catalog_name, catalog_name_builder.Finish()); ARROW_ASSIGN_OR_RAISE(auto schema_name, schema_name_builder.Finish()); - std::shared_ptr batch = RecordBatch::Make( + std::shared_ptr batch = arrow::RecordBatch::Make( schema, length, {std::move(catalog_name), std::move(schema_name)}); - ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch})); - return std::make_unique(reader); + ARROW_ASSIGN_OR_RAISE(auto reader, arrow::RecordBatchReader::Make({batch})); + return std::make_unique(reader); } - arrow::Result> GetFlightInfoTables( - const ServerCallContext& context, const GetTables& command, - const FlightDescriptor& descriptor) { - std::vector endpoints{ - FlightEndpoint{{descriptor.cmd}, {}, std::nullopt, ""}}; + arrow::Result> GetFlightInfoTables( + const flight::ServerCallContext& context, const flight::sql::GetTables& command, + const flight::FlightDescriptor& descriptor) { + std::vector endpoints{ + flight::FlightEndpoint{{descriptor.cmd}, {}, std::nullopt, ""}}; bool include_schema = command.include_schema; ARROW_LOG(INFO) << "GetTables include_schema=" << include_schema; ARROW_ASSIGN_OR_RAISE( auto result, - FlightInfo::Make(include_schema ? *SqlSchema::GetTablesSchemaWithIncludedSchema() - : *SqlSchema::GetTablesSchema(), - descriptor, endpoints, -1, -1, false)) + flight::FlightInfo::Make( + include_schema ? *flight::sql::SqlSchema::GetTablesSchemaWithIncludedSchema() + : *flight::sql::SqlSchema::GetTablesSchema(), + descriptor, endpoints, -1, -1, false)) - return std::make_unique(std::move(result)); + return std::make_unique(std::move(result)); } - arrow::Result> DoGetTables( - const ServerCallContext& context, const GetTables& command) { + arrow::Result> DoGetTables( + const flight::ServerCallContext& context, const flight::sql::GetTables& command) { std::string query = PrepareQueryForGetTables(command); ARROW_LOG(INFO) << "GetTables: " << query; @@ -413,20 +423,22 @@ class SQLiteFlightSqlServer::Impl { ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, query)); std::shared_ptr reader; - ARROW_ASSIGN_OR_RAISE(reader, SqliteStatementBatchReader::Create( - statement, SqlSchema::GetTablesSchema())); + ARROW_ASSIGN_OR_RAISE(reader, + SqliteStatementBatchReader::Create( + statement, flight::sql::SqlSchema::GetTablesSchema())); if (command.include_schema) { std::shared_ptr table_schema_reader = std::make_shared(reader, query, db_); - return std::make_unique(table_schema_reader); + return std::make_unique(table_schema_reader); } else { - return std::make_unique(reader); + return std::make_unique(reader); } } - arrow::Result DoPutCommandStatementUpdate(const ServerCallContext& context, - const StatementUpdate& command) { + arrow::Result DoPutCommandStatementUpdate( + const flight::ServerCallContext& context, + const flight::sql::StatementUpdate& command) { const std::string& sql = command.query; ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(command.transaction_id)); ARROW_LOG(INFO) << "Executing update: " << sql; @@ -434,9 +446,9 @@ class SQLiteFlightSqlServer::Impl { return statement->ExecuteUpdate(); } - arrow::Result CreatePreparedStatement( - const ServerCallContext& context, - const ActionCreatePreparedStatementRequest& request) { + arrow::Result CreatePreparedStatement( + const flight::ServerCallContext& context, + const flight::sql::ActionCreatePreparedStatementRequest& request) { std::shared_ptr statement; ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(request.transaction_id)); ARROW_LOG(INFO) << "Creating prepared statement: " << request.query; @@ -452,13 +464,13 @@ class SQLiteFlightSqlServer::Impl { sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); const int parameter_count = sqlite3_bind_parameter_count(stmt); - FieldVector parameter_fields; + arrow::FieldVector parameter_fields; parameter_fields.reserve(parameter_count); // As SQLite doesn't know the parameter types before executing the query, the // example server is accepting any SQLite supported type as input by using a dense // union. - const std::shared_ptr& dense_union_type = GetUnknownColumnDataType(); + const std::shared_ptr& dense_union_type = GetUnknownColumnDataType(); for (int i = 0; i < parameter_count; i++) { const char* parameter_name_chars = sqlite3_bind_parameter_name(stmt, i + 1); @@ -471,13 +483,14 @@ class SQLiteFlightSqlServer::Impl { parameter_fields.push_back(field(parameter_name, dense_union_type)); } - std::shared_ptr parameter_schema = arrow::schema(parameter_fields); - return ActionCreatePreparedStatementResult{ + std::shared_ptr parameter_schema = arrow::schema(parameter_fields); + return flight::sql::ActionCreatePreparedStatementResult{ std::move(dataset_schema), std::move(parameter_schema), std::move(handle)}; } - Status ClosePreparedStatement(const ServerCallContext& context, - const ActionClosePreparedStatementRequest& request) { + Status ClosePreparedStatement( + const flight::ServerCallContext& context, + const flight::sql::ActionClosePreparedStatementRequest& request) { std::lock_guard guard(mutex_); const std::string& prepared_statement_handle = request.prepared_statement_handle; @@ -491,9 +504,10 @@ class SQLiteFlightSqlServer::Impl { return Status::OK(); } - arrow::Result> GetFlightInfoPreparedStatement( - const ServerCallContext& context, const PreparedStatementQuery& command, - const FlightDescriptor& descriptor) { + arrow::Result> GetFlightInfoPreparedStatement( + const flight::ServerCallContext& context, + const flight::sql::PreparedStatementQuery& command, + const flight::FlightDescriptor& descriptor) { std::lock_guard guard(mutex_); const std::string& prepared_statement_handle = command.prepared_statement_handle; @@ -509,8 +523,9 @@ class SQLiteFlightSqlServer::Impl { return GetFlightInfoForCommand(descriptor, schema); } - arrow::Result> DoGetPreparedStatement( - const ServerCallContext& context, const PreparedStatementQuery& command) { + arrow::Result> DoGetPreparedStatement( + const flight::ServerCallContext& context, + const flight::sql::PreparedStatementQuery& command) { std::lock_guard guard(mutex_); const std::string& prepared_statement_handle = command.prepared_statement_handle; @@ -524,13 +539,13 @@ class SQLiteFlightSqlServer::Impl { std::shared_ptr reader; ARROW_ASSIGN_OR_RAISE(reader, SqliteStatementBatchReader::Create(statement)); - return std::make_unique(reader); + return std::make_unique(reader); } - Status DoPutPreparedStatementQuery(const ServerCallContext& context, - const PreparedStatementQuery& command, - FlightMessageReader* reader, - FlightMetadataWriter* writer) { + Status DoPutPreparedStatementQuery(const flight::ServerCallContext& context, + const flight::sql::PreparedStatementQuery& command, + flight::FlightMessageReader* reader, + flight::FlightMetadataWriter* writer) { const std::string& prepared_statement_handle = command.prepared_statement_handle; ARROW_ASSIGN_OR_RAISE(auto statement, GetStatementByHandle(prepared_statement_handle)); @@ -541,8 +556,9 @@ class SQLiteFlightSqlServer::Impl { } arrow::Result DoPutPreparedStatementUpdate( - const ServerCallContext& context, const PreparedStatementUpdate& command, - FlightMessageReader* reader) { + const flight::ServerCallContext& context, + const flight::sql::PreparedStatementUpdate& command, + flight::FlightMessageReader* reader) { const std::string& prepared_statement_handle = command.prepared_statement_handle; ARROW_ASSIGN_OR_RAISE(std::shared_ptr statement, GetStatementByHandle(prepared_statement_handle)); @@ -560,43 +576,52 @@ class SQLiteFlightSqlServer::Impl { return rows_affected; } - arrow::Result> GetFlightInfoTableTypes( - const ServerCallContext& context, const FlightDescriptor& descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetTableTypesSchema()); + arrow::Result> GetFlightInfoTableTypes( + const flight::ServerCallContext& context, + const flight::FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, + flight::sql::SqlSchema::GetTableTypesSchema()); } - arrow::Result> DoGetTableTypes( - const ServerCallContext& context) { + arrow::Result> DoGetTableTypes( + const flight::ServerCallContext& context) { std::string query = "SELECT DISTINCT type as table_type FROM sqlite_master"; - return DoGetSQLiteQuery(db_, query, SqlSchema::GetTableTypesSchema()); + return DoGetSQLiteQuery(db_, query, flight::sql::SqlSchema::GetTableTypesSchema()); } - arrow::Result> GetFlightInfoTypeInfo( - const ServerCallContext& context, const GetXdbcTypeInfo& command, - const FlightDescriptor& descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetXdbcTypeInfoSchema()); + arrow::Result> GetFlightInfoTypeInfo( + const flight::ServerCallContext& context, + const flight::sql::GetXdbcTypeInfo& command, + const flight::FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, + flight::sql::SqlSchema::GetXdbcTypeInfoSchema()); } - arrow::Result> DoGetTypeInfo( - const ServerCallContext& context, const GetXdbcTypeInfo& command) { + arrow::Result> DoGetTypeInfo( + const flight::ServerCallContext& context, + const flight::sql::GetXdbcTypeInfo& command) { ARROW_ASSIGN_OR_RAISE(auto type_info_result, command.data_type.has_value() ? DoGetTypeInfoResult(command.data_type.value()) : DoGetTypeInfoResult()); - ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({type_info_result})); - return std::make_unique(reader); + ARROW_ASSIGN_OR_RAISE(auto reader, + arrow::RecordBatchReader::Make({type_info_result})); + return std::make_unique(reader); } - arrow::Result> GetFlightInfoPrimaryKeys( - const ServerCallContext& context, const GetPrimaryKeys& command, - const FlightDescriptor& descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetPrimaryKeysSchema()); + arrow::Result> GetFlightInfoPrimaryKeys( + const flight::ServerCallContext& context, + const flight::sql::GetPrimaryKeys& command, + const flight::FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, + flight::sql::SqlSchema::GetPrimaryKeysSchema()); } - arrow::Result> DoGetPrimaryKeys( - const ServerCallContext& context, const GetPrimaryKeys& command) { + arrow::Result> DoGetPrimaryKeys( + const flight::ServerCallContext& context, + const flight::sql::GetPrimaryKeys& command) { std::stringstream table_query; // The field key_name cannot be recovered by the sqlite, so it is being set @@ -608,7 +633,7 @@ class SQLiteFlightSqlServer::Impl { "table_name, type as table_type\n" "FROM sqlite_master) where 1=1 and pk != 0"; - const TableRef& table_ref = command.table_ref; + const flight::sql::TableRef& table_ref = command.table_ref; if (table_ref.catalog.has_value()) { table_query << " and catalog_name LIKE '" << table_ref.catalog.value() << "'"; } @@ -619,18 +644,22 @@ class SQLiteFlightSqlServer::Impl { table_query << " and table_name LIKE '" << table_ref.table << "'"; - return DoGetSQLiteQuery(db_, table_query.str(), SqlSchema::GetPrimaryKeysSchema()); + return DoGetSQLiteQuery(db_, table_query.str(), + flight::sql::SqlSchema::GetPrimaryKeysSchema()); } - arrow::Result> GetFlightInfoImportedKeys( - const ServerCallContext& context, const GetImportedKeys& command, - const FlightDescriptor& descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetImportedKeysSchema()); + arrow::Result> GetFlightInfoImportedKeys( + const flight::ServerCallContext& context, + const flight::sql::GetImportedKeys& command, + const flight::FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, + flight::sql::SqlSchema::GetImportedKeysSchema()); } - arrow::Result> DoGetImportedKeys( - const ServerCallContext& context, const GetImportedKeys& command) { - const TableRef& table_ref = command.table_ref; + arrow::Result> DoGetImportedKeys( + const flight::ServerCallContext& context, + const flight::sql::GetImportedKeys& command) { + const flight::sql::TableRef& table_ref = command.table_ref; std::string filter = "fk_table_name = '" + table_ref.table + "'"; if (table_ref.catalog.has_value()) { filter += " AND fk_catalog_name = '" + table_ref.catalog.value() + "'"; @@ -640,18 +669,21 @@ class SQLiteFlightSqlServer::Impl { } std::string query = PrepareQueryForGetImportedOrExportedKeys(filter); - return DoGetSQLiteQuery(db_, query, SqlSchema::GetImportedKeysSchema()); + return DoGetSQLiteQuery(db_, query, flight::sql::SqlSchema::GetImportedKeysSchema()); } - arrow::Result> GetFlightInfoExportedKeys( - const ServerCallContext& context, const GetExportedKeys& command, - const FlightDescriptor& descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetExportedKeysSchema()); + arrow::Result> GetFlightInfoExportedKeys( + const flight::ServerCallContext& context, + const flight::sql::GetExportedKeys& command, + const flight::FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, + flight::sql::SqlSchema::GetExportedKeysSchema()); } - arrow::Result> DoGetExportedKeys( - const ServerCallContext& context, const GetExportedKeys& command) { - const TableRef& table_ref = command.table_ref; + arrow::Result> DoGetExportedKeys( + const flight::ServerCallContext& context, + const flight::sql::GetExportedKeys& command) { + const flight::sql::TableRef& table_ref = command.table_ref; std::string filter = "pk_table_name = '" + table_ref.table + "'"; if (table_ref.catalog.has_value()) { filter += " AND pk_catalog_name = '" + table_ref.catalog.value() + "'"; @@ -661,18 +693,21 @@ class SQLiteFlightSqlServer::Impl { } std::string query = PrepareQueryForGetImportedOrExportedKeys(filter); - return DoGetSQLiteQuery(db_, query, SqlSchema::GetExportedKeysSchema()); + return DoGetSQLiteQuery(db_, query, flight::sql::SqlSchema::GetExportedKeysSchema()); } - arrow::Result> GetFlightInfoCrossReference( - const ServerCallContext& context, const GetCrossReference& command, - const FlightDescriptor& descriptor) { - return GetFlightInfoForCommand(descriptor, SqlSchema::GetCrossReferenceSchema()); + arrow::Result> GetFlightInfoCrossReference( + const flight::ServerCallContext& context, + const flight::sql::GetCrossReference& command, + const flight::FlightDescriptor& descriptor) { + return GetFlightInfoForCommand(descriptor, + flight::sql::SqlSchema::GetCrossReferenceSchema()); } - arrow::Result> DoGetCrossReference( - const ServerCallContext& context, const GetCrossReference& command) { - const TableRef& pk_table_ref = command.pk_table_ref; + arrow::Result> DoGetCrossReference( + const flight::ServerCallContext& context, + const flight::sql::GetCrossReference& command) { + const flight::sql::TableRef& pk_table_ref = command.pk_table_ref; std::string filter = "pk_table_name = '" + pk_table_ref.table + "'"; if (pk_table_ref.catalog.has_value()) { filter += " AND pk_catalog_name = '" + pk_table_ref.catalog.value() + "'"; @@ -681,7 +716,7 @@ class SQLiteFlightSqlServer::Impl { filter += " AND pk_schema_name = '" + pk_table_ref.db_schema.value() + "'"; } - const TableRef& fk_table_ref = command.fk_table_ref; + const flight::sql::TableRef& fk_table_ref = command.fk_table_ref; filter += " AND fk_table_name = '" + fk_table_ref.table + "'"; if (fk_table_ref.catalog.has_value()) { filter += " AND fk_catalog_name = '" + fk_table_ref.catalog.value() + "'"; @@ -691,7 +726,8 @@ class SQLiteFlightSqlServer::Impl { } std::string query = PrepareQueryForGetImportedOrExportedKeys(filter); - return DoGetSQLiteQuery(db_, query, SqlSchema::GetCrossReferenceSchema()); + return DoGetSQLiteQuery(db_, query, + flight::sql::SqlSchema::GetCrossReferenceSchema()); } Status ExecuteSql(const std::string& sql) { return ExecuteSql(db_, sql); } @@ -711,8 +747,9 @@ class SQLiteFlightSqlServer::Impl { return Status::OK(); } - arrow::Result BeginTransaction( - const ServerCallContext& context, const ActionBeginTransactionRequest& request) { + arrow::Result BeginTransaction( + const flight::ServerCallContext& context, + const flight::sql::ActionBeginTransactionRequest& request) { std::string handle = GenerateRandomString(); sqlite3* new_db = nullptr; if (sqlite3_open_v2(db_uri_.c_str(), &new_db, @@ -732,11 +769,11 @@ class SQLiteFlightSqlServer::Impl { std::lock_guard guard(mutex_); open_transactions_[handle] = new_db; - return ActionBeginTransactionResult{std::move(handle)}; + return flight::sql::ActionBeginTransactionResult{std::move(handle)}; } - Status EndTransaction(const ServerCallContext& context, - const ActionEndTransactionRequest& request) { + Status EndTransaction(const flight::ServerCallContext& context, + const flight::sql::ActionEndTransactionRequest& request) { Status status; sqlite3* transaction = nullptr; { @@ -746,7 +783,7 @@ class SQLiteFlightSqlServer::Impl { return Status::KeyError("Unknown transaction ID: ", request.transaction_id); } - if (request.action == ActionEndTransactionRequest::kCommit) { + if (request.action == flight::sql::ActionEndTransactionRequest::kCommit) { ARROW_LOG(INFO) << "Committing on " << request.transaction_id; status = ExecuteSql(it->second, "COMMIT"); } else { @@ -767,20 +804,21 @@ std::atomic kDbCounter(0); SQLiteFlightSqlServer::SQLiteFlightSqlServer(std::shared_ptr impl) : impl_(std::move(impl)) {} -arrow::Result> SQLiteFlightSqlServer::Create(std::string path) { +arrow::Result> SQLiteFlightSqlServer::Create( + std::string path) { std::cout << "SQLite version: " << sqlite3_libversion() << std::endl; sqlite3* db = nullptr; - char* db_location; + char* db_location; - bool in_memory = path == ""; + bool in_memory = path == ""; - if (in_memory) { - db_location = (char*)":memory:"; - } else { - db_location = (char*)path.c_str(); // TODO: validate that the path exists - } + if (in_memory) { + db_location = (char*)":memory:"; + } else { + db_location = (char*)path.c_str(); // TODO: validate that the path exists + } if (sqlite3_open_v2(db_location, &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI, @@ -813,174 +851,193 @@ Status SQLiteFlightSqlServer::ExecuteSql(const std::string& sql) { return impl_->ExecuteSql(sql); } -arrow::Result> SQLiteFlightSqlServer::GetFlightInfoStatement( - const ServerCallContext& context, const StatementQuery& command, - const FlightDescriptor& descriptor) { +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoStatement( + const flight::ServerCallContext& context, const flight::sql::StatementQuery& command, + const flight::FlightDescriptor& descriptor) { return impl_->GetFlightInfoStatement(context, command, descriptor); } -arrow::Result> SQLiteFlightSqlServer::DoGetStatement( - const ServerCallContext& context, const StatementQueryTicket& command) { +arrow::Result> +SQLiteFlightSqlServer::DoGetStatement(const flight::ServerCallContext& context, + const flight::sql::StatementQueryTicket& command) { return impl_->DoGetStatement(context, command); } -arrow::Result> SQLiteFlightSqlServer::GetFlightInfoCatalogs( - const ServerCallContext& context, const FlightDescriptor& descriptor) { +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoCatalogs(const flight::ServerCallContext& context, + const flight::FlightDescriptor& descriptor) { return impl_->GetFlightInfoCatalogs(context, descriptor); } -arrow::Result> SQLiteFlightSqlServer::DoGetCatalogs( - const ServerCallContext& context) { +arrow::Result> +SQLiteFlightSqlServer::DoGetCatalogs(const flight::ServerCallContext& context) { return impl_->DoGetCatalogs(context); } -arrow::Result> SQLiteFlightSqlServer::GetFlightInfoSchemas( - const ServerCallContext& context, const GetDbSchemas& command, - const FlightDescriptor& descriptor) { +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoSchemas(const flight::ServerCallContext& context, + const flight::sql::GetDbSchemas& command, + const flight::FlightDescriptor& descriptor) { return impl_->GetFlightInfoSchemas(context, command, descriptor); } -arrow::Result> SQLiteFlightSqlServer::DoGetDbSchemas( - const ServerCallContext& context, const GetDbSchemas& command) { +arrow::Result> +SQLiteFlightSqlServer::DoGetDbSchemas(const flight::ServerCallContext& context, + const flight::sql::GetDbSchemas& command) { return impl_->DoGetDbSchemas(context, command); } -arrow::Result> SQLiteFlightSqlServer::GetFlightInfoTables( - const ServerCallContext& context, const GetTables& command, - const FlightDescriptor& descriptor) { +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoTables(const flight::ServerCallContext& context, + const flight::sql::GetTables& command, + const flight::FlightDescriptor& descriptor) { return impl_->GetFlightInfoTables(context, command, descriptor); } -arrow::Result> SQLiteFlightSqlServer::DoGetTables( - const ServerCallContext& context, const GetTables& command) { +arrow::Result> +SQLiteFlightSqlServer::DoGetTables(const flight::ServerCallContext& context, + const flight::sql::GetTables& command) { return impl_->DoGetTables(context, command); } arrow::Result SQLiteFlightSqlServer::DoPutCommandStatementUpdate( - const ServerCallContext& context, const StatementUpdate& command) { + const flight::ServerCallContext& context, + const flight::sql::StatementUpdate& command) { return impl_->DoPutCommandStatementUpdate(context, command); } -arrow::Result +arrow::Result SQLiteFlightSqlServer::CreatePreparedStatement( - const ServerCallContext& context, - const ActionCreatePreparedStatementRequest& request) { + const flight::ServerCallContext& context, + const flight::sql::ActionCreatePreparedStatementRequest& request) { return impl_->CreatePreparedStatement(context, request); } Status SQLiteFlightSqlServer::ClosePreparedStatement( - const ServerCallContext& context, - const ActionClosePreparedStatementRequest& request) { + const flight::ServerCallContext& context, + const flight::sql::ActionClosePreparedStatementRequest& request) { return impl_->ClosePreparedStatement(context, request); } -arrow::Result> +arrow::Result> SQLiteFlightSqlServer::GetFlightInfoPreparedStatement( - const ServerCallContext& context, const PreparedStatementQuery& command, - const FlightDescriptor& descriptor) { + const flight::ServerCallContext& context, + const flight::sql::PreparedStatementQuery& command, + const flight::FlightDescriptor& descriptor) { return impl_->GetFlightInfoPreparedStatement(context, command, descriptor); } -arrow::Result> -SQLiteFlightSqlServer::DoGetPreparedStatement(const ServerCallContext& context, - const PreparedStatementQuery& command) { +arrow::Result> +SQLiteFlightSqlServer::DoGetPreparedStatement( + const flight::ServerCallContext& context, + const flight::sql::PreparedStatementQuery& command) { return impl_->DoGetPreparedStatement(context, command); } Status SQLiteFlightSqlServer::DoPutPreparedStatementQuery( - const ServerCallContext& context, const PreparedStatementQuery& command, - FlightMessageReader* reader, FlightMetadataWriter* writer) { + const flight::ServerCallContext& context, + const flight::sql::PreparedStatementQuery& command, + flight::FlightMessageReader* reader, flight::FlightMetadataWriter* writer) { return impl_->DoPutPreparedStatementQuery(context, command, reader, writer); } arrow::Result SQLiteFlightSqlServer::DoPutPreparedStatementUpdate( - const ServerCallContext& context, const PreparedStatementUpdate& command, - FlightMessageReader* reader) { + const flight::ServerCallContext& context, + const flight::sql::PreparedStatementUpdate& command, + flight::FlightMessageReader* reader) { return impl_->DoPutPreparedStatementUpdate(context, command, reader); } -arrow::Result> SQLiteFlightSqlServer::GetFlightInfoTableTypes( - const ServerCallContext& context, const FlightDescriptor& descriptor) { +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoTableTypes( + const flight::ServerCallContext& context, + const flight::FlightDescriptor& descriptor) { return impl_->GetFlightInfoTableTypes(context, descriptor); } -arrow::Result> SQLiteFlightSqlServer::DoGetTableTypes( - const ServerCallContext& context) { +arrow::Result> +SQLiteFlightSqlServer::DoGetTableTypes(const flight::ServerCallContext& context) { return impl_->DoGetTableTypes(context); } -arrow::Result> +arrow::Result> SQLiteFlightSqlServer::GetFlightInfoXdbcTypeInfo( - const ServerCallContext& context, const arrow::flight::sql::GetXdbcTypeInfo& command, - const FlightDescriptor& descriptor) { + const flight::ServerCallContext& context, const flight::sql::GetXdbcTypeInfo& command, + const flight::FlightDescriptor& descriptor) { return impl_->GetFlightInfoTypeInfo(context, command, descriptor); } -arrow::Result> SQLiteFlightSqlServer::DoGetXdbcTypeInfo( - const ServerCallContext& context, - const arrow::flight::sql::GetXdbcTypeInfo& command) { +arrow::Result> +SQLiteFlightSqlServer::DoGetXdbcTypeInfo(const flight::ServerCallContext& context, + const flight::sql::GetXdbcTypeInfo& command) { return impl_->DoGetTypeInfo(context, command); } -arrow::Result> -SQLiteFlightSqlServer::GetFlightInfoPrimaryKeys(const ServerCallContext& context, - const GetPrimaryKeys& command, - const FlightDescriptor& descriptor) { +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoPrimaryKeys( + const flight::ServerCallContext& context, const flight::sql::GetPrimaryKeys& command, + const flight::FlightDescriptor& descriptor) { return impl_->GetFlightInfoPrimaryKeys(context, command, descriptor); } -arrow::Result> SQLiteFlightSqlServer::DoGetPrimaryKeys( - const ServerCallContext& context, const GetPrimaryKeys& command) { +arrow::Result> +SQLiteFlightSqlServer::DoGetPrimaryKeys(const flight::ServerCallContext& context, + const flight::sql::GetPrimaryKeys& command) { return impl_->DoGetPrimaryKeys(context, command); } -arrow::Result> -SQLiteFlightSqlServer::GetFlightInfoImportedKeys(const ServerCallContext& context, - const GetImportedKeys& command, - const FlightDescriptor& descriptor) { +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoImportedKeys( + const flight::ServerCallContext& context, const flight::sql::GetImportedKeys& command, + const flight::FlightDescriptor& descriptor) { return impl_->GetFlightInfoImportedKeys(context, command, descriptor); } -arrow::Result> SQLiteFlightSqlServer::DoGetImportedKeys( - const ServerCallContext& context, const GetImportedKeys& command) { +arrow::Result> +SQLiteFlightSqlServer::DoGetImportedKeys(const flight::ServerCallContext& context, + const flight::sql::GetImportedKeys& command) { return impl_->DoGetImportedKeys(context, command); } -arrow::Result> -SQLiteFlightSqlServer::GetFlightInfoExportedKeys(const ServerCallContext& context, - const GetExportedKeys& command, - const FlightDescriptor& descriptor) { +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoExportedKeys( + const flight::ServerCallContext& context, const flight::sql::GetExportedKeys& command, + const flight::FlightDescriptor& descriptor) { return impl_->GetFlightInfoExportedKeys(context, command, descriptor); } -arrow::Result> SQLiteFlightSqlServer::DoGetExportedKeys( - const ServerCallContext& context, const GetExportedKeys& command) { +arrow::Result> +SQLiteFlightSqlServer::DoGetExportedKeys(const flight::ServerCallContext& context, + const flight::sql::GetExportedKeys& command) { return impl_->DoGetExportedKeys(context, command); } -arrow::Result> -SQLiteFlightSqlServer::GetFlightInfoCrossReference(const ServerCallContext& context, - const GetCrossReference& command, - const FlightDescriptor& descriptor) { +arrow::Result> +SQLiteFlightSqlServer::GetFlightInfoCrossReference( + const flight::ServerCallContext& context, + const flight::sql::GetCrossReference& command, + const flight::FlightDescriptor& descriptor) { return impl_->GetFlightInfoCrossReference(context, command, descriptor); } -arrow::Result> -SQLiteFlightSqlServer::DoGetCrossReference(const ServerCallContext& context, - const GetCrossReference& command) { +arrow::Result> +SQLiteFlightSqlServer::DoGetCrossReference( + const flight::ServerCallContext& context, + const flight::sql::GetCrossReference& command) { return impl_->DoGetCrossReference(context, command); } -arrow::Result SQLiteFlightSqlServer::BeginTransaction( - const ServerCallContext& context, const ActionBeginTransactionRequest& request) { +arrow::Result +SQLiteFlightSqlServer::BeginTransaction( + const flight::ServerCallContext& context, + const flight::sql::ActionBeginTransactionRequest& request) { return impl_->BeginTransaction(context, request); } -Status SQLiteFlightSqlServer::EndTransaction(const ServerCallContext& context, - const ActionEndTransactionRequest& request) { +Status SQLiteFlightSqlServer::EndTransaction( + const flight::ServerCallContext& context, + const flight::sql::ActionEndTransactionRequest& request) { return impl_->EndTransaction(context, request); } -} // namespace sqlite -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::sqlite diff --git a/src/sqlite/sqlite_server.h b/src/sqlite/sqlite_server.h index 3952ed8..d8cb9fb 100644 --- a/src/sqlite/sqlite_server.h +++ b/src/sqlite/sqlite_server.h @@ -26,17 +26,16 @@ #include "sqlite_statement.h" #include "sqlite_statement_batch_reader.h" #include "arrow/flight/sql/server.h" +#include "arrow/flight/types.h" #include "arrow/result.h" +#include "flight_sql_fwd.h" -namespace arrow { -namespace flight { -namespace sql { -namespace sqlite { +namespace sqlflite::sqlite { /// \brief Convert a column type to a ArrowType. /// \param sqlite_type the sqlite type. /// \return The equivalent ArrowType. -arrow::Result> GetArrowType(const char* sqlite_type); +arrow::Result> GetArrowType(const char* sqlite_type); /// \brief Convert a column type name to SQLite type. /// \param type_name the type name. @@ -45,18 +44,18 @@ int32_t GetSqlTypeFromTypeName(const char* type_name); /// \brief Get the DataType used when parameter type is not known. /// \return DataType used when parameter type is not known. -inline std::shared_ptr GetUnknownColumnDataType() { - return dense_union({ - field("string", utf8()), - field("bytes", binary()), - field("bigint", int64()), - field("double", float64()), +inline std::shared_ptr GetUnknownColumnDataType() { + return arrow::dense_union({ + field("string", arrow::utf8()), + field("bytes", arrow::binary()), + field("bigint", arrow::int64()), + field("double", arrow::float64()), }); } /// \brief Example implementation of FlightSqlServerBase backed by an in-memory SQLite3 /// database. -class SQLiteFlightSqlServer : public FlightSqlServerBase { +class SQLiteFlightSqlServer : public flight::sql::FlightSqlServerBase { public: ~SQLiteFlightSqlServer() override; @@ -64,89 +63,108 @@ class SQLiteFlightSqlServer : public FlightSqlServerBase { /// \brief Auxiliary method used to execute an arbitrary SQL statement on the underlying /// SQLite database. - Status ExecuteSql(const std::string& sql); - - arrow::Result> GetFlightInfoStatement( - const ServerCallContext& context, const StatementQuery& command, - const FlightDescriptor& descriptor) override; - - arrow::Result> DoGetStatement( - const ServerCallContext& context, const StatementQueryTicket& command) override; - arrow::Result> GetFlightInfoCatalogs( - const ServerCallContext& context, const FlightDescriptor& descriptor) override; - arrow::Result> DoGetCatalogs( - const ServerCallContext& context) override; - arrow::Result> GetFlightInfoSchemas( - const ServerCallContext& context, const GetDbSchemas& command, - const FlightDescriptor& descriptor) override; - arrow::Result> DoGetDbSchemas( - const ServerCallContext& context, const GetDbSchemas& command) override; + arrow::Status ExecuteSql(const std::string& sql); + + arrow::Result> GetFlightInfoStatement( + const flight::ServerCallContext& context, + const flight::sql::StatementQuery& command, + const flight::FlightDescriptor& descriptor) override; + + arrow::Result> DoGetStatement( + const flight::ServerCallContext& context, + const flight::sql::StatementQueryTicket& command) override; + arrow::Result> GetFlightInfoCatalogs( + const flight::ServerCallContext& context, + const flight::FlightDescriptor& descriptor) override; + arrow::Result> DoGetCatalogs( + const flight::ServerCallContext& context) override; + arrow::Result> GetFlightInfoSchemas( + const flight::ServerCallContext& context, const flight::sql::GetDbSchemas& command, + const flight::FlightDescriptor& descriptor) override; + arrow::Result> DoGetDbSchemas( + const flight::ServerCallContext& context, + const flight::sql::GetDbSchemas& command) override; arrow::Result DoPutCommandStatementUpdate( - const ServerCallContext& context, const StatementUpdate& update) override; - arrow::Result CreatePreparedStatement( - const ServerCallContext& context, - const ActionCreatePreparedStatementRequest& request) override; - Status ClosePreparedStatement( - const ServerCallContext& context, - const ActionClosePreparedStatementRequest& request) override; - arrow::Result> GetFlightInfoPreparedStatement( - const ServerCallContext& context, const PreparedStatementQuery& command, - const FlightDescriptor& descriptor) override; - arrow::Result> DoGetPreparedStatement( - const ServerCallContext& context, const PreparedStatementQuery& command) override; - Status DoPutPreparedStatementQuery(const ServerCallContext& context, - const PreparedStatementQuery& command, - FlightMessageReader* reader, - FlightMetadataWriter* writer) override; + const flight::ServerCallContext& context, + const flight::sql::StatementUpdate& update) override; + arrow::Result CreatePreparedStatement( + const flight::ServerCallContext& context, + const flight::sql::ActionCreatePreparedStatementRequest& request) override; + arrow::Status ClosePreparedStatement( + const flight::ServerCallContext& context, + const flight::sql::ActionClosePreparedStatementRequest& request) override; + arrow::Result> GetFlightInfoPreparedStatement( + const flight::ServerCallContext& context, + const flight::sql::PreparedStatementQuery& command, + const flight::FlightDescriptor& descriptor) override; + arrow::Result> DoGetPreparedStatement( + const flight::ServerCallContext& context, + const flight::sql::PreparedStatementQuery& command) override; + arrow::Status DoPutPreparedStatementQuery( + const flight::ServerCallContext& context, + const flight::sql::PreparedStatementQuery& command, + flight::FlightMessageReader* reader, flight::FlightMetadataWriter* writer) override; arrow::Result DoPutPreparedStatementUpdate( - const ServerCallContext& context, const PreparedStatementUpdate& command, - FlightMessageReader* reader) override; - - arrow::Result> GetFlightInfoTables( - const ServerCallContext& context, const GetTables& command, - const FlightDescriptor& descriptor) override; - - arrow::Result> DoGetTables( - const ServerCallContext& context, const GetTables& command) override; - arrow::Result> GetFlightInfoXdbcTypeInfo( - const ServerCallContext& context, - const arrow::flight::sql::GetXdbcTypeInfo& command, - const FlightDescriptor& descriptor) override; - arrow::Result> DoGetXdbcTypeInfo( - const ServerCallContext& context, - const arrow::flight::sql::GetXdbcTypeInfo& command) override; - arrow::Result> GetFlightInfoTableTypes( - const ServerCallContext& context, const FlightDescriptor& descriptor) override; - arrow::Result> DoGetTableTypes( - const ServerCallContext& context) override; - arrow::Result> GetFlightInfoImportedKeys( - const ServerCallContext& context, const GetImportedKeys& command, - const FlightDescriptor& descriptor) override; - arrow::Result> DoGetImportedKeys( - const ServerCallContext& context, const GetImportedKeys& command) override; - arrow::Result> GetFlightInfoExportedKeys( - const ServerCallContext& context, const GetExportedKeys& command, - const FlightDescriptor& descriptor) override; - arrow::Result> DoGetExportedKeys( - const ServerCallContext& context, const GetExportedKeys& command) override; - arrow::Result> GetFlightInfoCrossReference( - const ServerCallContext& context, const GetCrossReference& command, - const FlightDescriptor& descriptor) override; - arrow::Result> DoGetCrossReference( - const ServerCallContext& context, const GetCrossReference& command) override; - - arrow::Result> GetFlightInfoPrimaryKeys( - const ServerCallContext& context, const GetPrimaryKeys& command, - const FlightDescriptor& descriptor) override; - - arrow::Result> DoGetPrimaryKeys( - const ServerCallContext& context, const GetPrimaryKeys& command) override; - - arrow::Result BeginTransaction( - const ServerCallContext& context, - const ActionBeginTransactionRequest& request) override; - Status EndTransaction(const ServerCallContext& context, - const ActionEndTransactionRequest& request) override; + const flight::ServerCallContext& context, + const flight::sql::PreparedStatementUpdate& command, + flight::FlightMessageReader* reader) override; + + arrow::Result> GetFlightInfoTables( + const flight::ServerCallContext& context, const flight::sql::GetTables& command, + const flight::FlightDescriptor& descriptor) override; + + arrow::Result> DoGetTables( + const flight::ServerCallContext& context, + const flight::sql::GetTables& command) override; + arrow::Result> GetFlightInfoXdbcTypeInfo( + const flight::ServerCallContext& context, + const flight::sql::GetXdbcTypeInfo& command, + const flight::FlightDescriptor& descriptor) override; + arrow::Result> DoGetXdbcTypeInfo( + const flight::ServerCallContext& context, + const flight::sql::GetXdbcTypeInfo& command) override; + arrow::Result> GetFlightInfoTableTypes( + const flight::ServerCallContext& context, + const flight::FlightDescriptor& descriptor) override; + arrow::Result> DoGetTableTypes( + const flight::ServerCallContext& context) override; + arrow::Result> GetFlightInfoImportedKeys( + const flight::ServerCallContext& context, + const flight::sql::GetImportedKeys& command, + const flight::FlightDescriptor& descriptor) override; + arrow::Result> DoGetImportedKeys( + const flight::ServerCallContext& context, + const flight::sql::GetImportedKeys& command) override; + arrow::Result> GetFlightInfoExportedKeys( + const flight::ServerCallContext& context, + const flight::sql::GetExportedKeys& command, + const flight::FlightDescriptor& descriptor) override; + arrow::Result> DoGetExportedKeys( + const flight::ServerCallContext& context, + const flight::sql::GetExportedKeys& command) override; + arrow::Result> GetFlightInfoCrossReference( + const flight::ServerCallContext& context, + const flight::sql::GetCrossReference& command, + const flight::FlightDescriptor& descriptor) override; + arrow::Result> DoGetCrossReference( + const flight::ServerCallContext& context, + const flight::sql::GetCrossReference& command) override; + + arrow::Result> GetFlightInfoPrimaryKeys( + const flight::ServerCallContext& context, + const flight::sql::GetPrimaryKeys& command, + const flight::FlightDescriptor& descriptor) override; + + arrow::Result> DoGetPrimaryKeys( + const flight::ServerCallContext& context, + const flight::sql::GetPrimaryKeys& command) override; + + arrow::Result BeginTransaction( + const flight::ServerCallContext& context, + const flight::sql::ActionBeginTransactionRequest& request) override; + arrow::Status EndTransaction( + const flight::ServerCallContext& context, + const flight::sql::ActionEndTransactionRequest& request) override; private: class Impl; @@ -155,7 +173,4 @@ class SQLiteFlightSqlServer : public FlightSqlServerBase { explicit SQLiteFlightSqlServer(std::shared_ptr impl); }; -} // namespace sqlite -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::sqlite diff --git a/src/sqlite/sqlite_sql_info.cc b/src/sqlite/sqlite_sql_info.cc index eab203c..22ff8e2 100644 --- a/src/sqlite/sqlite_sql_info.cc +++ b/src/sqlite/sqlite_sql_info.cc @@ -19,15 +19,15 @@ #include "arrow/flight/sql/types.h" #include "arrow/util/config.h" +#include "flight_sql_fwd.h" -namespace arrow { -namespace flight { -namespace sql { -namespace sqlite { +namespace sqlflite::sqlite { /// \brief Gets the mapping from SQL info ids to SqlInfoResult instances. /// \return the cache. -SqlInfoResultMap GetSqlInfoResultMap() { +flight::sql::SqlInfoResultMap GetSqlInfoResultMap() { + using SqlInfoOptions = flight::sql::SqlInfoOptions; + using SqlInfoResult = flight::sql::SqlInfoResult; return { {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, SqlInfoResult(std::string("db_name"))}, @@ -48,17 +48,17 @@ SqlInfoResultMap GetSqlInfoResultMap() { SqlInfoResult(false /* SQLite 3 does not support schemas */)}, {SqlInfoOptions::SqlInfo::SQL_DDL_TABLE, SqlInfoResult(true)}, {SqlInfoOptions::SqlInfo::SQL_IDENTIFIER_CASE, - SqlInfoResult(int64_t(SqlInfoOptions::SqlSupportedCaseSensitivity:: - SQL_CASE_SENSITIVITY_CASE_INSENSITIVE))}, + SqlInfoResult(static_cast(SqlInfoOptions::SqlSupportedCaseSensitivity:: + SQL_CASE_SENSITIVITY_CASE_INSENSITIVE))}, {SqlInfoOptions::SqlInfo::SQL_IDENTIFIER_QUOTE_CHAR, SqlInfoResult(std::string("\""))}, {SqlInfoOptions::SqlInfo::SQL_QUOTED_IDENTIFIER_CASE, - SqlInfoResult(int64_t(SqlInfoOptions::SqlSupportedCaseSensitivity:: - SQL_CASE_SENSITIVITY_CASE_INSENSITIVE))}, + SqlInfoResult(static_cast(SqlInfoOptions::SqlSupportedCaseSensitivity:: + SQL_CASE_SENSITIVITY_CASE_INSENSITIVE))}, {SqlInfoOptions::SqlInfo::SQL_ALL_TABLES_ARE_SELECTABLE, SqlInfoResult(true)}, {SqlInfoOptions::SqlInfo::SQL_NULL_ORDERING, - SqlInfoResult( - int64_t(SqlInfoOptions::SqlNullOrdering::SQL_NULLS_SORTED_AT_START))}, + SqlInfoResult(static_cast( + SqlInfoOptions::SqlNullOrdering::SQL_NULLS_SORTED_AT_START))}, {SqlInfoOptions::SqlInfo::SQL_KEYWORDS, SqlInfoResult(std::vector({"ABORT", "ACTION", @@ -224,7 +224,4 @@ SqlInfoResultMap GetSqlInfoResultMap() { {SqlInfoOptions::SqlSupportsConvert::SQL_CONVERT_INTEGER})}}))}}; } -} // namespace sqlite -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::sqlite diff --git a/src/sqlite/sqlite_sql_info.h b/src/sqlite/sqlite_sql_info.h index f205f68..5464bd8 100644 --- a/src/sqlite/sqlite_sql_info.h +++ b/src/sqlite/sqlite_sql_info.h @@ -19,16 +19,10 @@ #include "arrow/flight/sql/types.h" -namespace arrow { -namespace flight { -namespace sql { -namespace sqlite { +namespace sqlflite::sqlite { /// \brief Gets the mapping from SQL info ids to SqlInfoResult instances. /// \return the cache. -SqlInfoResultMap GetSqlInfoResultMap(); +arrow::flight::sql::SqlInfoResultMap GetSqlInfoResultMap(); -} // namespace sqlite -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::sqlite diff --git a/src/sqlite/sqlite_statement.cc b/src/sqlite/sqlite_statement.cc index 4d3dbdc..fa7754c 100644 --- a/src/sqlite/sqlite_statement.cc +++ b/src/sqlite/sqlite_statement.cc @@ -31,27 +31,25 @@ #include "arrow/table.h" #include "arrow/type.h" #include "arrow/util/checked_cast.h" +#include "flight_sql_fwd.h" -namespace arrow { -namespace flight { -namespace sql { -namespace sqlite { +namespace sqlflite::sqlite { using arrow::internal::checked_cast; -std::shared_ptr GetDataTypeFromSqliteType(const int column_type) { +std::shared_ptr GetDataTypeFromSqliteType(const int column_type) { switch (column_type) { case SQLITE_INTEGER: - return int64(); + return arrow::int64(); case SQLITE_FLOAT: - return float64(); + return arrow::float64(); case SQLITE_BLOB: - return binary(); + return arrow::binary(); case SQLITE_TEXT: - return utf8(); + return arrow::utf8(); case SQLITE_NULL: default: - return null(); + return arrow::null(); } } @@ -67,8 +65,9 @@ int32_t GetPrecisionFromColumn(int column_type) { } } -ColumnMetadata GetColumnMetadata(int column_type, const char* table) { - ColumnMetadata::ColumnMetadataBuilder builder = ColumnMetadata::Builder(); +flight::sql::ColumnMetadata GetColumnMetadata(int column_type, const char* table) { + flight::sql::ColumnMetadata::ColumnMetadataBuilder builder = + flight::sql::ColumnMetadata::Builder(); builder.Scale(15).IsAutoIncrement(false).IsReadOnly(false); if (table == NULLPTR) { @@ -98,15 +97,15 @@ arrow::Result> SqliteStatement::Create( err_msg += std::string(sqlite3_errmsg(db)); } } - return Status::Invalid(err_msg); + return arrow::Status::Invalid(err_msg); } std::shared_ptr result(new SqliteStatement(db, stmt)); return result; } -arrow::Result> SqliteStatement::GetSchema() const { - std::vector> fields; +arrow::Result> SqliteStatement::GetSchema() const { + std::vector> fields; int column_count = sqlite3_column_count(stmt_); for (int i = 0; i < column_count; i++) { const char* column_name = sqlite3_column_name(stmt_, i); @@ -123,19 +122,19 @@ arrow::Result> SqliteStatement::GetSchema() const { // SQLite supports. const int column_type = sqlite3_column_type(stmt_, i); const char* table = sqlite3_column_table_name(stmt_, i); - std::shared_ptr data_type = GetDataTypeFromSqliteType(column_type); - if (data_type->id() == Type::NA) { + std::shared_ptr data_type = GetDataTypeFromSqliteType(column_type); + if (data_type->id() == arrow::Type::NA) { // Try to retrieve column type from sqlite3_column_decltype const char* column_decltype = sqlite3_column_decltype(stmt_, i); if (column_decltype != NULLPTR) { - ARROW_ASSIGN_OR_RAISE(data_type, GetArrowType(column_decltype)); + ARROW_ASSIGN_OR_RAISE(data_type, sqlflite::sqlite::GetArrowType(column_decltype)); } else { // If it cannot determine the actual column type, return a dense_union type // covering any type SQLite supports. - data_type = GetUnknownColumnDataType(); + data_type = sqlflite::sqlite::GetUnknownColumnDataType(); } } - ColumnMetadata column_metadata = GetColumnMetadata(column_type, table); + flight::sql::ColumnMetadata column_metadata = GetColumnMetadata(column_type, table); fields.push_back( arrow::field(column_name, data_type, column_metadata.metadata_map())); @@ -149,8 +148,8 @@ SqliteStatement::~SqliteStatement() { sqlite3_finalize(stmt_); } arrow::Result SqliteStatement::Step() { int rc = sqlite3_step(stmt_); if (rc == SQLITE_ERROR) { - return Status::ExecutionError("A SQLite runtime error has occurred: ", - sqlite3_errmsg(db_)); + return arrow::Status::ExecutionError("A SQLite runtime error has occurred: ", + sqlite3_errmsg(db_)); } return rc; @@ -159,8 +158,8 @@ arrow::Result SqliteStatement::Step() { arrow::Result SqliteStatement::Reset() { int rc = sqlite3_reset(stmt_); if (rc == SQLITE_ERROR) { - return Status::ExecutionError("A SQLite runtime error has occurred: ", - sqlite3_errmsg(db_)); + return arrow::Status::ExecutionError("A SQLite runtime error has occurred: ", + sqlite3_errmsg(db_)); } return rc; @@ -176,42 +175,43 @@ arrow::Result SqliteStatement::ExecuteUpdate() { return sqlite3_changes(db_); } -Status SqliteStatement::SetParameters( +arrow::Status SqliteStatement::SetParameters( std::vector> parameters) { const int num_params = sqlite3_bind_parameter_count(stmt_); for (const auto& batch : parameters) { if (batch->num_columns() != num_params) { - return Status::Invalid("Expected ", num_params, " parameters, but got ", - batch->num_columns()); + return arrow::Status::Invalid("Expected ", num_params, " parameters, but got ", + batch->num_columns()); } } parameters_ = std::move(parameters); - auto end = std::remove_if( - parameters_.begin(), parameters_.end(), - [](const std::shared_ptr& batch) { return batch->num_rows() == 0; }); + auto end = std::remove_if(parameters_.begin(), parameters_.end(), + [](const std::shared_ptr& batch) { + return batch->num_rows() == 0; + }); parameters_.erase(end, parameters_.end()); - return Status::OK(); + return arrow::Status::OK(); } -Status SqliteStatement::Bind(size_t batch_index, int64_t row_index) { +arrow::Status SqliteStatement::Bind(size_t batch_index, int64_t row_index) { if (batch_index >= parameters_.size()) { - return Status::IndexError("Cannot bind to batch ", batch_index); + return arrow::Status::IndexError("Cannot bind to batch ", batch_index); } - const RecordBatch& batch = *parameters_[batch_index]; + const arrow::RecordBatch& batch = *parameters_[batch_index]; if (row_index < 0 || row_index >= batch.num_rows()) { - return Status::IndexError("Cannot bind to row ", row_index, " in batch ", - batch_index); + return arrow::Status::IndexError("Cannot bind to row ", row_index, " in batch ", + batch_index); } if (sqlite3_clear_bindings(stmt_) != SQLITE_OK) { - return Status::Invalid("Failed to reset bindings: ", sqlite3_errmsg(db_)); + return arrow::Status::Invalid("Failed to reset bindings: ", sqlite3_errmsg(db_)); } for (int c = 0; c < batch.num_columns(); ++c) { - Array* column = batch.column(c).get(); + arrow::Array* column = batch.column(c).get(); int64_t column_index = row_index; - if (column->type_id() == Type::DENSE_UNION) { + if (column->type_id() == arrow::Type::DENSE_UNION) { // Allow polymorphic bindings via union - const auto& u = checked_cast(*column); + const auto& u = checked_cast(*column); column_index = u.value_offset(column_index); column = u.field(u.child_id(row_index)).get(); } @@ -222,48 +222,48 @@ Status SqliteStatement::Bind(size_t batch_index, int64_t row_index) { continue; } switch (column->type_id()) { - case Type::INT32: { + case arrow::Type::INT32: { const int32_t value = - checked_cast(*column).Value(column_index); + checked_cast(*column).Value(column_index); rc = sqlite3_bind_int64(stmt_, c + 1, value); break; } - case Type::INT64: { + case arrow::Type::INT64: { const int64_t value = - checked_cast(*column).Value(column_index); + checked_cast(*column).Value(column_index); rc = sqlite3_bind_int64(stmt_, c + 1, value); break; } - case Type::FLOAT: { - const float value = checked_cast(*column).Value(column_index); + case arrow::Type::FLOAT: { + const float value = + checked_cast(*column).Value(column_index); rc = sqlite3_bind_double(stmt_, c + 1, value); break; } - case Type::DOUBLE: { + case arrow::Type::DOUBLE: { const double value = - checked_cast(*column).Value(column_index); + checked_cast(*column).Value(column_index); rc = sqlite3_bind_double(stmt_, c + 1, value); break; } - case Type::STRING: { + case arrow::Type::STRING: { const std::string_view value = - checked_cast(*column).Value(column_index); + checked_cast(*column).Value(column_index); rc = sqlite3_bind_text(stmt_, c + 1, value.data(), static_cast(value.size()), SQLITE_TRANSIENT); break; } default: - return Status::TypeError("Received unsupported data type: ", *column->type()); + return arrow::Status::TypeError("Received unsupported data type: ", + *column->type()); } if (rc != SQLITE_OK) { - return Status::UnknownError("Failed to bind parameter: ", sqlite3_errmsg(db_)); + return arrow::Status::UnknownError("Failed to bind parameter: ", + sqlite3_errmsg(db_)); } } - return Status::OK(); + return arrow::Status::OK(); } -} // namespace sqlite -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::sqlite diff --git a/src/sqlite/sqlite_statement.h b/src/sqlite/sqlite_statement.h index 8b7caf5..53e69cc 100644 --- a/src/sqlite/sqlite_statement.h +++ b/src/sqlite/sqlite_statement.h @@ -25,17 +25,14 @@ #include "arrow/flight/sql/column_metadata.h" #include "arrow/type_fwd.h" -namespace arrow { -namespace flight { -namespace sql { -namespace sqlite { +namespace sqlflite::sqlite { /// \brief Create an object ColumnMetadata using the column type and /// table name. /// \param column_type The SQLite type. /// \param table The table name. /// \return A Column Metadata object. -ColumnMetadata GetColumnMetadata(int column_type, const char* table); +arrow::flight::sql::ColumnMetadata GetColumnMetadata(int column_type, const char* table); class SqliteStatement { public: @@ -50,7 +47,7 @@ class SqliteStatement { /// \brief Creates an Arrow Schema based on the results of this statement. /// \return The resulting Schema. - arrow::Result> GetSchema() const; + arrow::Result> GetSchema() const; /// \brief Steps on underlying sqlite3_stmt. /// \return The resulting return code from SQLite. @@ -73,8 +70,9 @@ class SqliteStatement { const std::vector>& parameters() const { return parameters_; } - Status SetParameters(std::vector> parameters); - Status Bind(size_t batch_index, int64_t row_index); + arrow::Status SetParameters( + std::vector> parameters); + arrow::Status Bind(size_t batch_index, int64_t row_index); private: sqlite3* db_; @@ -84,7 +82,4 @@ class SqliteStatement { SqliteStatement(sqlite3* db, sqlite3_stmt* stmt) : db_(db), stmt_(stmt) {} }; -} // namespace sqlite -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::sqlite diff --git a/src/sqlite/sqlite_statement_batch_reader.cc b/src/sqlite/sqlite_statement_batch_reader.cc index 3fda44f..eb01964 100644 --- a/src/sqlite/sqlite_statement_batch_reader.cc +++ b/src/sqlite/sqlite_statement_batch_reader.cc @@ -22,64 +22,63 @@ #include "arrow/builder.h" #include "sqlite_statement.h" -#define STRING_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ - case TYPE_CLASS##Type::type_id: { \ - auto builder = reinterpret_cast(array_builder); \ - const int bytes = sqlite3_column_bytes(STMT, COLUMN); \ - const uint8_t* string = \ - reinterpret_cast(sqlite3_column_text(STMT, COLUMN)); \ - if (string == nullptr) { \ - ARROW_RETURN_NOT_OK(builder->AppendNull()); \ - break; \ - } \ - ARROW_RETURN_NOT_OK(builder->Append(string, bytes)); \ - break; \ +#define STRING_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ + case arrow::TYPE_CLASS##Type::type_id: { \ + auto builder = reinterpret_cast(array_builder); \ + const int bytes = sqlite3_column_bytes(STMT, COLUMN); \ + const uint8_t* string = \ + reinterpret_cast(sqlite3_column_text(STMT, COLUMN)); \ + if (string == nullptr) { \ + ARROW_RETURN_NOT_OK(builder->AppendNull()); \ + break; \ + } \ + ARROW_RETURN_NOT_OK(builder->Append(string, bytes)); \ + break; \ } -#define BINARY_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ - case TYPE_CLASS##Type::type_id: { \ - auto builder = reinterpret_cast(array_builder); \ - const int bytes = sqlite3_column_bytes(STMT, COLUMN); \ - const uint8_t* blob = \ - reinterpret_cast(sqlite3_column_blob(STMT, COLUMN)); \ - if (blob == nullptr) { \ - ARROW_RETURN_NOT_OK(builder->AppendNull()); \ - break; \ - } \ - ARROW_RETURN_NOT_OK(builder->Append(blob, bytes)); \ - break; \ +#define BINARY_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ + case arrow::TYPE_CLASS##Type::type_id: { \ + auto builder = reinterpret_cast(array_builder); \ + const int bytes = sqlite3_column_bytes(STMT, COLUMN); \ + const uint8_t* blob = \ + reinterpret_cast(sqlite3_column_blob(STMT, COLUMN)); \ + if (blob == nullptr) { \ + ARROW_RETURN_NOT_OK(builder->AppendNull()); \ + break; \ + } \ + ARROW_RETURN_NOT_OK(builder->Append(blob, bytes)); \ + break; \ } -#define INT_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ - case TYPE_CLASS##Type::type_id: { \ - using c_type = typename TYPE_CLASS##Type::c_type; \ - auto builder = reinterpret_cast(array_builder); \ - const sqlite3_int64 value = sqlite3_column_int64(STMT, COLUMN); \ - ARROW_RETURN_NOT_OK(builder->Append(static_cast(value))); \ - break; \ +#define INT_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ + case arrow::TYPE_CLASS##Type::type_id: { \ + using c_type = typename arrow::TYPE_CLASS##Type::c_type; \ + auto builder = reinterpret_cast(array_builder); \ + const sqlite3_int64 value = sqlite3_column_int64(STMT, COLUMN); \ + ARROW_RETURN_NOT_OK(builder->Append(static_cast(value))); \ + break; \ } -#define FLOAT_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ - case TYPE_CLASS##Type::type_id: { \ - auto builder = reinterpret_cast(array_builder); \ - const double value = sqlite3_column_double(STMT, COLUMN); \ - ARROW_RETURN_NOT_OK( \ - builder->Append(static_cast(value))); \ - break; \ +#define FLOAT_BUILDER_CASE(TYPE_CLASS, STMT, COLUMN) \ + case arrow::TYPE_CLASS##Type::type_id: { \ + auto builder = reinterpret_cast(array_builder); \ + const double value = sqlite3_column_double(STMT, COLUMN); \ + ARROW_RETURN_NOT_OK( \ + builder->Append(static_cast(value))); \ + break; \ } -namespace arrow { -namespace flight { -namespace sql { -namespace sqlite { +namespace sqlflite::sqlite { // Batch size for SQLite statement results static constexpr int32_t kMaxBatchSize = 16384; -std::shared_ptr SqliteStatementBatchReader::schema() const { return schema_; } +std::shared_ptr SqliteStatementBatchReader::schema() const { + return schema_; +} SqliteStatementBatchReader::SqliteStatementBatchReader( - std::shared_ptr statement, std::shared_ptr schema) + std::shared_ptr statement, std::shared_ptr schema) : statement_(std::move(statement)), schema_(std::move(schema)), rc_(SQLITE_OK), @@ -103,22 +102,24 @@ SqliteStatementBatchReader::Create(const std::shared_ptr& state arrow::Result> SqliteStatementBatchReader::Create(const std::shared_ptr& statement, - const std::shared_ptr& schema) { + const std::shared_ptr& schema) { return std::shared_ptr( new SqliteStatementBatchReader(statement, schema)); } -Status SqliteStatementBatchReader::ReadNext(std::shared_ptr* out) { +arrow::Status SqliteStatementBatchReader::ReadNext( + std::shared_ptr* out) { sqlite3_stmt* stmt_ = statement_->GetSqlite3Stmt(); const int num_fields = schema_->num_fields(); std::vector> builders(num_fields); for (int i = 0; i < num_fields; i++) { - const std::shared_ptr& field = schema_->field(i); - const std::shared_ptr& field_type = field->type(); + const std::shared_ptr& field = schema_->field(i); + const std::shared_ptr& field_type = field->type(); - ARROW_RETURN_NOT_OK(MakeBuilder(default_memory_pool(), field_type, &builders[i])); + ARROW_RETURN_NOT_OK( + MakeBuilder(arrow::default_memory_pool(), field_type, &builders[i])); } int64_t rows = 0; @@ -139,9 +140,9 @@ Status SqliteStatementBatchReader::ReadNext(std::shared_ptr* out) { while (rows < kMaxBatchSize && rc_ == SQLITE_ROW) { rows++; for (int i = 0; i < num_fields; i++) { - const std::shared_ptr& field = schema_->field(i); - const std::shared_ptr& field_type = field->type(); - ArrayBuilder* array_builder = builders[i].get(); + const std::shared_ptr& field = schema_->field(i); + const std::shared_ptr& field_type = field->type(); + arrow::ArrayBuilder* array_builder = builders[i].get(); if (sqlite3_column_type(stmt_, i) == SQLITE_NULL) { ARROW_RETURN_NOT_OK(array_builder->AppendNull()); @@ -167,8 +168,8 @@ Status SqliteStatementBatchReader::ReadNext(std::shared_ptr* out) { STRING_BUILDER_CASE(String, stmt_, i) STRING_BUILDER_CASE(LargeString, stmt_, i) default: - return Status::NotImplemented("Not implemented SQLite data conversion to ", - field_type->name()); + return arrow::Status::NotImplemented( + "Not implemented SQLite data conversion to ", field_type->name()); } } @@ -176,7 +177,8 @@ Status SqliteStatementBatchReader::ReadNext(std::shared_ptr* out) { } // If we still have bind parameters, bind again and retry - const std::vector>& params = statement_->parameters(); + const std::vector>& params = + statement_->parameters(); if (!params.empty() && rc_ == SQLITE_DONE && batch_index_ < params.size()) { row_index_++; if (row_index_ < params[batch_index_]->num_rows()) { @@ -193,18 +195,18 @@ Status SqliteStatementBatchReader::ReadNext(std::shared_ptr* out) { } if (rows > 0) { - std::vector> arrays(builders.size()); + std::vector> arrays(builders.size()); for (int i = 0; i < num_fields; i++) { ARROW_RETURN_NOT_OK(builders[i]->Finish(&arrays[i])); } - *out = RecordBatch::Make(schema_, rows, arrays); + *out = arrow::RecordBatch::Make(schema_, rows, arrays); } else { *out = nullptr; } break; } - return Status::OK(); + return arrow::Status::OK(); } #undef STRING_BUILDER_CASE @@ -212,7 +214,4 @@ Status SqliteStatementBatchReader::ReadNext(std::shared_ptr* out) { #undef INT_BUILDER_CASE #undef FLOAT_BUILDER_CASE -} // namespace sqlite -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::sqlite diff --git a/src/sqlite/sqlite_statement_batch_reader.h b/src/sqlite/sqlite_statement_batch_reader.h index 8106faa..ff1930c 100644 --- a/src/sqlite/sqlite_statement_batch_reader.h +++ b/src/sqlite/sqlite_statement_batch_reader.h @@ -24,12 +24,9 @@ #include "sqlite_statement.h" #include "arrow/record_batch.h" -namespace arrow { -namespace flight { -namespace sql { -namespace sqlite { +namespace sqlflite::sqlite { -class SqliteStatementBatchReader : public RecordBatchReader { +class SqliteStatementBatchReader : public arrow::RecordBatchReader { public: /// \brief Creates a RecordBatchReader backed by a SQLite statement. /// \param[in] statement SQLite statement to be read. @@ -43,15 +40,15 @@ class SqliteStatementBatchReader : public RecordBatchReader { /// \return A SqliteStatementBatchReader.. static arrow::Result> Create( const std::shared_ptr& statement, - const std::shared_ptr& schema); + const std::shared_ptr& schema); - std::shared_ptr schema() const override; + std::shared_ptr schema() const override; - Status ReadNext(std::shared_ptr* out) override; + arrow::Status ReadNext(std::shared_ptr* out) override; private: std::shared_ptr statement_; - std::shared_ptr schema_; + std::shared_ptr schema_; int rc_; bool already_executed_; @@ -60,10 +57,7 @@ class SqliteStatementBatchReader : public RecordBatchReader { int64_t row_index_{0}; SqliteStatementBatchReader(std::shared_ptr statement, - std::shared_ptr schema); + std::shared_ptr schema); }; -} // namespace sqlite -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::sqlite diff --git a/src/sqlite/sqlite_tables_schema_batch_reader.cc b/src/sqlite/sqlite_tables_schema_batch_reader.cc index 02c01ba..92efa7b 100644 --- a/src/sqlite/sqlite_tables_schema_batch_reader.cc +++ b/src/sqlite/sqlite_tables_schema_batch_reader.cc @@ -29,16 +29,16 @@ #include "arrow/ipc/writer.h" #include "arrow/record_batch.h" -namespace arrow { -namespace flight { -namespace sql { -namespace sqlite { +#include "flight_sql_fwd.h" -std::shared_ptr SqliteTablesWithSchemaBatchReader::schema() const { - return SqlSchema::GetTablesSchemaWithIncludedSchema(); +namespace sqlflite::sqlite { + +std::shared_ptr SqliteTablesWithSchemaBatchReader::schema() const { + return flight::sql::SqlSchema::GetTablesSchemaWithIncludedSchema(); } -Status SqliteTablesWithSchemaBatchReader::ReadNext(std::shared_ptr* batch) { +arrow::Status SqliteTablesWithSchemaBatchReader::ReadNext( + std::shared_ptr* batch) { std::stringstream schema_query; schema_query @@ -49,23 +49,23 @@ Status SqliteTablesWithSchemaBatchReader::ReadNext(std::shared_ptr* ARROW_ASSIGN_OR_RAISE(schema_statement, sqlite::SqliteStatement::Create(db_, schema_query.str())) - std::shared_ptr first_batch; + std::shared_ptr first_batch; ARROW_RETURN_NOT_OK(reader_->ReadNext(&first_batch)); if (!first_batch) { *batch = NULLPTR; - return Status::OK(); + return arrow::Status::OK(); } - const std::shared_ptr table_name_array = + const std::shared_ptr table_name_array = first_batch->GetColumnByName("table_name"); - BinaryBuilder schema_builder; + arrow::BinaryBuilder schema_builder; - auto* string_array = reinterpret_cast(table_name_array.get()); + auto* string_array = reinterpret_cast(table_name_array.get()); - std::vector> column_fields; + std::vector> column_fields; for (int i = 0; i < table_name_array->length(); i++) { const std::string& table_name = string_array->GetString(i); @@ -79,35 +79,33 @@ Status SqliteTablesWithSchemaBatchReader::ReadNext(std::shared_ptr* sqlite3_column_text(schema_statement->GetSqlite3Stmt(), 2)); int nullable = sqlite3_column_int(schema_statement->GetSqlite3Stmt(), 3); - const ColumnMetadata& column_metadata = GetColumnMetadata( - GetSqlTypeFromTypeName(column_type), sqlite_table_name.c_str()); - std::shared_ptr arrow_type; - auto status = GetArrowType(column_type).Value(&arrow_type); + const flight::sql::ColumnMetadata& column_metadata = + GetColumnMetadata(sqlflite::sqlite::GetSqlTypeFromTypeName(column_type), + sqlite_table_name.c_str()); + std::shared_ptr arrow_type; + auto status = sqlflite::sqlite::GetArrowType(column_type).Value(&arrow_type); if (!status.ok()) { - return Status::NotImplemented("Unknown SQLite type '", column_type, - "' for column '", column_name, "' in table '", - table_name, "': ", status); + return arrow::Status::NotImplemented("Unknown SQLite type '", column_type, + "' for column '", column_name, + "' in table '", table_name, "': ", status); } column_fields.push_back(arrow::field(column_name, arrow_type, nullable == 0, column_metadata.metadata_map())); } } - ARROW_ASSIGN_OR_RAISE(std::shared_ptr schema_buffer, - ipc::SerializeSchema(*arrow::schema(column_fields))); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr schema_buffer, + arrow::ipc::SerializeSchema(*arrow::schema(column_fields))); column_fields.clear(); ARROW_RETURN_NOT_OK(schema_builder.Append(::std::string_view(*schema_buffer))); } - std::shared_ptr schema_array; + std::shared_ptr schema_array; ARROW_RETURN_NOT_OK(schema_builder.Finish(&schema_array)); ARROW_ASSIGN_OR_RAISE(*batch, first_batch->AddColumn(4, "table_schema", schema_array)); - return Status::OK(); + return arrow::Status::OK(); } -} // namespace sqlite -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::sqlite diff --git a/src/sqlite/sqlite_tables_schema_batch_reader.h b/src/sqlite/sqlite_tables_schema_batch_reader.h index c795160..fd6b523 100644 --- a/src/sqlite/sqlite_tables_schema_batch_reader.h +++ b/src/sqlite/sqlite_tables_schema_batch_reader.h @@ -26,12 +26,9 @@ #include "sqlite_statement_batch_reader.h" #include "arrow/record_batch.h" -namespace arrow { -namespace flight { -namespace sql { -namespace sqlite { +namespace sqlflite::sqlite { -class SqliteTablesWithSchemaBatchReader : public RecordBatchReader { +class SqliteTablesWithSchemaBatchReader : public arrow::RecordBatchReader { private: std::shared_ptr reader_; std::string main_query_; @@ -47,12 +44,9 @@ class SqliteTablesWithSchemaBatchReader : public RecordBatchReader { sqlite3* db) : reader_(std::move(reader)), main_query_(std::move(main_query)), db_(db) {} - std::shared_ptr schema() const override; + std::shared_ptr schema() const override; - Status ReadNext(std::shared_ptr* batch) override; + arrow::Status ReadNext(std::shared_ptr* batch) override; }; -} // namespace sqlite -} // namespace sql -} // namespace flight -} // namespace arrow +} // namespace sqlflite::sqlite diff --git a/src/sqlite/sqlite_type_info.cc b/src/sqlite/sqlite_type_info.cc index 79c9921..c421b0e 100644 --- a/src/sqlite/sqlite_type_info.cc +++ b/src/sqlite/sqlite_type_info.cc @@ -25,33 +25,37 @@ #include "arrow/record_batch.h" #include "arrow/util/rows_to_batches.h" -namespace arrow { -namespace flight { -namespace sql { -namespace sqlite { +#include "flight_sql_fwd.h" -arrow::Result> DoGetTypeInfoResult() { - auto schema = SqlSchema::GetXdbcTypeInfoSchema(); +namespace sql = flight::sql; + +namespace sqlflite::sqlite { + +arrow::Result> DoGetTypeInfoResult() { + auto schema = sql::SqlSchema::GetXdbcTypeInfoSchema(); using ValueType = std::variant>; - auto VariantConverter = [](ArrayBuilder& array_builder, const ValueType& value) { + auto VariantConverter = [](arrow::ArrayBuilder& array_builder, const ValueType& value) { if (std::holds_alternative(value)) { - return dynamic_cast(array_builder).Append(std::get(value)); + return dynamic_cast(array_builder) + .Append(std::get(value)); } else if (std::holds_alternative(value)) { - return dynamic_cast(array_builder).Append(std::get(value)); + return dynamic_cast(array_builder) + .Append(std::get(value)); } else if (std::holds_alternative(value)) { return array_builder.AppendNull(); } else if (std::holds_alternative(value)) { - return dynamic_cast(array_builder) + return dynamic_cast(array_builder) .Append(std::get(value)); } else { - auto& list_builder = dynamic_cast(array_builder); + auto& list_builder = dynamic_cast(array_builder); ARROW_RETURN_NOT_OK(list_builder.Append()); - auto value_builder = dynamic_cast(list_builder.value_builder()); + auto value_builder = + dynamic_cast(list_builder.value_builder()); for (const auto& v : std::get>(value)) { ARROW_RETURN_NOT_OK(value_builder->Append(v)); } - return Status::OK(); + return arrow::Status::OK(); } }; std::vector> rows = { @@ -193,7 +197,8 @@ arrow::Result> DoGetTypeInfoResult() { return reader->Next(); } -arrow::Result> DoGetTypeInfoResult(int data_type_filter) { +arrow::Result> DoGetTypeInfoResult( + int data_type_filter) { ARROW_ASSIGN_OR_RAISE(auto record_batch, DoGetTypeInfoResult()); std::vector data_type_vector{-7, -6, -5, -4, -3, -1, -1, 1, 4, @@ -207,7 +212,5 @@ arrow::Result> DoGetTypeInfoResult(int data_type_fi return record_batch->Slice(pair.first - data_type_vector.begin(), pair.second - pair.first); } -} // namespace sqlite -} // namespace sql -} // namespace flight -} // namespace arrow + +} // namespace sqlflite::sqlite diff --git a/src/sqlite/sqlite_type_info.h b/src/sqlite/sqlite_type_info.h index 3fd6da5..bd31904 100644 --- a/src/sqlite/sqlite_type_info.h +++ b/src/sqlite/sqlite_type_info.h @@ -19,20 +19,16 @@ #include "arrow/record_batch.h" -namespace arrow { -namespace flight { -namespace sql { -namespace sqlite { +namespace sqlflite::sqlite { /// \brief Gets the hard-coded type info from Sqlite for all data types. /// \return A record batch. -arrow::Result> DoGetTypeInfoResult(); +arrow::Result> DoGetTypeInfoResult(); /// \brief Gets the hard-coded type info from Sqlite filtering /// for a specific data type. /// \return A record batch. -arrow::Result> DoGetTypeInfoResult(int data_type_filter); -} // namespace sqlite -} // namespace sql -} // namespace flight -} // namespace arrow +arrow::Result> DoGetTypeInfoResult( + int data_type_filter); + +} // namespace sqlflite::sqlite