Skip to content

Commit

Permalink
src: Use a proper namespace for this project
Browse files Browse the repository at this point in the history
Additionally:
 - Remove the fs namespace alias from headers
 - Wrap most global functions in the namespace
  • Loading branch information
felipecrv committed Jul 18, 2024
1 parent 65cdb7e commit 7b1c82f
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 131 deletions.
21 changes: 10 additions & 11 deletions src/library/include/sqlflite_library.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ const int DEFAULT_FLIGHT_PORT = 31337;

enum class BackendType { duckdb, sqlite };

namespace fs = std::filesystem;

/**
* @brief Run a SQLFlite Server with the specified configuration.
*
Expand Down Expand Up @@ -57,13 +55,14 @@ namespace fs = std::filesystem;
*/

extern "C" {
int RunFlightSQLServer(const BackendType backend, fs::path &database_filename,
std::string hostname = "", const int &port = DEFAULT_FLIGHT_PORT,
std::string username = "", std::string password = "",
std::string secret_key = "", fs::path tls_cert_path = fs::path(),
fs::path tls_key_path = fs::path(),
fs::path mtls_ca_cert_path = fs::path(),
std::string init_sql_commands = "",
fs::path init_sql_commands_file = fs::path(),
const bool &print_queries = false);
int RunFlightSQLServer(
const BackendType backend, std::filesystem::path &database_filename,
std::string hostname = "", const int &port = DEFAULT_FLIGHT_PORT,
std::string username = "", std::string password = "", std::string secret_key = "",
std::filesystem::path tls_cert_path = std::filesystem::path(),
std::filesystem::path tls_key_path = std::filesystem::path(),
std::filesystem::path mtls_ca_cert_path = std::filesystem::path(),
std::string init_sql_commands = "",
std::filesystem::path init_sql_commands_file = std::filesystem::path(),
const bool &print_queries = false);
}
62 changes: 31 additions & 31 deletions src/library/include/sqlflite_security.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,37 @@
#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>
#include "flight_sql_fwd.h"

namespace fs = std::filesystem;

