From 8cdee376c6507bdf5ea187ac798687c03994fc1d Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Tue, 7 Apr 2026 20:25:17 +0000 Subject: [PATCH 01/36] use jthread to join thread on scope exit --- src/TNetwork.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index d0393e89..accc8c67 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -718,7 +718,7 @@ void TNetwork::TCPClient(const std::weak_ptr& c) { OnConnect(c); RegisterThread("(" + std::to_string(c.lock()->GetID()) + ") \"" + c.lock()->GetName() + "\""); - std::thread QueueSync(&TNetwork::Looper, this, c); + std::jthread QueueSync(&TNetwork::Looper, this, c); while (true) { if (c.expired()) @@ -744,9 +744,6 @@ void TNetwork::TCPClient(const std::weak_ptr& c) { } } - if (QueueSync.joinable()) - QueueSync.join(); - if (!c.expired()) { auto Client = c.lock(); OnDisconnect(c); From 7925caf1a2df1399e77875266343beba9b588013 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Tue, 7 Apr 2026 21:19:33 +0000 Subject: [PATCH 02/36] catch boost tcp's remote_endpoint() throwing and crashing the server this happens when, somehow, the client disconnects before we get here. I had this happen when breaking in the debugger and continuing, which leads to clients timing out (client-side timeouts). --- src/TNetwork.cpp | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index accc8c67..b76829a2 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -661,15 +661,19 @@ void TNetwork::DisconnectClient(const std::weak_ptr &c, const std::stri void TNetwork::DisconnectClient(TClient &c, const std::string &R) { if (c.IsDisconnected()) return; - std::string ClientIP = c.GetTCPSock().remote_endpoint().address().to_string(); - mClientMapMutex.lock(); - if (mClientMap[ClientIP] > 0) { - mClientMap[ClientIP]--; - } - if (mClientMap[ClientIP] == 0) { - mClientMap.erase(ClientIP); + try { + std::string ClientIP = c.GetTCPSock().remote_endpoint().address().to_string(); + mClientMapMutex.lock(); + if (mClientMap[ClientIP] > 0) { + mClientMap[ClientIP]--; + } + if (mClientMap[ClientIP] == 0) { + mClientMap.erase(ClientIP); + } + mClientMapMutex.unlock(); + } catch (const std::exception& e) { + beammp_debugf("Failed to disconnect client (already disconnected?). There might be a lingering client in the IP-to-client map. This is not an error."); } - mClientMapMutex.unlock(); c.Disconnect(R); } From 6fca901aa2e4f19c200dcfac1938d52fa904fbcf Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Wed, 8 Apr 2026 18:44:37 +0200 Subject: [PATCH 03/36] implement connection limiter to replace manual limiting code --- CMakeLists.txt | 2 ++ include/TConnectionLimiter.h | 62 ++++++++++++++++++++++++++++++++++ include/TNetwork.h | 6 ++-- src/TConnectionLimiter.cpp | 61 ++++++++++++++++++++++++++++++++++ src/TNetwork.cpp | 64 ++++++++---------------------------- 5 files changed, 142 insertions(+), 53 deletions(-) create mode 100644 include/TConnectionLimiter.h create mode 100644 src/TConnectionLimiter.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 67e758d6..cdec565b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,6 +52,7 @@ set(PRJ_HEADERS include/Settings.h include/Profiling.h include/ChronoWrapper.h + include/TConnectionLimiter.h ) # add all source files (.cpp) to this, except the one with main() set(PRJ_SOURCES @@ -78,6 +79,7 @@ set(PRJ_SOURCES src/Settings.cpp src/Profiling.cpp src/ChronoWrapper.cpp + src/TConnectionLimiter.cpp ) find_package(Lua REQUIRED) diff --git a/include/TConnectionLimiter.h b/include/TConnectionLimiter.h new file mode 100644 index 00000000..ab70fa50 --- /dev/null +++ b/include/TConnectionLimiter.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include +#include +#include +#include + +class TConnectionLimiter { +public: + class TGuard { + public: + TGuard() = default; + TGuard(TConnectionLimiter* owner, std::string ip); + + TGuard(const TGuard&) = delete; + TGuard& operator=(const TGuard&) = delete; + + ~TGuard() { Release(); } + + // not threadsafe + TGuard(TGuard&& other) noexcept { *this = std::move(other); } + // not threadsafe + TGuard& operator=(TGuard&& other) noexcept; + + private: + friend class TConnectionLimiter; + + // not threadsafe + void Release(); + + TConnectionLimiter* mOwner { nullptr }; + std::string mIp; + }; + + TConnectionLimiter(size_t maxPerIp, size_t maxGlobal); + + [[nodiscard]] std::optional TryAcquire(const std::string& ip); + +private: + void Release(const std::string& ip) { + std::unique_lock Lock { mMutex }; + auto It = mPerIp.find(ip); + if (It != mPerIp.end()) { + // this guard exists to avoid underflow in case something goes wrong and this gets called too many times + if (It->second > 0) + --It->second; + if (It->second == 0) + mPerIp.erase(It); + } + // this guard exists to avoid underflow in case something goes wrong and this gets called too many times + if (mGlobal > 0) + --mGlobal; + } + + const size_t mMaxPerIp; + const size_t mMaxGlobal; + + std::mutex mMutex { }; + std::unordered_map mPerIp { }; + size_t mGlobal = 0; +}; diff --git a/include/TNetwork.h b/include/TNetwork.h index d5e54731..d1391b45 100644 --- a/include/TNetwork.h +++ b/include/TNetwork.h @@ -20,6 +20,7 @@ #include "BoostAliases.h" #include "Compat.h" +#include "TConnectionLimiter.h" #include "TResourceManager.h" #include "TServer.h" #include @@ -40,7 +41,7 @@ class TNetwork { void DisconnectClient(const std::weak_ptr& c, const std::string& R); void DisconnectClient(TClient& c, const std::string& R); [[nodiscard]] bool SyncClient(const std::weak_ptr& c); - void Identify(TConnection&& client); + void Identify(TConnection&& client, TConnectionLimiter::TGuard&&); std::shared_ptr Authentication(TConnection&& ClientConnection); void SyncResources(TClient& c); [[nodiscard]] bool UDPSend(TClient& Client, std::vector Data); @@ -61,8 +62,7 @@ class TNetwork { std::thread mUDPThread; std::thread mTCPThread; std::mutex mOpenIDMutex; - std::map mClientMap; - std::mutex mClientMapMutex; + TConnectionLimiter mConnectionLimiter; std::vector UDPRcvFromClient(boost::asio::ip::udp::endpoint& ClientEndpoint); void OnConnect(const std::weak_ptr& c); diff --git a/src/TConnectionLimiter.cpp b/src/TConnectionLimiter.cpp new file mode 100644 index 00000000..c8b8c4b2 --- /dev/null +++ b/src/TConnectionLimiter.cpp @@ -0,0 +1,61 @@ +#include "TConnectionLimiter.h" +#include +#include +#include "Common.h" + +TConnectionLimiter::TConnectionLimiter(size_t maxPerIp, size_t maxGlobal) + : mMaxPerIp(maxPerIp) + , mMaxGlobal(maxGlobal) { + if (maxPerIp == 0) { + beammp_errorf("Max connections count per IP is set to zero; the server would reject ALL connections"); + throw std::runtime_error("Invalid maximum connections per IP setting"); + } + if (maxGlobal == 0) { + beammp_errorf("Max connection count is set to zero; the server would reject ALL connections"); + throw std::runtime_error("Invalid maximum connections setting"); + } +} + +std::optional TConnectionLimiter::TryAcquire(const std::string& ip) { + std::unique_lock Lock { mMutex }; + if (mGlobal >= mMaxGlobal) return std::nullopt; + // `It` is the inserted element (so if insertion worked, its 0), or its the element that + // was already there, in which case we must check the ip connection limit + auto [It, _] = mPerIp.try_emplace(ip, 0); + if (It->second >= mMaxPerIp) { + return std::nullopt; + } + // now increment the counter finally + ++It->second; + ++mGlobal; + // RAII guard will drop the count once destructed + return TGuard(this, ip); +} + +TConnectionLimiter::TGuard::TGuard(TConnectionLimiter* owner, std::string ip) + : mOwner(owner) + , mIp(std::move(ip)) { + beammp_debugf("Acquired connection guard for {}", ip); +} + +TConnectionLimiter::TGuard& TConnectionLimiter::TGuard::operator=(TGuard&& other) noexcept { + // identity check + if (this != &other) { + Release(); + mOwner = other.mOwner; + mIp = std::move(other.mIp); + other.mOwner = nullptr; + } + return *this; +} + +void TConnectionLimiter::TGuard::Release() { + beammp_debugf("Trying to release connection guard for {} ...", mIp); + if (mOwner) { + mOwner->Release(mIp); + mOwner = nullptr; + beammp_debugf("... Released connection guard for {}", mIp); + } else { + beammp_debugf("... Connection guard for {} was already released (nothing happened)", mIp); + } +} diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index b76829a2..71e8f44c 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -20,6 +20,7 @@ #include "Client.h" #include "Common.h" #include "LuaAPI.h" +#include "TConnectionLimiter.h" #include "THeartbeatThread.h" #include "TLuaEngine.h" #include "TScopedTimer.h" @@ -59,7 +60,8 @@ TNetwork::TNetwork(TServer& Server, TPPSMonitor& PPSMonitor, TResourceManager& R : mServer(Server) , mPPSMonitor(PPSMonitor) , mUDPSock(Server.IoCtx()) - , mResourceManager(ResourceManager) { + , mResourceManager(ResourceManager) + , mConnectionLimiter(MAX_CONCURRENT_CONNECTIONS, MAX_GLOBAL_CONNECTIONS){ Application::SetSubsystemStatus("TCPNetwork", Application::Status::Starting); Application::SetSubsystemStatus("UDPNetwork", Application::Status::Starting); Application::RegisterShutdownHandler([&] { @@ -247,22 +249,17 @@ void TNetwork::TCPServerMain() { boost::asio::ip::tcp::socket ClientSocket = Acceptor.accept(ClientEp, ec); std::string ClientIP = ClientEp.address().to_string(); if (!ec) { - mClientMapMutex.lock(); - if (mClientMap[ClientIP] >= MAX_CONCURRENT_CONNECTIONS) { - beammp_debugf("The connection was rejected for {}, as it had {} concurrent connections.", ClientIP, mClientMap[ClientIP]); - } - else if (mClientMap.size() >= MAX_GLOBAL_CONNECTIONS) { - beammp_debugf("The connection was rejected for {}, as there are {} global connections.", ClientIP, mClientMap.size()); - } - else { + auto MaybeGuard = mConnectionLimiter.TryAcquire(ClientEp.address().to_v6().to_string()); + if (MaybeGuard.has_value()) { + // move-swap to avoid copy ctor (deleted) + auto Guard = std::move(MaybeGuard.value()); TConnection Conn { std::move(ClientSocket), ClientEp }; - std::thread ID(&TNetwork::Identify, this, std::move(Conn)); + std::thread ID(&TNetwork::Identify, this, std::move(Conn), std::move(Guard)); ID.detach(); // TODO: Add to a queue and attempt to join periodically - mClientMap[ClientIP]++; + } else { + beammp_errorf("Connection rejected for {} due to the global or concurrent connection limit", ClientIP); } - mClientMapMutex.unlock(); - } - else { + } else { beammp_errorf("Failed to accept() new client: {}", ec.message()); } } catch (const std::exception& e) { @@ -276,7 +273,7 @@ void TNetwork::TCPServerMain() { #include "Json.h" namespace json = rapidjson; -void TNetwork::Identify(TConnection&& RawConnection) { +void TNetwork::Identify(TConnection&& RawConnection, TConnectionLimiter::TGuard&& Guard) { RegisterThreadAuto(); char Code; @@ -285,17 +282,8 @@ void TNetwork::Identify(TConnection&& RawConnection) { // TODO: is this right?! beammp_debug("Error occured reading code"); RawConnection.Socket.shutdown(socket_base::shutdown_both, ec); - mClientMapMutex.lock(); - { - std::string ClientIP = RawConnection.SockAddr.address().to_string(); - if (mClientMap[ClientIP] > 0) { - mClientMap[ClientIP]--; - } - if (mClientMap[ClientIP] == 0) { - mClientMap.erase(ClientIP); - } - } - mClientMapMutex.unlock(); + // TODO: is this right too? + RawConnection.Socket.close(ec); return; } std::shared_ptr Client { nullptr }; @@ -326,17 +314,6 @@ void TNetwork::Identify(TConnection&& RawConnection) { beammp_errorf("Error during handling of code {} - client left in invalid state, closing socket: {}", Code, e.what()); boost::system::error_code ec; RawConnection.Socket.shutdown(boost::asio::socket_base::shutdown_both, ec); - mClientMapMutex.lock(); - { - std::string ClientIP = RawConnection.SockAddr.address().to_string(); - if (mClientMap[ClientIP] > 0) { - mClientMap[ClientIP]--; - } - if (mClientMap[ClientIP] == 0) { - mClientMap.erase(ClientIP); - } - } - mClientMapMutex.unlock(); if (ec) { beammp_debugf("Failed to shutdown client socket: {}", ec.message()); } @@ -661,19 +638,6 @@ void TNetwork::DisconnectClient(const std::weak_ptr &c, const std::stri void TNetwork::DisconnectClient(TClient &c, const std::string &R) { if (c.IsDisconnected()) return; - try { - std::string ClientIP = c.GetTCPSock().remote_endpoint().address().to_string(); - mClientMapMutex.lock(); - if (mClientMap[ClientIP] > 0) { - mClientMap[ClientIP]--; - } - if (mClientMap[ClientIP] == 0) { - mClientMap.erase(ClientIP); - } - mClientMapMutex.unlock(); - } catch (const std::exception& e) { - beammp_debugf("Failed to disconnect client (already disconnected?). There might be a lingering client in the IP-to-client map. This is not an error."); - } c.Disconnect(R); } From a59f7296d0df557ba208fba87424f1ce45ba6582 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Wed, 8 Apr 2026 18:06:58 +0200 Subject: [PATCH 04/36] avoid logging moved-from ip --- src/TConnectionLimiter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TConnectionLimiter.cpp b/src/TConnectionLimiter.cpp index c8b8c4b2..4f3c1792 100644 --- a/src/TConnectionLimiter.cpp +++ b/src/TConnectionLimiter.cpp @@ -35,7 +35,7 @@ std::optional TConnectionLimiter::TryAcquire(const s TConnectionLimiter::TGuard::TGuard(TConnectionLimiter* owner, std::string ip) : mOwner(owner) , mIp(std::move(ip)) { - beammp_debugf("Acquired connection guard for {}", ip); + beammp_debugf("Acquired connection guard for {}", mIp); } TConnectionLimiter::TGuard& TConnectionLimiter::TGuard::operator=(TGuard&& other) noexcept { From a46f2d3204f57f44ef352d210a54e55fb9904e90 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Wed, 8 Apr 2026 18:07:10 +0200 Subject: [PATCH 05/36] set moved-from ip to explicitly obvious moved-from string --- src/TConnectionLimiter.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/TConnectionLimiter.cpp b/src/TConnectionLimiter.cpp index 4f3c1792..86e686b7 100644 --- a/src/TConnectionLimiter.cpp +++ b/src/TConnectionLimiter.cpp @@ -45,6 +45,8 @@ TConnectionLimiter::TGuard& TConnectionLimiter::TGuard::operator=(TGuard&& other mOwner = other.mOwner; mIp = std::move(other.mIp); other.mOwner = nullptr; + // setting this so its obvious when this happens, instead of being UB or empty string + other.mIp = ""; } return *this; } From b5642247a65ed4848d7dd733e35667858f9778d1 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Wed, 8 Apr 2026 18:07:21 +0200 Subject: [PATCH 06/36] accept without EP, explicitly (and fallibly) get the endpoint --- src/TNetwork.cpp | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index 71e8f44c..a698918c 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -245,23 +245,27 @@ void TNetwork::TCPServerMain() { beammp_debug("shutdown during TCP wait for accept loop"); break; } - boost::asio::ip::tcp::endpoint ClientEp; - boost::asio::ip::tcp::socket ClientSocket = Acceptor.accept(ClientEp, ec); - std::string ClientIP = ClientEp.address().to_string(); - if (!ec) { - auto MaybeGuard = mConnectionLimiter.TryAcquire(ClientEp.address().to_v6().to_string()); - if (MaybeGuard.has_value()) { - // move-swap to avoid copy ctor (deleted) - auto Guard = std::move(MaybeGuard.value()); - TConnection Conn { std::move(ClientSocket), ClientEp }; - std::thread ID(&TNetwork::Identify, this, std::move(Conn), std::move(Guard)); - ID.detach(); // TODO: Add to a queue and attempt to join periodically - } else { - beammp_errorf("Connection rejected for {} due to the global or concurrent connection limit", ClientIP); - } - } else { + boost::asio::ip::tcp::socket ClientSocket = Acceptor.accept(ec); + if (ec) { beammp_errorf("Failed to accept() new client: {}", ec.message()); + continue; + } + boost::asio::ip::tcp::endpoint ClientEp = ClientSocket.remote_endpoint(ec); + if (ec) { + beammp_errorf("Accepted socket but failed to query remote endpoint for IP address: {}", ec.message()); + continue; + } + std::string ClientIP = ClientEp.address().to_string(); + auto MaybeGuard = mConnectionLimiter.TryAcquire(ClientIP); + if (!MaybeGuard.has_value()) { + beammp_errorf("Connection rejected for {} due to the global or concurrent connection limit", ClientIP); + continue; } + // move-swap to avoid copy ctor (deleted) + auto Guard = std::move(MaybeGuard.value()); + TConnection Conn { std::move(ClientSocket), ClientEp }; + std::thread ID(&TNetwork::Identify, this, std::move(Conn), std::move(Guard)); + ID.detach(); // TODO: Add to a queue and attempt to join periodically } catch (const std::exception& e) { beammp_errorf("Exception in accept routine: {}", e.what()); } From 0ac5da23669c8677089b59eecc9aceb34864028d Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Wed, 8 Apr 2026 18:29:40 +0200 Subject: [PATCH 07/36] add connection limiter stats to `status` command --- include/TConnectionLimiter.h | 10 ++++++++++ include/TNetwork.h | 1 + src/TConnectionLimiter.cpp | 18 ++++++++++++++++++ src/TConsole.cpp | 6 ++++++ 4 files changed, 35 insertions(+) diff --git a/include/TConnectionLimiter.h b/include/TConnectionLimiter.h index ab70fa50..62df05f5 100644 --- a/include/TConnectionLimiter.h +++ b/include/TConnectionLimiter.h @@ -8,6 +8,15 @@ class TConnectionLimiter { public: + struct TStats { + size_t CurrentGlobal = 0; + size_t MaxGlobal = 0; + size_t ActiveIpBuckets = 0; + size_t CurrentMaxPerIp = 0; + size_t MaxPerIp = 0; + size_t SaturatedIpBuckets = 0; + }; + class TGuard { public: TGuard() = default; @@ -36,6 +45,7 @@ class TConnectionLimiter { TConnectionLimiter(size_t maxPerIp, size_t maxGlobal); [[nodiscard]] std::optional TryAcquire(const std::string& ip); + [[nodiscard]] TStats GetStats(); private: void Release(const std::string& ip) { diff --git a/include/TNetwork.h b/include/TNetwork.h index d1391b45..0d3df814 100644 --- a/include/TNetwork.h +++ b/include/TNetwork.h @@ -48,6 +48,7 @@ class TNetwork { void SendToAll(TClient* c, const std::vector& Data, bool Self, bool Rel); void UpdatePlayer(TClient& Client); boost::system::error_code ReadWithTimeout(TConnection& Connection, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout); + [[nodiscard]] TConnectionLimiter::TStats GetConnectionLimiterStats() { return mConnectionLimiter.GetStats(); } TResourceManager& ResourceManager() const { return mResourceManager; } diff --git a/src/TConnectionLimiter.cpp b/src/TConnectionLimiter.cpp index 86e686b7..86969633 100644 --- a/src/TConnectionLimiter.cpp +++ b/src/TConnectionLimiter.cpp @@ -32,6 +32,24 @@ std::optional TConnectionLimiter::TryAcquire(const s return TGuard(this, ip); } +TConnectionLimiter::TStats TConnectionLimiter::GetStats() { + std::unique_lock Lock { mMutex }; + TStats Stats; + Stats.CurrentGlobal = mGlobal; + Stats.MaxGlobal = mMaxGlobal; + Stats.ActiveIpBuckets = mPerIp.size(); + Stats.MaxPerIp = mMaxPerIp; + for (const auto& [_, Count] : mPerIp) { + if (Count > Stats.CurrentMaxPerIp) { + Stats.CurrentMaxPerIp = Count; + } + if (Count >= mMaxPerIp) { + ++Stats.SaturatedIpBuckets; + } + } + return Stats; +} + TConnectionLimiter::TGuard::TGuard(TConnectionLimiter* owner, std::string ip) : mOwner(owner) , mIp(std::move(ip)) { diff --git a/src/TConsole.cpp b/src/TConsole.cpp index fb1dc8b0..317ec160 100644 --- a/src/TConsole.cpp +++ b/src/TConsole.cpp @@ -653,6 +653,7 @@ void TConsole::Command_Status(const std::string&, const std::vector SystemsShutdownList = SystemsShutdownList.substr(0, SystemsShutdownList.size() - 2); auto ElapsedTime = mLuaEngine->Server().UptimeTimer.GetElapsedTime(); + auto ConnectionLimiterStats = mLuaEngine->Network().GetConnectionLimiterStats(); Status << "BeamMP-Server Status:\n" << "\tTotal Players: " << mLuaEngine->Server().ClientCount() << "\n" @@ -667,6 +668,11 @@ void TConsole::Command_Status(const std::string&, const std::vector << "\t\tStates: " << mLuaEngine->GetLuaStateCount() << "\n" << "\t\tEvent timers: " << mLuaEngine->GetTimedEventsCount() << "\n" << "\t\tEvent handlers: " << mLuaEngine->GetRegisteredEventHandlerCount() << "\n" + << "\tConnection limiter:\n" + << "\t\tActive/Max global: " << ConnectionLimiterStats.CurrentGlobal << "/" << ConnectionLimiterStats.MaxGlobal << "\n" + << "\t\tActive IP buckets: " << ConnectionLimiterStats.ActiveIpBuckets << "\n" + << "\t\tHighest single IP load: " << ConnectionLimiterStats.CurrentMaxPerIp << "/" << ConnectionLimiterStats.MaxPerIp << "\n" + << "\t\tSaturated IP buckets: " << ConnectionLimiterStats.SaturatedIpBuckets << "\n" << "\tSubsystems:\n" << "\t\tGood/Starting/Bad: " << SystemsGood << "/" << SystemsStarting << "/" << SystemsBad << "\n" << "\t\tShutting down/Shut down: " << SystemsShuttingDown << "/" << SystemsShutdown << "\n" From be6f3a2fd1210da53a9996212fdbaaf97af516b5 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Wed, 8 Apr 2026 18:37:58 +0200 Subject: [PATCH 08/36] fix spammy logs on guard release --- src/TConnectionLimiter.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/TConnectionLimiter.cpp b/src/TConnectionLimiter.cpp index 86969633..b2c681e2 100644 --- a/src/TConnectionLimiter.cpp +++ b/src/TConnectionLimiter.cpp @@ -70,12 +70,9 @@ TConnectionLimiter::TGuard& TConnectionLimiter::TGuard::operator=(TGuard&& other } void TConnectionLimiter::TGuard::Release() { - beammp_debugf("Trying to release connection guard for {} ...", mIp); if (mOwner) { mOwner->Release(mIp); mOwner = nullptr; - beammp_debugf("... Released connection guard for {}", mIp); - } else { - beammp_debugf("... Connection guard for {} was already released (nothing happened)", mIp); + beammp_debugf("Released connection guard for {}", mIp); } } From 58e5317687b3c524b4255d4847af9b6a7ab34f73 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Wed, 8 Apr 2026 20:01:37 +0000 Subject: [PATCH 09/36] refactor ReadWithTimeout to not spawn a thread + use an fd each read this was exhausting file descriptors with enough concurrent reads, from what I can tell. Either way, spawning a new OS thread per read is not the way. Because this is so critical, I added unit-tests for that behavior. --- src/TNetwork.cpp | 156 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 137 insertions(+), 19 deletions(-) diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index a698918c..0ba9e764 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -56,6 +57,75 @@ static void CompressProperly(std::vector& Data) { Data = CombinedData; } +static boost::system::error_code ReadSocketWithTimeout(ip::tcp::socket& Socket, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout) { + boost::system::error_code Ec; + const bool WasNonBlocking = Socket.non_blocking(); + Socket.non_blocking(true, Ec); + if (Ec) { + return Ec; + } + + auto RestoreBlockingMode = [&]() { + boost::system::error_code IgnoreEc; + Socket.non_blocking(WasNonBlocking, IgnoreEc); + }; + + const auto Deadline = std::chrono::steady_clock::now() + Timeout; + auto* Data = static_cast(Buf); + size_t TotalRead = 0; + + while (TotalRead < Len) { + const size_t BytesRead = Socket.read_some(boost::asio::buffer(Data + TotalRead, Len - TotalRead), Ec); + if (!Ec) { + TotalRead += BytesRead; + continue; + } + + if (Ec == error::would_block || Ec == error::try_again) { + if (std::chrono::steady_clock::now() >= Deadline) { + RestoreBlockingMode(); + return error::timed_out; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + + RestoreBlockingMode(); + return Ec; + } + + RestoreBlockingMode(); + return Ec; +} + +// for unit-tests, otherwise unused +static bool OpenLoopbackSocketPair(io_context& IoCtx, ip::tcp::socket& ClientSocket, ip::tcp::socket& ServerSocket, boost::system::error_code& Ec) { + ip::tcp::acceptor Acceptor(IoCtx); + Acceptor.open(ip::tcp::v4(), Ec); + if (Ec) { + return true; + } + Acceptor.bind(ip::tcp::endpoint(ip::address_v4::loopback(), 0), Ec); + if (Ec) { + return true; + } + Acceptor.listen(socket_base::max_listen_connections, Ec); + if (Ec) { + return true; + } + + const auto Port = Acceptor.local_endpoint(Ec).port(); + if (Ec) { + return true; + } + ClientSocket.connect(ip::tcp::endpoint(ip::address_v4::loopback(), Port), Ec); + if (Ec) { + return true; + } + Acceptor.accept(ServerSocket, Ec); + return true; +} + TNetwork::TNetwork(TServer& Server, TPPSMonitor& PPSMonitor, TResourceManager& ResourceManager) : mServer(Server) , mPPSMonitor(PPSMonitor) @@ -741,30 +811,78 @@ void TNetwork::UpdatePlayer(TClient& Client) { boost::system::error_code TNetwork::ReadWithTimeout(TConnection& Connection, void *Buf, size_t Len, std::chrono::steady_clock::duration Timeout) { - io_context TimerIO; - steady_timer Timer(TimerIO); - Timer.expires_after(Timeout); + return ReadSocketWithTimeout(Connection.Socket, Buf, Len, Timeout); +} - std::atomic TimedOut = false; +TEST_CASE("ReadSocketWithTimeout returns timed_out when peer sends no data") { + io_context IoCtx; + boost::system::error_code Ec; + ip::tcp::socket ClientSocket(IoCtx); + ip::tcp::socket ServerSocket(IoCtx); + OpenLoopbackSocketPair(IoCtx, ClientSocket, ServerSocket, Ec); + REQUIRE(!Ec); - Timer.async_wait([&](const boost::system::error_code& ec) { - if (!ec) { - TimedOut = true; - Connection.Socket.cancel(); - } - }); - std::thread TimerThread([&]() { TimerIO.run(); }); + uint8_t ReadByte = 0; + const auto ReadEc = ReadSocketWithTimeout(ServerSocket, &ReadByte, 1, std::chrono::milliseconds(50)); - boost::system::error_code ReadEc; - boost::asio::read(Connection.Socket, boost::asio::buffer(Buf, Len), ReadEc); + CHECK(ReadEc == error::timed_out); +} - TimerIO.stop(); - TimerThread.join(); +TEST_CASE("ReadSocketWithTimeout reads small payload") { + io_context IoCtx; + boost::system::error_code Ec; + ip::tcp::socket ClientSocket(IoCtx); + ip::tcp::socket ServerSocket(IoCtx); + OpenLoopbackSocketPair(IoCtx, ClientSocket, ServerSocket, Ec); + REQUIRE(!Ec); + + const std::array Sent { 'O', 'K' }; + boost::asio::write(ClientSocket, boost::asio::buffer(Sent), Ec); + REQUIRE(!Ec); + std::array Received {}; + const auto ReadEc = ReadSocketWithTimeout(ServerSocket, Received.data(), Received.size(), std::chrono::milliseconds(200)); + + CHECK(!ReadEc); + CHECK(Received == Sent); +} - if (TimedOut.load()) { - return error::timed_out; // synthesize a clean timeout error - } - return ReadEc; //Succes! +TEST_CASE("ReadSocketWithTimeout reads large payload") { + io_context IoCtx; + boost::system::error_code Ec; + ip::tcp::socket ClientSocket(IoCtx); + ip::tcp::socket ServerSocket(IoCtx); + OpenLoopbackSocketPair(IoCtx, ClientSocket, ServerSocket, Ec); + REQUIRE(!Ec); + + constexpr size_t PacketSize = 2 * 1024 * 1024; + std::vector Sent(PacketSize, uint8_t(0x7A)); + boost::asio::write(ClientSocket, boost::asio::buffer(Sent), Ec); + REQUIRE(!Ec); + std::vector Received(PacketSize); + const auto ReadEc = ReadSocketWithTimeout(ServerSocket, Received.data(), Received.size(), std::chrono::seconds(2)); + + CHECK(!ReadEc); + CHECK(Received == Sent); +} + +TEST_CASE("ReadSocketWithTimeout can timeout then retry successfully") { + io_context IoCtx; + boost::system::error_code Ec; + ip::tcp::socket ClientSocket(IoCtx); + ip::tcp::socket ServerSocket(IoCtx); + OpenLoopbackSocketPair(IoCtx, ClientSocket, ServerSocket, Ec); + REQUIRE(!Ec); + + uint8_t Received = 0; + CHECK(ReadSocketWithTimeout(ServerSocket, &Received, 1, std::chrono::milliseconds(20)) == error::timed_out); + + const uint8_t Sent = 0x42; + boost::asio::write(ClientSocket, boost::asio::buffer(&Sent, 1), Ec); + REQUIRE(!Ec); + const auto ReadEc = ReadSocketWithTimeout(ServerSocket, &Received, 1, std::chrono::milliseconds(200)); + + CHECK(!ReadEc); + CHECK(Received == Sent); } void TNetwork::OnDisconnect(const std::weak_ptr& ClientPtr) { From a3cfe47e7828edf637deced73e80497cbf231e7a Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Thu, 9 Apr 2026 18:08:54 +0200 Subject: [PATCH 10/36] fix http connections eating all fds with this many http connnections, we were exhausting all available file descriptors, leading to a dead server that keeps CLOSE_WAIT tcp sockets. Because we want to retain the behavior that we keep connections open for reuse, we instead make a pool of 8 curl instances now, shared between all the different requests. --- src/Http.cpp | 118 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 107 insertions(+), 11 deletions(-) diff --git a/src/Http.cpp b/src/Http.cpp index 0aba6639..ed075d62 100644 --- a/src/Http.cpp +++ b/src/Http.cpp @@ -18,15 +18,12 @@ #include "Http.h" -#include "Client.h" #include "Common.h" #include "CustomAssert.h" -#include "LuaAPI.h" +#include #include #include -#include -#include using json = nlohmann::json; @@ -37,17 +34,107 @@ static size_t CurlWriteCallback(void* contents, size_t size, size_t nmemb, void* return size * nmemb; } +struct CurlDeleter { + void operator()(CURL* c) const noexcept { + if (c) + curl_easy_cleanup(c); + } +}; + +static std::mutex gCurlPoolMutex; +static std::map gCurlPool; // false = free, true = in use +constexpr size_t MAX_CURL_POOL_SIZE = 8; + +static CURL* AcquireCurl() { + std::unique_lock Lock(gCurlPoolMutex); + for (auto& [Curl, InUse] : gCurlPool) { + if (!InUse) { + InUse = true; + curl_easy_reset(Curl); + return Curl; + } + } + + // none found, if we're under the limit add one! + if (gCurlPool.size() >= MAX_CURL_POOL_SIZE) { + beammp_debugf("Ran out of curl handles for network requests, skipping this request"); + return nullptr; + } + + CURL* Curl = curl_easy_init(); + if (!Curl) { + // failed, ignore + return nullptr; + } + + gCurlPool[Curl] = true; + return Curl; +} + +static void ReleaseCurl(CURL* curl) { + if (!curl) { + return; + } + + std::unique_lock Lock(gCurlPoolMutex); + auto It = gCurlPool.find(curl); + if (It != gCurlPool.end()) { + It->second = false; + } +} + +/// RAII container for curl handles to ensure correct release +class CurlLease { +private: + CURL* mHandle = nullptr; +public: + + CurlLease() : mHandle(AcquireCurl()) {} + ~CurlLease() { + ReleaseCurl(mHandle); + } + CurlLease(const CurlLease&) = delete; + CurlLease& operator=(const CurlLease&) = delete; + + CurlLease(CurlLease&& other) noexcept + : mHandle(other.mHandle) { + other.mHandle = nullptr; + } + + CurlLease& operator=(CurlLease&& other) noexcept { + if (this != &other) { + ReleaseCurl(mHandle); + mHandle = other.mHandle; + other.mHandle = nullptr; + } + return *this; + } + + CURL* GetHandle() const noexcept { + return mHandle; + } +}; + std::string Http::GET(const std::string& url, unsigned int* status) { std::string Ret; - static thread_local CURL* curl = curl_easy_init(); + CurlLease Lease{}; + CURL* curl = Lease.GetHandle(); if (curl) { CURLcode res; char errbuf[CURL_ERROR_SIZE]; curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_IPRESOLVE, CURL_IPRESOLVE_V4); curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, (void*)&Ret); - curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 10); // seconds + curl_easy_setopt(curl, CURLOPT_WRITEDATA, reinterpret_cast(&Ret)); + + // ensure we dont keep connections open for long + curl_easy_setopt(curl, CURLOPT_MAXCONNECTS, 8L); + curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 10L); + curl_easy_setopt(curl, CURLOPT_TIMEOUT, 30L); + curl_easy_setopt(curl, CURLOPT_NOSIGNAL, 1L); + curl_easy_setopt(curl, CURLOPT_MAXAGE_CONN, 60L); + curl_easy_setopt(curl, CURLOPT_MAXLIFETIME_CONN, 300L); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); curl_easy_setopt(curl, CURLOPT_ERRORBUFFER, errbuf); errbuf[0] = 0; @@ -71,15 +158,16 @@ std::string Http::GET(const std::string& url, unsigned int* status) { std::string Http::POST(const std::string& url, const std::string& body, const std::string& ContentType, unsigned int* status, const std::map& headers) { std::string Ret; - static thread_local CURL* curl = curl_easy_init(); + CurlLease Lease{}; + CURL* curl = Lease.GetHandle(); if (curl) { CURLcode res; char errbuf[CURL_ERROR_SIZE]; curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_IPRESOLVE, CURL_IPRESOLVE_V4); curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, (void*)&Ret); - curl_easy_setopt(curl, CURLOPT_POST, 1); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, reinterpret_cast(&Ret)); + curl_easy_setopt(curl, CURLOPT_POST, 1L); curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, body.size()); struct curl_slist* list = nullptr; @@ -90,7 +178,15 @@ std::string Http::POST(const std::string& url, const std::string& body, const st } curl_easy_setopt(curl, CURLOPT_HTTPHEADER, list); - curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 10); // seconds + + // ensure we dont keep connections open for long + curl_easy_setopt(curl, CURLOPT_MAXCONNECTS, 8L); + curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 10L); + curl_easy_setopt(curl, CURLOPT_TIMEOUT, 30L); + curl_easy_setopt(curl, CURLOPT_NOSIGNAL, 1L); + curl_easy_setopt(curl, CURLOPT_MAXAGE_CONN, 60L); + curl_easy_setopt(curl, CURLOPT_MAXLIFETIME_CONN, 300L); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); curl_easy_setopt(curl, CURLOPT_ERRORBUFFER, errbuf); errbuf[0] = 0; From 1f55c35f2b3dea6965a40832e970b084a1d05306 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Thu, 9 Apr 2026 18:54:21 +0200 Subject: [PATCH 11/36] fix ReadSocketWithTimeout to use dedicated io context polled on a new jthread jthread so it's cancellable --- include/TNetwork.h | 14 ++++ src/TNetwork.cpp | 172 +++++++++++++++++++++++++++------------------ 2 files changed, 118 insertions(+), 68 deletions(-) diff --git a/include/TNetwork.h b/include/TNetwork.h index 0d3df814..bd5ba580 100644 --- a/include/TNetwork.h +++ b/include/TNetwork.h @@ -28,6 +28,19 @@ struct TConnection; +class TIoPollThread { +public: + TIoPollThread(); + ~TIoPollThread(); + + boost::asio::io_context& IoCtx() noexcept { return mIoCtx; } + +private: + boost::asio::io_context mIoCtx; + boost::asio::executor_work_guard mWorkGuard; + std::jthread mThread; +}; + class TNetwork { public: TNetwork(TServer& Server, TPPSMonitor& PPSMonitor, TResourceManager& ResourceManager); @@ -63,6 +76,7 @@ class TNetwork { std::thread mUDPThread; std::thread mTCPThread; std::mutex mOpenIDMutex; + TIoPollThread mIoCtxPoller; TConnectionLimiter mConnectionLimiter; std::vector UDPRcvFromClient(boost::asio::ip::udp::endpoint& ClientEndpoint); diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index 0ba9e764..c3bdc5eb 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -28,12 +28,16 @@ #include #include #include +#include #include #include #include #include #include +#include +#include #include +#include #include #include #include @@ -57,47 +61,6 @@ static void CompressProperly(std::vector& Data) { Data = CombinedData; } -static boost::system::error_code ReadSocketWithTimeout(ip::tcp::socket& Socket, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout) { - boost::system::error_code Ec; - const bool WasNonBlocking = Socket.non_blocking(); - Socket.non_blocking(true, Ec); - if (Ec) { - return Ec; - } - - auto RestoreBlockingMode = [&]() { - boost::system::error_code IgnoreEc; - Socket.non_blocking(WasNonBlocking, IgnoreEc); - }; - - const auto Deadline = std::chrono::steady_clock::now() + Timeout; - auto* Data = static_cast(Buf); - size_t TotalRead = 0; - - while (TotalRead < Len) { - const size_t BytesRead = Socket.read_some(boost::asio::buffer(Data + TotalRead, Len - TotalRead), Ec); - if (!Ec) { - TotalRead += BytesRead; - continue; - } - - if (Ec == error::would_block || Ec == error::try_again) { - if (std::chrono::steady_clock::now() >= Deadline) { - RestoreBlockingMode(); - return error::timed_out; - } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - continue; - } - - RestoreBlockingMode(); - return Ec; - } - - RestoreBlockingMode(); - return Ec; -} - // for unit-tests, otherwise unused static bool OpenLoopbackSocketPair(io_context& IoCtx, ip::tcp::socket& ClientSocket, ip::tcp::socket& ServerSocket, boost::system::error_code& Ec) { ip::tcp::acceptor Acceptor(IoCtx); @@ -131,7 +94,8 @@ TNetwork::TNetwork(TServer& Server, TPPSMonitor& PPSMonitor, TResourceManager& R , mPPSMonitor(PPSMonitor) , mUDPSock(Server.IoCtx()) , mResourceManager(ResourceManager) - , mConnectionLimiter(MAX_CONCURRENT_CONNECTIONS, MAX_GLOBAL_CONNECTIONS){ + , mIoCtxPoller() + , mConnectionLimiter(MAX_CONCURRENT_CONNECTIONS, MAX_GLOBAL_CONNECTIONS) { Application::SetSubsystemStatus("TCPNetwork", Application::Status::Starting); Application::SetSubsystemStatus("UDPNetwork", Application::Status::Starting); Application::RegisterShutdownHandler([&] { @@ -809,49 +773,103 @@ void TNetwork::UpdatePlayer(TClient& Client) { //(void)Respond(Client, Packet, true); } -boost::system::error_code TNetwork::ReadWithTimeout(TConnection& Connection, void *Buf, size_t Len, std::chrono::steady_clock::duration Timeout) -{ - return ReadSocketWithTimeout(Connection.Socket, Buf, Len, Timeout); +static boost::system::error_code ReadSocketWithTimeout( + boost::asio::io_context& io, + boost::asio::ip::tcp::socket& socket, + void* buffer, + std::size_t length, + std::chrono::steady_clock::duration timeout); + +boost::system::error_code TNetwork::ReadWithTimeout(TConnection& Connection, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout) { + return ReadSocketWithTimeout(mIoCtxPoller.IoCtx(), Connection.Socket, Buf, Len, Timeout); +} + +static boost::system::error_code ReadSocketWithTimeout( + boost::asio::io_context& IoCtx, + boost::asio::ip::tcp::socket& Socket, + void* Buffer, + std::size_t Length, + std::chrono::steady_clock::duration Timeout) { + namespace asio = boost::asio; + using boost::system::error_code; + + struct TTimeoutState { + explicit TTimeoutState(asio::io_context& IoCtx) + : Timer(IoCtx) { } + + asio::steady_timer Timer; + std::promise> Promise; + std::atomic_bool Completed { false }; + }; + + auto State = std::make_shared(IoCtx); + auto Future = State->Promise.get_future(); + + asio::async_read( + Socket, + asio::buffer(Buffer, Length), + [State](error_code ec, std::size_t n) { + if (!State->Completed.exchange(true)) { + State->Timer.cancel(); + State->Promise.set_value({ ec, n }); + } + }); + + State->Timer.expires_after(Timeout); + State->Timer.async_wait( + [State, &Socket](error_code ec) { + if (ec == asio::error::operation_aborted) + return; + + if (!State->Completed.exchange(true)) { + error_code IgnoredEc; + Socket.cancel(IgnoredEc); + State->Promise.set_value({ asio::error::timed_out, 0 }); + } + }); + + auto [ec, NRead] = Future.get(); + return ec; } TEST_CASE("ReadSocketWithTimeout returns timed_out when peer sends no data") { - io_context IoCtx; + TIoPollThread TimerThread; boost::system::error_code Ec; - ip::tcp::socket ClientSocket(IoCtx); - ip::tcp::socket ServerSocket(IoCtx); - OpenLoopbackSocketPair(IoCtx, ClientSocket, ServerSocket, Ec); + ip::tcp::socket ClientSocket(TimerThread.IoCtx()); + ip::tcp::socket ServerSocket(TimerThread.IoCtx()); + OpenLoopbackSocketPair(TimerThread.IoCtx(), ClientSocket, ServerSocket, Ec); REQUIRE(!Ec); uint8_t ReadByte = 0; - const auto ReadEc = ReadSocketWithTimeout(ServerSocket, &ReadByte, 1, std::chrono::milliseconds(50)); + const auto ReadEc = ReadSocketWithTimeout(TimerThread.IoCtx(), ServerSocket, &ReadByte, 1, std::chrono::milliseconds(50)); CHECK(ReadEc == error::timed_out); } TEST_CASE("ReadSocketWithTimeout reads small payload") { - io_context IoCtx; + TIoPollThread TimerThread; boost::system::error_code Ec; - ip::tcp::socket ClientSocket(IoCtx); - ip::tcp::socket ServerSocket(IoCtx); - OpenLoopbackSocketPair(IoCtx, ClientSocket, ServerSocket, Ec); + ip::tcp::socket ClientSocket(TimerThread.IoCtx()); + ip::tcp::socket ServerSocket(TimerThread.IoCtx()); + OpenLoopbackSocketPair(TimerThread.IoCtx(), ClientSocket, ServerSocket, Ec); REQUIRE(!Ec); const std::array Sent { 'O', 'K' }; boost::asio::write(ClientSocket, boost::asio::buffer(Sent), Ec); REQUIRE(!Ec); - std::array Received {}; - const auto ReadEc = ReadSocketWithTimeout(ServerSocket, Received.data(), Received.size(), std::chrono::milliseconds(200)); + std::array Received { }; + const auto ReadEc = ReadSocketWithTimeout(TimerThread.IoCtx(), ServerSocket, Received.data(), Received.size(), std::chrono::milliseconds(200)); CHECK(!ReadEc); CHECK(Received == Sent); } TEST_CASE("ReadSocketWithTimeout reads large payload") { - io_context IoCtx; + TIoPollThread TimerThread; boost::system::error_code Ec; - ip::tcp::socket ClientSocket(IoCtx); - ip::tcp::socket ServerSocket(IoCtx); - OpenLoopbackSocketPair(IoCtx, ClientSocket, ServerSocket, Ec); + ip::tcp::socket ClientSocket(TimerThread.IoCtx()); + ip::tcp::socket ServerSocket(TimerThread.IoCtx()); + OpenLoopbackSocketPair(TimerThread.IoCtx(), ClientSocket, ServerSocket, Ec); REQUIRE(!Ec); constexpr size_t PacketSize = 2 * 1024 * 1024; @@ -859,27 +877,27 @@ TEST_CASE("ReadSocketWithTimeout reads large payload") { boost::asio::write(ClientSocket, boost::asio::buffer(Sent), Ec); REQUIRE(!Ec); std::vector Received(PacketSize); - const auto ReadEc = ReadSocketWithTimeout(ServerSocket, Received.data(), Received.size(), std::chrono::seconds(2)); + const auto ReadEc = ReadSocketWithTimeout(TimerThread.IoCtx(), ServerSocket, Received.data(), Received.size(), std::chrono::seconds(2)); CHECK(!ReadEc); CHECK(Received == Sent); } TEST_CASE("ReadSocketWithTimeout can timeout then retry successfully") { - io_context IoCtx; + TIoPollThread TimerThread; boost::system::error_code Ec; - ip::tcp::socket ClientSocket(IoCtx); - ip::tcp::socket ServerSocket(IoCtx); - OpenLoopbackSocketPair(IoCtx, ClientSocket, ServerSocket, Ec); + ip::tcp::socket ClientSocket(TimerThread.IoCtx()); + ip::tcp::socket ServerSocket(TimerThread.IoCtx()); + OpenLoopbackSocketPair(TimerThread.IoCtx(), ClientSocket, ServerSocket, Ec); REQUIRE(!Ec); uint8_t Received = 0; - CHECK(ReadSocketWithTimeout(ServerSocket, &Received, 1, std::chrono::milliseconds(20)) == error::timed_out); + CHECK(ReadSocketWithTimeout(TimerThread.get(), ServerSocket, &Received, 1, std::chrono::milliseconds(20)) == error::timed_out); const uint8_t Sent = 0x42; boost::asio::write(ClientSocket, boost::asio::buffer(&Sent, 1), Ec); REQUIRE(!Ec); - const auto ReadEc = ReadSocketWithTimeout(ServerSocket, &Received, 1, std::chrono::milliseconds(200)); + const auto ReadEc = ReadSocketWithTimeout(TimerThread.IoCtx(), ServerSocket, &Received, 1, std::chrono::milliseconds(200)); CHECK(!ReadEc); CHECK(Received == Sent); @@ -1047,9 +1065,9 @@ void TNetwork::SendFile(TClient& c, const std::string& UnsafeName) { #if defined(BEAMMP_LINUX) #include #include +#include #include #include -#include #endif void TNetwork::SendFileToClient(TClient& c, size_t Size, const std::string& Name) { TScopedTimer timer(fmt::format("Download of '{}' for client {}", Name, c.GetID())); @@ -1275,3 +1293,21 @@ std::vector TNetwork::UDPRcvFromClient(boost::asio::ip::udp::endpoint& beammp_assert(Rcv <= Ret.size()); return std::vector(Ret.begin(), Ret.begin() + Rcv); } + +TIoPollThread::TIoPollThread() + : mWorkGuard(boost::asio::make_work_guard(mIoCtx)) + , mThread([this](std::stop_token StopToken) { + while (!StopToken.stop_requested()) { + try { + mIoCtx.run(); + break; + } catch (...) { + mIoCtx.restart(); + } + } + }) { } + +TIoPollThread::~TIoPollThread() { + mWorkGuard.reset(); + mIoCtx.stop(); +} From be0d8d53340e288d783bfa0b1d51669909607b5f Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Thu, 9 Apr 2026 18:00:15 +0200 Subject: [PATCH 12/36] make connection reject msg a debug message, avoiding spam on ddos --- src/TNetwork.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index c3bdc5eb..6d153178 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -292,7 +292,7 @@ void TNetwork::TCPServerMain() { std::string ClientIP = ClientEp.address().to_string(); auto MaybeGuard = mConnectionLimiter.TryAcquire(ClientIP); if (!MaybeGuard.has_value()) { - beammp_errorf("Connection rejected for {} due to the global or concurrent connection limit", ClientIP); + beammp_debugf("Connection rejected for {} due to the global or concurrent connection limit", ClientIP); continue; } // move-swap to avoid copy ctor (deleted) From c08eefdf69772076fbee4f87fe6eb67cacde121b Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Thu, 9 Apr 2026 18:49:10 +0200 Subject: [PATCH 13/36] move TIoPollThread out and use it in TServer the previous IoCtx was never being polled --- CMakeLists.txt | 2 ++ include/TIoPollThread.h | 36 +++++++++++++++++++ include/TNetwork.h | 15 -------- include/TServer.h | 6 ++-- src/TIoPollThread.cpp | 37 ++++++++++++++++++++ src/TNetwork.cpp | 77 +++++++++++++++-------------------------- 6 files changed, 105 insertions(+), 68 deletions(-) create mode 100644 include/TIoPollThread.h create mode 100644 src/TIoPollThread.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index cdec565b..00ca619f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,7 @@ set(PRJ_HEADERS include/TConfig.h include/TConsole.h include/THeartbeatThread.h + include/TIoPollThread.h include/TLuaEngine.h include/TLuaPlugin.h include/TNetwork.h @@ -66,6 +67,7 @@ set(PRJ_SOURCES src/TConfig.cpp src/TConsole.cpp src/THeartbeatThread.cpp + src/TIoPollThread.cpp src/TLuaEngine.cpp src/TLuaPlugin.cpp src/TNetwork.cpp diff --git a/include/TIoPollThread.h b/include/TIoPollThread.h new file mode 100644 index 00000000..4c8e805f --- /dev/null +++ b/include/TIoPollThread.h @@ -0,0 +1,36 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2024 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +#pragma once + +#include +#include +#include + +class TIoPollThread { +public: + TIoPollThread(); + ~TIoPollThread(); + + boost::asio::io_context& IoCtx() noexcept { return mIoCtx; } + +private: + boost::asio::io_context mIoCtx; + boost::asio::executor_work_guard mWorkGuard; + std::jthread mThread; +}; diff --git a/include/TNetwork.h b/include/TNetwork.h index bd5ba580..4f5fb34f 100644 --- a/include/TNetwork.h +++ b/include/TNetwork.h @@ -23,24 +23,10 @@ #include "TConnectionLimiter.h" #include "TResourceManager.h" #include "TServer.h" -#include #include struct TConnection; -class TIoPollThread { -public: - TIoPollThread(); - ~TIoPollThread(); - - boost::asio::io_context& IoCtx() noexcept { return mIoCtx; } - -private: - boost::asio::io_context mIoCtx; - boost::asio::executor_work_guard mWorkGuard; - std::jthread mThread; -}; - class TNetwork { public: TNetwork(TServer& Server, TPPSMonitor& PPSMonitor, TResourceManager& ResourceManager); @@ -76,7 +62,6 @@ class TNetwork { std::thread mUDPThread; std::thread mTCPThread; std::mutex mOpenIDMutex; - TIoPollThread mIoCtxPoller; TConnectionLimiter mConnectionLimiter; std::vector UDPRcvFromClient(boost::asio::ip::udp::endpoint& ClientEndpoint); diff --git a/include/TServer.h b/include/TServer.h index 4ee9f969..3202c9ea 100644 --- a/include/TServer.h +++ b/include/TServer.h @@ -20,6 +20,7 @@ #include "IThreaded.h" #include "RWMutex.h" +#include "TIoPollThread.h" #include "TScopedTimer.h" #include #include @@ -50,11 +51,10 @@ class TServer final { const TScopedTimer UptimeTimer; - // asio io context - io_context& IoCtx() { return mIoCtx; } + io_context& IoCtx() { return mIoCtxPoller.IoCtx(); } private: - io_context mIoCtx {}; + TIoPollThread mIoCtxPoller; TClientSet mClients; mutable RWMutex mClientsMutex; static void ParseVehicle(TClient& c, const std::string& Pckt, TNetwork& Network); diff --git a/src/TIoPollThread.cpp b/src/TIoPollThread.cpp new file mode 100644 index 00000000..dabf8c89 --- /dev/null +++ b/src/TIoPollThread.cpp @@ -0,0 +1,37 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2024 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +#include "TIoPollThread.h" + +TIoPollThread::TIoPollThread() + : mWorkGuard(boost::asio::make_work_guard(mIoCtx)) + , mThread([this](std::stop_token StopToken) { + while (!StopToken.stop_requested()) { + try { + mIoCtx.run(); + break; + } catch (...) { + mIoCtx.restart(); + } + } + }) { } + +TIoPollThread::~TIoPollThread() { + mWorkGuard.reset(); + mIoCtx.stop(); +} diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index 6d153178..31492001 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -46,7 +46,7 @@ typedef boost::asio::detail::socket_option::integer rcv static constexpr uint8_t MAX_CONCURRENT_CONNECTIONS = 10; static constexpr uint8_t MAX_GLOBAL_CONNECTIONS = 128; -static constexpr uint8_t READ_TIMEOUT_S = 10; //seconds +static constexpr uint8_t READ_TIMEOUT_S = 10; // seconds std::vector StringToVector(const std::string& Str) { return std::vector(Str.data(), Str.data() + Str.size()); @@ -94,7 +94,6 @@ TNetwork::TNetwork(TServer& Server, TPPSMonitor& PPSMonitor, TResourceManager& R , mPPSMonitor(PPSMonitor) , mUDPSock(Server.IoCtx()) , mResourceManager(ResourceManager) - , mIoCtxPoller() , mConnectionLimiter(MAX_CONCURRENT_CONNECTIONS, MAX_GLOBAL_CONNECTIONS) { Application::SetSubsystemStatus("TCPNetwork", Application::Status::Starting); Application::SetSubsystemStatus("UDPNetwork", Application::Status::Starting); @@ -161,13 +160,13 @@ void TNetwork::UDPServerMain() { + std::to_string(Application::Settings.getAsInt(Settings::Key::General_MaxPlayers)) + (" Clients")); while (!Application::IsShuttingDown()) { try { - boost::asio::ip::udp::endpoint remote_client_ep {}; + boost::asio::ip::udp::endpoint remote_client_ep { }; std::vector Data = UDPRcvFromClient(remote_client_ep); if (Data.empty()) { continue; } if (Data.size() == 1 && Data.at(0) == 'P') { - mUDPSock.send_to(boost::asio::const_buffer("P", 1), remote_client_ep, {}, ec); + mUDPSock.send_to(boost::asio::const_buffer("P", 1), remote_client_ep, { }, ec); // ignore errors (void)ec; continue; @@ -188,7 +187,7 @@ void TNetwork::UDPServerMain() { } if (Client->GetID() == ID) { - if (Client->GetUDPAddr() == boost::asio::ip::udp::endpoint {} && !Client->IsUDPConnected() && !Client->GetMagic().empty()) { + if (Client->GetUDPAddr() == boost::asio::ip::udp::endpoint { } && !Client->IsUDPConnected() && !Client->GetMagic().empty()) { if (Data.size() != 66) { beammp_debugf("Invalid size for UDP value. IP: {} ID: {}", remote_client_ep.address().to_string(), ID); return false; @@ -201,7 +200,7 @@ void TNetwork::UDPServerMain() { return false; } - Client->SetMagic({}); + Client->SetMagic({ }); Client->SetUDPAddr(remote_client_ep); Client->SetIsUDPConnected(true); return false; @@ -253,7 +252,7 @@ void TNetwork::TCPServerMain() { beammp_warnf("WARNING: On FreeBSD, for IPv4 to work, you must run `sysctl net.inet6.ip6.v6only=0`!"); beammp_debugf("This is due to an annoying detail in the *BSDs: In the name of security, unsetting the IPV6_V6ONLY option does not work by default (but does not fail???), as it allows IPv4 mapped IPv6 like ::ffff:127.0.0.1, which they deem a security issue. For more information, see RFC 2553, section 3.7."); #endif - socket_base::linger LingerOpt {}; + socket_base::linger LingerOpt { }; LingerOpt.enabled(false); Listener.set_option(LingerOpt, ec); if (ec) { @@ -362,8 +361,6 @@ void TNetwork::Identify(TConnection&& RawConnection, TConnectionLimiter::TGuard& } } - - std::string HashPassword(const std::string& str) { std::stringstream ret; unsigned char* hash = SHA256(reinterpret_cast(str.c_str()), str.length(), nullptr); @@ -424,8 +421,8 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { std::string AuthKey = Application::Settings.getAsString(Settings::Key::General_AuthKey); std::string ClientIp = Client->GetIdentifiers().at("ip"); - nlohmann::json AuthReq {}; - std::string AuthResStr {}; + nlohmann::json AuthReq { }; + std::string AuthResStr { }; try { AuthReq = nlohmann::json { { "key", Key }, @@ -523,8 +520,8 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { } if (!NotAllowed && !NotAllowedWithReason && mServer.ClientCount() >= size_t(Application::Settings.getAsInt(Settings::Key::General_MaxPlayers)) && !BypassLimit) { - NotAllowedWithReason = true; - Reason = "Server full!"; + NotAllowedWithReason = true; + Reason = "Server full!"; } if (NotAllowedWithReason) { @@ -592,10 +589,10 @@ bool TNetwork::TCPSend(TClient& c, const std::vector& Data, bool IsSync std::vector TNetwork::TCPRcv(TClient& c) { if (c.IsDisconnected()) { beammp_error("Client disconnected, cancelling TCPRcv"); - return {}; + return { }; } - int32_t Header {}; + int32_t Header { }; auto& Sock = c.GetTCPSock(); boost::system::error_code ec; @@ -604,14 +601,14 @@ std::vector TNetwork::TCPRcv(TClient& c) { if (ec) { // TODO: handle this case (read failed) beammp_debugf("TCPRcv: Reading header failed: {}", ec.message()); - return {}; + return { }; } Header = *reinterpret_cast(HeaderData.data()); if (Header < 0) { ClientKick(c, "Invalid packet - header negative"); beammp_errorf("Client {} send negative TCP header, ignoring packet", c.GetID()); - return {}; + return { }; } std::vector Data; @@ -623,13 +620,13 @@ std::vector TNetwork::TCPRcv(TClient& c) { } else { ClientKick(c, "Header size limit exceeded"); beammp_warn("Client " + c.GetName() + " (" + std::to_string(c.GetID()) + ") sent header larger than expected - assuming malicious intent and disconnecting the client."); - return {}; + return { }; } auto N = boost::asio::read(Sock, boost::asio::buffer(Data), ec); if (ec) { // TODO: handle this case properly beammp_debugf("TCPRcv: Reading data failed: {}", ec.message()); - return {}; + return { }; } if (N != Header) { @@ -641,7 +638,7 @@ std::vector TNetwork::TCPRcv(TClient& c) { Data.erase(Data.begin(), Data.begin() + ABG.size()); try { return DeComp(Data); - } catch (const InvalidDataError& ) { + } catch (const InvalidDataError&) { beammp_errorf("Failed to decompress packet from a client. The receive failed and the client may be disconnected as a result"); // return empty -> error return std::vector(); @@ -663,19 +660,17 @@ void TNetwork::ClientKick(TClient& c, const std::string& R) { DisconnectClient(c, "Kicked"); } -void TNetwork::DisconnectClient(const std::weak_ptr &c, const std::string &R) -{ +void TNetwork::DisconnectClient(const std::weak_ptr& c, const std::string& R) { if (auto locked = c.lock()) { DisconnectClient(*locked, R); - } - else { + } else { beammp_debugf("Tried to disconnect a non existant client with reason: {}", R); } } -void TNetwork::DisconnectClient(TClient &c, const std::string &R) -{ - if (c.IsDisconnected()) return; +void TNetwork::DisconnectClient(TClient& c, const std::string& R) { + if (c.IsDisconnected()) + return; c.Disconnect(R); } @@ -690,7 +685,7 @@ void TNetwork::Looper(const std::weak_ptr& c) { if (!Client->IsSyncing() && Client->IsSynced() && Client->MissedPacketQueueSize() != 0) { // debug("sending " + std::to_string(Client->MissedPacketQueueSize()) + " queued packets"); while (Client->MissedPacketQueueSize() > 0) { - std::vector QData {}; + std::vector QData { }; { // locked context std::unique_lock lock(Client->MissedPacketQueueMutex()); if (Client->MissedPacketQueueSize() <= 0) { @@ -781,7 +776,7 @@ static boost::system::error_code ReadSocketWithTimeout( std::chrono::steady_clock::duration timeout); boost::system::error_code TNetwork::ReadWithTimeout(TConnection& Connection, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout) { - return ReadSocketWithTimeout(mIoCtxPoller.IoCtx(), Connection.Socket, Buf, Len, Timeout); + return ReadSocketWithTimeout(mServer.IoCtx(), Connection.Socket, Buf, Len, Timeout); } static boost::system::error_code ReadSocketWithTimeout( @@ -892,7 +887,7 @@ TEST_CASE("ReadSocketWithTimeout can timeout then retry successfully") { REQUIRE(!Ec); uint8_t Received = 0; - CHECK(ReadSocketWithTimeout(TimerThread.get(), ServerSocket, &Received, 1, std::chrono::milliseconds(20)) == error::timed_out); + CHECK(ReadSocketWithTimeout(TimerThread.IoCtx(), ServerSocket, &Received, 1, std::chrono::milliseconds(20)) == error::timed_out); const uint8_t Sent = 0x42; boost::asio::write(ClientSocket, boost::asio::buffer(&Sent, 1), Ec); @@ -1283,31 +1278,13 @@ bool TNetwork::UDPSend(TClient& Client, std::vector Data) { } std::vector TNetwork::UDPRcvFromClient(boost::asio::ip::udp::endpoint& ClientEndpoint) { - std::array Ret {}; + std::array Ret { }; boost::system::error_code ec; const auto Rcv = mUDPSock.receive_from(boost::asio::mutable_buffer(Ret.data(), Ret.size()), ClientEndpoint, 0, ec); if (ec) { beammp_errorf("UDP recvfrom() failed: {}", ec.message()); - return {}; + return { }; } beammp_assert(Rcv <= Ret.size()); return std::vector(Ret.begin(), Ret.begin() + Rcv); } - -TIoPollThread::TIoPollThread() - : mWorkGuard(boost::asio::make_work_guard(mIoCtx)) - , mThread([this](std::stop_token StopToken) { - while (!StopToken.stop_requested()) { - try { - mIoCtx.run(); - break; - } catch (...) { - mIoCtx.restart(); - } - } - }) { } - -TIoPollThread::~TIoPollThread() { - mWorkGuard.reset(); - mIoCtx.stop(); -} From b1946ef1d980949b08eded73fc3c407398ebeed0 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Thu, 9 Apr 2026 18:03:07 +0200 Subject: [PATCH 14/36] use ReadWithTimeout until fully completed auth --- include/TNetwork.h | 3 ++- src/TNetwork.cpp | 30 ++++++++++++++++++++++++------ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/include/TNetwork.h b/include/TNetwork.h index 4f5fb34f..be55b0df 100644 --- a/include/TNetwork.h +++ b/include/TNetwork.h @@ -35,7 +35,7 @@ class TNetwork { [[nodiscard]] bool SendLarge(TClient& c, std::vector Data, bool isSync = false); [[nodiscard]] bool Respond(TClient& c, const std::vector& MSG, bool Rel, bool isSync = false); std::shared_ptr CreateClient(boost::asio::ip::tcp::socket&& TCPSock); - std::vector TCPRcv(TClient& c); + std::vector TCPRcv(TClient& c, bool WithTimeout = false); void ClientKick(TClient& c, const std::string& R); void DisconnectClient(const std::weak_ptr& c, const std::string& R); void DisconnectClient(TClient& c, const std::string& R); @@ -46,6 +46,7 @@ class TNetwork { [[nodiscard]] bool UDPSend(TClient& Client, std::vector Data); void SendToAll(TClient* c, const std::vector& Data, bool Self, bool Rel); void UpdatePlayer(TClient& Client); + boost::system::error_code ReadWithTimeout(boost::asio::ip::tcp::socket& Socket, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout); boost::system::error_code ReadWithTimeout(TConnection& Connection, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout); [[nodiscard]] TConnectionLimiter::TStats GetConnectionLimiterStats() { return mConnectionLimiter.GetStats(); } diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index 31492001..3055b133 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -388,7 +388,7 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { beammp_info("Identifying new ClientConnection..."); - auto Data = TCPRcv(*Client); + auto Data = TCPRcv(*Client, true); constexpr std::string_view VC = "VC"; if (Data.size() > 3 && std::equal(Data.begin(), Data.begin() + VC.size(), VC.begin(), VC.end())) { @@ -410,7 +410,7 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { // TODO: handle } - Data = TCPRcv(*Client); + Data = TCPRcv(*Client, true); if (Data.size() > 50) { ClientKick(*Client, "Invalid Key (too long)!"); @@ -586,7 +586,7 @@ bool TNetwork::TCPSend(TClient& c, const std::vector& Data, bool IsSync return true; } -std::vector TNetwork::TCPRcv(TClient& c) { +std::vector TNetwork::TCPRcv(TClient& c, bool WithTimeout) { if (c.IsDisconnected()) { beammp_error("Client disconnected, cancelling TCPRcv"); return { }; @@ -597,7 +597,11 @@ std::vector TNetwork::TCPRcv(TClient& c) { boost::system::error_code ec; std::array HeaderData; - boost::asio::read(Sock, boost::asio::buffer(HeaderData), ec); + if (WithTimeout) { + ec = ReadWithTimeout(Sock, HeaderData.data(), HeaderData.size(), std::chrono::seconds(READ_TIMEOUT_S)); + } else { + boost::asio::read(Sock, boost::asio::buffer(HeaderData), ec); + } if (ec) { // TODO: handle this case (read failed) beammp_debugf("TCPRcv: Reading header failed: {}", ec.message()); @@ -622,7 +626,17 @@ std::vector TNetwork::TCPRcv(TClient& c) { beammp_warn("Client " + c.GetName() + " (" + std::to_string(c.GetID()) + ") sent header larger than expected - assuming malicious intent and disconnecting the client."); return { }; } - auto N = boost::asio::read(Sock, boost::asio::buffer(Data), ec); + std::size_t N = 0; + if (WithTimeout) { + if (!Data.empty()) { + ec = ReadWithTimeout(Sock, Data.data(), Data.size(), std::chrono::seconds(READ_TIMEOUT_S)); + if (!ec) { + N = Data.size(); + } + } + } else { + N = boost::asio::read(Sock, boost::asio::buffer(Data), ec); + } if (ec) { // TODO: handle this case properly beammp_debugf("TCPRcv: Reading data failed: {}", ec.message()); @@ -776,7 +790,11 @@ static boost::system::error_code ReadSocketWithTimeout( std::chrono::steady_clock::duration timeout); boost::system::error_code TNetwork::ReadWithTimeout(TConnection& Connection, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout) { - return ReadSocketWithTimeout(mServer.IoCtx(), Connection.Socket, Buf, Len, Timeout); + return ReadWithTimeout(Connection.Socket, Buf, Len, Timeout); +} + +boost::system::error_code TNetwork::ReadWithTimeout(boost::asio::ip::tcp::socket& Socket, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout) { + return ReadSocketWithTimeout(mServer.IoCtx(), Socket, Buf, Len, Timeout); } static boost::system::error_code ReadSocketWithTimeout( From 4da5a7440a06418237c73a83ddfa01bfc9999c8c Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Thu, 9 Apr 2026 18:08:53 +0200 Subject: [PATCH 15/36] increase http curl pool to 128 --- src/Http.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Http.cpp b/src/Http.cpp index ed075d62..98bb6ce2 100644 --- a/src/Http.cpp +++ b/src/Http.cpp @@ -43,7 +43,7 @@ struct CurlDeleter { static std::mutex gCurlPoolMutex; static std::map gCurlPool; // false = free, true = in use -constexpr size_t MAX_CURL_POOL_SIZE = 8; +constexpr size_t MAX_CURL_POOL_SIZE = 128; static CURL* AcquireCurl() { std::unique_lock Lock(gCurlPoolMutex); From d26e53a975b151c37f071ee2c4caae425fc3312b Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Thu, 9 Apr 2026 18:32:35 +0200 Subject: [PATCH 16/36] handle error cases early --- src/TNetwork.cpp | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index 3055b133..a027f932 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -389,6 +389,11 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { beammp_info("Identifying new ClientConnection..."); auto Data = TCPRcv(*Client, true); + if (Data.empty()) { + beammp_debug("Authentication failed: did not receive version packet"); + ClientKick(*Client, "Connection closed during version handshake"); + return nullptr; + } constexpr std::string_view VC = "VC"; if (Data.size() > 3 && std::equal(Data.begin(), Data.begin() + VC.size(), VC.begin(), VC.end())) { @@ -411,6 +416,11 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { } Data = TCPRcv(*Client, true); + if (Data.empty()) { + beammp_debug("Authentication failed: did not receive auth key packet"); + ClientKick(*Client, "Connection closed during authentication"); + return nullptr; + } if (Data.size() > 50) { ClientKick(*Client, "Invalid Key (too long)!"); @@ -418,6 +428,10 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { } std::string Key(reinterpret_cast(Data.data()), Data.size()); + if (Key.empty()) { + ClientKick(*Client, "Invalid Key (empty)!"); + return nullptr; + } std::string AuthKey = Application::Settings.getAsString(Settings::Key::General_AuthKey); std::string ClientIp = Client->GetIdentifiers().at("ip"); @@ -783,7 +797,6 @@ void TNetwork::UpdatePlayer(TClient& Client) { } static boost::system::error_code ReadSocketWithTimeout( - boost::asio::io_context& io, boost::asio::ip::tcp::socket& socket, void* buffer, std::size_t length, @@ -794,11 +807,10 @@ boost::system::error_code TNetwork::ReadWithTimeout(TConnection& Connection, voi } boost::system::error_code TNetwork::ReadWithTimeout(boost::asio::ip::tcp::socket& Socket, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout) { - return ReadSocketWithTimeout(mServer.IoCtx(), Socket, Buf, Len, Timeout); + return ReadSocketWithTimeout(Socket, Buf, Len, Timeout); } static boost::system::error_code ReadSocketWithTimeout( - boost::asio::io_context& IoCtx, boost::asio::ip::tcp::socket& Socket, void* Buffer, std::size_t Length, @@ -807,15 +819,15 @@ static boost::system::error_code ReadSocketWithTimeout( using boost::system::error_code; struct TTimeoutState { - explicit TTimeoutState(asio::io_context& IoCtx) - : Timer(IoCtx) { } + explicit TTimeoutState(const boost::asio::any_io_executor& Executor) + : Timer(Executor) { } asio::steady_timer Timer; std::promise> Promise; std::atomic_bool Completed { false }; }; - auto State = std::make_shared(IoCtx); + auto State = std::make_shared(Socket.get_executor()); auto Future = State->Promise.get_future(); asio::async_read( @@ -854,7 +866,7 @@ TEST_CASE("ReadSocketWithTimeout returns timed_out when peer sends no data") { REQUIRE(!Ec); uint8_t ReadByte = 0; - const auto ReadEc = ReadSocketWithTimeout(TimerThread.IoCtx(), ServerSocket, &ReadByte, 1, std::chrono::milliseconds(50)); + const auto ReadEc = ReadSocketWithTimeout(ServerSocket, &ReadByte, 1, std::chrono::milliseconds(50)); CHECK(ReadEc == error::timed_out); } @@ -871,7 +883,7 @@ TEST_CASE("ReadSocketWithTimeout reads small payload") { boost::asio::write(ClientSocket, boost::asio::buffer(Sent), Ec); REQUIRE(!Ec); std::array Received { }; - const auto ReadEc = ReadSocketWithTimeout(TimerThread.IoCtx(), ServerSocket, Received.data(), Received.size(), std::chrono::milliseconds(200)); + const auto ReadEc = ReadSocketWithTimeout(ServerSocket, Received.data(), Received.size(), std::chrono::milliseconds(200)); CHECK(!ReadEc); CHECK(Received == Sent); @@ -890,7 +902,7 @@ TEST_CASE("ReadSocketWithTimeout reads large payload") { boost::asio::write(ClientSocket, boost::asio::buffer(Sent), Ec); REQUIRE(!Ec); std::vector Received(PacketSize); - const auto ReadEc = ReadSocketWithTimeout(TimerThread.IoCtx(), ServerSocket, Received.data(), Received.size(), std::chrono::seconds(2)); + const auto ReadEc = ReadSocketWithTimeout(ServerSocket, Received.data(), Received.size(), std::chrono::seconds(2)); CHECK(!ReadEc); CHECK(Received == Sent); @@ -905,12 +917,12 @@ TEST_CASE("ReadSocketWithTimeout can timeout then retry successfully") { REQUIRE(!Ec); uint8_t Received = 0; - CHECK(ReadSocketWithTimeout(TimerThread.IoCtx(), ServerSocket, &Received, 1, std::chrono::milliseconds(20)) == error::timed_out); + CHECK(ReadSocketWithTimeout(ServerSocket, &Received, 1, std::chrono::milliseconds(20)) == error::timed_out); const uint8_t Sent = 0x42; boost::asio::write(ClientSocket, boost::asio::buffer(&Sent, 1), Ec); REQUIRE(!Ec); - const auto ReadEc = ReadSocketWithTimeout(TimerThread.IoCtx(), ServerSocket, &Received, 1, std::chrono::milliseconds(200)); + const auto ReadEc = ReadSocketWithTimeout(ServerSocket, &Received, 1, std::chrono::milliseconds(200)); CHECK(!ReadEc); CHECK(Received == Sent); @@ -1003,6 +1015,7 @@ void TNetwork::SyncResources(TClient& c) { while (!c.IsDisconnected()) { Data = TCPRcv(c); if (Data.empty()) { + DisconnectClient(c, "TCPRcv failed during resource sync"); break; } constexpr std::string_view Done = "Done"; From e9ce71d39aa2c208e9df29e8360ca76148e147a9 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Thu, 9 Apr 2026 18:38:38 +0200 Subject: [PATCH 17/36] make send file accept a non-blocking socket --- src/TNetwork.cpp | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index a027f932..c49a8de6 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -1108,17 +1108,40 @@ void TNetwork::SendFileToClient(TClient& c, size_t Size, const std::string& Name // native handle, needed in order to make native syscalls with it int socket = c.GetTCPSock().native_handle(); - ssize_t ret = 0; - auto ToSendTotal = Size; - auto Start = 0; - while (ret < ssize_t(ToSendTotal)) { - auto SysOffset = off_t(Start + size_t(ret)); - ret = sendfile(socket, fd, &SysOffset, ToSendTotal - size_t(ret)); - if (ret < 0) { - beammp_errorf("Failed to send mod '{}' to client {}: {}", Name, c.GetID(), std::strerror(errno)); + const auto ToSendTotal = Size; + size_t TotalSent = 0; + while (TotalSent < ToSendTotal) { + off_t SysOffset = off_t(TotalSent); + const ssize_t SentNow = sendfile(socket, fd, &SysOffset, ToSendTotal - TotalSent); + if (SentNow > 0) { + TotalSent += size_t(SentNow); + continue; + } + if (SentNow == 0) { + beammp_errorf("Failed to send mod '{}' to client {}: sendfile returned 0 before all bytes were sent", Name, c.GetID()); + ::close(fd); + DisconnectClient(c, "sendfile returned 0 during mod download"); return; } + + if (errno == EINTR) { + continue; + } + if (errno == EAGAIN +#if EWOULDBLOCK != EAGAIN + || errno == EWOULDBLOCK +#endif + ) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + + beammp_errorf("Failed to send mod '{}' to client {}: {}", Name, c.GetID(), std::strerror(errno)); + ::close(fd); + DisconnectClient(c, "sendfile failed during mod download"); + return; } + ::close(fd); #else std::ifstream f(Name.c_str(), std::ios::binary); From 7089e5da9ad4c8b95afb42c4fc449cb9c05bbadd Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Fri, 10 Apr 2026 18:09:57 +0200 Subject: [PATCH 18/36] fix race on disconnect between the time we check for `is_open` and the actual disconnect, the socket could already have been disconnected by another thread (TOCTOU). Furthermore, the disconnects can race causing a segfault or similar issue in the asio's internals. --- include/Client.h | 16 ++++++++++++++-- src/Client.cpp | 11 ++++++++++- src/TNetwork.cpp | 7 +++---- src/TPPSMonitor.cpp | 2 +- src/TServer.cpp | 2 +- 5 files changed, 29 insertions(+), 9 deletions(-) diff --git a/include/Client.h b/include/Client.h index 9ab2dc72..6f29bd57 100644 --- a/include/Client.h +++ b/include/Client.h @@ -18,6 +18,7 @@ #pragma once +#include #include #include #include @@ -68,8 +69,11 @@ class TClient final { std::string GetCarPositionRaw(int Ident); void SetUDPAddr(const ip::udp::endpoint& Addr) { mUDPAddress = Addr; } void SetTCPSock(ip::tcp::socket&& CSock) { mSocket = std::move(CSock); } - void Disconnect(std::string_view Reason); - bool IsDisconnected() const { return !mSocket.is_open(); } + // Returns true only for the thread that actually performs socket shutdown/close. + [[nodiscard]] bool Disconnect(std::string_view Reason); + bool IsDisconnected() const { + return mDisconnectState.load(std::memory_order_acquire) != EDisconnectState::Connected; + } // locks void DeleteCar(int Ident); [[nodiscard]] const std::unordered_map& GetIdentifiers() const { return mIdentifiers; } @@ -106,6 +110,12 @@ class TClient final { [[nodiscard]] const std::vector& GetMagic() const { return mMagic; } private: + enum class EDisconnectState { + Connected, + Disconnecting, + Disconnected + }; + void InsertVehicle(int ID, const std::string& Data); TServer& mServer; @@ -121,6 +131,8 @@ class TClient final { TSetOfVehicleData mVehicleData; SparseArray mVehiclePosition; std::string mName = "Unknown Client"; + // Once disconnect starts, this client is terminal and its socket must be treated as dead. + std::atomic mDisconnectState { EDisconnectState::Connected }; ip::tcp::socket mSocket; ip::udp::endpoint mUDPAddress {}; int mUnicycleID = -1; diff --git a/src/Client.cpp b/src/Client.cpp index ed67945f..f9ae40eb 100644 --- a/src/Client.cpp +++ b/src/Client.cpp @@ -76,7 +76,13 @@ std::string TClient::GetCarPositionRaw(int Ident) { } } -void TClient::Disconnect(std::string_view Reason) { +bool TClient::Disconnect(std::string_view Reason) { + // Do not remove this guard: concurrent close() on the same socket can crash in Asio internals. + EDisconnectState Expected = EDisconnectState::Connected; + if (!mDisconnectState.compare_exchange_strong(Expected, EDisconnectState::Disconnecting, std::memory_order_acq_rel, std::memory_order_acquire)) { + return false; + } + beammp_debugf("Disconnecting client {} for reason: {}", GetID(), Reason); boost::system::error_code ec; if (mSocket.is_open()) { @@ -91,6 +97,9 @@ void TClient::Disconnect(std::string_view Reason) { } else { beammp_debug("Socket is already closed."); } + // Terminal state: this client must not perform TCP IO after this point. + mDisconnectState.store(EDisconnectState::Disconnected, std::memory_order_release); + return true; } void TClient::SetCarPosition(int Ident, const std::string& Data) { diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index c49a8de6..e5f50e21 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -697,9 +697,8 @@ void TNetwork::DisconnectClient(const std::weak_ptr& c, const std::stri } void TNetwork::DisconnectClient(TClient& c, const std::string& R) { - if (c.IsDisconnected()) - return; - c.Disconnect(R); + // Keep this unconditional; TClient::Disconnect() is the single-winner guard. + (void)c.Disconnect(R); } void TNetwork::Looper(const std::weak_ptr& c) { @@ -740,7 +739,7 @@ void TNetwork::Looper(const std::weak_ptr& c) { void TNetwork::TCPClient(const std::weak_ptr& c) { // TODO: the c.expired() might cause issues here, remove if you end up here with your debugger - if (c.expired() || !c.lock()->GetTCPSock().is_open()) { + if (c.expired() || c.lock()->IsDisconnected()) { mServer.RemoveClient(c); return; } diff --git a/src/TPPSMonitor.cpp b/src/TPPSMonitor.cpp index ccababf7..7fe5c6fc 100644 --- a/src/TPPSMonitor.cpp +++ b/src/TPPSMonitor.cpp @@ -76,7 +76,7 @@ void TPPSMonitor::operator()() { return true; }); for (auto& ClientToKick : TimedOutClients) { - ClientToKick->Disconnect("Timeout"); + Network().DisconnectClient(*ClientToKick, "Timeout"); } TimedOutClients.clear(); if (C == 0 || mInternalPPS == 0) { diff --git a/src/TServer.cpp b/src/TServer.cpp index 587cef67..d9d98efa 100644 --- a/src/TServer.cpp +++ b/src/TServer.cpp @@ -224,7 +224,7 @@ void TServer::GlobalParser(const std::weak_ptr& Client, std::vectorDisconnect("Failed to send ping"); + Network.DisconnectClient(*LockedClient, "Failed to send ping"); } else { Network.UpdatePlayer(*LockedClient); } From 7096fe058af45420c4a5b79ce516addd96a172b6 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Mon, 13 Apr 2026 18:23:36 +0200 Subject: [PATCH 19/36] fix PanicHandler crashing itself with another panic inside sol2 When sol2 does stack::get, it can panic, which causes the stack to explode, corrupt it, and then any subsequent action crashes the server. --- src/LuaAPI.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/LuaAPI.cpp b/src/LuaAPI.cpp index 9ee24620..e7fab3ff 100644 --- a/src/LuaAPI.cpp +++ b/src/LuaAPI.cpp @@ -440,7 +440,16 @@ void LuaAPI::MP::PrintRaw(sol::variadic_args Args) { } int LuaAPI::PanicHandler(lua_State* State) { - beammp_lua_error("PANIC: " + sol::stack::get(State, 1)); + // panic path: use raw lua c api only; sol2 conversions can raise lua_error + // and re-enter panic recursively. + const int ErrorType = lua_type(State, -1); + if (ErrorType == LUA_TSTRING) { + const char* Message = lua_tostring(State, -1); + beammp_lua_error(std::string("PANIC: ") + (Message ? Message : "(null)")); + } else { + const char* ErrorTypeName = lua_typename(State, ErrorType); + beammp_lua_error(std::string("PANIC: non-string error object (") + (ErrorTypeName ? ErrorTypeName : "unknown") + ")"); + } return 0; } From e260de55afeef67f4a8d7cb04327f3950551ee44 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Wed, 15 Apr 2026 18:14:31 +0200 Subject: [PATCH 20/36] fix sol::error crash --- include/TLuaEngine.h | 6 ++++++ src/TLuaEngine.cpp | 10 ++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/include/TLuaEngine.h b/include/TLuaEngine.h index 47a5c760..96479ddc 100644 --- a/include/TLuaEngine.h +++ b/include/TLuaEngine.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -84,6 +85,11 @@ struct TLuaResult { std::make_shared() }; + void SetErrorMessageFromResult(sol::protected_function_result& Res) { + auto error = Res.get>(); + ErrorMessage = error ? *error : "(unknown error; error object is not a string value)"; + } + void MarkAsReady(); void WaitUntilReady(); }; diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index d7bee3e6..7481dd11 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -550,8 +550,8 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerLocalEvent(const std::string& Result.set(i, FnRet); ++i; } else { - sol::error Err = FnRet; - beammp_lua_error(std::string("TriggerLocalEvent: ") + Err.what()); + auto ErrStr = FnRet.get>(); + beammp_lua_error(std::string("TriggerLocalEvent: ") + (ErrStr ? *ErrStr : "(unknown error; error object is not a string)")); } } } @@ -1158,8 +1158,7 @@ void TLuaEngine::StateThreadData::operator()() { S.second->Result = std::move(Res); } else { S.second->Error = true; - sol::error Err = Res; - S.second->ErrorMessage = Err.what(); + S.second->SetErrorMessageFromResult(Res); } S.second->MarkAsReady(); } @@ -1222,8 +1221,7 @@ void TLuaEngine::StateThreadData::operator()() { Result->Result = std::move(Res); } else { Result->Error = true; - sol::error Err = Res; - Result->ErrorMessage = Err.what(); + Result->SetErrorMessageFromResult(Res); } Result->MarkAsReady(); } else { From 66f5f2b8b65904e2e5ca093180ca9877d0c6dabe Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Thu, 16 Apr 2026 18:32:54 +0200 Subject: [PATCH 21/36] fix error handling from SetErrorMessageFromResult --- include/TLuaEngine.h | 11 +++++++++-- src/TLuaEngine.cpp | 9 +++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/include/TLuaEngine.h b/include/TLuaEngine.h index 96479ddc..b11a928c 100644 --- a/include/TLuaEngine.h +++ b/include/TLuaEngine.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -86,8 +87,14 @@ struct TLuaResult { }; void SetErrorMessageFromResult(sol::protected_function_result& Res) { - auto error = Res.get>(); - ErrorMessage = error ? *error : "(unknown error; error object is not a string value)"; + if (Res.valid()) { + beammp_lua_errorf("Error was not an error"); + } + if (Res.get_type() == sol::type::string) { + ErrorMessage = Res.get().what(); + } else { + ErrorMessage = "(unknown error; error object is not inspectable)"; + } } void MarkAsReady(); diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index 7481dd11..d5e89cfd 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -550,8 +550,13 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerLocalEvent(const std::string& Result.set(i, FnRet); ++i; } else { - auto ErrStr = FnRet.get>(); - beammp_lua_error(std::string("TriggerLocalEvent: ") + (ErrStr ? *ErrStr : "(unknown error; error object is not a string)")); + std::string ErrStr; + if (FnRet.get_type() == sol::type::string) { + ErrStr = FnRet.get().what(); + } else { + ErrStr = "(unknown error; error object is not inspectable)"; + } + beammp_lua_errorf("TriggerLocalEvent: {}", ErrStr); } } } From 58d9f2e98053fb9ed5ad80210d14d64b0af351b3 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Thu, 16 Apr 2026 18:37:39 +0200 Subject: [PATCH 22/36] fix GetIdentifiers() --- src/TLuaEngine.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index d5e89cfd..c4ccda52 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -572,7 +572,7 @@ sol::table TLuaEngine::StateThreadData::Lua_GetPlayerIdentifiers(int ID) { } sol::table Result = mStateView.create_table(); for (const auto& Pair : IDs) { - Result[Pair.first] = Pair.second; + Result.add(Pair.first, Pair.second); } return Result; } else { From 4bb7232a4168fe8efd900f241cb4d5b254c60318 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Thu, 16 Apr 2026 18:44:33 +0200 Subject: [PATCH 23/36] fix GetIdentifiers building the wrong kind of table --- src/TLuaEngine.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index c4ccda52..cd3562aa 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -572,7 +572,7 @@ sol::table TLuaEngine::StateThreadData::Lua_GetPlayerIdentifiers(int ID) { } sol::table Result = mStateView.create_table(); for (const auto& Pair : IDs) { - Result.add(Pair.first, Pair.second); + Result.set(Pair.first, Pair.second) } return Result; } else { From 0c864665bb52848ffb0525209944ecaf2edf19ad Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Thu, 16 Apr 2026 18:46:48 +0200 Subject: [PATCH 24/36] fix missing semicolon --- src/TLuaEngine.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index cd3562aa..7b33284e 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -572,7 +572,7 @@ sol::table TLuaEngine::StateThreadData::Lua_GetPlayerIdentifiers(int ID) { } sol::table Result = mStateView.create_table(); for (const auto& Pair : IDs) { - Result.set(Pair.first, Pair.second) + Result.set(Pair.first, Pair.second); } return Result; } else { From b3e8d86cef3646ea32878d89f67c01480887775f Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Sat, 18 Apr 2026 18:16:16 +0200 Subject: [PATCH 25/36] refactor Lua result handling for safety this massively improves thread safety and cleanly serializes accesses into the lua engine's result objects where accesses before were extremely unsafe and could access a corrupt/invalid stack. this fixes various obscure crashes related to accessing results, without changing any observable behavior. --- CMakeLists.txt | 2 + include/TLuaEngine.h | 34 +------ include/TLuaResult.h | 92 ++++++++++++++++++ src/Common.cpp | 39 ++++++++ src/TConsole.cpp | 27 ++++-- src/TLuaEngine.cpp | 156 ++++++++++++++++-------------- src/TLuaPlugin.cpp | 5 +- src/TLuaResult.cpp | 210 +++++++++++++++++++++++++++++++++++++++++ src/TNetwork.cpp | 31 +++--- src/TPluginMonitor.cpp | 5 +- src/TServer.cpp | 32 ++++++- 11 files changed, 499 insertions(+), 134 deletions(-) create mode 100644 include/TLuaResult.h create mode 100644 src/TLuaResult.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 00ca619f..b2a76866 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,6 +54,7 @@ set(PRJ_HEADERS include/Profiling.h include/ChronoWrapper.h include/TConnectionLimiter.h + include/TLuaResult.h ) # add all source files (.cpp) to this, except the one with main() set(PRJ_SOURCES @@ -82,6 +83,7 @@ set(PRJ_SOURCES src/Profiling.cpp src/ChronoWrapper.cpp src/TConnectionLimiter.cpp + src/TLuaResult.cpp ) find_package(Lua REQUIRED) diff --git a/include/TLuaEngine.h b/include/TLuaEngine.h index b11a928c..1aee7e83 100644 --- a/include/TLuaEngine.h +++ b/include/TLuaEngine.h @@ -19,13 +19,12 @@ #pragma once #include "Profiling.h" +#include "TLuaResult.h" #include "TNetwork.h" #include "TServer.h" -#include #include #include #include -#include #include #include #include @@ -72,35 +71,6 @@ enum TLuaType { class TLuaPlugin; -struct TLuaResult { - bool Ready; - bool Error; - std::string ErrorMessage; - sol::object Result { sol::lua_nil }; - TLuaStateId StateId; - std::string Function; - std::shared_ptr ReadyMutex { - std::make_shared() - }; - std::shared_ptr ReadyCondition { - std::make_shared() - }; - - void SetErrorMessageFromResult(sol::protected_function_result& Res) { - if (Res.valid()) { - beammp_lua_errorf("Error was not an error"); - } - if (Res.get_type() == sol::type::string) { - ErrorMessage = Res.get().what(); - } else { - ErrorMessage = "(unknown error; error object is not inspectable)"; - } - } - - void MarkAsReady(); - void WaitUntilReady(); -}; - struct TLuaPluginConfig { static inline const std::string FileName = "PluginConfig.toml"; TLuaStateId StateId; @@ -242,7 +212,7 @@ class TLuaEngine : public std::enable_shared_from_this, IThreaded { std::unordered_map /* handlers */> Debug_GetEventsForState(TLuaStateId StateId); std::queue>> Debug_GetStateExecuteQueueForState(TLuaStateId StateId); std::vector Debug_GetStateFunctionQueueForState(TLuaStateId StateId); - std::vector Debug_GetResultsToCheckForState(TLuaStateId StateId); + std::vector Debug_GetResultsToCheckForState(TLuaStateId StateId); private: void CollectAndInitPlugins(); diff --git a/include/TLuaResult.h b/include/TLuaResult.h new file mode 100644 index 00000000..81ee223e --- /dev/null +++ b/include/TLuaResult.h @@ -0,0 +1,92 @@ +#pragma once + +#include "Common.h" +#include +#include +#include + +using TLuaStateId = std::string; + +struct TDetachedLuaValue { + using Array = std::vector; + using Object = std::unordered_map; + std::variant V; +}; +std::ostream& operator<<(std::ostream& os, const TDetachedLuaValue& value); + +struct TLuaResult { + struct Snapshot { + bool Error; + bool Ready; + std::string ErrorMessage; + /// CAUTION: Accessing this object causes the Lua stack of the owning state + /// to be accessed. Only call this from the owning Lua state. Otherwise, + /// use `DetachedSnapshot`. + sol::object Result; + TLuaStateId StateId; + std::string Function; + }; + + struct DetachedSnapshot { + bool Error; + bool Ready; + std::string ErrorMessage; + /// Serialized Lua result. Has no reference to the Lua state, and is safe to use + /// from any thread that acquires it. + TDetachedLuaValue Result; + TLuaStateId StateId; + std::string Function; + }; + + + TLuaResult(TLuaStateId StateId, std::string FunctionName) : mStateId(StateId), mFunction(std::move(FunctionName)) {} + + /// Marks this result as success & sets it as ready, waking all threads waiting on it. + void MarkReadySuccess(sol::object Res); + /// Marks this result as erroneous & sets it as ready, waking all threads waiting on it. + void MarkReadyError(sol::protected_function_result Res); + /// Marks this result as erroneous & sets it as ready, waking all threads waiting on it. + void MarkReadyError(std::string Res); + /// Wait (suspend) until this result is ready. Use `GetSnapshot` to get the results. + void WaitUntilReady(); + bool IsReady() const; + bool IsError() const; + TLuaStateId OwnerState() const; + void SetOwnerState(TLuaStateId StateId); + + /// To get the result or error, while in the owning state, use this function. + /// Pass your lua state ID so that the function can verify the safety of the access. + /// If you have no state ID, use `GetDetachedSnapshot`. + /// Accesses the Lua state directly when using the result. + Snapshot GetSnapshot(TLuaStateId CallingState) const; + /// Returns a snapshot with, if not an error, a serialized version of the result + /// object. In contrast to `GetSnapshot`, this does not access the Lua stack when + /// accessing `.Result`. + DetachedSnapshot GetDetachedSnapshot() const; + +private: + mutable std::mutex mMutex {}; + bool mReady { false }; + bool mError; + std::string mErrorMessage; + sol::object mResult { sol::lua_nil }; + TDetachedLuaValue mDetachedResult; + TLuaStateId mStateId; + std::string mFunction; + std::condition_variable mReadyCondition {}; + + TDetachedLuaValue Freeze(const sol::object& o, int depth = 0); + + void SetErrorMessageFromResult(sol::protected_function_result& Res) { + if (Res.valid()) { + beammp_lua_errorf("Error was not an error"); + } + if (Res.get_type() == sol::type::string) { + mErrorMessage = Res.get().what(); + } else { + mErrorMessage = "(unknown error; error object is not inspectable)"; + } + } + + void MarkAsReady(); +}; diff --git a/src/Common.cpp b/src/Common.cpp index 1785b9e2..7fff4c60 100644 --- a/src/Common.cpp +++ b/src/Common.cpp @@ -31,6 +31,7 @@ #include #include +#include "TLuaResult.h" #include "Compat.h" #include "CustomAssert.h" #include "Http.h" @@ -373,6 +374,44 @@ std::string GetPlatformAgnosticErrorString() { return "(no human-readable errors on this platform)"; #endif } +std::ostream& operator<<(std::ostream& os, const TDetachedLuaValue& value) { + + std::visit([&os](auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + size_t i = 0; + for (auto val : arg) { + if (i > 0) { + os << ", "; + } + os << val; + } + } else if constexpr (std::is_same_v>) { + size_t i = 0; + for (auto [key, val] : arg) { + if (i > 0) { + os << ", "; + } + os << key << "=" << val; + } + } else if constexpr (std::is_same_v) + os << (arg ? "true" : "false"); + else if constexpr (std::is_same_v) + os << arg; + else if constexpr (std::is_same_v) + os << arg; + else if constexpr (std::is_same_v) + os << arg; + else if constexpr (std::is_same_v) + // monostate means no result value + os << ""; + else + static_assert(false, "non-exhaustive visitor!"); + }, + value.V); + + return os; +} // TODO: add unit tests to SplitString void SplitString(const std::string& str, const char delim, std::vector& out) { diff --git a/src/TConsole.cpp b/src/TConsole.cpp index 317ec160..5e41f821 100644 --- a/src/TConsole.cpp +++ b/src/TConsole.cpp @@ -22,9 +22,9 @@ #include "Client.h" #include "CustomAssert.h" +#include "Http.h" #include "LuaAPI.h" #include "TLuaEngine.h" -#include "Http.h" #include #include @@ -689,9 +689,13 @@ void TConsole::Command_Status(const std::string&, const std::vector void TConsole::RunAsCommand(const std::string& cmd, bool IgnoreNotACommand) { auto FutureIsNonNil = [](const std::shared_ptr& Future) { - if (!Future->Error && Future->Result.valid()) { - auto Type = Future->Result.get_type(); - return Type != sol::type::lua_nil && Type != sol::type::none; + if (!Future->IsError()) { + auto Snapshot = Future->GetDetachedSnapshot(); + if (Snapshot.Result.V.valueless_by_exception() || std::get_if(&Snapshot.Result.V) != nullptr) { + // no value contained + return false; + } + return true; } return false; }; @@ -701,7 +705,7 @@ void TConsole::RunAsCommand(const std::string& cmd, bool IgnoreNotACommand) { TLuaEngine::WaitForAll(Futures, std::chrono::seconds(5)); size_t Count = 0; for (auto& Future : Futures) { - if (!Future->Error) { + if (!Future->IsError()) { ++Count; } } @@ -719,14 +723,16 @@ void TConsole::RunAsCommand(const std::string& cmd, bool IgnoreNotACommand) { std::stringstream Reply; if (NonNilFutures.size() > 1) { for (size_t i = 0; i < NonNilFutures.size(); ++i) { - Reply << NonNilFutures[i]->StateId << ": \n" - << LuaAPI::LuaToString(NonNilFutures[i]->Result); + auto Snapshot = NonNilFutures[i]->GetDetachedSnapshot(); + Reply << Snapshot.StateId << ": \n" + << Snapshot.Result; if (i < NonNilFutures.size() - 1) { Reply << "\n"; } } } else { - Reply << LuaAPI::LuaToString(NonNilFutures[0]->Result); + auto Snapshot = NonNilFutures[0]->GetDetachedSnapshot(); + Reply << Snapshot.Result; } Application::Console().WriteRaw(Reply.str()); } @@ -807,8 +813,9 @@ void TConsole::InitializeCommandline() { } else { auto Future = mLuaEngine->EnqueueScript(mStateId, { std::make_shared(TrimmedCmd), "", "" }); Future->WaitUntilReady(); - if (Future->Error) { - beammp_lua_error("error in " + mStateId + ": " + Future->ErrorMessage); + if (Future->IsError()) { + auto Snapshot = Future->GetDetachedSnapshot(); + beammp_lua_error("error in " + mStateId + ": " + Snapshot.ErrorMessage); } } } else { diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index 7b33284e..27f5e120 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -20,20 +20,25 @@ #include "Client.h" #include "Common.h" #include "CustomAssert.h" +#include "Env.h" #include "Http.h" #include "LuaAPI.h" -#include "Env.h" #include "Profiling.h" #include "TLuaPlugin.h" +#include "TLuaResult.h" #include "sol/object.hpp" #include #include #include +#include #include #include +#include #include #include +#include +#include TLuaEngine* LuaAPI::MP::Engine; @@ -74,8 +79,9 @@ void TLuaEngine::operator()() { auto Futures = TriggerEvent("onInit", ""); WaitForAll(Futures, std::chrono::seconds(5)); for (const auto& Future : Futures) { - if (Future->Error && Future->ErrorMessage != BeamMPFnNotFoundError) { - beammp_lua_error("Calling \"onInit\" on \"" + Future->StateId + "\" failed: " + Future->ErrorMessage); + auto Snapshot = Future->GetDetachedSnapshot(); + if (Snapshot.Error && Snapshot.ErrorMessage != BeamMPFnNotFoundError) { + beammp_lua_error("Calling \"onInit\" on \"" + Snapshot.StateId + "\" failed: " + Snapshot.ErrorMessage); } } @@ -85,10 +91,11 @@ void TLuaEngine::operator()() { std::unique_lock Lock(mResultsToCheckMutex); if (!mResultsToCheck.empty()) { mResultsToCheck.remove_if([](const std::shared_ptr& Ptr) -> bool { - if (Ptr->Ready) { - if (Ptr->Error) { - if (Ptr->ErrorMessage != BeamMPFnNotFoundError) { - beammp_lua_error(Ptr->Function + ": " + Ptr->ErrorMessage); + if (Ptr->IsReady()) { + auto Snapshot = Ptr->GetDetachedSnapshot(); + if (Snapshot.Error) { + if (Snapshot.ErrorMessage != BeamMPFnNotFoundError) { + beammp_lua_error(Snapshot.Function + ": " + Snapshot.ErrorMessage); } } return true; @@ -128,6 +135,7 @@ void TLuaEngine::operator()() { } } } + std::unique_lock Lock(mLuaStatesMutex); if (mLuaStates.empty()) { beammp_trace("No Lua states, event loop running extremely sparsely"); Application::SleepSafeSeconds(10); @@ -215,16 +223,16 @@ std::vector TLuaEngine::Debug_GetStateFunctionQueueF return Result; } -std::vector TLuaEngine::Debug_GetResultsToCheckForState(TLuaStateId StateId) { +std::vector TLuaEngine::Debug_GetResultsToCheckForState(TLuaStateId StateId) { std::unique_lock Lock(mResultsToCheckMutex); auto ResultsToCheckCopy = mResultsToCheck; Lock.unlock(); - std::vector Result; + std::vector Result; while (!ResultsToCheckCopy.empty()) { auto ResultToCheck = std::move(ResultsToCheckCopy.front()); ResultsToCheckCopy.pop_front(); - if (ResultToCheck->StateId == StateId) { - Result.push_back(*ResultToCheck); + if (ResultToCheck->OwnerState() == StateId) { + Result.push_back(ResultToCheck->GetDetachedSnapshot()); } } return Result; @@ -311,27 +319,30 @@ void TLuaEngine::WaitForAll(std::vector>& Results, c size_t ms = 0; std::set WarnedResults; - while (!Result->Ready && !Cancelled) { + while (!Result->IsReady() && !Cancelled) { std::this_thread::sleep_for(std::chrono::milliseconds(10)); ms += 10; if (Max.has_value() && std::chrono::milliseconds(ms) > Max.value()) { - beammp_trace("'" + Result->Function + "' in '" + Result->StateId + "' did not finish executing in time (took: " + std::to_string(ms) + "ms)."); + auto Snapshot = Result->GetDetachedSnapshot(); + beammp_trace("'" + Snapshot.Function + "' in '" + Snapshot.StateId + "' did not finish executing in time (took: " + std::to_string(ms) + "ms)."); Cancelled = true; } else if (ms > 1000 * 60) { - auto ResultId = Result->StateId + "_" + Result->Function; + auto Snapshot = Result->GetDetachedSnapshot(); + auto ResultId = Snapshot.StateId + "_" + Snapshot.Function; if (WarnedResults.count(ResultId) == 0) { WarnedResults.insert(ResultId); - beammp_lua_warn("'" + Result->Function + "' in '" + Result->StateId + "' is taking very long. The event it's handling is too important to discard the result of this handler, but may block this event and possibly the whole lua state."); + beammp_lua_warn("'" + Snapshot.Function + "' in '" + Snapshot.StateId + "' is taking very long. The event it's handling is too important to discard the result of this handler, but may block this event and possibly the whole lua state."); } } } + auto Snapshot = Result->GetDetachedSnapshot(); if (Cancelled) { - beammp_lua_warn("'" + Result->Function + "' in '" + Result->StateId + "' failed to execute in time and was not waited for. It may still finish executing at a later time."); + beammp_lua_warn("'" + Snapshot.Function + "' in '" + Snapshot.StateId + "' failed to execute in time and was not waited for. It may still finish executing at a later time."); LuaAPI::MP::Engine->ReportErrors({ Result }); - } else if (Result->Error) { - if (Result->ErrorMessage != BeamMPFnNotFoundError) { - beammp_lua_error(Result->Function + ": " + Result->ErrorMessage); + } else if (Snapshot.Error) { + if (Snapshot.ErrorMessage != BeamMPFnNotFoundError) { + beammp_lua_error(Snapshot.Function + ": " + Snapshot.ErrorMessage); } } } @@ -431,10 +442,11 @@ void TLuaEngine::EnsureStateExists(TLuaStateId StateId, const std::string& Name, mLuaStates[StateId] = std::move(DataPtr); RegisterEvent("onInit", StateId, "onInit"); if (!DontCallOnInit) { - auto Res = EnqueueFunctionCall(StateId, "onInit", {}, "onInit"); + auto Res = EnqueueFunctionCall(StateId, "onInit", { }, "onInit"); Res->WaitUntilReady(); - if (Res->Error && Res->ErrorMessage != TLuaEngine::BeamMPFnNotFoundError) { - beammp_lua_error("Calling \"onInit\" on \"" + StateId + "\" failed: " + Res->ErrorMessage); + auto Snapshot = Res->GetSnapshot(StateId); + if (Snapshot.Error && Snapshot.ErrorMessage != TLuaEngine::BeamMPFnNotFoundError) { + beammp_lua_error("Calling \"onInit\" on \"" + StateId + "\" failed: " + Snapshot.ErrorMessage); } } } @@ -446,6 +458,7 @@ void TLuaEngine::RegisterEvent(const std::string& EventName, TLuaStateId StateId } std::set TLuaEngine::GetEventHandlersForState(const std::string& EventName, TLuaStateId StateId) { + std::unique_lock Lock(mLuaEventsMutex); return mLuaEvents[EventName][StateId]; } @@ -495,15 +508,12 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerGlobalEvent(const std::string auto Fn = mStateView[Handler]; if (Fn.valid()) { auto LuaResult = Fn(LocalArgs); - auto Result = std::make_shared(); + auto Result = std::make_shared(mStateId, Handler); if (LuaResult.valid()) { - Result->Error = false; - Result->Result = LuaResult; + Result->MarkReadySuccess(LuaResult); } else { - Result->Error = true; - Result->ErrorMessage = "Function result in TriggerGlobalEvent was invalid"; + Result->MarkReadyError("Function result in TriggerGlobalEvent was invalid"); } - Result->MarkAsReady(); Return.push_back(Result); } } @@ -511,26 +521,55 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerGlobalEvent(const std::string sol::table AsyncEventReturn = StateView.create_table(); AsyncEventReturn["ReturnValueImpl"] = Return; AsyncEventReturn.set_function("IsDone", - [&](const sol::table& Self) -> bool { + [](const sol::table& Self) -> bool { auto Vector = Self.get>>("ReturnValueImpl"); for (const auto& Value : Vector) { - if (!Value->Ready) { + if (!Value->IsReady()) { return false; } } return true; }); + TLuaStateId StateId = mStateId; AsyncEventReturn.set_function("GetResults", - [&](const sol::table& Self) -> sol::table { - sol::state_view StateView(mState); + [StateId](const sol::table& Self, sol::this_state State) -> sol::table { + sol::state_view StateView(State); sol::table Result = StateView.create_table(); auto Vector = Self.get>>("ReturnValueImpl"); int i = 1; for (const auto& Value : Vector) { - if (!Value->Ready) { + if (!Value->IsReady()) { return sol::lua_nil; } - Result.set(i, Value->Result); + // event results from this state are valid unserialized + if (Value->OwnerState() == StateId) { + auto Snapshot = Value->GetSnapshot(StateId); + Result.set(i, Snapshot.Result); + } else { + // event result from another state, goes through serialization boundary + auto Snapshot = Value->GetDetachedSnapshot(); + std::visit([i, &Result](auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v>) + Result.set(i, arg); + else if constexpr (std::is_same_v>) + Result.set(i, arg); + else if constexpr (std::is_same_v) + Result.set(i, arg); + else if constexpr (std::is_same_v) + Result.set(i, arg); + else if constexpr (std::is_same_v) + Result.set(i, arg); + else if constexpr (std::is_same_v) + Result.set(i, arg); + else if constexpr (std::is_same_v) + // monostate means no result value + Result.set(i, sol::lua_nil_t()); + else + static_assert(false, "non-exhaustive visitor!"); + }, Snapshot.Result.V); + } + ++i; } return Result; @@ -589,7 +628,6 @@ std::variant TLuaEngine::StateThreadData::Lua_GetPlayer } } - sol::table TLuaEngine::StateThreadData::Lua_GetPlayers() { sol::table Result = mStateView.create_table(); mEngine->Server().ForEachClient([&](std::weak_ptr Client) -> bool { @@ -1080,13 +1118,15 @@ TLuaEngine::StateThreadData::StateThreadData(const std::string& Name, TLuaStateI std::shared_ptr TLuaEngine::StateThreadData::EnqueueScript(const TLuaChunk& Script) { std::unique_lock Lock(mStateExecuteQueueMutex); - auto Result = std::make_shared(); + // explicitly passing empty string as there's no single function being called here + auto Result = std::make_shared(mStateId, std::string()); mStateExecuteQueue.push({ Script, Result }); return Result; } std::shared_ptr TLuaEngine::StateThreadData::EnqueueFunctionCallFromCustomEvent(const std::string& FunctionName, const std::vector& Args, const std::string& EventName, CallStrategy Strategy) { // TODO: Document all this + std::unique_lock Lock(mStateFunctionQueueMutex); decltype(mStateFunctionQueue)::iterator Iter = mStateFunctionQueue.end(); if (Strategy == CallStrategy::BestEffort) { Iter = std::find_if(mStateFunctionQueue.begin(), mStateFunctionQueue.end(), @@ -1095,10 +1135,7 @@ std::shared_ptr TLuaEngine::StateThreadData::EnqueueFunctionCallFrom }); } if (Iter == mStateFunctionQueue.end()) { - auto Result = std::make_shared(); - Result->StateId = mStateId; - Result->Function = FunctionName; - std::unique_lock Lock(mStateFunctionQueueMutex); + auto Result = std::make_shared(mStateId, FunctionName); mStateFunctionQueue.push_back({ FunctionName, Result, Args, EventName }); mStateFunctionQueueCond.notify_all(); return Result; @@ -1108,9 +1145,7 @@ std::shared_ptr TLuaEngine::StateThreadData::EnqueueFunctionCallFrom } std::shared_ptr TLuaEngine::StateThreadData::EnqueueFunctionCall(const std::string& FunctionName, const std::vector& Args, const std::string& EventName) { - auto Result = std::make_shared(); - Result->StateId = mStateId; - Result->Function = FunctionName; + auto Result = std::make_shared(mStateId, FunctionName); std::unique_lock Lock(mStateFunctionQueueMutex); mStateFunctionQueue.push_back({ FunctionName, Result, Args, EventName }); mStateFunctionQueueCond.notify_all(); @@ -1159,13 +1194,10 @@ void TLuaEngine::StateThreadData::operator()() { sol::state_view StateView(mState); auto Res = StateView.safe_script(*S.first.Content, sol::script_pass_on_error, S.first.FileName); if (Res.valid()) { - S.second->Error = false; - S.second->Result = std::move(Res); + S.second->MarkReadySuccess(std::move(Res)); } else { - S.second->Error = true; - S.second->SetErrorMessageFromResult(Res); + S.second->MarkReadyError(std::move(Res)); } - S.second->MarkAsReady(); } } { // StateFunctionQueue Scope @@ -1182,7 +1214,7 @@ void TLuaEngine::StateThreadData::operator()() { auto& Result = TheQueuedFunction.Result; auto Args = TheQueuedFunction.Args; // TODO: Use TheQueuedFunction.EventName for errors, warnings, etc - Result->StateId = mStateId; + Result->SetOwnerState(mStateId); sol::state_view StateView(mState); auto Fn = StateView[FnName]; if (Fn.valid() && Fn.get_type() == sol::type::function) { @@ -1222,17 +1254,12 @@ void TLuaEngine::StateThreadData::operator()() { } auto Res = Fn(sol::as_args(LuaArgs)); if (Res.valid()) { - Result->Error = false; - Result->Result = std::move(Res); + Result->MarkReadySuccess(std::move(Res)); } else { - Result->Error = true; - Result->SetErrorMessageFromResult(Res); + Result->MarkReadyError(std::move(Res)); } - Result->MarkAsReady(); } else { - Result->Error = true; - Result->ErrorMessage = BeamMPFnNotFoundError; // special error kind that we can ignore later - Result->MarkAsReady(); + Result->MarkReadyError(BeamMPFnNotFoundError); } auto ProfEnd = prof::now(); auto ProfDuration = prof::duration(ProfStart, ProfEnd); @@ -1285,21 +1312,6 @@ void TLuaEngine::StateThreadData::AddPath(const fs::path& Path) { mPaths.push(Path); } -void TLuaResult::MarkAsReady() { - { - std::lock_guard readyLock(*this->ReadyMutex); - this->Ready = true; - } - this->ReadyCondition->notify_all(); -} - -void TLuaResult::WaitUntilReady() { - std::unique_lock readyLock(*this->ReadyMutex); - // wait if not ready yet - if (!this->Ready) - this->ReadyCondition->wait(readyLock); -} - TLuaChunk::TLuaChunk(std::shared_ptr Content, std::string FileName, std::string PluginPath) : Content(Content) , FileName(FileName) diff --git a/src/TLuaPlugin.cpp b/src/TLuaPlugin.cpp index 50620be4..0d558ac1 100644 --- a/src/TLuaPlugin.cpp +++ b/src/TLuaPlugin.cpp @@ -63,8 +63,9 @@ TLuaPlugin::TLuaPlugin(TLuaEngine& Engine, const TLuaPluginConfig& Config, const } for (auto& Result : ResultsToCheck) { Result.second->WaitUntilReady(); - if (Result.second->Error) { - beammp_lua_error("Failed: \"" + Result.first.string() + "\": " + Result.second->ErrorMessage); + auto Snapshot = Result.second->GetDetachedSnapshot(); + if (Snapshot.Error) { + beammp_lua_error("Failed: \"" + Result.first.string() + "\": " + Snapshot.ErrorMessage); } } } diff --git a/src/TLuaResult.cpp b/src/TLuaResult.cpp new file mode 100644 index 00000000..253a4783 --- /dev/null +++ b/src/TLuaResult.cpp @@ -0,0 +1,210 @@ +#include "TLuaResult.h" +#include +#include +#include +#include + +void TLuaResult::MarkReadySuccess(sol::object Res) { + std::unique_lock Lock(mMutex); + mError = false; + mResult = Res; + mDetachedResult = Freeze(Res); + + MarkAsReady(); +} + +void TLuaResult::MarkReadyError(sol::protected_function_result Res) { + std::unique_lock Lock(mMutex); + SetErrorMessageFromResult(Res); + + MarkAsReady(); +} + +void TLuaResult::MarkReadyError(std::string Res) { + std::unique_lock Lock(mMutex); + mError = true; + mErrorMessage = std::move(Res); + + MarkAsReady(); +} + +bool TLuaResult::IsReady() const { + std::unique_lock Lock(mMutex); + return mReady; +} +bool TLuaResult::IsError() const { + std::unique_lock Lock(mMutex); + return mError; +} + +TLuaStateId TLuaResult::OwnerState() const { + std::unique_lock Lock(mMutex); + // copy + return mStateId; +} +void TLuaResult::SetOwnerState(TLuaStateId StateId) { + std::unique_lock Lock(mMutex); + mStateId = std::move(StateId); +} + +TLuaResult::Snapshot TLuaResult::GetSnapshot(TLuaStateId CallingState) const { + std::unique_lock Lock(mMutex); + if (CallingState != mStateId) { + throw std::logic_error("Tried to get snapshot from non-owning state (use a detached snapshot instead)"); + } + Snapshot snapshot { + .Error = mError, + .Ready = mReady, + .ErrorMessage = mErrorMessage, + .Result = mResult, + .StateId = mStateId, + .Function = mFunction, + }; + return snapshot; +} +TLuaResult::DetachedSnapshot TLuaResult::GetDetachedSnapshot() const { + std::unique_lock Lock(mMutex); + DetachedSnapshot snapshot { + .Error = mError, + .Ready = mReady, + .ErrorMessage = mErrorMessage, + .Result = mDetachedResult, + .StateId = mStateId, + .Function = mFunction, + }; + return snapshot; +} + +TDetachedLuaValue TLuaResult::Freeze(const sol::object& o, int depth) { + if (depth > 64) + throw std::runtime_error("max depth (64) reached"); + switch (o.get_type()) { + case sol::type::lua_nil: + return { { std::monostate { } } }; + case sol::type::boolean: + return { { o.as() } }; + case sol::type::number: { + if (o.is()) { + return { { o.as() } }; + } else { + return { { o.as() } }; + } + } + case sol::type::string: + return { { o.as() } }; + case sol::type::table: { + TDetachedLuaValue::Object out; + for (auto&& [k, v] : o.as()) { + if (!k.is()) + continue; // no numeric-key handling, don't need it + out.emplace(k.as(), Freeze(v, depth + 1)); + } + return { { std::move(out) } }; + } + default: + throw std::runtime_error("unsupported Lua type for cross-thread snapshot"); + } +} +void TLuaResult::MarkAsReady() { + mReady = true; + mReadyCondition.notify_all(); +} + +void TLuaResult::WaitUntilReady() { + std::unique_lock Lock(mMutex); + while (!mReady) + mReadyCondition.wait_for(Lock, std::chrono::milliseconds(50), + [this] { + return mReady; + }); +} + +TEST_CASE("TLuaResult MarkReadyError(string) marks ready and wakes waiters") { + TLuaResult result("state_a", "fn_a"); + std::atomic waiterDone { false }; + + auto waiter = std::thread([&] { + result.WaitUntilReady(); + waiterDone.store(true, std::memory_order_release); + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + CHECK_FALSE(waiterDone.load(std::memory_order_acquire)); + + result.MarkReadyError(std::string("boom")); + waiter.join(); + + CHECK(result.IsReady()); + const auto snapshot = result.GetDetachedSnapshot(); + CHECK(snapshot.Ready); + CHECK(snapshot.Error); + CHECK(snapshot.ErrorMessage == "boom"); + CHECK(snapshot.StateId == "state_a"); + CHECK(snapshot.Function == "fn_a"); +} + +TEST_CASE("TLuaResult GetSnapshot enforces owner state id") { + TLuaResult result("owner_state", "fn_owner"); + sol::state lua; + lua.open_libraries(sol::lib::base); + result.MarkReadySuccess(sol::make_object(lua.lua_state(), std::string("ok"))); + + CHECK_NOTHROW(result.GetSnapshot("owner_state")); + CHECK_THROWS_AS(result.GetSnapshot("different_state"), std::logic_error); +} + +TEST_CASE("TLuaResult detached snapshot freezes nested string-keyed tables") { + TLuaResult result("state_table", "fn_table"); + sol::state lua; + lua.open_libraries(sol::lib::base); + + auto outer = lua.create_table(); + auto inner = lua.create_table(); + inner["k"] = std::string("v"); + + outer["flag"] = true; + outer["msg"] = std::string("hello"); + outer["inner"] = inner; + outer[1] = std::string("ignored_numeric_key"); + + result.MarkReadySuccess(sol::make_object(lua.lua_state(), outer)); + const auto detached = result.GetDetachedSnapshot(); + + CHECK(detached.Ready); + CHECK_FALSE(detached.Error); + const auto* object = std::get_if(&detached.Result.V); + REQUIRE(object != nullptr); + + CHECK(object->contains("flag")); + CHECK(object->contains("msg")); + CHECK(object->contains("inner")); + CHECK_FALSE(object->contains("1")); + + const auto* flag = std::get_if(&object->at("flag").V); + REQUIRE(flag != nullptr); + CHECK(*flag); + + const auto* msg = std::get_if(&object->at("msg").V); + REQUIRE(msg != nullptr); + CHECK(*msg == "hello"); + + const auto* innerObj = std::get_if(&object->at("inner").V); + REQUIRE(innerObj != nullptr); + REQUIRE(innerObj->contains("k")); + const auto* innerValue = std::get_if(&innerObj->at("k").V); + REQUIRE(innerValue != nullptr); + CHECK(*innerValue == "v"); +} + +TEST_CASE("TLuaResult MarkReadySuccess throws on unsupported Lua function value") { + TLuaResult result("state_fn", "fn_fn"); + sol::state lua; + lua.open_libraries(sol::lib::base); + lua["f"] = [] { return 1; }; + const sol::table globals = lua.globals(); + const sol::protected_function fn = globals.get("f"); + const sol::object fnObj = sol::make_object(lua.lua_state(), fn); + + CHECK_THROWS_AS(result.MarkReadySuccess(fnObj), std::runtime_error); + CHECK_FALSE(result.IsReady()); +} diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index e5f50e21..764584cc 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -40,6 +40,7 @@ #include #include #include +#include #include typedef boost::asio::detail::socket_option::integer rcv_timeout_option; @@ -507,23 +508,31 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { bool BypassLimit = false; for (const auto& Result : Futures) { - if (!Result->Error && Result->Result.is()) { - auto Res = Result->Result.as(); - - if (Res == 1) { - NotAllowed = true; - break; - } else if (Res == 2) { - BypassLimit = true; + auto Snapshot = Result->GetDetachedSnapshot(); + if (!Snapshot.Error) { + const int* MaybeInt = std::get_if(&Snapshot.Result.V); + if (MaybeInt != nullptr) { + auto Res = *MaybeInt; + + if (Res == 1) { + NotAllowed = true; + break; + } else if (Res == 2) { + BypassLimit = true; + } } } } std::string Reason; bool NotAllowedWithReason = std::any_of(Futures.begin(), Futures.end(), [&Reason](const std::shared_ptr& Result) -> bool { - if (!Result->Error && Result->Result.is()) { - Reason = Result->Result.as(); - return true; + auto Snapshot = Result->GetDetachedSnapshot(); + if (!Snapshot.Error) { + const std::string* MaybeStr = std::get_if(&Snapshot.Result.V); + if (MaybeStr != nullptr) { + Reason = *MaybeStr; + return true; + } } return false; }); diff --git a/src/TPluginMonitor.cpp b/src/TPluginMonitor.cpp index 049fd200..d766a70c 100644 --- a/src/TPluginMonitor.cpp +++ b/src/TPluginMonitor.cpp @@ -69,8 +69,9 @@ void TPluginMonitor::operator()() { auto StateID = mEngine->GetStateIDForPlugin(fs::path(Pair.first).parent_path()); auto Res = mEngine->EnqueueScript(StateID, Chunk); Res->WaitUntilReady(); - if (Res->Error) { - beammp_lua_errorf("Error while hot-reloading \"{}\": {}", Pair.first, Res->ErrorMessage); + if (Res->IsError()) { + auto Snapshot = Res->GetDetachedSnapshot(); + beammp_lua_errorf("Error while hot-reloading \"{}\": {}", Pair.first, Snapshot.ErrorMessage); } else { mEngine->ReportErrors(mEngine->TriggerLocalEvent(StateID, "onInit")); mEngine->ReportErrors(mEngine->TriggerEvent("onFileChanged", "", Pair.first)); diff --git a/src/TServer.cpp b/src/TServer.cpp index d9d98efa..2c475424 100644 --- a/src/TServer.cpp +++ b/src/TServer.cpp @@ -265,9 +265,15 @@ void TServer::GlobalParser(const std::weak_ptr& Client, std::vectorGetName(), LockedClient->GetID(), PacketAsString.substr(PacketAsString.find(':', 3) + 1)); bool Rejected = std::any_of(Futures.begin(), Futures.end(), [](const std::shared_ptr& Elem) { - return !Elem->Error - && Elem->Result.is() - && bool(Elem->Result.as()); + auto Snapshot = Elem->GetDetachedSnapshot(); + if (Snapshot.Error) { + return false; + } + const int* MaybeInt = std::get_if(&Snapshot.Result.V); + if (MaybeInt == nullptr) { + return false; + } + return bool(*MaybeInt); }); if (!Rejected) { std::string SanitizedPacket = fmt::format("C:{}: {}", LockedClient->GetName(), Message); @@ -380,7 +386,15 @@ void TServer::ParseVehicle(TClient& c, const std::string& Pckt, TNetwork& Networ TLuaEngine::WaitForAll(Futures); bool ShouldntSpawn = std::any_of(Futures.begin(), Futures.end(), [](const std::shared_ptr& Result) { - return !Result->Error && Result->Result.is() && Result->Result.as() != 0; + auto Snapshot = Result->GetDetachedSnapshot(); + if (Snapshot.Error) { + return false; + } + const int* MaybeInt = std::get_if(&Snapshot.Result.V); + if (MaybeInt == nullptr) { + return false; + } + return *MaybeInt != 0; }); bool SpawnConfirmed = false; @@ -417,7 +431,15 @@ void TServer::ParseVehicle(TClient& c, const std::string& Pckt, TNetwork& Networ TLuaEngine::WaitForAll(Futures); bool ShouldntAllow = std::any_of(Futures.begin(), Futures.end(), [](const std::shared_ptr& Result) { - return !Result->Error && Result->Result.is() && Result->Result.as() != 0; + auto Snapshot = Result->GetDetachedSnapshot(); + if (Snapshot.Error) { + return false; + } + const int* MaybeInt = std::get_if(&Snapshot.Result.V); + if (MaybeInt == nullptr) { + return false; + } + return *MaybeInt != 0; }); auto FoundPos = Packet.find('{'); From 3e12e487abf950ae86684c1153368137bdb65fdd Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Sat, 18 Apr 2026 18:24:02 +0200 Subject: [PATCH 26/36] fix success/error handling in lua engine --- src/TLuaEngine.cpp | 18 +++++++++++++++--- src/TLuaResult.cpp | 1 + 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index 27f5e120..6006375b 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -510,7 +510,11 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerGlobalEvent(const std::string auto LuaResult = Fn(LocalArgs); auto Result = std::make_shared(mStateId, Handler); if (LuaResult.valid()) { - Result->MarkReadySuccess(LuaResult); + try { + Result->MarkReadySuccess(LuaResult); + } catch (const std::exception& e) { + Result->MarkReadyError(fmt::format("Call was successful, but result could not be serialized")); + } } else { Result->MarkReadyError("Function result in TriggerGlobalEvent was invalid"); } @@ -1194,7 +1198,11 @@ void TLuaEngine::StateThreadData::operator()() { sol::state_view StateView(mState); auto Res = StateView.safe_script(*S.first.Content, sol::script_pass_on_error, S.first.FileName); if (Res.valid()) { - S.second->MarkReadySuccess(std::move(Res)); + try { + S.second->MarkReadySuccess(std::move(Res)); + } catch (const std::exception& e) { + S.second->MarkReadyError(fmt::format("Call was successful, but result could not be serialized")); + } } else { S.second->MarkReadyError(std::move(Res)); } @@ -1254,7 +1262,11 @@ void TLuaEngine::StateThreadData::operator()() { } auto Res = Fn(sol::as_args(LuaArgs)); if (Res.valid()) { - Result->MarkReadySuccess(std::move(Res)); + try { + Result->MarkReadySuccess(std::move(Res)); + } catch (const std::exception& e) { + Result->MarkReadyError(fmt::format("Call was successful, but result could not be serialized")); + } } else { Result->MarkReadyError(std::move(Res)); } diff --git a/src/TLuaResult.cpp b/src/TLuaResult.cpp index 253a4783..75b86b06 100644 --- a/src/TLuaResult.cpp +++ b/src/TLuaResult.cpp @@ -15,6 +15,7 @@ void TLuaResult::MarkReadySuccess(sol::object Res) { void TLuaResult::MarkReadyError(sol::protected_function_result Res) { std::unique_lock Lock(mMutex); + mError = true; SetErrorMessageFromResult(Res); MarkAsReady(); From 98fb12bbcc2d119b65d267b1b911836deb789ab7 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Sat, 18 Apr 2026 18:46:02 +0200 Subject: [PATCH 27/36] fix holding lock during sleep oops --- src/TLuaEngine.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index 6006375b..d1d86673 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -135,8 +135,13 @@ void TLuaEngine::operator()() { } } } - std::unique_lock Lock(mLuaStatesMutex); - if (mLuaStates.empty()) { + bool StatesEmpty = false; + { + std::unique_lock Lock(mLuaStatesMutex); + StatesEmpty = mLuaStates.empty(); + } + + if (StatesEmpty) { beammp_trace("No Lua states, event loop running extremely sparsely"); Application::SleepSafeSeconds(10); } else { From 9ca12fc7a683a46ac2fd574f74013484b8fafab0 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Sat, 18 Apr 2026 18:52:26 +0200 Subject: [PATCH 28/36] fix std::weak_ptr locking and expiry checks You're supposed to .lock() instead of TOCTOU checking, of course. Not sure what I was thinking when I built that. .lock() returns a default constructed std::shared_ptr on error, which is `false` via `operator bool`. --- src/Client.cpp | 3 +- src/LuaAPI.cpp | 171 ++++++++++++++++++++++----------------- src/TConsole.cpp | 36 ++++----- src/THeartbeatThread.cpp | 4 +- src/TLuaEngine.cpp | 163 +++++++++++++++++++------------------ src/TNetwork.cpp | 80 +++++++++--------- src/TPPSMonitor.cpp | 4 +- src/TServer.cpp | 28 +++---- 8 files changed, 260 insertions(+), 229 deletions(-) diff --git a/src/Client.cpp b/src/Client.cpp index f9ae40eb..7d215941 100644 --- a/src/Client.cpp +++ b/src/Client.cpp @@ -178,8 +178,7 @@ std::optional> GetClient(TServer& Server, int ID) { std::optional> MaybeClient { std::nullopt }; Server.ForEachClient([&](std::weak_ptr CPtr) -> bool { ReadLock Lock(Server.GetClientMutex()); - if (!CPtr.expired()) { - auto C = CPtr.lock(); + if (auto C = CPtr.lock()) { if (C->GetID() == ID) { MaybeClient = CPtr; return false; diff --git a/src/LuaAPI.cpp b/src/LuaAPI.cpp index e7fab3ff..bb16dda0 100644 --- a/src/LuaAPI.cpp +++ b/src/LuaAPI.cpp @@ -143,22 +143,23 @@ static inline std::pair InternalTriggerClientEvent(int Player return { true, "" }; } else { auto MaybeClient = GetClient(LuaAPI::MP::Engine->Server(), PlayerID); - if (!MaybeClient || MaybeClient.value().expired()) { - beammp_lua_errorf("TriggerClientEvent invalid Player ID '{}'", PlayerID); - return { false, "Invalid Player ID" }; - } - auto c = MaybeClient.value().lock(); + if (MaybeClient) { + if (auto c = MaybeClient.value().lock()) { + if (!c->IsSyncing() && !c->IsSynced()) { + return { false, "Player hasn't joined yet" }; + } - if (!c->IsSyncing() && !c->IsSynced()) { - return { false, "Player hasn't joined yet" }; + if (!LuaAPI::MP::Engine->Network().Respond(*c, StringToVector(Packet), true)) { + beammp_lua_errorf("Respond failed, dropping client {}", PlayerID); + LuaAPI::MP::Engine->Network().ClientKick(*c, "Disconnected after failing to receive packets"); + return { false, "Respond failed, dropping client" }; + } + return { true, "" }; + } } - if (!LuaAPI::MP::Engine->Network().Respond(*c, StringToVector(Packet), true)) { - beammp_lua_errorf("Respond failed, dropping client {}", PlayerID); - LuaAPI::MP::Engine->Network().ClientKick(*c, "Disconnected after failing to receive packets"); - return { false, "Respond failed, dropping client" }; - } - return { true, "" }; + beammp_lua_errorf("TriggerClientEvent invalid Player ID '{}'", PlayerID); + return { false, "Invalid Player ID" }; } } @@ -169,13 +170,15 @@ std::pair LuaAPI::MP::TriggerClientEvent(int PlayerID, const std::pair LuaAPI::MP::DropPlayer(int ID, std::optional MaybeReason) { auto MaybeClient = GetClient(Engine->Server(), ID); - if (!MaybeClient || MaybeClient.value().expired()) { - beammp_lua_errorf("Tried to drop client with id {}, who doesn't exist", ID); - return { false, "Player does not exist" }; + if (MaybeClient) { + if (auto c = MaybeClient.value().lock()) { + LuaAPI::MP::Engine->Network().ClientKick(*c, MaybeReason.value_or("No reason")); + return { true, "" }; + } } - auto c = MaybeClient.value().lock(); - LuaAPI::MP::Engine->Network().ClientKick(*c, MaybeReason.value_or("No reason")); - return { true, "" }; + + beammp_lua_errorf("Tried to drop client with id {}, who doesn't exist", ID); + return { false, "Player does not exist" }; } std::pair LuaAPI::MP::SendChatMessage(int ID, const std::string& Message, const bool& LogChat) { @@ -189,21 +192,26 @@ std::pair LuaAPI::MP::SendChatMessage(int ID, const std::stri Result.first = true; } else { auto MaybeClient = GetClient(Engine->Server(), ID); - if (MaybeClient && !MaybeClient.value().expired()) { - auto c = MaybeClient.value().lock(); - if (!c->IsSynced()) { + if (MaybeClient) { + if (auto c = MaybeClient.value().lock()) { + if (!c->IsSynced()) { + Result.first = false; + Result.second = "Player still syncing data"; + return Result; + } + if (LogChat) { + LogChatMessage(" (to \"" + c->GetName() + "\")", -1, Message); + } + if (!Engine->Network().Respond(*c, StringToVector(Packet), true)) { + beammp_errorf("Failed to send chat message back to sender (id {}) - did the sender disconnect?", ID); + // TODO: should we return an error here? + } + Result.first = true; + } else { + beammp_lua_error("SendChatMessage invalid argument [1] invalid ID"); Result.first = false; - Result.second = "Player still syncing data"; - return Result; - } - if (LogChat) { - LogChatMessage(" (to \"" + c->GetName() + "\")", -1, Message); - } - if (!Engine->Network().Respond(*c, StringToVector(Packet), true)) { - beammp_errorf("Failed to send chat message back to sender (id {}) - did the sender disconnect?", ID); - // TODO: should we return an error here? + Result.second = "Invalid Player ID"; } - Result.first = true; } else { beammp_lua_error("SendChatMessage invalid argument [1] invalid ID"); Result.first = false; @@ -223,18 +231,23 @@ std::pair LuaAPI::MP::SendNotification(int ID, const std::str } else { auto MaybeClient = GetClient(Engine->Server(), ID); if (MaybeClient) { - auto c = MaybeClient.value().lock(); - if (!c->IsSynced()) { - Result.first = false; - Result.second = "Player is not synced yet"; - return Result; - } - if (!Engine->Network().Respond(*c, StringToVector(Packet), true)) { - beammp_errorf("Failed to send notification to player (id {}) - did the player disconnect?", ID); + if (auto c = MaybeClient.value().lock()) { + if (!c->IsSynced()) { + Result.first = false; + Result.second = "Player is not synced yet"; + return Result; + } + if (!Engine->Network().Respond(*c, StringToVector(Packet), true)) { + beammp_errorf("Failed to send notification to player (id {}) - did the player disconnect?", ID); + Result.first = false; + Result.second = "Failed to send packet"; + } + Result.first = true; + } else { + beammp_lua_error("SendNotification invalid argument [1] invalid ID"); Result.first = false; - Result.second = "Failed to send packet"; + Result.second = "Invalid Player ID"; } - Result.first = true; } else { beammp_lua_error("SendNotification invalid argument [1] invalid ID"); Result.first = false; @@ -265,18 +278,23 @@ std::pair LuaAPI::MP::ConfirmationDialog(int ID, const std::s } else { auto MaybeClient = GetClient(Engine->Server(), ID); if (MaybeClient) { - auto c = MaybeClient.value().lock(); - if (!c->IsSynced()) { - Result.first = false; - Result.second = "Player is not synced yet"; - return Result; - } - if (!Engine->Network().Respond(*c, StringToVector(Packet), true)) { - beammp_errorf("Failed to send confirmation dialog to player (id {}) - did the player disconnect?", ID); + if (auto c = MaybeClient.value().lock()) { + if (!c->IsSynced()) { + Result.first = false; + Result.second = "Player is not synced yet"; + return Result; + } + if (!Engine->Network().Respond(*c, StringToVector(Packet), true)) { + beammp_errorf("Failed to send confirmation dialog to player (id {}) - did the player disconnect?", ID); + Result.first = false; + Result.second = "Failed to send packet"; + } + Result.first = true; + } else { + beammp_lua_error("ConfirmationDialog invalid argument [1] invalid ID"); Result.first = false; - Result.second = "Failed to send packet"; + Result.second = "Invalid Player ID"; } - Result.first = true; } else { beammp_lua_error("ConfirmationDialog invalid argument [1] invalid ID"); Result.first = false; @@ -290,22 +308,27 @@ std::pair LuaAPI::MP::ConfirmationDialog(int ID, const std::s std::pair LuaAPI::MP::RemoveVehicle(int PID, int VID) { std::pair Result; auto MaybeClient = GetClient(Engine->Server(), PID); - if (!MaybeClient || MaybeClient.value().expired()) { + if (MaybeClient) { + if (auto c = MaybeClient.value().lock()) { + if (c->GetCarData(VID) != nlohmann::detail::value_t::null) { + std::string Destroy = "Od:" + std::to_string(PID) + "-" + std::to_string(VID); + LuaAPI::MP::Engine->ReportErrors(LuaAPI::MP::Engine->TriggerEvent("onVehicleDeleted", "", PID, VID)); + Engine->Network().SendToAll(nullptr, StringToVector(Destroy), true, true); + c->DeleteCar(VID); + Result.first = true; + } else { + Result.first = false; + Result.second = "Vehicle does not exist"; + } + } else { + beammp_lua_error("RemoveVehicle invalid Player ID"); + Result.first = false; + Result.second = "Invalid Player ID"; + } + } else { beammp_lua_error("RemoveVehicle invalid Player ID"); Result.first = false; Result.second = "Invalid Player ID"; - return Result; - } - auto c = MaybeClient.value().lock(); - if (c->GetCarData(VID) != nlohmann::detail::value_t::null) { - std::string Destroy = "Od:" + std::to_string(PID) + "-" + std::to_string(VID); - LuaAPI::MP::Engine->ReportErrors(LuaAPI::MP::Engine->TriggerEvent("onVehicleDeleted", "", PID, VID)); - Engine->Network().SendToAll(nullptr, StringToVector(Destroy), true, true); - c->DeleteCar(VID); - Result.first = true; - } else { - Result.first = false; - Result.second = "Vehicle does not exist"; } return Result; } @@ -412,20 +435,22 @@ void LuaAPI::MP::Sleep(size_t Ms) { bool LuaAPI::MP::IsPlayerConnected(int ID) { auto MaybeClient = GetClient(Engine->Server(), ID); - if (MaybeClient && !MaybeClient.value().expired()) { - return MaybeClient.value().lock()->IsUDPConnected(); - } else { - return false; + if (MaybeClient) { + if (auto c = MaybeClient.value().lock()) { + return c->IsUDPConnected(); + } } + return false; } bool LuaAPI::MP::IsPlayerGuest(int ID) { auto MaybeClient = GetClient(Engine->Server(), ID); - if (MaybeClient && !MaybeClient.value().expired()) { - return MaybeClient.value().lock()->IsGuest(); - } else { - return false; + if (MaybeClient) { + if (auto c = MaybeClient.value().lock()) { + return c->IsGuest(); + } } + return false; } void LuaAPI::MP::PrintRaw(sol::variadic_args Args) { diff --git a/src/TConsole.cpp b/src/TConsole.cpp index 5e41f821..2aca3a0d 100644 --- a/src/TConsole.cpp +++ b/src/TConsole.cpp @@ -296,8 +296,7 @@ void TConsole::Command_NetTest(const std::string& cmd, const std::vector& return StringStartsWith(Name1, Name2) || StringStartsWith(Name2, Name1); }; mLuaEngine->Server().ForEachClient([&](std::weak_ptr Client) -> bool { - if (!Client.expired()) { - auto locked = Client.lock(); - if (NameCompare(locked->GetName(), Name)) { - mLuaEngine->Network().ClientKick(*locked, Reason); + if (auto Locked = Client.lock()) { + if (NameCompare(Locked->GetName(), Name)) { + mLuaEngine->Network().ClientKick(*Locked, Reason); Kicked = true; return false; } @@ -356,7 +354,7 @@ std::tuple> TConsole::ParseCommand(const s // It correctly splits arguments, including respecting single and double quotes, as well as backticks auto End_i = CommandWithArgs.find_first_of(' '); std::string Command = CommandWithArgs.substr(0, End_i); - std::string ArgsStr {}; + std::string ArgsStr { }; if (End_i != std::string::npos) { ArgsStr = CommandWithArgs.substr(End_i); } @@ -566,11 +564,10 @@ void TConsole::Command_List(const std::string&, const std::vector& std::stringstream ss; ss << std::left << std::setw(25) << "Name" << std::setw(6) << "ID" << std::setw(6) << "Cars" << std::endl; mLuaEngine->Server().ForEachClient([&](std::weak_ptr Client) -> bool { - if (!Client.expired()) { - auto locked = Client.lock(); - ss << std::left << std::setw(25) << locked->GetName() - << std::setw(6) << locked->GetID() - << std::setw(6) << locked->GetCarCount() << "\n"; + if (auto Locked = Client.lock()) { + ss << std::left << std::setw(25) << Locked->GetName() + << std::setw(6) << Locked->GetID() + << std::setw(6) << Locked->GetCarCount() << "\n"; } return true; }); @@ -593,8 +590,7 @@ void TConsole::Command_Status(const std::string&, const std::vector size_t MissedPacketQueueSum = 0; int LargestSecondsSinceLastPing = 0; mLuaEngine->Server().ForEachClient([&](std::weak_ptr Client) -> bool { - if (!Client.expired()) { - auto Locked = Client.lock(); + if (auto Locked = Client.lock()) { CarCount += Locked->GetCarCount(); ConnectedCount += Locked->IsUDPConnected() ? 1 : 0; GuestCount += Locked->IsGuest() ? 1 : 0; @@ -613,11 +609,11 @@ void TConsole::Command_Status(const std::string&, const std::vector size_t SystemsBad = 0; size_t SystemsShuttingDown = 0; size_t SystemsShutdown = 0; - std::string SystemsBadList {}; - std::string SystemsGoodList {}; - std::string SystemsStartingList {}; - std::string SystemsShuttingDownList {}; - std::string SystemsShutdownList {}; + std::string SystemsBadList { }; + std::string SystemsGoodList { }; + std::string SystemsStartingList { }; + std::string SystemsShuttingDownList { }; + std::string SystemsShutdownList { }; auto Statuses = Application::GetSubsystemStatuses(); for (const auto& NameStatusPair : Statuses) { switch (NameStatusPair.second) { @@ -847,7 +843,7 @@ void TConsole::InitializeCommandline() { if (!mLuaEngine) { beammp_info("Lua not started yet, please try again in a second"); } else { - std::string prefix {}; // stores non-table part of input + std::string prefix { }; // stores non-table part of input for (size_t i = stub.length(); i > 0; i--) { // separate table from input if (!std::isalnum(stub[i - 1]) && stub[i - 1] != '_' && stub[i - 1] != '.') { prefix = stub.substr(0, i); diff --git a/src/THeartbeatThread.cpp b/src/THeartbeatThread.cpp index 36cdde1c..347f59a3 100644 --- a/src/THeartbeatThread.cpp +++ b/src/THeartbeatThread.cpp @@ -181,8 +181,8 @@ std::string THeartbeatThread::GetPlayers() { std::string Return; mServer.ForEachClient([&](const std::weak_ptr& ClientPtr) -> bool { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - Return += ClientPtr.lock()->GetName() + ";"; + if (auto Client = ClientPtr.lock()) { + Return += Client->GetName() + ";"; } return true; }); diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index d1d86673..7fdd91a3 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -120,7 +120,7 @@ void TLuaEngine::operator()() { std::unique_lock StateLock(mLuaStatesMutex); std::unique_lock Lock2(mResultsToCheckMutex); for (auto& Handler : Handlers) { - auto Res = mLuaStates[Timer.StateId]->EnqueueFunctionCallFromCustomEvent(Handler, {}, Timer.EventName, Timer.Strategy); + auto Res = mLuaStates[Timer.StateId]->EnqueueFunctionCallFromCustomEvent(Handler, { }, Timer.EventName, Timer.Strategy); if (Res) { mResultsToCheck.push_back(Res); mResultsToCheckCond.notify_one(); @@ -268,7 +268,7 @@ std::vector TLuaEngine::StateThreadData::GetStateTableKeys(const st auto globals = mStateView.globals(); sol::table current = globals; - std::vector Result {}; + std::vector Result { }; for (const auto& [key, value] : current) { std::string s = key.as(); @@ -471,7 +471,7 @@ std::vector TLuaEngine::StateThreadData::JsonStringToArray(JsonStri auto LocalTable = Lua_JsonDecode(Str.value).as>(); for (auto& value : LocalTable) { if (value.is() && value.as() == BEAMMP_INTERNAL_NIL) { - value = sol::object {}; + value = sol::object { }; } } return LocalTable; @@ -613,36 +613,37 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerLocalEvent(const std::string& sol::table TLuaEngine::StateThreadData::Lua_GetPlayerIdentifiers(int ID) { auto MaybeClient = GetClient(mEngine->Server(), ID); - if (MaybeClient && !MaybeClient.value().expired()) { - auto IDs = MaybeClient.value().lock()->GetIdentifiers(); - if (IDs.empty()) { - return sol::lua_nil; - } - sol::table Result = mStateView.create_table(); - for (const auto& Pair : IDs) { - Result.set(Pair.first, Pair.second); + if (MaybeClient) { + if (std::shared_ptr Locked = MaybeClient.value().lock()) { + auto IDs = Locked->GetIdentifiers(); + if (IDs.empty()) { + return sol::lua_nil; + } + sol::table Result = mStateView.create_table(); + for (const auto& Pair : IDs) { + Result.set(Pair.first, Pair.second); + } + return Result; } - return Result; - } else { - return sol::lua_nil; } + return sol::lua_nil; } std::variant TLuaEngine::StateThreadData::Lua_GetPlayerRole(int ID) { auto MaybeClient = GetClient(mEngine->Server(), ID); if (MaybeClient) { - return MaybeClient.value().lock()->GetRoles(); - } else { - return sol::nil; + if (auto Locked = MaybeClient.value().lock()) { + return Locked->GetRoles(); + } } + return sol::nil; } sol::table TLuaEngine::StateThreadData::Lua_GetPlayers() { sol::table Result = mStateView.create_table(); mEngine->Server().ForEachClient([&](std::weak_ptr Client) -> bool { - if (!Client.expired()) { - auto locked = Client.lock(); - Result[locked->GetID()] = locked->GetName(); + if (auto Locked = Client.lock()) { + Result[Locked->GetID()] = Locked->GetName(); } return true; }); @@ -652,10 +653,9 @@ sol::table TLuaEngine::StateThreadData::Lua_GetPlayers() { int TLuaEngine::StateThreadData::Lua_GetPlayerIDByName(const std::string& Name) { int Id = -1; mEngine->mServer->ForEachClient([&Id, &Name](std::weak_ptr Client) -> bool { - if (!Client.expired()) { - auto locked = Client.lock(); - if (locked->GetName() == Name) { - Id = locked->GetID(); + if (auto Locked = Client.lock()) { + if (Locked->GetName() == Name) { + Id = Locked->GetID(); return false; } } @@ -692,60 +692,59 @@ sol::table TLuaEngine::StateThreadData::Lua_FS_ListDirectories(const std::string std::string TLuaEngine::StateThreadData::Lua_GetPlayerName(int ID) { auto MaybeClient = GetClient(mEngine->Server(), ID); - if (MaybeClient && !MaybeClient.value().expired()) { - return MaybeClient.value().lock()->GetName(); - } else { - return ""; + if (MaybeClient) { + if (auto Locked = MaybeClient.value().lock()) { + return Locked->GetName(); + } } + return ""; } sol::table TLuaEngine::StateThreadData::Lua_GetPlayerVehicles(int ID) { auto MaybeClient = GetClient(mEngine->Server(), ID); - if (MaybeClient && !MaybeClient.value().expired()) { - auto Client = MaybeClient.value().lock(); - TClient::TSetOfVehicleData VehicleData; - { // Vehicle Data Lock Scope - auto LockedData = Client->GetAllCars(); - VehicleData = *LockedData.VehicleData; - } // End Vehicle Data Lock Scope - if (VehicleData.empty()) { - return sol::lua_nil; - } - sol::state_view StateView(mState); - sol::table Result = StateView.create_table(); - for (const auto& v : VehicleData) { - Result[v.ID()] = v.DataAsPacket(Client->GetRoles(), Client->GetName(), Client->GetID()).substr(3); + if (MaybeClient) { + if (auto Client = MaybeClient.value().lock()) { + TClient::TSetOfVehicleData VehicleData; + { // Vehicle Data Lock Scope + auto LockedData = Client->GetAllCars(); + VehicleData = *LockedData.VehicleData; + } // End Vehicle Data Lock Scope + if (VehicleData.empty()) { + return sol::lua_nil; + } + sol::state_view StateView(mState); + sol::table Result = StateView.create_table(); + for (const auto& v : VehicleData) { + Result[v.ID()] = v.DataAsPacket(Client->GetRoles(), Client->GetName(), Client->GetID()).substr(3); + } + return Result; } - return Result; - } else - return sol::lua_nil; + } + return sol::lua_nil; } std::pair TLuaEngine::StateThreadData::Lua_GetPositionRaw(int PID, int VID) { std::pair Result; auto MaybeClient = GetClient(mEngine->Server(), PID); - if (MaybeClient && !MaybeClient.value().expired()) { - auto Client = MaybeClient.value().lock(); - std::string VehiclePos = Client->GetCarPositionRaw(VID); + if (MaybeClient) { + if (auto Client = MaybeClient.value().lock()) { + std::string VehiclePos = Client->GetCarPositionRaw(VID); - if (VehiclePos.empty()) { - // return std::make_tuple(sol::lua_nil, sol::make_object(StateView, "Vehicle not found")); - Result.second = "Vehicle not found"; - return Result; - } + if (VehiclePos.empty()) { + Result.second = "Vehicle not found"; + return Result; + } - sol::table t = Lua_JsonDecode(VehiclePos); - if (t == sol::lua_nil) { - Result.second = "Packet decode failed"; + sol::table t = Lua_JsonDecode(VehiclePos); + if (t == sol::lua_nil) { + Result.second = "Packet decode failed"; + } + Result.first = t; + return Result; } - // return std::make_tuple(Result, sol::make_object(StateView, sol::lua_nil)); - Result.first = t; - return Result; - } else { - // return std::make_tuple(sol::lua_nil, sol::make_object(StateView, "Client expired")); - Result.second = "No such player"; - return Result; } + Result.second = "No such player"; + return Result; } sol::table TLuaEngine::StateThreadData::Lua_HttpCreateConnection(const std::string& host, uint16_t port) { @@ -786,22 +785,31 @@ static bool mDisableMPSet = [] { static auto GetSettingName = [](int id) -> const char* { switch (id) { - case 0: return "Debug"; - case 1: return "Private"; - case 2: return "MaxCars"; - case 3: return "MaxPlayers"; - case 4: return "Map"; - case 5: return "Name"; - case 6: return "Description"; - case 7: return "InformationPacket"; - default: return "Unknown"; + case 0: + return "Debug"; + case 1: + return "Private"; + case 2: + return "MaxCars"; + case 3: + return "MaxPlayers"; + case 4: + return "Map"; + case 5: + return "Name"; + case 6: + return "Description"; + case 7: + return "InformationPacket"; + default: + return "Unknown"; } }; static void JsonDecodeRecursive(sol::state_view& StateView, sol::table& table, const std::string& left, const nlohmann::json& right) { switch (right.type()) { case nlohmann::detail::value_t::null: - AddToTable(table, left, sol::lua_nil_t {}); + AddToTable(table, left, sol::lua_nil_t { }); return; case nlohmann::detail::value_t::object: { auto value = table.create(); @@ -949,12 +957,9 @@ TLuaEngine::StateThreadData::StateThreadData(const std::string& Name, TLuaStateI beammp_lua_error("SendNotification expects 2, 3 or 4 arguments."); } }); - MPTable.set_function("ConfirmationDialog", sol::overload( - &LuaAPI::MP::ConfirmationDialog, - [&](const int& ID, const std::string& Title, const std::string& Body, const sol::table& Buttons, const std::string& InteractionID) { - LuaAPI::MP::ConfirmationDialog(ID, Title, Body, Buttons, InteractionID); - } - )); + MPTable.set_function("ConfirmationDialog", sol::overload(&LuaAPI::MP::ConfirmationDialog, [&](const int& ID, const std::string& Title, const std::string& Body, const sol::table& Buttons, const std::string& InteractionID) { + LuaAPI::MP::ConfirmationDialog(ID, Title, Body, Buttons, InteractionID); + })); MPTable.set_function("GetPlayers", [&]() -> sol::table { return Lua_GetPlayers(); }); diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index 764584cc..c7aff281 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -101,8 +101,8 @@ TNetwork::TNetwork(TServer& Server, TPPSMonitor& PPSMonitor, TResourceManager& R Application::RegisterShutdownHandler([&] { beammp_debug("Kicking all players due to shutdown"); Server.ForEachClient([&](std::weak_ptr client) -> bool { - if (!client.expired()) { - ClientKick(*client.lock(), "Server shutdown"); + if (auto Locked = client.lock()) { + ClientKick(*Locked, "Server shutdown"); } return true; }); @@ -181,8 +181,8 @@ void TNetwork::UDPServerMain() { std::shared_ptr Client; { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - Client = ClientPtr.lock(); + if (auto Locked = ClientPtr.lock()) { + Client = std::move(Locked); } else return true; } @@ -489,8 +489,8 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { std::shared_ptr Cl; { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - Cl = ClientPtr.lock(); + if (auto Locked = ClientPtr.lock()) { + Cl = std::move(Locked); } else return true; } @@ -712,8 +712,11 @@ void TNetwork::DisconnectClient(TClient& c, const std::string& R) { void TNetwork::Looper(const std::weak_ptr& c) { RegisterThreadAuto(); - while (!c.expired()) { + while (true) { auto Client = c.lock(); + if (!Client) { + break; + } if (Client->IsDisconnected()) { beammp_debug("client is disconnected, breaking client loop"); break; @@ -747,20 +750,29 @@ void TNetwork::Looper(const std::weak_ptr& c) { } void TNetwork::TCPClient(const std::weak_ptr& c) { - // TODO: the c.expired() might cause issues here, remove if you end up here with your debugger - if (c.expired() || c.lock()->IsDisconnected()) { + if (auto Client = c.lock()) { + if (Client->IsDisconnected()) { + mServer.RemoveClient(c); + return; + } + } else { mServer.RemoveClient(c); return; } OnConnect(c); - RegisterThread("(" + std::to_string(c.lock()->GetID()) + ") \"" + c.lock()->GetName() + "\""); + if (auto Client = c.lock()) { + RegisterThread("(" + std::to_string(Client->GetID()) + ") \"" + Client->GetName() + "\""); + } else { + return; + } std::jthread QueueSync(&TNetwork::Looper, this, c); while (true) { - if (c.expired()) - break; auto Client = c.lock(); + if (!Client) { + break; + } if (Client->IsDisconnected()) { beammp_debug("client status < 0, breaking client loop"); break; @@ -781,8 +793,7 @@ void TNetwork::TCPClient(const std::weak_ptr& c) { } } - if (!c.expired()) { - auto Client = c.lock(); + if (auto Client = c.lock()) { OnDisconnect(c); } else { beammp_warn("client expired in TCPClient, should never happen"); @@ -793,8 +804,7 @@ void TNetwork::UpdatePlayer(TClient& Client) { std::string Packet = ("Ss") + std::to_string(mServer.ClientCount()) + "/" + std::to_string(Application::Settings.getAsInt(Settings::Key::General_MaxPlayers)) + ":"; mServer.ForEachClient([&](const std::weak_ptr& ClientPtr) -> bool { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - auto c = ClientPtr.lock(); + if (auto c = ClientPtr.lock()) { Packet += c->GetName() + ","; } return true; @@ -937,14 +947,11 @@ TEST_CASE("ReadSocketWithTimeout can timeout then retry successfully") { } void TNetwork::OnDisconnect(const std::weak_ptr& ClientPtr) { - std::shared_ptr LockedClientPtr { nullptr }; - try { - LockedClientPtr = ClientPtr.lock(); - } catch (const std::exception&) { + auto LockedClientPtr = ClientPtr.lock(); + if (!LockedClientPtr) { beammp_warn("Client expired in OnDisconnect, this is unexpected"); return; } - beammp_assert(LockedClientPtr != nullptr); TClient& c = *LockedClientPtr; beammp_info(c.GetName() + (" Connection Terminated")); std::string Packet; @@ -975,8 +982,7 @@ int TNetwork::OpenID() { found = true; mServer.ForEachClient([&](const std::weak_ptr& ClientPtr) -> bool { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - auto c = ClientPtr.lock(); + if (auto c = ClientPtr.lock()) { if (c->GetID() == ID) { found = false; ID++; @@ -989,9 +995,11 @@ int TNetwork::OpenID() { } void TNetwork::OnConnect(const std::weak_ptr& c) { - beammp_assert(!c.expired()); - beammp_info("Client connected"); auto LockedClient = c.lock(); + if (!LockedClient) { + return; + } + beammp_info("Client connected"); LockedClient->SetID(OpenID()); beammp_info("Assigned ID " + std::to_string(LockedClient->GetID()) + " to " + LockedClient->GetName()); LuaAPI::MP::Engine->ReportErrors(LuaAPI::MP::Engine->TriggerEvent("onPlayerConnecting", "", LockedClient->GetID())); @@ -1218,10 +1226,10 @@ bool TNetwork::Respond(TClient& c, const std::vector& MSG, bool Rel, bo } bool TNetwork::SyncClient(const std::weak_ptr& c) { - if (c.expired()) { + auto LockedClient = c.lock(); + if (!LockedClient) { return false; } - auto LockedClient = c.lock(); if (LockedClient->IsSynced()) return true; // Syncing, later set isSynced @@ -1240,8 +1248,8 @@ bool TNetwork::SyncClient(const std::weak_ptr& c) { std::shared_ptr client; { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - client = ClientPtr.lock(); + if (auto Locked = ClientPtr.lock()) { + client = std::move(Locked); } else return true; } @@ -1278,14 +1286,14 @@ void TNetwork::SendToAll(TClient* c, const std::vector& Data, bool Self char C = Data.at(0); bool ret = true; mServer.ForEachClient([&](std::weak_ptr ClientPtr) -> bool { - std::shared_ptr Client; - try { + std::shared_ptr Client { nullptr }; + { ReadLock Lock(mServer.GetClientMutex()); - Client = ClientPtr.lock(); - } catch (const std::exception&) { - // continue - beammp_warn("Client expired, shouldn't happen - if a client disconnected recently, you can ignore this"); - return true; + if (auto Locked = ClientPtr.lock()) { + Client = std::move(Locked); + } else { + return true; + } } if (Self || Client.get() != c) { if (Client->IsSynced() || Client->IsSyncing()) { diff --git a/src/TPPSMonitor.cpp b/src/TPPSMonitor.cpp index 7fe5c6fc..859871e7 100644 --- a/src/TPPSMonitor.cpp +++ b/src/TPPSMonitor.cpp @@ -55,8 +55,8 @@ void TPPSMonitor::operator()() { std::shared_ptr c; { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - c = ClientPtr.lock(); + if (auto Locked = ClientPtr.lock()) { + c = std::move(Locked); } else return true; } diff --git a/src/TServer.cpp b/src/TServer.cpp index 2c475424..f4b0a452 100644 --- a/src/TServer.cpp +++ b/src/TServer.cpp @@ -127,19 +127,15 @@ TServer::TServer(const std::vector& Arguments) { } void TServer::RemoveClient(const std::weak_ptr& WeakClientPtr) { - std::shared_ptr LockedClientPtr { nullptr }; - try { - LockedClientPtr = WeakClientPtr.lock(); - } catch (const std::exception&) { - // silently fail, as there's nothing to do + auto LockedClientPtr = WeakClientPtr.lock(); + if (!LockedClientPtr) { return; } - beammp_assert(LockedClientPtr != nullptr); TClient& Client = *LockedClientPtr; beammp_debug("removing client " + Client.GetName() + " (" + std::to_string(ClientCount()) + ")"); Client.ClearCars(); WriteLock Lock(mClientsMutex); - mClients.erase(WeakClientPtr.lock()); + mClients.erase(LockedClientPtr); } void TServer::ForEachClient(const std::function)>& Fn) { @@ -167,14 +163,16 @@ void TServer::GlobalParser(const std::weak_ptr& Client, std::vectorGetID()); - Network.ClientKick(*LockedClient, "Sent invalid compressed packet (this is likely a bug on your end)"); + if (auto LockedClient = Client.lock()) { + beammp_errorf("Failed to decompress packet from client {}. The client sent invalid data and will now be disconnected.", LockedClient->GetID()); + Network.ClientKick(*LockedClient, "Sent invalid compressed packet (this is likely a bug on your end)"); + } return; } catch (const std::runtime_error& e) { - auto LockedClient = Client.lock(); - beammp_errorf("Failed to decompress packet from client {}: {}. The server might be out of RAM! The client will now be disconnected.", LockedClient->GetID(), e.what()); - Network.ClientKick(*LockedClient, "Decompression failed (likely a server-side problem)"); + if (auto LockedClient = Client.lock()) { + beammp_errorf("Failed to decompress packet from client {}: {}. The server might be out of RAM! The client will now be disconnected.", LockedClient->GetID(), e.what()); + Network.ClientKick(*LockedClient, "Decompression failed (likely a server-side problem)"); + } return; } } @@ -182,10 +180,10 @@ void TServer::GlobalParser(const std::weak_ptr& Client, std::vector Date: Sun, 19 Apr 2026 18:14:04 +0000 Subject: [PATCH 29/36] fix TLuaValue handling to be less odd --- include/TLuaEngine.h | 10 +--- src/TLuaEngine.cpp | 135 ++++++++++++++++++++++++++++++++----------- 2 files changed, 103 insertions(+), 42 deletions(-) diff --git a/include/TLuaEngine.h b/include/TLuaEngine.h index 1aee7e83..e0e1f9b8 100644 --- a/include/TLuaEngine.h +++ b/include/TLuaEngine.h @@ -59,15 +59,7 @@ namespace fs = std::filesystem; /** * std::variant means, that TLuaArgTypes may be one of the Types listed as template args */ -using TLuaValue = std::variant, float>; -enum TLuaType { - String = 0, - Int = 1, - Json = 2, - Bool = 3, - StringStringMap = 4, - Float = 5, -}; +using TLuaValue = std::variant, float>; class TLuaPlugin; diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index 7fdd91a3..eddff6ff 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -37,6 +37,7 @@ #include #include #include +#include #include #include @@ -62,12 +63,6 @@ TLuaEngine::TLuaEngine() IThreaded::Start(); } -TEST_CASE("TLuaEngine ctor & dtor") { - Application::Settings.set(Settings::Key::General_ResourceFolder, "beammp_server_test_resources"); - TLuaEngine engine; - Application::GracefullyShutdown(); -} - void TLuaEngine::operator()() { RegisterThread("LuaEngine"); // lua engine main thread @@ -1241,34 +1236,32 @@ void TLuaEngine::StateThreadData::operator()() { if (Arg.valueless_by_exception()) { continue; } - switch (Arg.index()) { - case TLuaType::String: - LuaArgs.push_back(sol::make_object(StateView, std::get(Arg))); - break; - case TLuaType::Int: - LuaArgs.push_back(sol::make_object(StateView, std::get(Arg))); - break; - case TLuaType::Json: { - auto LocalArgs = JsonStringToArray(std::get(Arg)); - LuaArgs.insert(LuaArgs.end(), LocalArgs.begin(), LocalArgs.end()); - break; - } - case TLuaType::Bool: - LuaArgs.push_back(sol::make_object(StateView, std::get(Arg))); - break; - case TLuaType::StringStringMap: { - auto Map = std::get>(Arg); - auto Table = StateView.create_table(); - for (const auto& [k, v] : Map) { - Table[k] = v; + std::visit([&LuaArgs, &StateView, this](const auto& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + LuaArgs.push_back(sol::make_object(StateView, arg)); + } else if constexpr (std::is_same_v) { + LuaArgs.push_back(sol::make_object(StateView, arg)); + } else if constexpr (std::is_same_v) { + LuaArgs.push_back(sol::make_object(StateView, arg)); + } else if constexpr (std::is_same_v) { + auto LocalArgs = JsonStringToArray(arg); + LuaArgs.insert(LuaArgs.end(), LocalArgs.begin(), LocalArgs.end()); + } else if constexpr (std::is_same_v) { + beammp_lua_error("Unknown argument type, passed as nil"); + LuaArgs.push_back(sol::lua_nil_t()); + } else if constexpr (std::is_same_v>) { + auto Table = StateView.create_table(); + for (const auto& [k, v] : arg) { + Table[k] = v; + } + LuaArgs.push_back(Table); + } else if constexpr (std::is_same_v) { + LuaArgs.push_back(sol::make_object(StateView, arg)); + } else { + static_assert(false, "unhandled variant"); } - LuaArgs.push_back(sol::make_object(StateView, Table)); - break; - } - default: - beammp_error("Unknown argument type, passed as nil"); - break; - } + }, Arg); } auto Res = Fn(sol::as_args(LuaArgs)); if (Res.valid()) { @@ -1348,3 +1341,79 @@ bool TLuaEngine::TimedEvent::Expired() { void TLuaEngine::TimedEvent::Reset() { LastCompletion = std::chrono::high_resolution_clock::now(); } + +TEST_CASE("TLuaEngine ctor & dtor") { + Application::Settings.set(Settings::Key::General_ResourceFolder, "beammp_server_test_resources"); + TLuaEngine engine; + + const TLuaStateId StateId = "lua_event_contract_test"; + engine.EnsureStateExists(StateId, "LuaEventContractTest", true); + + // LLM generated test code + auto Script = std::make_shared(R"( +function onPlayerAuth(playerName, playerRole, isGuest, identifiers) + if type(playerName) ~= "string" then return "on:bad-playerName-type:" .. type(playerName) end + if type(playerRole) ~= "string" then return "on:bad-playerRole-type:" .. type(playerRole) end + if type(isGuest) ~= "boolean" then return "on:bad-isGuest-type:" .. type(isGuest) end + if type(identifiers) ~= "table" then return "on:bad-identifiers-type:" .. type(identifiers) end + return "on:" .. playerName .. ":" .. playerRole .. ":" .. tostring(isGuest) .. ":" .. tostring(identifiers.ip) .. ":" .. tostring(identifiers.beammp) +end + +function postPlayerAuth(isDenied, reason, playerName, playerRole, isGuest, identifiers) + if type(isDenied) ~= "boolean" then return "post:bad-isDenied-type:" .. type(isDenied) end + if type(reason) ~= "string" then return "post:bad-reason-type:" .. type(reason) end + if type(playerName) ~= "string" then return "post:bad-playerName-type:" .. type(playerName) end + if type(playerRole) ~= "string" then return "post:bad-playerRole-type:" .. type(playerRole) end + if type(isGuest) ~= "boolean" then return "post:bad-isGuest-type:" .. type(isGuest) end + if type(identifiers) ~= "table" then return "post:bad-identifiers-type:" .. type(identifiers) end + return "post:" .. tostring(isDenied) .. ":" .. reason .. ":" .. playerName .. ":" .. playerRole .. ":" .. tostring(isGuest) .. ":" .. tostring(identifiers.ip) .. ":" .. tostring(identifiers.beammp) +end + +MP.RegisterEvent("onPlayerAuth", "onPlayerAuth") +MP.RegisterEvent("postPlayerAuth", "postPlayerAuth") +)"); + + auto LoadResult = engine.EnqueueScript(StateId, TLuaChunk(Script, "event_contract.lua", "beammp_server_test_resources/Server/LuaEventContractTest")); + LoadResult->WaitUntilReady(); + auto LoadSnapshot = LoadResult->GetDetachedSnapshot(); + CHECK(!LoadSnapshot.Error); + + const std::unordered_map Identifiers { + {"ip", "410.0.24.1"}, + {"beammp", "123456"}, + }; + + auto OnPlayerAuthResults = engine.TriggerEvent( + "onPlayerAuth", "", + std::string("guest8133569"), + std::string("USER"), + true, + Identifiers); + REQUIRE(OnPlayerAuthResults.size() == 1); + TLuaEngine::WaitForAll(OnPlayerAuthResults); + + auto OnPlayerAuthSnapshot = OnPlayerAuthResults.front()->GetDetachedSnapshot(); + CHECK(!OnPlayerAuthSnapshot.Error); + const auto* OnPlayerAuthValue = std::get_if(&OnPlayerAuthSnapshot.Result.V); + REQUIRE(OnPlayerAuthValue != nullptr); + CHECK(*OnPlayerAuthValue == "on:guest8133569:USER:true:410.0.24.1:123456"); + + auto PostPlayerAuthResults = engine.TriggerEvent( + "postPlayerAuth", "", + false, + std::string(""), + std::string("guest8133569"), + std::string("USER"), + true, + Identifiers); + REQUIRE(PostPlayerAuthResults.size() == 1); + TLuaEngine::WaitForAll(PostPlayerAuthResults); + + auto PostPlayerAuthSnapshot = PostPlayerAuthResults.front()->GetDetachedSnapshot(); + CHECK(!PostPlayerAuthSnapshot.Error); + const auto* PostPlayerAuthValue = std::get_if(&PostPlayerAuthSnapshot.Result.V); + REQUIRE(PostPlayerAuthValue != nullptr); + CHECK(*PostPlayerAuthValue == "post:false::guest8133569:USER:true:410.0.24.1:123456"); + + Application::GracefullyShutdown(); +} From 8f2b8b25e83bb0abef3a11d51d51cbc393a2fc5b Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Sun, 19 Apr 2026 18:41:06 +0000 Subject: [PATCH 30/36] add AGPL headers to more new files --- include/TConnectionLimiter.h | 18 ++++++++++++++++++ include/TLuaResult.h | 18 ++++++++++++++++++ src/TConnectionLimiter.cpp | 18 ++++++++++++++++++ src/TLuaResult.cpp | 18 ++++++++++++++++++ 4 files changed, 72 insertions(+) diff --git a/include/TConnectionLimiter.h b/include/TConnectionLimiter.h index 62df05f5..c9f38dd1 100644 --- a/include/TConnectionLimiter.h +++ b/include/TConnectionLimiter.h @@ -1,3 +1,21 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2026 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + #pragma once #include diff --git a/include/TLuaResult.h b/include/TLuaResult.h index 81ee223e..18afb2e9 100644 --- a/include/TLuaResult.h +++ b/include/TLuaResult.h @@ -1,3 +1,21 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2026 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + #pragma once #include "Common.h" diff --git a/src/TConnectionLimiter.cpp b/src/TConnectionLimiter.cpp index b2c681e2..97d3fabb 100644 --- a/src/TConnectionLimiter.cpp +++ b/src/TConnectionLimiter.cpp @@ -1,3 +1,21 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2026 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + #include "TConnectionLimiter.h" #include #include diff --git a/src/TLuaResult.cpp b/src/TLuaResult.cpp index 75b86b06..7c6e449d 100644 --- a/src/TLuaResult.cpp +++ b/src/TLuaResult.cpp @@ -1,3 +1,21 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2026 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + #include "TLuaResult.h" #include #include From dd6d90a9864860d62c1ae6e7dbaaa74532e23105 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Sun, 19 Apr 2026 19:16:37 +0000 Subject: [PATCH 31/36] force enable SOL_ALL_SAFETIES_ON --- cmake/Vcpkg.cmake | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmake/Vcpkg.cmake b/cmake/Vcpkg.cmake index 9d8ecbac..167f1d35 100644 --- a/cmake/Vcpkg.cmake +++ b/cmake/Vcpkg.cmake @@ -15,3 +15,5 @@ if(NOT DEFINED CMAKE_TOOLCHAIN_FILE) set(CMAKE_TOOLCHAIN_FILE ${CMAKE_SOURCE_DIR}/vcpkg/scripts/buildsystems/vcpkg.cmake) endif() +# ensure SOL2 safeties are all ON +add_compile_definitions(SOL_ALL_SAFETIES_ON=1) From c97d7b1b7326ae896bde64cd6af8529e06fe1b13 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Sun, 19 Apr 2026 19:27:17 +0000 Subject: [PATCH 32/36] fix server cmake version --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b2a76866..24a8a8ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,7 +8,7 @@ include(cmake/Vcpkg.cmake) # needs to happen before project() project( "BeamMP-Server" # replace this - VERSION 3.3.0 + VERSION 3.9.2 ) include(cmake/StandardSettings.cmake) From 4045727c8b3f260b3f16bd2ed62639213cb3380b Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Sun, 19 Apr 2026 20:15:58 +0000 Subject: [PATCH 33/36] fix unit-tests crashing due to sol::object::~object --- src/TLuaResult.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/TLuaResult.cpp b/src/TLuaResult.cpp index 7c6e449d..a57d20e9 100644 --- a/src/TLuaResult.cpp +++ b/src/TLuaResult.cpp @@ -163,8 +163,8 @@ TEST_CASE("TLuaResult MarkReadyError(string) marks ready and wakes waiters") { } TEST_CASE("TLuaResult GetSnapshot enforces owner state id") { - TLuaResult result("owner_state", "fn_owner"); sol::state lua; + TLuaResult result("owner_state", "fn_owner"); lua.open_libraries(sol::lib::base); result.MarkReadySuccess(sol::make_object(lua.lua_state(), std::string("ok"))); @@ -173,8 +173,8 @@ TEST_CASE("TLuaResult GetSnapshot enforces owner state id") { } TEST_CASE("TLuaResult detached snapshot freezes nested string-keyed tables") { - TLuaResult result("state_table", "fn_table"); sol::state lua; + TLuaResult result("state_table", "fn_table"); lua.open_libraries(sol::lib::base); auto outer = lua.create_table(); @@ -216,8 +216,8 @@ TEST_CASE("TLuaResult detached snapshot freezes nested string-keyed tables") { } TEST_CASE("TLuaResult MarkReadySuccess throws on unsupported Lua function value") { - TLuaResult result("state_fn", "fn_fn"); sol::state lua; + TLuaResult result("state_fn", "fn_fn"); lua.open_libraries(sol::lib::base); lua["f"] = [] { return 1; }; const sol::table globals = lua.globals(); From 06bd719dbf98d6a9d283f80514ba7cfb111488df Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Sun, 19 Apr 2026 20:34:20 +0000 Subject: [PATCH 34/36] clarify/fix std::visit compile time check --- include/Common.h | 4 ++++ src/Common.cpp | 6 +++--- src/TLuaEngine.cpp | 4 ++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/include/Common.h b/include/Common.h index 7759e437..a6aea214 100644 --- a/include/Common.h +++ b/include/Common.h @@ -132,6 +132,10 @@ class Application final { static inline Version mVersion { 3, 9, 2 }; }; +/// Used to static_assert in std::visit +template +inline constexpr bool AlwaysFalseV = false; + void SplitString(std::string const& str, const char delim, std::vector& out); std::string LowerString(std::string str); diff --git a/src/Common.cpp b/src/Common.cpp index 7fff4c60..a20141ff 100644 --- a/src/Common.cpp +++ b/src/Common.cpp @@ -378,7 +378,7 @@ std::ostream& operator<<(std::ostream& os, const TDetachedLuaValue& value) { std::visit([&os](auto&& arg) { using T = std::decay_t; - if constexpr (std::is_same_v>) { + if constexpr (std::is_same_v) { size_t i = 0; for (auto val : arg) { if (i > 0) { @@ -386,7 +386,7 @@ std::ostream& operator<<(std::ostream& os, const TDetachedLuaValue& value) { } os << val; } - } else if constexpr (std::is_same_v>) { + } else if constexpr (std::is_same_v) { size_t i = 0; for (auto [key, val] : arg) { if (i > 0) { @@ -406,7 +406,7 @@ std::ostream& operator<<(std::ostream& os, const TDetachedLuaValue& value) { // monostate means no result value os << ""; else - static_assert(false, "non-exhaustive visitor!"); + static_assert(AlwaysFalseV, "non-exhaustive visitor!"); }, value.V); diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index eddff6ff..84dd5013 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -570,7 +570,7 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerGlobalEvent(const std::string // monostate means no result value Result.set(i, sol::lua_nil_t()); else - static_assert(false, "non-exhaustive visitor!"); + static_assert(AlwaysFalseV, "non-exhaustive visitor!"); }, Snapshot.Result.V); } @@ -1259,7 +1259,7 @@ void TLuaEngine::StateThreadData::operator()() { } else if constexpr (std::is_same_v) { LuaArgs.push_back(sol::make_object(StateView, arg)); } else { - static_assert(false, "unhandled variant"); + static_assert(AlwaysFalseV, "unhandled variant"); } }, Arg); } From 7c6acfdc860d021424c55fdb08fb3e6837e9af78 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Sun, 19 Apr 2026 20:35:15 +0000 Subject: [PATCH 35/36] fix never incrementing i in lua result print --- src/Common.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Common.cpp b/src/Common.cpp index a20141ff..be772752 100644 --- a/src/Common.cpp +++ b/src/Common.cpp @@ -385,6 +385,7 @@ std::ostream& operator<<(std::ostream& os, const TDetachedLuaValue& value) { os << ", "; } os << val; + ++i; } } else if constexpr (std::is_same_v) { size_t i = 0; @@ -393,6 +394,7 @@ std::ostream& operator<<(std::ostream& os, const TDetachedLuaValue& value) { os << ", "; } os << key << "=" << val; + ++i; } } else if constexpr (std::is_same_v) os << (arg ? "true" : "false"); From 57fe7cb055f40433142864adc6c5b30fc82047b5 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Sun, 19 Apr 2026 21:16:43 +0000 Subject: [PATCH 36/36] fix GCC 11 compiler/libstdc++ error GCC 11's C++ stdlib does a weird maneuver here where it needs to know the size of the std::pair<>::second's type. So we wrap it in a ptr. --- include/TLuaResult.h | 7 +++++- src/Common.cpp | 40 --------------------------------- src/TLuaEngine.cpp | 4 ++-- src/TLuaResult.cpp | 53 +++++++++++++++++++++++++++++++++++++++----- 4 files changed, 56 insertions(+), 48 deletions(-) diff --git a/include/TLuaResult.h b/include/TLuaResult.h index 18afb2e9..a75da92e 100644 --- a/include/TLuaResult.h +++ b/include/TLuaResult.h @@ -21,13 +21,18 @@ #include "Common.h" #include #include +#include #include using TLuaStateId = std::string; struct TDetachedLuaValue { + using Ptr = std::shared_ptr; using Array = std::vector; - using Object = std::unordered_map; + // This weird Ptr indirection is needed because some implementations of libstc++, + // like the GCC 11 one shipped with ubuntu 22.04, need to know the size of the second + // member of the pairs that make up the elements of such a map. It's fine in vector. + using Object = std::unordered_map; std::variant V; }; std::ostream& operator<<(std::ostream& os, const TDetachedLuaValue& value); diff --git a/src/Common.cpp b/src/Common.cpp index be772752..1c769a3a 100644 --- a/src/Common.cpp +++ b/src/Common.cpp @@ -374,46 +374,6 @@ std::string GetPlatformAgnosticErrorString() { return "(no human-readable errors on this platform)"; #endif } -std::ostream& operator<<(std::ostream& os, const TDetachedLuaValue& value) { - - std::visit([&os](auto&& arg) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - size_t i = 0; - for (auto val : arg) { - if (i > 0) { - os << ", "; - } - os << val; - ++i; - } - } else if constexpr (std::is_same_v) { - size_t i = 0; - for (auto [key, val] : arg) { - if (i > 0) { - os << ", "; - } - os << key << "=" << val; - ++i; - } - } else if constexpr (std::is_same_v) - os << (arg ? "true" : "false"); - else if constexpr (std::is_same_v) - os << arg; - else if constexpr (std::is_same_v) - os << arg; - else if constexpr (std::is_same_v) - os << arg; - else if constexpr (std::is_same_v) - // monostate means no result value - os << ""; - else - static_assert(AlwaysFalseV, "non-exhaustive visitor!"); - }, - value.V); - - return os; -} // TODO: add unit tests to SplitString void SplitString(const std::string& str, const char delim, std::vector& out) { diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index 84dd5013..c0e3a28e 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -554,9 +554,9 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerGlobalEvent(const std::string auto Snapshot = Value->GetDetachedSnapshot(); std::visit([i, &Result](auto&& arg) { using T = std::decay_t; - if constexpr (std::is_same_v>) + if constexpr (std::is_same_v) Result.set(i, arg); - else if constexpr (std::is_same_v>) + else if constexpr (std::is_same_v) Result.set(i, arg); else if constexpr (std::is_same_v) Result.set(i, arg); diff --git a/src/TLuaResult.cpp b/src/TLuaResult.cpp index a57d20e9..39f4af64 100644 --- a/src/TLuaResult.cpp +++ b/src/TLuaResult.cpp @@ -19,9 +19,52 @@ #include "TLuaResult.h" #include #include +#include #include #include + +std::ostream& operator<<(std::ostream& os, const TDetachedLuaValue& value) { + std::visit([&os](auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + size_t i = 0; + for (const auto& val : arg) { + if (i > 0) { + os << ", "; + } + os << val; + ++i; + } + } else if constexpr (std::is_same_v) { + size_t i = 0; + for (const auto& [key, val] : arg) { + if (i > 0) { + os << ", "; + } + os << key << "=" << val; + ++i; + } + } else if constexpr (std::is_same_v) + os << (arg ? "true" : "false"); + else if constexpr (std::is_same_v) + os << arg; + else if constexpr (std::is_same_v) + os << arg; + else if constexpr (std::is_same_v) + os << arg; + else if constexpr (std::is_same_v) + // monostate means no result value + os << ""; + else + static_assert(AlwaysFalseV, "non-exhaustive visitor!"); + }, + value.V); + + return os; +} + + void TLuaResult::MarkReadySuccess(sol::object Res) { std::unique_lock Lock(mMutex); mError = false; @@ -116,7 +159,7 @@ TDetachedLuaValue TLuaResult::Freeze(const sol::object& o, int depth) { for (auto&& [k, v] : o.as()) { if (!k.is()) continue; // no numeric-key handling, don't need it - out.emplace(k.as(), Freeze(v, depth + 1)); + out.emplace(k.as(), std::make_shared(std::move(Freeze(v, depth + 1)))); } return { { std::move(out) } }; } @@ -199,18 +242,18 @@ TEST_CASE("TLuaResult detached snapshot freezes nested string-keyed tables") { CHECK(object->contains("inner")); CHECK_FALSE(object->contains("1")); - const auto* flag = std::get_if(&object->at("flag").V); + const auto* flag = std::get_if(&object->at("flag")->V); REQUIRE(flag != nullptr); CHECK(*flag); - const auto* msg = std::get_if(&object->at("msg").V); + const auto* msg = std::get_if(&object->at("msg")->V); REQUIRE(msg != nullptr); CHECK(*msg == "hello"); - const auto* innerObj = std::get_if(&object->at("inner").V); + const auto* innerObj = std::get_if(&object->at("inner")->V); REQUIRE(innerObj != nullptr); REQUIRE(innerObj->contains("k")); - const auto* innerValue = std::get_if(&innerObj->at("k").V); + const auto* innerValue = std::get_if(&innerObj->at("k")->V); REQUIRE(innerValue != nullptr); CHECK(*innerValue == "v"); }