diff --git a/client/client.h b/client/client.h index a911d35..a3f3073 100644 --- a/client/client.h +++ b/client/client.h @@ -21,6 +21,7 @@ #include "toolbelt/logging.h" #include "toolbelt/sockets.h" #include "toolbelt/triggerfd.h" +#include #include #include #include @@ -57,43 +58,90 @@ struct ChannelStats { uint64_t max_message_size; }; -template class weak_ptr; +struct DefaultAliaser { + void Set(const void * /*buffer*/, size_t /*length*/) {} + void Reset() {} + template + void AliasTo(const void *buffer, size_t length, T *&dest) const { + assert(length >= sizeof(T)); + static_assert(std::is_standard_layout_v, + "Default aliasing requires a standard layout type! Casting " + "is undefined behavior!"); + dest = reinterpret_cast(buffer); + } +}; -template class shared_ptr { +// The shared_ptr / weak_ptr combination allows you to keep pointers with +// the standard semantics into an active message slot that was recently read +// from a channel subscriber. The pointer type `T` must be aliasable in the +// raw shared-memory of the active slot (and, of course, `const`). +// By default, the aliasing is done with reinterpret_cast, which only works +// if `T` is a standard-layout type (e.g., C-struct, primitive type, etc.). +// An aliaser type can be provided to customize how this works, for example, +// to accommodate a complex zero-copy type. An aliaser type should provide +// the functions seen in DefaultAliaser, that is: +// - Be default-constructible +// - Have a non-const `Set` function to initialize the aliased object for +// a given buffer (if needed, e.g., like finding fields and offsets) +// - Have a `Reset` function to forget whatever was setup in `Set` +// - Have a const `AliasTo` function that can deliver a valid `T*` pointer +// that points to an object that is aliasing the given memory +// See TestAliasedMessageAliaser in client_test.cc as an example. + +template class weak_ptr; + +template class shared_ptr { public: shared_ptr() = default; - shared_ptr(const Message &m) : msg_(m.active_message) { msg_->IncRef(); } shared_ptr(std::shared_ptr m) : msg_(std::move(m)) { - msg_->IncRef(); + if (msg_ != nullptr) { + msg_->IncRef(); + aliaser_.Set(msg_->buffer, msg_->length); + } } - shared_ptr(const shared_ptr &p) : msg_(p.msg_) { msg_->IncRef(); } - shared_ptr(shared_ptr &&p) : msg_(std::move(p.msg_)) {} - shared_ptr(const weak_ptr &p); + shared_ptr(const Message &m) : shared_ptr(m.active_message) {} + shared_ptr(const weak_ptr &p); + shared_ptr(const shared_ptr &p) : msg_(p.msg_), aliaser_(p.aliaser_) { + if (msg_ != nullptr) { + msg_->IncRef(); + } + } shared_ptr &operator=(const shared_ptr &p) { msg_ = p.msg_; + aliaser_ = p.aliaser_; + if (msg_ != nullptr) { + msg_->IncRef(); + } return *this; } - shared_ptr &operator=(shared_ptr &&p) { - msg_ = std::move(p.msg_); - return *this; - } + shared_ptr(shared_ptr &&p) noexcept = default; + shared_ptr &operator=(shared_ptr &&p) noexcept = default; ~shared_ptr() { if (msg_ != nullptr) { + aliaser_.Reset(); msg_->DecRef(); } } subspace::Message GetMessage() const { return Message(msg_); } - T *get() const { return reinterpret_cast(msg_->buffer); } - T &operator*() const { return *reinterpret_cast(msg_->buffer); } + T *get() const { + if (msg_ == nullptr) { + return nullptr; + } + T *result = nullptr; + aliaser_.AliasTo(msg_->buffer, msg_->length, result); + return result; + } + T &operator*() const { return *get(); } T *operator->() const { return get(); } operator bool() const { return msg_ != nullptr; } long use_count() const { return msg_ == nullptr ? 0 : msg_->refs.load(); } void reset() { if (msg_ != nullptr) { + aliaser_.Reset(); msg_->DecRef(); } msg_ = nullptr; @@ -103,15 +151,16 @@ template class shared_ptr { bool operator!=(const shared_ptr &p) const { return msg_ != p.msg_; } private: - template friend class weak_ptr; + template friend class weak_ptr; std::shared_ptr msg_; + [[no_unique_address]] Aliaser aliaser_; }; -template class weak_ptr { +template class weak_ptr { public: weak_ptr() = default; - weak_ptr(const shared_ptr &p) + weak_ptr(const shared_ptr &p) : sub_(p.msg_->sub), slot_(p.msg_->slot), ordinal_(p.msg_->ordinal) {} weak_ptr(const weak_ptr &p) : sub_(p.sub_), slot_(p.slot_), ordinal_(p.ordinal_) {} @@ -166,15 +215,16 @@ template class weak_ptr { } private: - template friend class shared_ptr; + template friend class shared_ptr; std::shared_ptr sub_; MessageSlot *slot_ = nullptr; uint64_t ordinal_ = 0; }; -template -inline shared_ptr::shared_ptr(const weak_ptr &p) : msg_(p.lock().msg_) {} +template +inline shared_ptr::shared_ptr(const weak_ptr &p) + : msg_(p.lock().msg_) {} class Publisher; class Subscriber; @@ -398,8 +448,8 @@ class ClientImpl : public std::enable_shared_from_this { // As ReadMessage above but returns a shared_ptr to the typed message. // NOTE: this is subspace::shared_ptr, not std::shared_ptr. - template - absl::StatusOr> + template + absl::StatusOr> ReadMessage(details::SubscriberImpl *subscriber, ReadMode mode = ReadMode::kReadNext); @@ -408,9 +458,9 @@ class ClientImpl : public std::enable_shared_from_this { uint64_t timestamp); // AsFindMessage above but returns a shared_ptr to the typed message. // NOTE: this is subspace::shared_ptr, not std::shared_ptr. - template - absl::StatusOr> FindMessage(details::SubscriberImpl *subscriber, - uint64_t timestamp); + template + absl::StatusOr> + FindMessage(details::SubscriberImpl *subscriber, uint64_t timestamp); // Gets the PollFd for a publisher and subscriber. PollFds are only // available for reliable publishers but a valid pollfd will be returned for @@ -585,28 +635,28 @@ class ClientImpl : public std::enable_shared_from_this { // prevent the slot referred to by the shared_ptr from being taken // by a publisher. Don't hold onto shared_ptr instances long than // you need to as it may prevent a publisher getting a slot. -template -inline absl::StatusOr<::subspace::shared_ptr> +template +inline absl::StatusOr<::subspace::shared_ptr> ClientImpl::ReadMessage(details::SubscriberImpl *subscriber, ReadMode mode) { absl::StatusOr msg = ReadMessage(subscriber, mode); if (!msg.ok()) { return msg.status(); } if (msg->length == 0) { - return ::subspace::shared_ptr(); + return ::subspace::shared_ptr(); } - return ::subspace::shared_ptr(std::move(*msg)); + return ::subspace::shared_ptr(std::move(*msg)); } -template -inline absl::StatusOr<::subspace::shared_ptr> +template +inline absl::StatusOr<::subspace::shared_ptr> ClientImpl::FindMessage(details::SubscriberImpl *subscriber, uint64_t timestamp) { absl::StatusOr msg = FindMessage(subscriber, timestamp); if (!msg.ok()) { return msg.status(); } - return ::subspace::shared_ptr(std::move(*msg)); + return ::subspace::shared_ptr(std::move(*msg)); } // The Publisher and Subscriber classes are the main interface for sending @@ -979,8 +1029,8 @@ class Subscriber { // As ReadMessage above but returns a shared_ptr to the typed message. // NOTE: this is subspace::shared_ptr, not std::shared_ptr. - template - absl::StatusOr> + template + absl::StatusOr> ReadMessage(ReadMode mode = ReadMode::kReadNext); bool AddActiveMessage(int32_t slot_id) { @@ -997,8 +1047,8 @@ class Subscriber { // AsFindMessage above but returns a shared_ptr to the typed message. // NOTE: this is subspace::shared_ptr, not std::shared_ptr. - template - absl::StatusOr> FindMessage(uint64_t timestamp); + template + absl::StatusOr> FindMessage(uint64_t timestamp); struct pollfd GetPollFd() const { return client_->GetPollFd(impl_.get()); @@ -1189,16 +1239,16 @@ class Subscriber { std::function message_callback_ = nullptr; }; -template -inline absl::StatusOr<::subspace::shared_ptr> +template +inline absl::StatusOr<::subspace::shared_ptr> Subscriber::ReadMessage(ReadMode mode) { - return client_->ReadMessage(impl_.get(), mode); + return client_->ReadMessage(impl_.get(), mode); } -template -inline absl::StatusOr<::subspace::shared_ptr> +template +inline absl::StatusOr<::subspace::shared_ptr> Subscriber::FindMessage(uint64_t timestamp) { - return client_->FindMessage(impl_.get(), timestamp); + return client_->FindMessage(impl_.get(), timestamp); } // This is a wrapper around the ClientImpl that is created as a shared_ptr @@ -1297,9 +1347,10 @@ class Client { std::shared_ptr impl_; }; -// Convenience functions to create a client and publisher/subscriber in one step. -// The Publisher and Subscriber objects hold a shared_ptr to the ClientImpl -// internally, so the client stays alive as long as the returned object does. +// Convenience functions to create a client and publisher/subscriber in one +// step. The Publisher and Subscriber objects hold a shared_ptr to the +// ClientImpl internally, so the client stays alive as long as the returned +// object does. inline absl::StatusOr CreatePublisher(const std::string &channel_name, diff --git a/client/client_test.cc b/client/client_test.cc index eb8efd3..475d828 100644 --- a/client/client_test.cc +++ b/client/client_test.cc @@ -1918,6 +1918,128 @@ TEST_F(ClientTest, Publish2Message2AndReadSharedPtrs) { // Number of active messages: 1 } +class TestAliasedMessage { +public: + TestAliasedMessage(const void *buffer, size_t length) + : base_ptr_(const_cast(buffer)) { + assert(length >= GetSize()); + } + TestAliasedMessage() = default; + + static size_t GetSize() { return 20; } + + void SetName(std::string_view s) { + assert(s.size() <= 12); + std::memcpy(MutableBase(), s.data(), s.size()); + std::memset(MutableBase() + s.size(), 0, 12 - s.size()); + }; + std::string_view Name() const { + std::string_view result{ConstBase(), 12}; + return result.substr(0, result.find('\0')); + } + + void SetId(std::int32_t i) { std::memcpy(MutableBase() + 12, &i, sizeof(i)); } + std::int32_t Id() const { + std::int32_t result = 0; + std::memcpy(&result, ConstBase() + 12, sizeof(result)); + return result; + } + + void SetScore(float f) { std::memcpy(MutableBase() + 16, &f, sizeof(f)); } + float Score() const { + float result = 0; + std::memcpy(&result, ConstBase() + 16, sizeof(result)); + return result; + } + +private: + void *base_ptr_ = nullptr; + char *MutableBase() { return reinterpret_cast(base_ptr_); } + const char *ConstBase() const { + return reinterpret_cast(base_ptr_); + } +}; + +struct TestAliasedMessageAliaser { + void Set(const void *buffer, size_t length) { + msg = TestAliasedMessage(buffer, length); + } + void Reset() { msg = TestAliasedMessage(); } + void AliasTo(const void *buffer, size_t length, + const TestAliasedMessage *&dest) const { + dest = &msg; + } + TestAliasedMessage msg; +}; + +TEST_F(ClientTest, PublishSingleMessageAndReadAliasedSharedPtr) { + subspace::Client pub_client; + subspace::Client sub_client; + ASSERT_OK(pub_client.Init(Socket())); + ASSERT_OK(sub_client.Init(Socket())); + absl::StatusOr pub = + pub_client.CreatePublisher("dave6", TestAliasedMessage::GetSize(), 10); + ASSERT_OK(pub); + absl::StatusOr buffer = pub->GetMessageBuffer(); + ASSERT_OK(buffer); + { + TestAliasedMessage pub_msg(*buffer, TestAliasedMessage::GetSize()); + pub_msg.SetName("foobar"); + pub_msg.SetId(42); + pub_msg.SetScore(3.14F); + } + absl::StatusOr pub_status = + pub->PublishMessage(TestAliasedMessage::GetSize()); + ASSERT_OK(pub_status); + + absl::StatusOr sub = + sub_client.CreateSubscriber("dave6", subspace::SubscriberOptions() + .SetMaxActiveMessages(3) + .SetKeepActiveMessage(false)); + ASSERT_OK(sub); + + using SharedPtrAliasedMsg = + subspace::shared_ptr; + absl::StatusOr p = + sub->ReadMessage(); + ASSERT_OK(p); + const auto &ptr = *p; + ASSERT_TRUE(static_cast(ptr)); + ASSERT_EQ("foobar", ptr->Name()); + ASSERT_EQ(42, ptr->Id()); + ASSERT_EQ(3.14F, ptr->Score()); + + ASSERT_EQ(1, ptr.use_count()); + + // Copy the shared ptr using copy constructor. + SharedPtrAliasedMsg p2(ptr); + ASSERT_EQ(2, ptr.use_count()); + ASSERT_EQ(2, p2.use_count()); + ASSERT_EQ("foobar", p2->Name()); + ASSERT_EQ(42, p2->Id()); + ASSERT_EQ(3.14F, p2->Score()); + + // Copy using copy operator. + SharedPtrAliasedMsg p3 = ptr; + ASSERT_EQ(3, ptr.use_count()); + ASSERT_EQ(3, p2.use_count()); + ASSERT_EQ(3, p3.use_count()); + ASSERT_EQ("foobar", p3->Name()); + ASSERT_EQ(42, p3->Id()); + ASSERT_EQ(3.14F, p3->Score()); + + // Move p3 to p4. + SharedPtrAliasedMsg p4 = std::move(p3); + ASSERT_FALSE(static_cast(p3)); + ASSERT_EQ(3, ptr.use_count()); + ASSERT_EQ(3, p2.use_count()); + ASSERT_EQ(0, p3.use_count()); + ASSERT_EQ(3, p4.use_count()); + ASSERT_EQ("foobar", p4->Name()); + ASSERT_EQ(42, p4->Id()); + ASSERT_EQ(3.14F, p4->Score()); +} + TEST_F(ClientTest, FindMessage) { subspace::Client pub_client; subspace::Client sub_client;