Skip to content

Commit

Permalink
Merge pull request #310 from MikhailBurdukov/multiple_endpoints
Browse files Browse the repository at this point in the history
Multiple endpoints for connection.
  • Loading branch information
Enmk authored Jul 10, 2023
2 parents 45680f2 + 87804a0 commit de6f56a
Show file tree
Hide file tree
Showing 10 changed files with 366 additions and 35 deletions.
1 change: 1 addition & 0 deletions clickhouse/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ SET ( clickhouse-cpp-lib-src
base/platform.cpp
base/socket.cpp
base/wire_format.cpp
base/endpoints_iterator.cpp

columns/array.cpp
columns/column.cpp
Expand Down
20 changes: 20 additions & 0 deletions clickhouse/base/endpoints_iterator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "endpoints_iterator.h"
#include <clickhouse/client.h>

namespace clickhouse {

RoundRobinEndpointsIterator::RoundRobinEndpointsIterator(const std::vector<Endpoint>& _endpoints)
: endpoints (_endpoints)
, current_index (endpoints.size() - 1ull)
{
}

Endpoint RoundRobinEndpointsIterator::Next()
{
current_index = (current_index + 1ull) % endpoints.size();
return endpoints[current_index];
}

RoundRobinEndpointsIterator::~RoundRobinEndpointsIterator() = default;

}
34 changes: 34 additions & 0 deletions clickhouse/base/endpoints_iterator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include "clickhouse/client.h"
#include <vector>

namespace clickhouse {

struct ClientOptions;

/**
* Base class for iterating through endpoints.
*/
class EndpointsIteratorBase
{
public:
virtual ~EndpointsIteratorBase() = default;

virtual Endpoint Next() = 0;
};

class RoundRobinEndpointsIterator : public EndpointsIteratorBase
{
public:
explicit RoundRobinEndpointsIterator(const std::vector<Endpoint>& opts);
Endpoint Next() override;

~RoundRobinEndpointsIterator() override;

private:
const std::vector<Endpoint>& endpoints;
size_t current_index;
};

}
4 changes: 2 additions & 2 deletions clickhouse/base/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,9 @@ std::unique_ptr<OutputStream> Socket::makeOutputStream() const {

NonSecureSocketFactory::~NonSecureSocketFactory() {}

std::unique_ptr<SocketBase> NonSecureSocketFactory::connect(const ClientOptions &opts) {
const auto address = NetworkAddress(opts.host, std::to_string(opts.port));
std::unique_ptr<SocketBase> NonSecureSocketFactory::connect(const ClientOptions &opts, const Endpoint& endpoint) {

const auto address = NetworkAddress(endpoint.host, std::to_string(endpoint.port));
auto socket = doConnect(address, opts);
setSocketOptions(*socket, opts);

Expand Down
5 changes: 3 additions & 2 deletions clickhouse/base/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "platform.h"
#include "input.h"
#include "output.h"
#include "endpoints_iterator.h"

#include <cstddef>
#include <string>
Expand Down Expand Up @@ -88,7 +89,7 @@ class SocketFactory {

// TODO: move connection-related options to ConnectionOptions structure.

virtual std::unique_ptr<SocketBase> connect(const ClientOptions& opts) = 0;
virtual std::unique_ptr<SocketBase> connect(const ClientOptions& opts, const Endpoint& endpoint) = 0;

virtual void sleepFor(const std::chrono::milliseconds& duration);
};
Expand Down Expand Up @@ -135,7 +136,7 @@ class NonSecureSocketFactory : public SocketFactory {
public:
~NonSecureSocketFactory() override;

std::unique_ptr<SocketBase> connect(const ClientOptions& opts) override;
std::unique_ptr<SocketBase> connect(const ClientOptions& opts, const Endpoint& endpoint) override;

protected:
virtual std::unique_ptr<Socket> doConnect(const NetworkAddress& address, const ClientOptions& opts);
Expand Down
155 changes: 128 additions & 27 deletions clickhouse/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ struct ClientInfo {

std::ostream& operator<<(std::ostream& os, const ClientOptions& opt) {
os << "Client(" << opt.user << '@' << opt.host << ":" << opt.port
<< " ping_before_query:" << opt.ping_before_query
<< "Endpoints :";
for (size_t i = 0; i < opt.endpoints.size(); i++)
os << opt.user << '@' << opt.endpoints[i].host << ":" << opt.endpoints[i].port
<< ((i == opt.endpoints.size() - 1) ? "" : ", ");

os << " ping_before_query:" << opt.ping_before_query
<< " send_retries:" << opt.send_retries
<< " retry_timeout:" << opt.retry_timeout.count()
<< " compression_method:"
Expand Down Expand Up @@ -111,6 +116,15 @@ std::unique_ptr<SocketFactory> GetSocketFactory(const ClientOptions& opts) {
return std::make_unique<NonSecureSocketFactory>();
}

std::unique_ptr<EndpointsIteratorBase> GetEndpointsIterator(const ClientOptions& opts) {
if (opts.endpoints.empty())
{
throw ValidationError("The list of endpoints is empty");
}

return std::make_unique<RoundRobinEndpointsIterator>(opts.endpoints);
}

}

class Client::Impl {
Expand All @@ -130,8 +144,12 @@ class Client::Impl {

void ResetConnection();

void ResetConnectionEndpoint();

const ServerInfo& GetServerInfo() const;

const std::optional<Endpoint>& GetCurrentEndpoint() const;

private:
bool Handshake();

Expand All @@ -155,13 +173,22 @@ class Client::Impl {

void WriteBlock(const Block& block, OutputStream& output);

void CreateConnection();

void InitializeStreams(std::unique_ptr<SocketBase>&& socket);

inline size_t GetConnectionAttempts() const
{
return options_.endpoints.size() * options_.send_retries;
}

private:
/// In case of network errors tries to reconnect to server and
/// call fuc several times.
void RetryGuard(std::function<void()> func);

void RetryConnectToTheEndpoint(std::function<void()>& func);

private:
class EnsureNull {
public:
Expand Down Expand Up @@ -194,32 +221,34 @@ class Client::Impl {
std::unique_ptr<InputStream> input_;
std::unique_ptr<OutputStream> output_;
std::unique_ptr<SocketBase> socket_;
std::unique_ptr<EndpointsIteratorBase> endpoints_iterator;

std::optional<Endpoint> current_endpoint_;

ServerInfo server_info_;
};

ClientOptions modifyClientOptions(ClientOptions opts)
{
if (opts.host.empty())
return opts;

Endpoint default_endpoint({opts.host, opts.port});
opts.endpoints.emplace(opts.endpoints.begin(), default_endpoint);
return opts;
}

Client::Impl::Impl(const ClientOptions& opts)
: Impl(opts, GetSocketFactory(opts)) {}

Client::Impl::Impl(const ClientOptions& opts,
std::unique_ptr<SocketFactory> socket_factory)
: options_(opts)
: options_(modifyClientOptions(opts))
, events_(nullptr)
, socket_factory_(std::move(socket_factory))
, endpoints_iterator(GetEndpointsIterator(options_))
{
for (unsigned int i = 0; ; ) {
try {
ResetConnection();
break;
} catch (const std::system_error&) {
if (++i > options_.send_retries) {
throw;
}

socket_factory_->sleepFor(options_.retry_timeout);
}
}
CreateConnection();

if (options_.compression_method != CompressionMethod::None) {
compression_ = CompressionState::Enable;
Expand Down Expand Up @@ -329,17 +358,57 @@ void Client::Impl::Ping() {
}

void Client::Impl::ResetConnection() {
InitializeStreams(socket_factory_->connect(options_));
InitializeStreams(socket_factory_->connect(options_, current_endpoint_.value()));

if (!Handshake()) {
throw ProtocolError("fail to connect to " + options_.host);
}
}

void Client::Impl::ResetConnectionEndpoint() {
current_endpoint_.reset();
for (size_t i = 0; i < options_.endpoints.size();)
{
try
{
current_endpoint_ = endpoints_iterator->Next();
ResetConnection();
return;
} catch (const std::system_error&) {
if (++i == options_.endpoints.size())
{
current_endpoint_.reset();
throw;
}
}
}
}

void Client::Impl::CreateConnection() {
for (size_t i = 0; i < options_.send_retries;)
{
try
{
ResetConnectionEndpoint();
return;
} catch (const std::system_error&) {
if (++i == options_.send_retries)
{
throw;
}
}
}
}

const ServerInfo& Client::Impl::GetServerInfo() const {
return server_info_;
}


const std::optional<Endpoint>& Client::Impl::GetCurrentEndpoint() const {
return current_endpoint_;
}

bool Client::Impl::Handshake() {
if (!SendHello()) {
return false;
Expand Down Expand Up @@ -859,21 +928,45 @@ bool Client::Impl::ReceiveHello() {
}

void Client::Impl::RetryGuard(std::function<void()> func) {
for (unsigned int i = 0; ; ++i) {
try {
func();
return;
} catch (const std::system_error&) {
bool ok = true;

if (current_endpoint_)
{
for (unsigned int i = 0; ; ++i) {
try {
socket_factory_->sleepFor(options_.retry_timeout);
ResetConnection();
} catch (...) {
ok = false;
func();
return;
} catch (const std::system_error&) {
bool ok = true;

try {
socket_factory_->sleepFor(options_.retry_timeout);
ResetConnection();
} catch (...) {
ok = false;
}

if (!ok && i == options_.send_retries) {
break;
}
}

if (!ok && i == options_.send_retries) {
}
}
// Connectiong with current_endpoint_ are broken.
// Trying to establish with the another one from the list.
size_t connection_attempts_count = GetConnectionAttempts();
for (size_t i = 0; i < connection_attempts_count;)
{
try
{
socket_factory_->sleepFor(options_.retry_timeout);
current_endpoint_ = endpoints_iterator->Next();
ResetConnection();
func();
return;
} catch (const std::system_error&) {
if (++i == connection_attempts_count)
{
current_endpoint_.reset();
throw;
}
}
Expand Down Expand Up @@ -936,6 +1029,14 @@ void Client::ResetConnection() {
impl_->ResetConnection();
}

void Client::ResetConnectionEndpoint() {
impl_->ResetConnectionEndpoint();
}

const std::optional<Endpoint>& Client::GetCurrentEndpoint() const {
return impl_->GetCurrentEndpoint();
}

const ServerInfo& Client::GetServerInfo() const {
return impl_->GetServerInfo();
}
Expand Down
28 changes: 27 additions & 1 deletion clickhouse/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ enum class CompressionMethod {
LZ4 = 1,
};

struct Endpoint {
std::string host;
uint16_t port = 9000;
inline bool operator==(const Endpoint& right) const {
return host == right.host && port == right.port;
}
};

enum class EndpointsIterationAlgorithm {
RoundRobin = 0,
};

struct ClientOptions {
// Setter goes first, so it is possible to apply 'deprecated' annotation safely.
#define DECLARE_FIELD(name, type, setter, default_value) \
Expand All @@ -56,7 +68,15 @@ struct ClientOptions {
/// Hostname of the server.
DECLARE_FIELD(host, std::string, SetHost, std::string());
/// Service port.
DECLARE_FIELD(port, unsigned int, SetPort, 9000);
DECLARE_FIELD(port, uint16_t, SetPort, 9000);

/** Set endpoints (host+port), only one is used.
* Client tries to connect to those endpoints one by one, on the round-robin basis:
* first default enpoint (set via SetHost() + SetPort()), then each of endpoints, from begin() to end(),
* the first one to establish connection is used for the rest of the session.
* If port isn't specified, default(9000) value will be used.
*/
DECLARE_FIELD(endpoints, std::vector<Endpoint>, SetEndpoints, {});

/// Default database.
DECLARE_FIELD(default_database, std::string, SetDefaultDatabase, "default");
Expand Down Expand Up @@ -240,6 +260,12 @@ class Client {

const ServerInfo& GetServerInfo() const;

/// Get current connected endpoint.
/// In case when client is not connected to any endpoint, nullopt will returned.
const std::optional<Endpoint>& GetCurrentEndpoint() const;

// Try to connect to different endpoints one by one only one time. If it doesn't work, throw an exception.
void ResetConnectionEndpoint();
private:
const ClientOptions options_;

Expand Down
Loading

0 comments on commit de6f56a

Please sign in to comment.