Skip to content

Commit

Permalink
feat: introduce an interface for provided buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
romange committed Aug 19, 2024
1 parent 50da303 commit 7c01369
Show file tree
Hide file tree
Showing 12 changed files with 231 additions and 30 deletions.
10 changes: 10 additions & 0 deletions util/fiber_socket_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ class FiberSocketBase : public io::Sink, public io::AsyncSink, public io::Source
: Recv(io::MutableBytes{reinterpret_cast<uint8_t*>(v->iov_base), v->iov_len}, 0);
}

// Waits for input. Returns error if socket had an I/O error.
// For raw io_uring sockers, with a valid, registerd buf_group_id, the socket may provide.
// reference to data owned by kernel via `mb`. In any case, if no error is returned,
// and mb is empty, the socket has data available that can be read by the following Recv
// cal.
// WaitForRecv can return no_buffer_space if io_uring socket does not have provided buffers
// currently available. This error can be ignored and be followed by Recv.
// It is returned so that a caller could track such events.
virtual std::error_code WaitForRecv(uint16_t buf_group_id, io::MutableBytes* mb) = 0;

virtual ::io::Result<size_t> Recv(const io::MutableBytes& mb, int flags = 0) = 0;

static bool IsConnClosed(const error_code& ec) {
Expand Down
16 changes: 16 additions & 0 deletions util/fibers/epoll_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,22 @@ auto EpollSocket::RecvMsg(const msghdr& msg, int flags) -> Result<size_t> {
return nonstd::make_unexpected(std::move(ec));
}

std::error_code EpollSocket::WaitForRecv(uint16_t buf_group_id, io::MutableBytes* mb) {
DCHECK(read_context_ == NULL);
*mb = {};

if (epoll_mask_) {
// we may return false positives.
return {};
}

read_context_ = detail::FiberActive();
absl::Cleanup clean = [this]() { read_context_ = nullptr; };
error_code ec;
SuspendMyself(read_context_, &ec);
return ec;
}

io::Result<size_t> EpollSocket::Recv(const io::MutableBytes& mb, int flags) {
msghdr msg;
memset(&msg, 0, sizeof(msg));
Expand Down
2 changes: 2 additions & 0 deletions util/fibers/epoll_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class EpollSocket : public LinuxSocketBase {
void AsyncWriteSome(const iovec* v, uint32_t len, AsyncProgressCb cb) override;

Result<size_t> RecvMsg(const msghdr& msg, int flags) override;

std::error_code WaitForRecv(uint16_t buf_group_id, io::MutableBytes* mb) override;
Result<size_t> Recv(const io::MutableBytes& mb, int flags = 0) override;

error_code Shutdown(int how) override;
Expand Down
48 changes: 48 additions & 0 deletions util/fibers/fiber_socket_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,54 @@ TEST_P(FiberSocketTest, UDS) {
LOG(INFO) << "Finished";
}

TEST_P(FiberSocketTest, WaitRecv) {
constexpr unsigned kBufGroupId = 0;
constexpr unsigned kBufLen = 40;
#ifdef __linux__
bool use_uring = GetParam() == "uring";

UringProactor* up = static_cast<UringProactor*>(proactor_.get());
if (use_uring) {
up->Await([up] {
up->RegisterBufferRing(kBufGroupId, 4 /*nentries*/, kBufLen);
});
}
#endif

unique_ptr<FiberSocketBase> sock;
error_code ec;
proactor_->Await([&] {
sock.reset(proactor_->CreateSocket());
ec = sock->Connect(listen_ep_);
});
ASSERT_FALSE(ec);
io::MutableBytes mb;

auto recv_fb = proactor_->LaunchFiber([&] {
ec = conn_socket_->WaitForRecv(kBufGroupId, &mb);
});

uint8_t buf[128];
memset(buf, 'x', sizeof(buf));

proactor_->Await([&] {
auto wrt_ec = sock->Write(io::Bytes(buf));
ASSERT_FALSE(wrt_ec);
});
recv_fb.Join();
ASSERT_FALSE(ec) << ec;
proactor_->Await([&] { std::ignore = sock->Close(); });

#ifdef __linux__
if (use_uring) {
ASSERT_EQ(mb.size(), kBufLen);
LOG(INFO) << "MB Size: " << mb.size();
ASSERT_EQ(0, memcmp(buf, mb.data(), mb.size()));
up->ReplenishBuffers(kBufGroupId, mb);
}
#endif
}

#ifdef __linux__
TEST_P(FiberSocketTest, NotEmpty) {
bool use_uring = GetParam() == "uring";
Expand Down
89 changes: 65 additions & 24 deletions util/fibers/uring_proactor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ constexpr uint16_t kMsgRingSubmitTag = 1;
constexpr uint16_t kTimeoutSubmitTag = 2;
constexpr uint16_t kCqeBatchLen = 128;

constexpr size_t kBufRingEntriesCnt = 8192;
constexpr size_t kBufRingEntrySize = 64; // TODO: should be configurable.

} // namespace

UringProactor::UringProactor() : ProactorBase() {
Expand All @@ -85,8 +82,8 @@ UringProactor::~UringProactor() {
for (size_t i = 0; i < bufring_groups_.size(); ++i) {
const auto& group = bufring_groups_[i];
if (group.ring != nullptr) {
io_uring_free_buf_ring(&ring_, group.ring, kBufRingEntriesCnt, i);
delete group.buf;
io_uring_free_buf_ring(&ring_, group.ring, group.nentries, i);
delete[] group.buf;
}
}

Expand Down Expand Up @@ -117,6 +114,7 @@ void UringProactor::Init(unsigned pool_index, size_t ring_size, int wq_fd) {
msgring_f_ = 0;
poll_first_ = 0;
buf_ring_f_ = 0;
bundle_f_ = 0;

// If we setup flags that kernel does not recognize, it fails the setup call.
if (kver.kernel > 5 || (kver.kernel == 5 && kver.major >= 19)) {
Expand Down Expand Up @@ -167,7 +165,11 @@ void UringProactor::Init(unsigned pool_index, size_t ring_size, int wq_fd) {
unsigned req_feats = IORING_FEAT_SINGLE_MMAP | IORING_FEAT_FAST_POLL | IORING_FEAT_NODROP;
CHECK_EQ(req_feats, params.features & req_feats)
<< "required feature feature is not present in the kernel";

#ifdef IORING_FEAT_RECVSEND_BUNDLE
if (params.features & IORING_FEAT_RECVSEND_BUNDLE) {
bundle_f_ = 1;
}
#endif
int res = io_uring_register_ring_fd(&ring_);
VLOG_IF(1, res < 0) << "io_uring_register_ring_fd failed: " << -res;

Expand Down Expand Up @@ -349,49 +351,88 @@ void UringProactor::ReturnBuffer(UringBuf buf) {
buf_pool_.segments.Return(segments);
}

int UringProactor::RegisterBufferRing(unsigned group_id) {
int UringProactor::RegisterBufferRing(unsigned group_id, unsigned nentries, unsigned esize) {
CHECK_LT(nentries, 32768u);
CHECK_EQ(0u, nentries & (nentries - 1)); // power of 2.
DCHECK(InMyThread());

if (buf_ring_f_ == 0)
return EOPNOTSUPP;

if (bufring_groups_.size() <= group_id) {
bufring_groups_.resize(group_id + 1);
}

auto& ring_group = bufring_groups_[group_id];
CHECK(ring_group.ring == nullptr);
auto& buf_group = bufring_groups_[group_id];
CHECK(buf_group.ring == nullptr);

int err = 0;

ring_group.ring = io_uring_setup_buf_ring(&ring_, kBufRingEntriesCnt, group_id, 0, &err);
if (ring_group.ring == nullptr) {
buf_group.ring = io_uring_setup_buf_ring(&ring_, nentries, group_id, 0, &err);
if (buf_group.ring == nullptr) {
return -err; // err is negative.
}

unsigned mask = kBufRingEntriesCnt - 1;
ring_group.buf = new uint8_t[kBufRingEntriesCnt * kBufRingEntrySize];
uint8_t* next = ring_group.buf;
for (unsigned i = 0; i < kBufRingEntriesCnt; ++i) {
io_uring_buf_ring_add(ring_group.ring, next, kBufRingEntrySize, i, mask, i);
next += 64;
unsigned mask = io_uring_buf_ring_mask(nentries);
buf_group.buf = new uint8_t[nentries * esize];
buf_group.nentries = nentries;
buf_group.entry_size = esize;
uint8_t* next = buf_group.buf;

// buffers are ordered nicely at first, in sequential order inside a single range
// but when we return them back to bufring, then will be reordered because
// CQEs complete in arbitrary order, moreover the ownership over buffers is passed back
// to bufring in arbitrary order inside ConsumeBufRing.
for (unsigned i = 0; i < nentries; ++i) {
io_uring_buf_ring_add(buf_group.ring, next, esize, i, mask, i);
next += esize;
}
io_uring_buf_ring_advance(ring_group.ring, kBufRingEntriesCnt);
// return the ownership to the ring.
io_uring_buf_ring_advance(buf_group.ring, nentries);

return 0;
}

uint8_t* UringProactor::GetBufRingPtr(unsigned group_id, unsigned bufid) {
DCHECK_LT(group_id, bufring_groups_.size());
DCHECK_LT(bufid, kBufRingEntriesCnt);
auto& buf_group = bufring_groups_[group_id];

DCHECK_LT(bufid, buf_group.nentries);
DCHECK(bufring_groups_[group_id].buf);
return bufring_groups_[group_id].buf + bufid * kBufRingEntrySize;
return bufring_groups_[group_id].buf + bufid * buf_group.entry_size;
}

void UringProactor::ReplenishBuffers(unsigned group_id, io::Bytes slice) {
DCHECK_LT(group_id, bufring_groups_.size());
DCHECK(!slice.empty());

auto& buf_group = bufring_groups_[group_id];
size_t total_len = size_t(buf_group.nentries) * size_t(buf_group.entry_size);
DCHECK(slice.end() <= buf_group.buf + total_len);
off_t offs = slice.data() - buf_group.buf;
DCHECK_GE(offs, 0);
DCHECK(offs % buf_group.entry_size == 0);

// Add 1 or more buffers back to the ring. ReplenishBuffers calls can come OOO, therefore
// we expect to see reshuffling of buffers within the ring.
unsigned bid = offs / buf_group.entry_size;
size_t replenished = 0;
uint8_t* cur_buf = buf_group.buf + bid * buf_group.entry_size;
unsigned mask = io_uring_buf_ring_mask(buf_group.nentries);
unsigned offset = 0;
while (replenished < slice.size()) {
io_uring_buf_ring_add(buf_group.ring, cur_buf, buf_group.entry_size, bid, mask, offset++);
replenished += buf_group.entry_size;
}

io_uring_buf_ring_advance(bufring_groups_[group_id].ring, offset);
}

void UringProactor::ConsumeBufRing(unsigned group_id, unsigned len) {
unsigned UringProactor::BufRingAvailable(unsigned group_id) const {
DCHECK_LT(group_id, bufring_groups_.size());
DCHECK_LE(len, kBufRingEntriesCnt);
DCHECK(bufring_groups_[group_id].ring);
auto& buf_group = bufring_groups_[group_id];

io_uring_buf_ring_advance(bufring_groups_[group_id].ring, len);
return io_uring_buf_ring_available(const_cast<io_uring*>(&ring_), buf_group.ring, group_id);
}

int UringProactor::CancelRequests(int fd, unsigned flags) {
Expand Down
28 changes: 22 additions & 6 deletions util/fibers/uring_proactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class UringProactor : public ProactorBase {
return poll_first_;
}

bool HasBundleSupport() const {
return bundle_f_;
}

bool HasDirectFD() const {
return !register_fds_.empty();
}
Expand Down Expand Up @@ -112,13 +116,21 @@ class UringProactor : public ProactorBase {
int RegisterBuffers(const struct iovec* iovecs, unsigned nr_vecs);
int UnregisterBuffers();

// Experimental. should not be called in production.
// Registers an iouring buffer ring (see io_uring_register_buf_ring(3)).
// Registers a predefined 16K buffer ring with specified buffer group_id.
// Available from kernel 5.19.
// Registers a buffer ring with specified buffer group_id.
// Returns 0 on success, errno on failure.
int RegisterBufferRing(unsigned group_id);
int RegisterBufferRing(unsigned group_id, unsigned nentries, unsigned esize);
uint8_t* GetBufRingPtr(unsigned group_id, unsigned bufid);
void ConsumeBufRing(unsigned group_id, unsigned len);

// Return 1 or more buffers to the bufring. slice.data() should point to a buffer returned by
// GetBufRingPtr and its length should be within the range of the buffers handled by group_id.
void ReplenishBuffers(unsigned group_id, io::Bytes slice);
bool BufRingExists(unsigned group_id) const {
return group_id < bufring_groups_.size();
}

unsigned BufRingAvailable(unsigned group_id) const;

// Returns 0 on success, errno on failure.
// See io_uring_prep_cancel(3) for flags.
Expand Down Expand Up @@ -152,7 +164,8 @@ class UringProactor : public ProactorBase {
uint8_t msgring_f_ : 1;
uint8_t poll_first_ : 1;
uint8_t buf_ring_f_ : 1;
uint8_t : 5;
uint8_t bundle_f_ : 1;
uint8_t : 4;

EventCount sqe_avail_;

Expand All @@ -175,8 +188,11 @@ class UringProactor : public ProactorBase {
struct BufRingGroup {
io_uring_buf_ring* ring = nullptr;
uint8_t* buf = nullptr;
uint16_t nentries = 0;
uint16_t reserved;
uint32_t entry_size = 0;
};

static_assert(sizeof(BufRingGroup) == 24);
std::vector<BufRingGroup> bufring_groups_;

// Keeps track of requested buffers
Expand Down
33 changes: 33 additions & 0 deletions util/fibers/uring_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,39 @@ auto UringSocket::RecvMsg(const msghdr& msg, int flags) -> Result<size_t> {
return make_unexpected(std::move(ec));
}

error_code UringSocket::WaitForRecv(uint16_t buf_group_id, io::MutableBytes* mb) {
Proactor* p = GetProactor();

if (!p->BufRingExists(buf_group_id))
return make_error_code(errc::no_buffer_space);

int fd = ShiftedFd();

FiberCall fc(p, timeout());
fc->PrepRecv(fd, nullptr, 0, 0);
fc->sqe()->flags |= (register_flag() | IOSQE_BUFFER_SELECT);
fc->sqe()->buf_group = buf_group_id;
IoResult res = fc.Get();

if (res > 0) {
uint32_t flags = fc.flags();
// should not happen unless there is a bug in kernel.
if (uring_unlikely((IORING_CQE_F_BUFFER & flags) == 0))
return make_error_code(errc::io_error);
has_recv_data_ = flags & IORING_CQE_F_SOCK_NONEMPTY ? 1 : 0;
unsigned bid = flags >> IORING_CQE_BUFFER_SHIFT;
uint8_t* start = p->GetBufRingPtr(buf_group_id, bid);
*mb = {start, size_t(res)};
return {};
}

if (res == 0) {
return make_error_code(errc::connection_aborted);
}

return error_code{-res, system_category()};
}

io::Result<size_t> UringSocket::Recv(const io::MutableBytes& mb, int flags) {
int fd = ShiftedFd();
Proactor* p = GetProactor();
Expand Down
1 change: 1 addition & 0 deletions util/fibers/uring_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class UringSocket : public LinuxSocketBase {
void AsyncWriteSome(const iovec* v, uint32_t len, AsyncProgressCb cb) override;

Result<size_t> RecvMsg(const msghdr& msg, int flags) override;
std::error_code WaitForRecv(uint16_t buf_group_id, io::MutableBytes* mb) override;
Result<size_t> Recv(const io::MutableBytes& mb, int flags = 0) override;

using FiberSocketBase::IsConnClosed;
Expand Down
6 changes: 6 additions & 0 deletions util/tls/tls_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ auto Engine::Read(uint8_t* dest, size_t len) -> OpResult {
RETURN_RESULT(result);
}

auto Engine::Peek(uint8_t* dest, size_t len) -> OpResult {
int result = SSL_peek(ssl_, dest, len);

RETURN_RESULT(result);
}

// returns -1 if failed to load any CA certificates, 0 if loaded successfully
int SslProbeSetDefaultCALocation(SSL_CTX* ctx) {
/* The probe paths are based on:
Expand Down
2 changes: 2 additions & 0 deletions util/tls/tls_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class Engine {
// Read bytes from the SSL session.
OpResult Read(uint8_t* dest, size_t len);

OpResult Peek(uint8_t* dest, size_t len);

//! Returns output (read) buffer. This operation is destructive, i.e. after calling
//! this function the buffer is being consumed.
//! See OutputPending() for checking if there is a output buffer to consume.
Expand Down
Loading

0 comments on commit 7c01369

Please sign in to comment.