Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 95 additions & 44 deletions client/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "toolbelt/logging.h"
#include "toolbelt/sockets.h"
#include "toolbelt/triggerfd.h"
#include <cassert>
#include <chrono>
#include <cstddef>
#include <functional>
Expand Down Expand Up @@ -57,43 +58,90 @@ struct ChannelStats {
uint64_t max_message_size;
};

template <typename T> class weak_ptr;
struct DefaultAliaser {
void Set(const void * /*buffer*/, size_t /*length*/) {}
void Reset() {}
template <typename T>
void AliasTo(const void *buffer, size_t length, T *&dest) const {
assert(length >= sizeof(T));
static_assert(std::is_standard_layout_v<T>,
"Default aliasing requires a standard layout type! Casting "
"is undefined behavior!");
dest = reinterpret_cast<T *>(buffer);
}
};

template <typename T> class shared_ptr {
// The shared_ptr / weak_ptr combination allows you to keep pointers with
// the standard semantics into an active message slot that was recently read
// from a channel subscriber. The pointer type `T` must be aliasable in the
// raw shared-memory of the active slot (and, of course, `const`).
// By default, the aliasing is done with reinterpret_cast, which only works
// if `T` is a standard-layout type (e.g., C-struct, primitive type, etc.).
// An aliaser type can be provided to customize how this works, for example,
// to accommodate a complex zero-copy type. An aliaser type should provide
// the functions seen in DefaultAliaser, that is:
// - Be default-constructible
// - Have a non-const `Set` function to initialize the aliased object for
// a given buffer (if needed, e.g., like finding fields and offsets)
// - Have a `Reset` function to forget whatever was setup in `Set`
// - Have a const `AliasTo` function that can deliver a valid `T*` pointer
// that points to an object that is aliasing the given memory
// See TestAliasedMessageAliaser in client_test.cc as an example.

template <typename T, typename Aliaser = DefaultAliaser> class weak_ptr;

template <typename T, typename Aliaser = DefaultAliaser> class shared_ptr {
public:
shared_ptr() = default;
shared_ptr(const Message &m) : msg_(m.active_message) { msg_->IncRef(); }
shared_ptr(std::shared_ptr<ActiveMessage> m) : msg_(std::move(m)) {
msg_->IncRef();
if (msg_ != nullptr) {
msg_->IncRef();
aliaser_.Set(msg_->buffer, msg_->length);
}
}
shared_ptr(const shared_ptr &p) : msg_(p.msg_) { msg_->IncRef(); }
shared_ptr(shared_ptr &&p) : msg_(std::move(p.msg_)) {}
shared_ptr(const weak_ptr<T> &p);
shared_ptr(const Message &m) : shared_ptr(m.active_message) {}
shared_ptr(const weak_ptr<T, Aliaser> &p);

shared_ptr(const shared_ptr &p) : msg_(p.msg_), aliaser_(p.aliaser_) {
if (msg_ != nullptr) {
msg_->IncRef();
}
}
shared_ptr &operator=(const shared_ptr &p) {
msg_ = p.msg_;
aliaser_ = p.aliaser_;
if (msg_ != nullptr) {
msg_->IncRef();
}
return *this;
}

shared_ptr &operator=(shared_ptr &&p) {
msg_ = std::move(p.msg_);
return *this;
}
shared_ptr(shared_ptr &&p) noexcept = default;
shared_ptr &operator=(shared_ptr &&p) noexcept = default;

~shared_ptr() {
if (msg_ != nullptr) {
aliaser_.Reset();
msg_->DecRef();
}
}
subspace::Message GetMessage() const { return Message(msg_); }

T *get() const { return reinterpret_cast<T *>(msg_->buffer); }
T &operator*() const { return *reinterpret_cast<T *>(msg_->buffer); }
T *get() const {
if (msg_ == nullptr) {
return nullptr;
}
T *result = nullptr;
aliaser_.AliasTo(msg_->buffer, msg_->length, result);
return result;
}
T &operator*() const { return *get(); }
T *operator->() const { return get(); }
operator bool() const { return msg_ != nullptr; }
long use_count() const { return msg_ == nullptr ? 0 : msg_->refs.load(); }
void reset() {
if (msg_ != nullptr) {
aliaser_.Reset();
msg_->DecRef();
}
msg_ = nullptr;
Expand All @@ -103,15 +151,16 @@ template <typename T> class shared_ptr {
bool operator!=(const shared_ptr &p) const { return msg_ != p.msg_; }

private:
template <typename M> friend class weak_ptr;
template <typename M, typename OtherAliaser> friend class weak_ptr;

std::shared_ptr<ActiveMessage> msg_;
[[no_unique_address]] Aliaser aliaser_;
};

template <typename T> class weak_ptr {
template <typename T, typename Aliaser> class weak_ptr {
public:
weak_ptr() = default;
weak_ptr(const shared_ptr<T> &p)
weak_ptr(const shared_ptr<T, Aliaser> &p)
: sub_(p.msg_->sub), slot_(p.msg_->slot), ordinal_(p.msg_->ordinal) {}
weak_ptr(const weak_ptr &p)
: sub_(p.sub_), slot_(p.slot_), ordinal_(p.ordinal_) {}
Expand Down Expand Up @@ -166,15 +215,16 @@ template <typename T> class weak_ptr {
}

private:
template <typename M> friend class shared_ptr;
template <typename M, typename OtherAliaser> friend class shared_ptr;

std::shared_ptr<details::SubscriberImpl> sub_;
MessageSlot *slot_ = nullptr;
uint64_t ordinal_ = 0;
};

template <typename T>
inline shared_ptr<T>::shared_ptr(const weak_ptr<T> &p) : msg_(p.lock().msg_) {}
template <typename T, typename Aliaser>
inline shared_ptr<T, Aliaser>::shared_ptr(const weak_ptr<T, Aliaser> &p)
: msg_(p.lock().msg_) {}

class Publisher;
class Subscriber;
Expand Down Expand Up @@ -398,8 +448,8 @@ class ClientImpl : public std::enable_shared_from_this<ClientImpl> {

// As ReadMessage above but returns a shared_ptr to the typed message.
// NOTE: this is subspace::shared_ptr, not std::shared_ptr.
template <typename T>
absl::StatusOr<shared_ptr<T>>
template <typename T, typename Aliaser = DefaultAliaser>
absl::StatusOr<shared_ptr<T, Aliaser>>
ReadMessage(details::SubscriberImpl *subscriber,
ReadMode mode = ReadMode::kReadNext);

Expand All @@ -408,9 +458,9 @@ class ClientImpl : public std::enable_shared_from_this<ClientImpl> {
uint64_t timestamp);
// AsFindMessage above but returns a shared_ptr to the typed message.
// NOTE: this is subspace::shared_ptr, not std::shared_ptr.
template <typename T>
absl::StatusOr<shared_ptr<T>> FindMessage(details::SubscriberImpl *subscriber,
uint64_t timestamp);
template <typename T, typename Aliaser = DefaultAliaser>
absl::StatusOr<shared_ptr<T, Aliaser>>
FindMessage(details::SubscriberImpl *subscriber, uint64_t timestamp);

// Gets the PollFd for a publisher and subscriber. PollFds are only
// available for reliable publishers but a valid pollfd will be returned for
Expand Down Expand Up @@ -585,28 +635,28 @@ class ClientImpl : public std::enable_shared_from_this<ClientImpl> {
// prevent the slot referred to by the shared_ptr from being taken
// by a publisher. Don't hold onto shared_ptr instances long than
// you need to as it may prevent a publisher getting a slot.
template <typename T>
inline absl::StatusOr<::subspace::shared_ptr<T>>
template <typename T, typename Aliaser>
inline absl::StatusOr<::subspace::shared_ptr<T, Aliaser>>
ClientImpl::ReadMessage(details::SubscriberImpl *subscriber, ReadMode mode) {
absl::StatusOr<Message> msg = ReadMessage(subscriber, mode);
if (!msg.ok()) {
return msg.status();
}
if (msg->length == 0) {
return ::subspace::shared_ptr<T>();
return ::subspace::shared_ptr<T, Aliaser>();
}
return ::subspace::shared_ptr<T>(std::move(*msg));
return ::subspace::shared_ptr<T, Aliaser>(std::move(*msg));
}

template <typename T>
inline absl::StatusOr<::subspace::shared_ptr<T>>
template <typename T, typename Aliaser>
inline absl::StatusOr<::subspace::shared_ptr<T, Aliaser>>
ClientImpl::FindMessage(details::SubscriberImpl *subscriber,
uint64_t timestamp) {
absl::StatusOr<Message> msg = FindMessage(subscriber, timestamp);
if (!msg.ok()) {
return msg.status();
}
return ::subspace::shared_ptr<T>(std::move(*msg));
return ::subspace::shared_ptr<T, Aliaser>(std::move(*msg));
}

// The Publisher and Subscriber classes are the main interface for sending
Expand Down Expand Up @@ -979,8 +1029,8 @@ class Subscriber {

// As ReadMessage above but returns a shared_ptr to the typed message.
// NOTE: this is subspace::shared_ptr, not std::shared_ptr.
template <typename T>
absl::StatusOr<shared_ptr<T>>
template <typename T, typename Aliaser = DefaultAliaser>
absl::StatusOr<shared_ptr<T, Aliaser>>
ReadMessage(ReadMode mode = ReadMode::kReadNext);

bool AddActiveMessage(int32_t slot_id) {
Expand All @@ -997,8 +1047,8 @@ class Subscriber {

// AsFindMessage above but returns a shared_ptr to the typed message.
// NOTE: this is subspace::shared_ptr, not std::shared_ptr.
template <typename T>
absl::StatusOr<shared_ptr<T>> FindMessage(uint64_t timestamp);
template <typename T, typename Aliaser = DefaultAliaser>
absl::StatusOr<shared_ptr<T, Aliaser>> FindMessage(uint64_t timestamp);

struct pollfd GetPollFd() const {
return client_->GetPollFd(impl_.get());
Expand Down Expand Up @@ -1189,16 +1239,16 @@ class Subscriber {
std::function<void(Subscriber *, Message)> message_callback_ = nullptr;
};

template <typename T>
inline absl::StatusOr<::subspace::shared_ptr<T>>
template <typename T, typename Aliaser>
inline absl::StatusOr<::subspace::shared_ptr<T, Aliaser>>
Subscriber::ReadMessage(ReadMode mode) {
return client_->ReadMessage<T>(impl_.get(), mode);
return client_->ReadMessage<T, Aliaser>(impl_.get(), mode);
}

template <typename T>
inline absl::StatusOr<::subspace::shared_ptr<T>>
template <typename T, typename Aliaser>
inline absl::StatusOr<::subspace::shared_ptr<T, Aliaser>>
Subscriber::FindMessage(uint64_t timestamp) {
return client_->FindMessage<T>(impl_.get(), timestamp);
return client_->FindMessage<T, Aliaser>(impl_.get(), timestamp);
}

// This is a wrapper around the ClientImpl that is created as a shared_ptr
Expand Down Expand Up @@ -1297,9 +1347,10 @@ class Client {
std::shared_ptr<ClientImpl> impl_;
};

// Convenience functions to create a client and publisher/subscriber in one step.
// The Publisher and Subscriber objects hold a shared_ptr to the ClientImpl
// internally, so the client stays alive as long as the returned object does.
// Convenience functions to create a client and publisher/subscriber in one
// step. The Publisher and Subscriber objects hold a shared_ptr to the
// ClientImpl internally, so the client stays alive as long as the returned
// object does.

inline absl::StatusOr<Publisher>
CreatePublisher(const std::string &channel_name,
Expand Down
122 changes: 122 additions & 0 deletions client/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,128 @@ TEST_F(ClientTest, Publish2Message2AndReadSharedPtrs) {
// Number of active messages: 1
}

class TestAliasedMessage {
public:
TestAliasedMessage(const void *buffer, size_t length)
: base_ptr_(const_cast<void *>(buffer)) {
assert(length >= GetSize());
}
TestAliasedMessage() = default;

static size_t GetSize() { return 20; }

void SetName(std::string_view s) {
assert(s.size() <= 12);
std::memcpy(MutableBase(), s.data(), s.size());
std::memset(MutableBase() + s.size(), 0, 12 - s.size());
};
std::string_view Name() const {
std::string_view result{ConstBase(), 12};
return result.substr(0, result.find('\0'));
}

void SetId(std::int32_t i) { std::memcpy(MutableBase() + 12, &i, sizeof(i)); }
std::int32_t Id() const {
std::int32_t result = 0;
std::memcpy(&result, ConstBase() + 12, sizeof(result));
return result;
}

void SetScore(float f) { std::memcpy(MutableBase() + 16, &f, sizeof(f)); }
float Score() const {
float result = 0;
std::memcpy(&result, ConstBase() + 16, sizeof(result));
return result;
}

private:
void *base_ptr_ = nullptr;
char *MutableBase() { return reinterpret_cast<char *>(base_ptr_); }
const char *ConstBase() const {
return reinterpret_cast<const char *>(base_ptr_);
}
};

struct TestAliasedMessageAliaser {
void Set(const void *buffer, size_t length) {
msg = TestAliasedMessage(buffer, length);
}
void Reset() { msg = TestAliasedMessage(); }
void AliasTo(const void *buffer, size_t length,
const TestAliasedMessage *&dest) const {
dest = &msg;
}
TestAliasedMessage msg;
};

TEST_F(ClientTest, PublishSingleMessageAndReadAliasedSharedPtr) {
subspace::Client pub_client;
subspace::Client sub_client;
ASSERT_OK(pub_client.Init(Socket()));
ASSERT_OK(sub_client.Init(Socket()));
absl::StatusOr<Publisher> pub =
pub_client.CreatePublisher("dave6", TestAliasedMessage::GetSize(), 10);
ASSERT_OK(pub);
absl::StatusOr<void *> buffer = pub->GetMessageBuffer();
ASSERT_OK(buffer);
{
TestAliasedMessage pub_msg(*buffer, TestAliasedMessage::GetSize());
pub_msg.SetName("foobar");
pub_msg.SetId(42);
pub_msg.SetScore(3.14F);
}
absl::StatusOr<const Message> pub_status =
pub->PublishMessage(TestAliasedMessage::GetSize());
ASSERT_OK(pub_status);

absl::StatusOr<Subscriber> sub =
sub_client.CreateSubscriber("dave6", subspace::SubscriberOptions()
.SetMaxActiveMessages(3)
.SetKeepActiveMessage(false));
ASSERT_OK(sub);

using SharedPtrAliasedMsg =
subspace::shared_ptr<const TestAliasedMessage, TestAliasedMessageAliaser>;
absl::StatusOr<SharedPtrAliasedMsg> p =
sub->ReadMessage<const TestAliasedMessage, TestAliasedMessageAliaser>();
ASSERT_OK(p);
const auto &ptr = *p;
ASSERT_TRUE(static_cast<bool>(ptr));
ASSERT_EQ("foobar", ptr->Name());
ASSERT_EQ(42, ptr->Id());
ASSERT_EQ(3.14F, ptr->Score());

ASSERT_EQ(1, ptr.use_count());

// Copy the shared ptr using copy constructor.
SharedPtrAliasedMsg p2(ptr);
ASSERT_EQ(2, ptr.use_count());
ASSERT_EQ(2, p2.use_count());
ASSERT_EQ("foobar", p2->Name());
ASSERT_EQ(42, p2->Id());
ASSERT_EQ(3.14F, p2->Score());

// Copy using copy operator.
SharedPtrAliasedMsg p3 = ptr;
ASSERT_EQ(3, ptr.use_count());
ASSERT_EQ(3, p2.use_count());
ASSERT_EQ(3, p3.use_count());
ASSERT_EQ("foobar", p3->Name());
ASSERT_EQ(42, p3->Id());
ASSERT_EQ(3.14F, p3->Score());

// Move p3 to p4.
SharedPtrAliasedMsg p4 = std::move(p3);
ASSERT_FALSE(static_cast<bool>(p3));
ASSERT_EQ(3, ptr.use_count());
ASSERT_EQ(3, p2.use_count());
ASSERT_EQ(0, p3.use_count());
ASSERT_EQ(3, p4.use_count());
ASSERT_EQ("foobar", p4->Name());
ASSERT_EQ(42, p4->Id());
ASSERT_EQ(3.14F, p4->Score());
}

TEST_F(ClientTest, FindMessage) {
subspace::Client pub_client;
subspace::Client sub_client;
Expand Down
Loading