From 7b1c82fc86c265efda9f74a0f2e58bd78b482dd3 Mon Sep 17 00:00:00 2001 From: Felipe Oliveira Carvalho Date: Thu, 18 Jul 2024 18:50:46 -0300 Subject: [PATCH] src: Use a proper namespace for this project Additionally: - Remove the fs namespace alias from headers - Wrap most global functions in the namespace --- src/library/include/sqlflite_library.h | 21 ++++---- src/library/include/sqlflite_security.h | 62 +++++++++++----------- src/library/sqlflite_library.cpp | 63 ++++++++++++----------- src/library/sqlflite_security.cpp | 50 +++++++++--------- src/sqlflite_client.cpp | 68 +++++++++++-------------- src/sqlflite_server.cpp | 1 + 6 files changed, 134 insertions(+), 131 deletions(-) diff --git a/src/library/include/sqlflite_library.h b/src/library/include/sqlflite_library.h index cdf04f4..e189902 100644 --- a/src/library/include/sqlflite_library.h +++ b/src/library/include/sqlflite_library.h @@ -28,8 +28,6 @@ const int DEFAULT_FLIGHT_PORT = 31337; enum class BackendType { duckdb, sqlite }; -namespace fs = std::filesystem; - /** * @brief Run a SQLFlite Server with the specified configuration. * @@ -57,13 +55,14 @@ namespace fs = std::filesystem; */ extern "C" { -int RunFlightSQLServer(const BackendType backend, fs::path &database_filename, - std::string hostname = "", const int &port = DEFAULT_FLIGHT_PORT, - std::string username = "", std::string password = "", - std::string secret_key = "", fs::path tls_cert_path = fs::path(), - fs::path tls_key_path = fs::path(), - fs::path mtls_ca_cert_path = fs::path(), - std::string init_sql_commands = "", - fs::path init_sql_commands_file = fs::path(), - const bool &print_queries = false); +int RunFlightSQLServer( + const BackendType backend, std::filesystem::path &database_filename, + std::string hostname = "", const int &port = DEFAULT_FLIGHT_PORT, + std::string username = "", std::string password = "", std::string secret_key = "", + std::filesystem::path tls_cert_path = std::filesystem::path(), + std::filesystem::path tls_key_path = std::filesystem::path(), + std::filesystem::path mtls_ca_cert_path = std::filesystem::path(), + std::string init_sql_commands = "", + std::filesystem::path init_sql_commands_file = std::filesystem::path(), + const bool &print_queries = false); } diff --git a/src/library/include/sqlflite_security.h b/src/library/include/sqlflite_security.h index 5a1fa12..f6612c4 100644 --- a/src/library/include/sqlflite_security.h +++ b/src/library/include/sqlflite_security.h @@ -28,38 +28,37 @@ #include #include #include +#include "flight_sql_fwd.h" -namespace fs = std::filesystem; - -namespace arrow { -namespace flight { +namespace sqlflite { class SecurityUtilities { public: - static Status FlightServerTlsCertificates(const fs::path &cert_path, - const fs::path &key_path, - std::vector *out); + static arrow::Status FlightServerTlsCertificates(const std::filesystem::path &cert_path, + const std::filesystem::path &key_path, + std::vector *out); - static Status FlightServerMtlsCACertificate(const std::string &cert_path, - std::string *out); + static arrow::Status FlightServerMtlsCACertificate(const std::string &cert_path, + std::string *out); - static std::string FindKeyValPrefixInCallHeaders(const CallHeaders &incoming_headers, - const std::string &key, - const std::string &prefix); + static std::string FindKeyValPrefixInCallHeaders( + const flight::CallHeaders &incoming_headers, const std::string &key, + const std::string &prefix); - static Status GetAuthHeaderType(const CallHeaders &incoming_headers, std::string *out); + static arrow::Status GetAuthHeaderType(const flight::CallHeaders &incoming_headers, + std::string *out); - static void ParseBasicHeader(const CallHeaders &incoming_headers, std::string &username, - std::string &password); + static void ParseBasicHeader(const flight::CallHeaders &incoming_headers, + std::string &username, std::string &password); }; -class HeaderAuthServerMiddleware : public ServerMiddleware { +class HeaderAuthServerMiddleware : public flight::ServerMiddleware { public: HeaderAuthServerMiddleware(const std::string &username, const std::string &secret_key); - void SendingHeaders(AddCallHeaders *outgoing_headers) override; + void SendingHeaders(flight::AddCallHeaders *outgoing_headers) override; - void CallCompleted(const Status &status) override; + void CallCompleted(const arrow::Status &status) override; std::string name() const override; @@ -70,14 +69,15 @@ class HeaderAuthServerMiddleware : public ServerMiddleware { std::string CreateJWTToken() const; }; -class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory { +class HeaderAuthServerMiddlewareFactory : public flight::ServerMiddlewareFactory { public: HeaderAuthServerMiddlewareFactory(const std::string &username, const std::string &password, const std::string &secret_key); - Status StartCall(const CallInfo &info, const CallHeaders &incoming_headers, - std::shared_ptr *middleware) override; + arrow::Status StartCall(const flight::CallInfo &info, + const flight::CallHeaders &incoming_headers, + std::shared_ptr *middleware) override; private: std::string username_; @@ -85,32 +85,33 @@ class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory { std::string secret_key_; }; -class BearerAuthServerMiddleware : public ServerMiddleware { +class BearerAuthServerMiddleware : public flight::ServerMiddleware { public: explicit BearerAuthServerMiddleware(const std::string &secret_key, - const CallHeaders &incoming_headers, + const flight::CallHeaders &incoming_headers, std::optional *isValid); - void SendingHeaders(AddCallHeaders *outgoing_headers) override; + void SendingHeaders(flight::AddCallHeaders *outgoing_headers) override; - void CallCompleted(const Status &status) override; + void CallCompleted(const arrow::Status &status) override; std::string name() const override; private: std::string secret_key_; - CallHeaders incoming_headers_; + flight::CallHeaders incoming_headers_; std::optional *isValid_; bool VerifyToken(const std::string &token) const; }; -class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory { +class BearerAuthServerMiddlewareFactory : public flight::ServerMiddlewareFactory { public: explicit BearerAuthServerMiddlewareFactory(const std::string &secret_key); - Status StartCall(const CallInfo &info, const CallHeaders &incoming_headers, - std::shared_ptr *middleware) override; + arrow::Status StartCall(const flight::CallInfo &info, + const flight::CallHeaders &incoming_headers, + std::shared_ptr *middleware) override; std::optional GetIsValid(); @@ -119,5 +120,4 @@ class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory { std::string secret_key_; }; -} // namespace flight -} // namespace arrow +} // namespace sqlflite diff --git a/src/library/sqlflite_library.cpp b/src/library/sqlflite_library.cpp index e41e5d9..355ad25 100644 --- a/src/library/sqlflite_library.cpp +++ b/src/library/sqlflite_library.cpp @@ -33,10 +33,12 @@ #include "sqlite_server.h" #include "duckdb_server.h" +#include "include/flight_sql_fwd.h" #include "include/sqlflite_security.h" -namespace flight = arrow::flight; -namespace flightsql = arrow::flight::sql; +namespace fs = std::filesystem; + +namespace sqlflite { const int port = 31337; @@ -58,24 +60,24 @@ const int port = 31337; } \ } while (false) -arrow::Result> -FlightSQLServerBuilder(const BackendType backend, const fs::path &database_filename, - const std::string &hostname, const int &port, - const std::string &username, const std::string &password, - const std::string &secret_key, const fs::path &tls_cert_path, - const fs::path &tls_key_path, const fs::path &mtls_ca_cert_path, - const std::string &init_sql_commands, const bool &print_queries) { +arrow::Result> FlightSQLServerBuilder( + const BackendType backend, const fs::path &database_filename, + const std::string &hostname, const int &port, const std::string &username, + const std::string &password, const std::string &secret_key, + const fs::path &tls_cert_path, const fs::path &tls_key_path, + const fs::path &mtls_ca_cert_path, const std::string &init_sql_commands, + const bool &print_queries) { ARROW_ASSIGN_OR_RAISE(auto location, (!tls_cert_path.empty()) - ? arrow::flight::Location::ForGrpcTls(hostname, port) - : arrow::flight::Location::ForGrpcTcp(hostname, port)); + ? flight::Location::ForGrpcTls(hostname, port) + : flight::Location::ForGrpcTcp(hostname, port)); std::cout << "Apache Arrow version: " << ARROW_VERSION_STRING << std::endl; - arrow::flight::FlightServerOptions options(location); + flight::FlightServerOptions options(location); if (!tls_cert_path.empty() && !tls_key_path.empty()) { - ARROW_CHECK_OK(arrow::flight::SecurityUtilities::FlightServerTlsCertificates( + ARROW_CHECK_OK(sqlflite::SecurityUtilities::FlightServerTlsCertificates( tls_cert_path, tls_key_path, &options.tls_certificates)); } else { std::cout << "WARNING - TLS is disabled for the SQLFlite server - this is insecure." @@ -83,24 +85,23 @@ FlightSQLServerBuilder(const BackendType backend, const fs::path &database_filen } // Setup authentication middleware (using the same TLS certificate keypair) - auto header_middleware = - std::make_shared( - username, password, secret_key); + auto header_middleware = std::make_shared( + username, password, secret_key); auto bearer_middleware = - std::make_shared(secret_key); + std::make_shared(secret_key); - options.auth_handler = std::make_unique(); + options.auth_handler = std::make_unique(); options.middleware.push_back({"header-auth-server", header_middleware}); options.middleware.push_back({"bearer-auth-server", bearer_middleware}); if (!mtls_ca_cert_path.empty()) { std::cout << "Using mTLS CA certificate: " << mtls_ca_cert_path << std::endl; - ARROW_CHECK_OK(arrow::flight::SecurityUtilities::FlightServerMtlsCACertificate( + ARROW_CHECK_OK(sqlflite::SecurityUtilities::FlightServerMtlsCACertificate( mtls_ca_cert_path, &options.root_certificates)); options.verify_client = true; } - std::shared_ptr server = nullptr; + std::shared_ptr server = nullptr; std::string db_type = ""; if (backend == BackendType::sqlite) { @@ -155,13 +156,12 @@ std::string SafeGetEnvVarValue(const std::string &env_var_name) { } } -arrow::Result> -CreateFlightSQLServer(const BackendType backend, fs::path &database_filename, - std::string hostname, const int &port, std::string username, - std::string password, std::string secret_key, - fs::path tls_cert_path, fs::path tls_key_path, - fs::path mtls_ca_cert_path, std::string init_sql_commands, - fs::path init_sql_commands_file, const bool &print_queries) { +arrow::Result> CreateFlightSQLServer( + const BackendType backend, fs::path &database_filename, std::string hostname, + const int &port, std::string username, std::string password, std::string secret_key, + fs::path tls_cert_path, fs::path tls_key_path, fs::path mtls_ca_cert_path, + std::string init_sql_commands, fs::path init_sql_commands_file, + const bool &print_queries) { // Validate and default the arguments to env var values where applicable if (database_filename.empty()) { return arrow::Status::Invalid("The database filename was not provided!"); @@ -255,17 +255,21 @@ CreateFlightSQLServer(const BackendType backend, fs::path &database_filename, } arrow::Status StartFlightSQLServer( - std::shared_ptr server) { + std::shared_ptr server) { return arrow::Status::OK(); } +} // namespace sqlflite + +extern "C" { + int RunFlightSQLServer(const BackendType backend, fs::path &database_filename, std::string hostname, const int &port, std::string username, std::string password, std::string secret_key, fs::path tls_cert_path, fs::path tls_key_path, fs::path mtls_ca_cert_path, std::string init_sql_commands, fs::path init_sql_commands_file, const bool &print_queries) { - auto create_server_result = CreateFlightSQLServer( + auto create_server_result = sqlflite::CreateFlightSQLServer( backend, database_filename, hostname, port, username, password, secret_key, tls_cert_path, tls_key_path, mtls_ca_cert_path, init_sql_commands, init_sql_commands_file, print_queries); @@ -281,3 +285,4 @@ int RunFlightSQLServer(const BackendType backend, fs::path &database_filename, return EXIT_FAILURE; } } +} diff --git a/src/library/sqlflite_security.cpp b/src/library/sqlflite_security.cpp index a29dccc..c8760dc 100644 --- a/src/library/sqlflite_security.cpp +++ b/src/library/sqlflite_security.cpp @@ -19,8 +19,9 @@ namespace fs = std::filesystem; -namespace arrow { -namespace flight { +using arrow::Status; + +namespace sqlflite { const std::string kJWTIssuer = "sqlflite"; const int kJWTExpiration = 24 * 3600; @@ -30,13 +31,13 @@ const std::string kBearerPrefix = "Bearer "; const std::string kAuthHeader = "authorization"; // ---------------------------------------- -Status SecurityUtilities::FlightServerTlsCertificates(const fs::path &cert_path, - const fs::path &key_path, - std::vector *out) { +Status SecurityUtilities::FlightServerTlsCertificates( + const fs::path &cert_path, const fs::path &key_path, + std::vector *out) { std::cout << "Using TLS Cert file: " << cert_path << std::endl; std::cout << "Using TLS Key file: " << key_path << std::endl; - *out = std::vector(); + *out = std::vector(); try { std::ifstream cert_file(cert_path); if (!cert_file) { @@ -52,7 +53,7 @@ Status SecurityUtilities::FlightServerTlsCertificates(const fs::path &cert_path, std::stringstream key; key << key_file.rdbuf(); - out->push_back(CertKeyPair{cert.str(), key.str()}); + out->push_back(flight::CertKeyPair{cert.str(), key.str()}); } catch (const std::ifstream::failure &e) { return Status::IOError(e.what()); } @@ -79,7 +80,7 @@ Status SecurityUtilities::FlightServerMtlsCACertificate(const std::string &cert_ // Function to look in CallHeaders for a key that has a value starting with prefix and // return the rest of the value after the prefix. std::string SecurityUtilities::FindKeyValPrefixInCallHeaders( - const CallHeaders &incoming_headers, const std::string &key, + const flight::CallHeaders &incoming_headers, const std::string &key, const std::string &prefix) { // Lambda function to compare characters without case sensitivity. auto char_compare = [](const char &char1, const char &char2) { @@ -100,7 +101,7 @@ std::string SecurityUtilities::FindKeyValPrefixInCallHeaders( return ""; } -Status SecurityUtilities::GetAuthHeaderType(const CallHeaders &incoming_headers, +Status SecurityUtilities::GetAuthHeaderType(const flight::CallHeaders &incoming_headers, std::string *out) { if (!FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix) .empty()) { @@ -114,7 +115,7 @@ Status SecurityUtilities::GetAuthHeaderType(const CallHeaders &incoming_headers, return Status::OK(); } -void SecurityUtilities::ParseBasicHeader(const CallHeaders &incoming_headers, +void SecurityUtilities::ParseBasicHeader(const flight::CallHeaders &incoming_headers, std::string &username, std::string &password) { std::string encoded_credentials = FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix); @@ -128,7 +129,8 @@ HeaderAuthServerMiddleware::HeaderAuthServerMiddleware(const std::string &userna const std::string &secret_key) : username_(username), secret_key_(secret_key) {} -void HeaderAuthServerMiddleware::SendingHeaders(AddCallHeaders *outgoing_headers) { +void HeaderAuthServerMiddleware::SendingHeaders( + flight::AddCallHeaders *outgoing_headers) { auto token = CreateJWTToken(); outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + token); } @@ -161,8 +163,8 @@ HeaderAuthServerMiddlewareFactory::HeaderAuthServerMiddlewareFactory( : username_(username), password_(password), secret_key_(secret_key) {} Status HeaderAuthServerMiddlewareFactory::StartCall( - const CallInfo &info, const CallHeaders &incoming_headers, - std::shared_ptr *middleware) { + const flight::CallInfo &info, const flight::CallHeaders &incoming_headers, + std::shared_ptr *middleware) { std::string auth_header_type; ARROW_RETURN_NOT_OK( SecurityUtilities::GetAuthHeaderType(incoming_headers, &auth_header_type)); @@ -175,7 +177,8 @@ Status HeaderAuthServerMiddlewareFactory::StartCall( if ((username == username_) && (password == password_)) { *middleware = std::make_shared(username, secret_key_); } else { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid credentials"); + return MakeFlightError(flight::FlightStatusCode::Unauthenticated, + "Invalid credentials"); } } return Status::OK(); @@ -183,11 +186,12 @@ Status HeaderAuthServerMiddlewareFactory::StartCall( // ---------------------------------------- BearerAuthServerMiddleware::BearerAuthServerMiddleware( - const std::string &secret_key, const CallHeaders &incoming_headers, + const std::string &secret_key, const flight::CallHeaders &incoming_headers, std::optional *isValid) : secret_key_(secret_key), incoming_headers_(incoming_headers), isValid_(isValid) {} -void BearerAuthServerMiddleware::SendingHeaders(AddCallHeaders *outgoing_headers) { +void BearerAuthServerMiddleware::SendingHeaders( + flight::AddCallHeaders *outgoing_headers) { std::string bearer_token = SecurityUtilities::FindKeyValPrefixInCallHeaders( incoming_headers_, kAuthHeader, kBearerPrefix); *isValid_ = (VerifyToken(bearer_token)); @@ -225,10 +229,11 @@ BearerAuthServerMiddlewareFactory::BearerAuthServerMiddlewareFactory( : secret_key_(secret_key) {} Status BearerAuthServerMiddlewareFactory::StartCall( - const CallInfo &info, const CallHeaders &incoming_headers, - std::shared_ptr *middleware) { - if (const std::pair - &iter_pair = incoming_headers.equal_range(kAuthHeader); + const flight::CallInfo &info, const flight::CallHeaders &incoming_headers, + std::shared_ptr *middleware) { + if (const std::pair &iter_pair = + incoming_headers.equal_range(kAuthHeader); iter_pair.first != iter_pair.second) { std::string auth_header_type; ARROW_RETURN_NOT_OK( @@ -241,7 +246,7 @@ Status BearerAuthServerMiddlewareFactory::StartCall( if (isValid_.has_value() && !*isValid_) { isValid_.reset(); - return MakeFlightError(FlightStatusCode::Unauthenticated, + return MakeFlightError(flight::FlightStatusCode::Unauthenticated, "Invalid bearer token provided"); } @@ -250,5 +255,4 @@ Status BearerAuthServerMiddlewareFactory::StartCall( std::optional BearerAuthServerMiddlewareFactory::GetIsValid() { return isValid_; } -} // namespace flight -} // namespace arrow +} // namespace sqlflite diff --git a/src/sqlflite_client.cpp b/src/sqlflite_client.cpp index b97363d..5cbf83d 100644 --- a/src/sqlflite_client.cpp +++ b/src/sqlflite_client.cpp @@ -33,21 +33,11 @@ #include "arrow/status.h" #include "arrow/table.h" -using arrow::Result; -using arrow::Schema; +#include "library/include/flight_sql_fwd.h" + using arrow::Status; -using arrow::flight::ClientAuthHandler; -using arrow::flight::FlightCallOptions; -using arrow::flight::FlightClient; -using arrow::flight::FlightDescriptor; -using arrow::flight::FlightEndpoint; -using arrow::flight::FlightInfo; -using arrow::flight::FlightStreamChunk; -using arrow::flight::FlightStreamReader; -using arrow::flight::Location; -using arrow::flight::Ticket; -using arrow::flight::sql::FlightSqlClient; -using arrow::flight::sql::TableRef; + +namespace sqlflite { DEFINE_string(host, "localhost", "Host to connect to"); DEFINE_int32(port, 31337, "Port to connect to"); @@ -69,12 +59,12 @@ DEFINE_string(catalog, "", "Catalog"); DEFINE_string(schema, "", "Schema"); DEFINE_string(table, "", "Table"); -Status PrintResultsForEndpoint(FlightSqlClient &client, - const FlightCallOptions &call_options, - const FlightEndpoint &endpoint) { +Status PrintResultsForEndpoint(flight::sql::FlightSqlClient &client, + const flight::FlightCallOptions &call_options, + const flight::FlightEndpoint &endpoint) { ARROW_ASSIGN_OR_RAISE(auto stream, client.DoGet(call_options, endpoint.ticket)); - const arrow::Result> &schema = stream->GetSchema(); + const arrow::Result> &schema = stream->GetSchema(); ARROW_RETURN_NOT_OK(schema); std::cout << "Schema:" << std::endl; @@ -85,7 +75,7 @@ Status PrintResultsForEndpoint(FlightSqlClient &client, int64_t num_rows = 0; while (true) { - ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, stream->Next()); + ARROW_ASSIGN_OR_RAISE(flight::FlightStreamChunk chunk, stream->Next()); if (chunk.data == nullptr) { break; } @@ -98,9 +88,10 @@ Status PrintResultsForEndpoint(FlightSqlClient &client, return Status::OK(); } -Status PrintResults(FlightSqlClient &client, const FlightCallOptions &call_options, - const std::unique_ptr &info) { - const std::vector &endpoints = info->endpoints(); +Status PrintResults(flight::sql::FlightSqlClient &client, + const flight::FlightCallOptions &call_options, + const std::unique_ptr &info) { + const std::vector &endpoints = info->endpoints(); for (size_t i = 0; i < endpoints.size(); i++) { std::cout << "Results from endpoint " << i + 1 << " of " << endpoints.size() @@ -127,11 +118,12 @@ Status getPEMCertFileContents(const std::string &cert_file_path, Status RunMain() { ARROW_ASSIGN_OR_RAISE(auto location, - (FLAGS_use_tls) ? Location::ForGrpcTls(FLAGS_host, FLAGS_port) - : Location::ForGrpcTcp(FLAGS_host, FLAGS_port)); + (FLAGS_use_tls) + ? flight::Location::ForGrpcTls(FLAGS_host, FLAGS_port) + : flight::Location::ForGrpcTcp(FLAGS_host, FLAGS_port)); // Setup our options - arrow::flight::FlightClientOptions options; + flight::FlightClientOptions options; if (!FLAGS_tls_roots.empty()) { ARROW_RETURN_NOT_OK(getPEMCertFileContents(FLAGS_tls_roots, options.tls_root_certs)); @@ -152,19 +144,19 @@ Status RunMain() { } } - ARROW_ASSIGN_OR_RAISE(auto client, FlightClient::Connect(location, options)); + ARROW_ASSIGN_OR_RAISE(auto client, flight::FlightClient::Connect(location, options)); - FlightCallOptions call_options; + flight::FlightCallOptions call_options; if (!FLAGS_username.empty() || !FLAGS_password.empty()) { - Result> bearer_result = + arrow::Result> bearer_result = client->AuthenticateBasicToken({}, FLAGS_username, FLAGS_password); ARROW_RETURN_NOT_OK(bearer_result); call_options.headers.push_back(bearer_result.ValueOrDie()); } - FlightSqlClient sql_client(std::move(client)); + flight::sql::FlightSqlClient sql_client(std::move(client)); if (FLAGS_command == "ExecuteUpdate") { ARROW_ASSIGN_OR_RAISE(auto rows, sql_client.ExecuteUpdate(call_options, FLAGS_query)); @@ -174,7 +166,7 @@ Status RunMain() { return Status::OK(); } - std::unique_ptr info; + std::unique_ptr info; std::shared_ptr prepared_statement; @@ -211,16 +203,16 @@ Status RunMain() { info, sql_client.GetTables(call_options, &FLAGS_catalog, &FLAGS_schema, &FLAGS_table, false, nullptr)); } else if (FLAGS_command == "GetExportedKeys") { - TableRef table_ref = {std::make_optional(FLAGS_catalog), - std::make_optional(FLAGS_schema), FLAGS_table}; + flight::sql::TableRef table_ref = {std::make_optional(FLAGS_catalog), + std::make_optional(FLAGS_schema), FLAGS_table}; ARROW_ASSIGN_OR_RAISE(info, sql_client.GetExportedKeys(call_options, table_ref)); } else if (FLAGS_command == "GetImportedKeys") { - TableRef table_ref = {std::make_optional(FLAGS_catalog), - std::make_optional(FLAGS_schema), FLAGS_table}; + flight::sql::TableRef table_ref = {std::make_optional(FLAGS_catalog), + std::make_optional(FLAGS_schema), FLAGS_table}; ARROW_ASSIGN_OR_RAISE(info, sql_client.GetImportedKeys(call_options, table_ref)); } else if (FLAGS_command == "GetPrimaryKeys") { - TableRef table_ref = {std::make_optional(FLAGS_catalog), - std::make_optional(FLAGS_schema), FLAGS_table}; + flight::sql::TableRef table_ref = {std::make_optional(FLAGS_catalog), + std::make_optional(FLAGS_schema), FLAGS_table}; ARROW_ASSIGN_OR_RAISE(info, sql_client.GetPrimaryKeys(call_options, table_ref)); } else if (FLAGS_command == "GetSqlInfo") { ARROW_ASSIGN_OR_RAISE(info, sql_client.GetSqlInfo(call_options, {})); @@ -238,10 +230,12 @@ Status RunMain() { return print_status; } +} // namespace sqlflite + int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); - Status st = RunMain(); + Status st = sqlflite::RunMain(); if (!st.ok()) { std::cerr << st << std::endl; return 1; diff --git a/src/sqlflite_server.cpp b/src/sqlflite_server.cpp index 8a62c9a..04337c3 100644 --- a/src/sqlflite_server.cpp +++ b/src/sqlflite_server.cpp @@ -20,6 +20,7 @@ #include namespace po = boost::program_options; +namespace fs = std::filesystem; int main(int argc, char **argv) { std::vector tls_token_values;