namespace arrow {
namespace flight {
namespace sqlflite {

class SecurityUtilities {
public:
static Status FlightServerTlsCertificates(const fs::path &cert_path,
const fs::path &key_path,
std::vector<CertKeyPair> *out);
static arrow::Status FlightServerTlsCertificates(const std::filesystem::path &cert_path,
const std::filesystem::path &key_path,
std::vector<flight::CertKeyPair> *out);

static Status FlightServerMtlsCACertificate(const std::string &cert_path,
std::string *out);
static arrow::Status FlightServerMtlsCACertificate(const std::string &cert_path,
std::string *out);

static std::string FindKeyValPrefixInCallHeaders(const CallHeaders &incoming_headers,
const std::string &key,
const std::string &prefix);
static std::string FindKeyValPrefixInCallHeaders(
const flight::CallHeaders &incoming_headers, const std::string &key,
const std::string &prefix);

static Status GetAuthHeaderType(const CallHeaders &incoming_headers, std::string *out);
static arrow::Status GetAuthHeaderType(const flight::CallHeaders &incoming_headers,
std::string *out);

static void ParseBasicHeader(const CallHeaders &incoming_headers, std::string &username,
std::string &password);
static void ParseBasicHeader(const flight::CallHeaders &incoming_headers,
std::string &username, std::string &password);
};

class HeaderAuthServerMiddleware : public ServerMiddleware {
class HeaderAuthServerMiddleware : public flight::ServerMiddleware {
public:
HeaderAuthServerMiddleware(const std::string &username, const std::string &secret_key);

void SendingHeaders(AddCallHeaders *outgoing_headers) override;
void SendingHeaders(flight::AddCallHeaders *outgoing_headers) override;

void CallCompleted(const Status &status) override;
void CallCompleted(const arrow::Status &status) override;

std::string name() const override;

Expand All @@ -70,47 +69,49 @@ class HeaderAuthServerMiddleware : public ServerMiddleware {
std::string CreateJWTToken() const;
};

class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
class HeaderAuthServerMiddlewareFactory : public flight::ServerMiddlewareFactory {
public:
HeaderAuthServerMiddlewareFactory(const std::string &username,
const std::string &password,
const std::string &secret_key);

Status StartCall(const CallInfo &info, const CallHeaders &incoming_headers,
std::shared_ptr<ServerMiddleware> *middleware) override;
arrow::Status StartCall(const flight::CallInfo &info,
const flight::CallHeaders &incoming_headers,
std::shared_ptr<flight::ServerMiddleware> *middleware) override;

private:
std::string username_;
std::string password_;
std::string secret_key_;
};

class BearerAuthServerMiddleware : public ServerMiddleware {
class BearerAuthServerMiddleware : public flight::ServerMiddleware {
public:
explicit BearerAuthServerMiddleware(const std::string &secret_key,
const CallHeaders &incoming_headers,
const flight::CallHeaders &incoming_headers,
std::optional<bool> *isValid);

void SendingHeaders(AddCallHeaders *outgoing_headers) override;
void SendingHeaders(flight::AddCallHeaders *outgoing_headers) override;

void CallCompleted(const Status &status) override;
void CallCompleted(const arrow::Status &status) override;

std::string name() const override;

private:
std::string secret_key_;
CallHeaders incoming_headers_;
flight::CallHeaders incoming_headers_;
std::optional<bool> *isValid_;

bool VerifyToken(const std::string &token) const;
};

class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
class BearerAuthServerMiddlewareFactory : public flight::ServerMiddlewareFactory {
public:
explicit BearerAuthServerMiddlewareFactory(const std::string &secret_key);

Status StartCall(const CallInfo &info, const CallHeaders &incoming_headers,
std::shared_ptr<ServerMiddleware> *middleware) override;
arrow::Status StartCall(const flight::CallInfo &info,
const flight::CallHeaders &incoming_headers,
std::shared_ptr<flight::ServerMiddleware> *middleware) override;

std::optional<bool> GetIsValid();

Expand All @@ -119,5 +120,4 @@ class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
std::string secret_key_;
};

} // namespace flight
} // namespace arrow
} // namespace sqlflite
63 changes: 34 additions & 29 deletions src/library/sqlflite_library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@

#include "sqlite_server.h"
#include "duckdb_server.h"
#include "include/flight_sql_fwd.h"
#include "include/sqlflite_security.h"

namespace flight = arrow::flight;
namespace flightsql = arrow::flight::sql;
namespace fs = std::filesystem;

namespace sqlflite {

const int port = 31337;

Expand All @@ -58,49 +60,48 @@ const int port = 31337;
} \
} while (false)

arrow::Result<std::shared_ptr<arrow::flight::sql::FlightSqlServerBase>>
FlightSQLServerBuilder(const BackendType backend, const fs::path &database_filename,
const std::string &hostname, const int &port,
const std::string &username, const std::string &password,
const std::string &secret_key, const fs::path &tls_cert_path,
const fs::path &tls_key_path, const fs::path &mtls_ca_cert_path,
const std::string &init_sql_commands, const bool &print_queries) {
arrow::Result<std::shared_ptr<flight::sql::FlightSqlServerBase>> FlightSQLServerBuilder(
const BackendType backend, const fs::path &database_filename,
const std::string &hostname, const int &port, const std::string &username,
const std::string &password, const std::string &secret_key,
const fs::path &tls_cert_path, const fs::path &tls_key_path,
const fs::path &mtls_ca_cert_path, const std::string &init_sql_commands,
const bool &print_queries) {
ARROW_ASSIGN_OR_RAISE(auto location,
(!tls_cert_path.empty())
? arrow::flight::Location::ForGrpcTls(hostname, port)
: arrow::flight::Location::ForGrpcTcp(hostname, port));
? flight::Location::ForGrpcTls(hostname, port)
: flight::Location::ForGrpcTcp(hostname, port));

std::cout << "Apache Arrow version: " << ARROW_VERSION_STRING << std::endl;

arrow::flight::FlightServerOptions options(location);
flight::FlightServerOptions options(location);

if (!tls_cert_path.empty() && !tls_key_path.empty()) {
ARROW_CHECK_OK(arrow::flight::SecurityUtilities::FlightServerTlsCertificates(
ARROW_CHECK_OK(sqlflite::SecurityUtilities::FlightServerTlsCertificates(
tls_cert_path, tls_key_path, &options.tls_certificates));
} else {
std::cout << "WARNING - TLS is disabled for the SQLFlite server - this is insecure."
<< std::endl;
}

// Setup authentication middleware (using the same TLS certificate keypair)
auto header_middleware =
std::make_shared<arrow::flight::HeaderAuthServerMiddlewareFactory>(
username, password, secret_key);
auto header_middleware = std::make_shared<sqlflite::HeaderAuthServerMiddlewareFactory>(
username, password, secret_key);
auto bearer_middleware =
std::make_shared<arrow::flight::BearerAuthServerMiddlewareFactory>(secret_key);
std::make_shared<sqlflite::BearerAuthServerMiddlewareFactory>(secret_key);

options.auth_handler = std::make_unique<arrow::flight::NoOpAuthHandler>();
options.auth_handler = std::make_unique<flight::NoOpAuthHandler>();
options.middleware.push_back({"header-auth-server", header_middleware});
options.middleware.push_back({"bearer-auth-server", bearer_middleware});

if (!mtls_ca_cert_path.empty()) {
std::cout << "Using mTLS CA certificate: " << mtls_ca_cert_path << std::endl;
ARROW_CHECK_OK(arrow::flight::SecurityUtilities::FlightServerMtlsCACertificate(
ARROW_CHECK_OK(sqlflite::SecurityUtilities::FlightServerMtlsCACertificate(
mtls_ca_cert_path, &options.root_certificates));
options.verify_client = true;
}

std::shared_ptr<arrow::flight::sql::FlightSqlServerBase> server = nullptr;
std::shared_ptr<flight::sql::FlightSqlServerBase> server = nullptr;

std::string db_type = "";
if (backend == BackendType::sqlite) {
Expand Down Expand Up @@ -155,13 +156,12 @@ std::string SafeGetEnvVarValue(const std::string &env_var_name) {
}
}

arrow::Result<std::shared_ptr<arrow::flight::sql::FlightSqlServerBase>>
CreateFlightSQLServer(const BackendType backend, fs::path &database_filename,
std::string hostname, const int &port, std::string username,
std::string password, std::string secret_key,
fs::path tls_cert_path, fs::path tls_key_path,
fs::path mtls_ca_cert_path, std::string init_sql_commands,
fs::path init_sql_commands_file, const bool &print_queries) {
arrow::Result<std::shared_ptr<flight::sql::FlightSqlServerBase>> CreateFlightSQLServer(
const BackendType backend, fs::path &database_filename, std::string hostname,
const int &port, std::string username, std::string password, std::string secret_key,
fs::path tls_cert_path, fs::path tls_key_path, fs::path mtls_ca_cert_path,
std::string init_sql_commands, fs::path init_sql_commands_file,
const bool &print_queries) {
// Validate and default the arguments to env var values where applicable
if (database_filename.empty()) {
return arrow::Status::Invalid("The database filename was not provided!");
Expand Down Expand Up @@ -255,17 +255,21 @@ CreateFlightSQLServer(const BackendType backend, fs::path &database_filename,
}

arrow::Status StartFlightSQLServer(
std::shared_ptr<arrow::flight::sql::FlightSqlServerBase> server) {
std::shared_ptr<flight::sql::FlightSqlServerBase> server) {
return arrow::Status::OK();
}

} // namespace sqlflite

extern "C" {

int RunFlightSQLServer(const BackendType backend, fs::path &database_filename,
std::string hostname, const int &port, std::string username,
std::string password, std::string secret_key,
fs::path tls_cert_path, fs::path tls_key_path,
fs::path mtls_ca_cert_path, std::string init_sql_commands,
fs::path init_sql_commands_file, const bool &print_queries) {
auto create_server_result = CreateFlightSQLServer(
auto create_server_result = sqlflite::CreateFlightSQLServer(
backend, database_filename, hostname, port, username, password, secret_key,
tls_cert_path, tls_key_path, mtls_ca_cert_path, init_sql_commands,
init_sql_commands_file, print_queries);
Expand All @@ -281,3 +285,4 @@ int RunFlightSQLServer(const BackendType backend, fs::path &database_filename,
return EXIT_FAILURE;
}
}
}
Loading

0 comments on commit 7b1c82f

Please sign in to comment.