diff --git a/CMakeLists.txt b/CMakeLists.txt index 67e758d6..24a8a8ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,7 +8,7 @@ include(cmake/Vcpkg.cmake) # needs to happen before project() project( "BeamMP-Server" # replace this - VERSION 3.3.0 + VERSION 3.9.2 ) include(cmake/StandardSettings.cmake) @@ -39,6 +39,7 @@ set(PRJ_HEADERS include/TConfig.h include/TConsole.h include/THeartbeatThread.h + include/TIoPollThread.h include/TLuaEngine.h include/TLuaPlugin.h include/TNetwork.h @@ -52,6 +53,8 @@ set(PRJ_HEADERS include/Settings.h include/Profiling.h include/ChronoWrapper.h + include/TConnectionLimiter.h + include/TLuaResult.h ) # add all source files (.cpp) to this, except the one with main() set(PRJ_SOURCES @@ -65,6 +68,7 @@ set(PRJ_SOURCES src/TConfig.cpp src/TConsole.cpp src/THeartbeatThread.cpp + src/TIoPollThread.cpp src/TLuaEngine.cpp src/TLuaPlugin.cpp src/TNetwork.cpp @@ -78,6 +82,8 @@ set(PRJ_SOURCES src/Settings.cpp src/Profiling.cpp src/ChronoWrapper.cpp + src/TConnectionLimiter.cpp + src/TLuaResult.cpp ) find_package(Lua REQUIRED) diff --git a/cmake/Vcpkg.cmake b/cmake/Vcpkg.cmake index 9d8ecbac..167f1d35 100644 --- a/cmake/Vcpkg.cmake +++ b/cmake/Vcpkg.cmake @@ -15,3 +15,5 @@ if(NOT DEFINED CMAKE_TOOLCHAIN_FILE) set(CMAKE_TOOLCHAIN_FILE ${CMAKE_SOURCE_DIR}/vcpkg/scripts/buildsystems/vcpkg.cmake) endif() +# ensure SOL2 safeties are all ON +add_compile_definitions(SOL_ALL_SAFETIES_ON=1) diff --git a/include/Client.h b/include/Client.h index 9ab2dc72..6f29bd57 100644 --- a/include/Client.h +++ b/include/Client.h @@ -18,6 +18,7 @@ #pragma once +#include #include #include #include @@ -68,8 +69,11 @@ class TClient final { std::string GetCarPositionRaw(int Ident); void SetUDPAddr(const ip::udp::endpoint& Addr) { mUDPAddress = Addr; } void SetTCPSock(ip::tcp::socket&& CSock) { mSocket = std::move(CSock); } - void Disconnect(std::string_view Reason); - bool IsDisconnected() const { return !mSocket.is_open(); } + // Returns true only for the thread that actually performs socket shutdown/close. + [[nodiscard]] bool Disconnect(std::string_view Reason); + bool IsDisconnected() const { + return mDisconnectState.load(std::memory_order_acquire) != EDisconnectState::Connected; + } // locks void DeleteCar(int Ident); [[nodiscard]] const std::unordered_map& GetIdentifiers() const { return mIdentifiers; } @@ -106,6 +110,12 @@ class TClient final { [[nodiscard]] const std::vector& GetMagic() const { return mMagic; } private: + enum class EDisconnectState { + Connected, + Disconnecting, + Disconnected + }; + void InsertVehicle(int ID, const std::string& Data); TServer& mServer; @@ -121,6 +131,8 @@ class TClient final { TSetOfVehicleData mVehicleData; SparseArray mVehiclePosition; std::string mName = "Unknown Client"; + // Once disconnect starts, this client is terminal and its socket must be treated as dead. + std::atomic mDisconnectState { EDisconnectState::Connected }; ip::tcp::socket mSocket; ip::udp::endpoint mUDPAddress {}; int mUnicycleID = -1; diff --git a/include/Common.h b/include/Common.h index 7759e437..a6aea214 100644 --- a/include/Common.h +++ b/include/Common.h @@ -132,6 +132,10 @@ class Application final { static inline Version mVersion { 3, 9, 2 }; }; +/// Used to static_assert in std::visit +template +inline constexpr bool AlwaysFalseV = false; + void SplitString(std::string const& str, const char delim, std::vector& out); std::string LowerString(std::string str); diff --git a/include/TConnectionLimiter.h b/include/TConnectionLimiter.h new file mode 100644 index 00000000..c9f38dd1 --- /dev/null +++ b/include/TConnectionLimiter.h @@ -0,0 +1,90 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2026 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +#pragma once + +#include +#include +#include +#include +#include + +class TConnectionLimiter { +public: + struct TStats { + size_t CurrentGlobal = 0; + size_t MaxGlobal = 0; + size_t ActiveIpBuckets = 0; + size_t CurrentMaxPerIp = 0; + size_t MaxPerIp = 0; + size_t SaturatedIpBuckets = 0; + }; + + class TGuard { + public: + TGuard() = default; + TGuard(TConnectionLimiter* owner, std::string ip); + + TGuard(const TGuard&) = delete; + TGuard& operator=(const TGuard&) = delete; + + ~TGuard() { Release(); } + + // not threadsafe + TGuard(TGuard&& other) noexcept { *this = std::move(other); } + // not threadsafe + TGuard& operator=(TGuard&& other) noexcept; + + private: + friend class TConnectionLimiter; + + // not threadsafe + void Release(); + + TConnectionLimiter* mOwner { nullptr }; + std::string mIp; + }; + + TConnectionLimiter(size_t maxPerIp, size_t maxGlobal); + + [[nodiscard]] std::optional TryAcquire(const std::string& ip); + [[nodiscard]] TStats GetStats(); + +private: + void Release(const std::string& ip) { + std::unique_lock Lock { mMutex }; + auto It = mPerIp.find(ip); + if (It != mPerIp.end()) { + // this guard exists to avoid underflow in case something goes wrong and this gets called too many times + if (It->second > 0) + --It->second; + if (It->second == 0) + mPerIp.erase(It); + } + // this guard exists to avoid underflow in case something goes wrong and this gets called too many times + if (mGlobal > 0) + --mGlobal; + } + + const size_t mMaxPerIp; + const size_t mMaxGlobal; + + std::mutex mMutex { }; + std::unordered_map mPerIp { }; + size_t mGlobal = 0; +}; diff --git a/include/TIoPollThread.h b/include/TIoPollThread.h new file mode 100644 index 00000000..4c8e805f --- /dev/null +++ b/include/TIoPollThread.h @@ -0,0 +1,36 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2024 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +#pragma once + +#include +#include +#include + +class TIoPollThread { +public: + TIoPollThread(); + ~TIoPollThread(); + + boost::asio::io_context& IoCtx() noexcept { return mIoCtx; } + +private: + boost::asio::io_context mIoCtx; + boost::asio::executor_work_guard mWorkGuard; + std::jthread mThread; +}; diff --git a/include/TLuaEngine.h b/include/TLuaEngine.h index 47a5c760..e0e1f9b8 100644 --- a/include/TLuaEngine.h +++ b/include/TLuaEngine.h @@ -19,13 +19,12 @@ #pragma once #include "Profiling.h" +#include "TLuaResult.h" #include "TNetwork.h" #include "TServer.h" -#include #include #include #include -#include #include #include #include @@ -34,6 +33,8 @@ #include #include #include +#include +#include #include #include #include @@ -58,36 +59,10 @@ namespace fs = std::filesystem; /** * std::variant means, that TLuaArgTypes may be one of the Types listed as template args */ -using TLuaValue = std::variant, float>; -enum TLuaType { - String = 0, - Int = 1, - Json = 2, - Bool = 3, - StringStringMap = 4, - Float = 5, -}; +using TLuaValue = std::variant, float>; class TLuaPlugin; -struct TLuaResult { - bool Ready; - bool Error; - std::string ErrorMessage; - sol::object Result { sol::lua_nil }; - TLuaStateId StateId; - std::string Function; - std::shared_ptr ReadyMutex { - std::make_shared() - }; - std::shared_ptr ReadyCondition { - std::make_shared() - }; - - void MarkAsReady(); - void WaitUntilReady(); -}; - struct TLuaPluginConfig { static inline const std::string FileName = "PluginConfig.toml"; TLuaStateId StateId; @@ -229,7 +204,7 @@ class TLuaEngine : public std::enable_shared_from_this, IThreaded { std::unordered_map /* handlers */> Debug_GetEventsForState(TLuaStateId StateId); std::queue>> Debug_GetStateExecuteQueueForState(TLuaStateId StateId); std::vector Debug_GetStateFunctionQueueForState(TLuaStateId StateId); - std::vector Debug_GetResultsToCheckForState(TLuaStateId StateId); + std::vector Debug_GetResultsToCheckForState(TLuaStateId StateId); private: void CollectAndInitPlugins(); diff --git a/include/TLuaResult.h b/include/TLuaResult.h new file mode 100644 index 00000000..a75da92e --- /dev/null +++ b/include/TLuaResult.h @@ -0,0 +1,115 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2026 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +#pragma once + +#include "Common.h" +#include +#include +#include +#include + +using TLuaStateId = std::string; + +struct TDetachedLuaValue { + using Ptr = std::shared_ptr; + using Array = std::vector; + // This weird Ptr indirection is needed because some implementations of libstc++, + // like the GCC 11 one shipped with ubuntu 22.04, need to know the size of the second + // member of the pairs that make up the elements of such a map. It's fine in vector. + using Object = std::unordered_map; + std::variant V; +}; +std::ostream& operator<<(std::ostream& os, const TDetachedLuaValue& value); + +struct TLuaResult { + struct Snapshot { + bool Error; + bool Ready; + std::string ErrorMessage; + /// CAUTION: Accessing this object causes the Lua stack of the owning state + /// to be accessed. Only call this from the owning Lua state. Otherwise, + /// use `DetachedSnapshot`. + sol::object Result; + TLuaStateId StateId; + std::string Function; + }; + + struct DetachedSnapshot { + bool Error; + bool Ready; + std::string ErrorMessage; + /// Serialized Lua result. Has no reference to the Lua state, and is safe to use + /// from any thread that acquires it. + TDetachedLuaValue Result; + TLuaStateId StateId; + std::string Function; + }; + + + TLuaResult(TLuaStateId StateId, std::string FunctionName) : mStateId(StateId), mFunction(std::move(FunctionName)) {} + + /// Marks this result as success & sets it as ready, waking all threads waiting on it. + void MarkReadySuccess(sol::object Res); + /// Marks this result as erroneous & sets it as ready, waking all threads waiting on it. + void MarkReadyError(sol::protected_function_result Res); + /// Marks this result as erroneous & sets it as ready, waking all threads waiting on it. + void MarkReadyError(std::string Res); + /// Wait (suspend) until this result is ready. Use `GetSnapshot` to get the results. + void WaitUntilReady(); + bool IsReady() const; + bool IsError() const; + TLuaStateId OwnerState() const; + void SetOwnerState(TLuaStateId StateId); + + /// To get the result or error, while in the owning state, use this function. + /// Pass your lua state ID so that the function can verify the safety of the access. + /// If you have no state ID, use `GetDetachedSnapshot`. + /// Accesses the Lua state directly when using the result. + Snapshot GetSnapshot(TLuaStateId CallingState) const; + /// Returns a snapshot with, if not an error, a serialized version of the result + /// object. In contrast to `GetSnapshot`, this does not access the Lua stack when + /// accessing `.Result`. + DetachedSnapshot GetDetachedSnapshot() const; + +private: + mutable std::mutex mMutex {}; + bool mReady { false }; + bool mError; + std::string mErrorMessage; + sol::object mResult { sol::lua_nil }; + TDetachedLuaValue mDetachedResult; + TLuaStateId mStateId; + std::string mFunction; + std::condition_variable mReadyCondition {}; + + TDetachedLuaValue Freeze(const sol::object& o, int depth = 0); + + void SetErrorMessageFromResult(sol::protected_function_result& Res) { + if (Res.valid()) { + beammp_lua_errorf("Error was not an error"); + } + if (Res.get_type() == sol::type::string) { + mErrorMessage = Res.get().what(); + } else { + mErrorMessage = "(unknown error; error object is not inspectable)"; + } + } + + void MarkAsReady(); +}; diff --git a/include/TNetwork.h b/include/TNetwork.h index d5e54731..be55b0df 100644 --- a/include/TNetwork.h +++ b/include/TNetwork.h @@ -20,9 +20,9 @@ #include "BoostAliases.h" #include "Compat.h" +#include "TConnectionLimiter.h" #include "TResourceManager.h" #include "TServer.h" -#include #include struct TConnection; @@ -35,18 +35,20 @@ class TNetwork { [[nodiscard]] bool SendLarge(TClient& c, std::vector Data, bool isSync = false); [[nodiscard]] bool Respond(TClient& c, const std::vector& MSG, bool Rel, bool isSync = false); std::shared_ptr CreateClient(boost::asio::ip::tcp::socket&& TCPSock); - std::vector TCPRcv(TClient& c); + std::vector TCPRcv(TClient& c, bool WithTimeout = false); void ClientKick(TClient& c, const std::string& R); void DisconnectClient(const std::weak_ptr& c, const std::string& R); void DisconnectClient(TClient& c, const std::string& R); [[nodiscard]] bool SyncClient(const std::weak_ptr& c); - void Identify(TConnection&& client); + void Identify(TConnection&& client, TConnectionLimiter::TGuard&&); std::shared_ptr Authentication(TConnection&& ClientConnection); void SyncResources(TClient& c); [[nodiscard]] bool UDPSend(TClient& Client, std::vector Data); void SendToAll(TClient* c, const std::vector& Data, bool Self, bool Rel); void UpdatePlayer(TClient& Client); + boost::system::error_code ReadWithTimeout(boost::asio::ip::tcp::socket& Socket, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout); boost::system::error_code ReadWithTimeout(TConnection& Connection, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout); + [[nodiscard]] TConnectionLimiter::TStats GetConnectionLimiterStats() { return mConnectionLimiter.GetStats(); } TResourceManager& ResourceManager() const { return mResourceManager; } @@ -61,8 +63,7 @@ class TNetwork { std::thread mUDPThread; std::thread mTCPThread; std::mutex mOpenIDMutex; - std::map mClientMap; - std::mutex mClientMapMutex; + TConnectionLimiter mConnectionLimiter; std::vector UDPRcvFromClient(boost::asio::ip::udp::endpoint& ClientEndpoint); void OnConnect(const std::weak_ptr& c); diff --git a/include/TServer.h b/include/TServer.h index 4ee9f969..3202c9ea 100644 --- a/include/TServer.h +++ b/include/TServer.h @@ -20,6 +20,7 @@ #include "IThreaded.h" #include "RWMutex.h" +#include "TIoPollThread.h" #include "TScopedTimer.h" #include #include @@ -50,11 +51,10 @@ class TServer final { const TScopedTimer UptimeTimer; - // asio io context - io_context& IoCtx() { return mIoCtx; } + io_context& IoCtx() { return mIoCtxPoller.IoCtx(); } private: - io_context mIoCtx {}; + TIoPollThread mIoCtxPoller; TClientSet mClients; mutable RWMutex mClientsMutex; static void ParseVehicle(TClient& c, const std::string& Pckt, TNetwork& Network); diff --git a/src/Client.cpp b/src/Client.cpp index ed67945f..7d215941 100644 --- a/src/Client.cpp +++ b/src/Client.cpp @@ -76,7 +76,13 @@ std::string TClient::GetCarPositionRaw(int Ident) { } } -void TClient::Disconnect(std::string_view Reason) { +bool TClient::Disconnect(std::string_view Reason) { + // Do not remove this guard: concurrent close() on the same socket can crash in Asio internals. + EDisconnectState Expected = EDisconnectState::Connected; + if (!mDisconnectState.compare_exchange_strong(Expected, EDisconnectState::Disconnecting, std::memory_order_acq_rel, std::memory_order_acquire)) { + return false; + } + beammp_debugf("Disconnecting client {} for reason: {}", GetID(), Reason); boost::system::error_code ec; if (mSocket.is_open()) { @@ -91,6 +97,9 @@ void TClient::Disconnect(std::string_view Reason) { } else { beammp_debug("Socket is already closed."); } + // Terminal state: this client must not perform TCP IO after this point. + mDisconnectState.store(EDisconnectState::Disconnected, std::memory_order_release); + return true; } void TClient::SetCarPosition(int Ident, const std::string& Data) { @@ -169,8 +178,7 @@ std::optional> GetClient(TServer& Server, int ID) { std::optional> MaybeClient { std::nullopt }; Server.ForEachClient([&](std::weak_ptr CPtr) -> bool { ReadLock Lock(Server.GetClientMutex()); - if (!CPtr.expired()) { - auto C = CPtr.lock(); + if (auto C = CPtr.lock()) { if (C->GetID() == ID) { MaybeClient = CPtr; return false; diff --git a/src/Common.cpp b/src/Common.cpp index 1785b9e2..1c769a3a 100644 --- a/src/Common.cpp +++ b/src/Common.cpp @@ -31,6 +31,7 @@ #include #include +#include "TLuaResult.h" #include "Compat.h" #include "CustomAssert.h" #include "Http.h" diff --git a/src/Http.cpp b/src/Http.cpp index 0aba6639..98bb6ce2 100644 --- a/src/Http.cpp +++ b/src/Http.cpp @@ -18,15 +18,12 @@ #include "Http.h" -#include "Client.h" #include "Common.h" #include "CustomAssert.h" -#include "LuaAPI.h" +#include #include #include -#include -#include using json = nlohmann::json; @@ -37,17 +34,107 @@ static size_t CurlWriteCallback(void* contents, size_t size, size_t nmemb, void* return size * nmemb; } +struct CurlDeleter { + void operator()(CURL* c) const noexcept { + if (c) + curl_easy_cleanup(c); + } +}; + +static std::mutex gCurlPoolMutex; +static std::map gCurlPool; // false = free, true = in use +constexpr size_t MAX_CURL_POOL_SIZE = 128; + +static CURL* AcquireCurl() { + std::unique_lock Lock(gCurlPoolMutex); + for (auto& [Curl, InUse] : gCurlPool) { + if (!InUse) { + InUse = true; + curl_easy_reset(Curl); + return Curl; + } + } + + // none found, if we're under the limit add one! + if (gCurlPool.size() >= MAX_CURL_POOL_SIZE) { + beammp_debugf("Ran out of curl handles for network requests, skipping this request"); + return nullptr; + } + + CURL* Curl = curl_easy_init(); + if (!Curl) { + // failed, ignore + return nullptr; + } + + gCurlPool[Curl] = true; + return Curl; +} + +static void ReleaseCurl(CURL* curl) { + if (!curl) { + return; + } + + std::unique_lock Lock(gCurlPoolMutex); + auto It = gCurlPool.find(curl); + if (It != gCurlPool.end()) { + It->second = false; + } +} + +/// RAII container for curl handles to ensure correct release +class CurlLease { +private: + CURL* mHandle = nullptr; +public: + + CurlLease() : mHandle(AcquireCurl()) {} + ~CurlLease() { + ReleaseCurl(mHandle); + } + CurlLease(const CurlLease&) = delete; + CurlLease& operator=(const CurlLease&) = delete; + + CurlLease(CurlLease&& other) noexcept + : mHandle(other.mHandle) { + other.mHandle = nullptr; + } + + CurlLease& operator=(CurlLease&& other) noexcept { + if (this != &other) { + ReleaseCurl(mHandle); + mHandle = other.mHandle; + other.mHandle = nullptr; + } + return *this; + } + + CURL* GetHandle() const noexcept { + return mHandle; + } +}; + std::string Http::GET(const std::string& url, unsigned int* status) { std::string Ret; - static thread_local CURL* curl = curl_easy_init(); + CurlLease Lease{}; + CURL* curl = Lease.GetHandle(); if (curl) { CURLcode res; char errbuf[CURL_ERROR_SIZE]; curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_IPRESOLVE, CURL_IPRESOLVE_V4); curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, (void*)&Ret); - curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 10); // seconds + curl_easy_setopt(curl, CURLOPT_WRITEDATA, reinterpret_cast(&Ret)); + + // ensure we dont keep connections open for long + curl_easy_setopt(curl, CURLOPT_MAXCONNECTS, 8L); + curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 10L); + curl_easy_setopt(curl, CURLOPT_TIMEOUT, 30L); + curl_easy_setopt(curl, CURLOPT_NOSIGNAL, 1L); + curl_easy_setopt(curl, CURLOPT_MAXAGE_CONN, 60L); + curl_easy_setopt(curl, CURLOPT_MAXLIFETIME_CONN, 300L); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); curl_easy_setopt(curl, CURLOPT_ERRORBUFFER, errbuf); errbuf[0] = 0; @@ -71,15 +158,16 @@ std::string Http::GET(const std::string& url, unsigned int* status) { std::string Http::POST(const std::string& url, const std::string& body, const std::string& ContentType, unsigned int* status, const std::map& headers) { std::string Ret; - static thread_local CURL* curl = curl_easy_init(); + CurlLease Lease{}; + CURL* curl = Lease.GetHandle(); if (curl) { CURLcode res; char errbuf[CURL_ERROR_SIZE]; curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_IPRESOLVE, CURL_IPRESOLVE_V4); curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, (void*)&Ret); - curl_easy_setopt(curl, CURLOPT_POST, 1); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, reinterpret_cast(&Ret)); + curl_easy_setopt(curl, CURLOPT_POST, 1L); curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, body.size()); struct curl_slist* list = nullptr; @@ -90,7 +178,15 @@ std::string Http::POST(const std::string& url, const std::string& body, const st } curl_easy_setopt(curl, CURLOPT_HTTPHEADER, list); - curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 10); // seconds + + // ensure we dont keep connections open for long + curl_easy_setopt(curl, CURLOPT_MAXCONNECTS, 8L); + curl_easy_setopt(curl, CURLOPT_CONNECTTIMEOUT, 10L); + curl_easy_setopt(curl, CURLOPT_TIMEOUT, 30L); + curl_easy_setopt(curl, CURLOPT_NOSIGNAL, 1L); + curl_easy_setopt(curl, CURLOPT_MAXAGE_CONN, 60L); + curl_easy_setopt(curl, CURLOPT_MAXLIFETIME_CONN, 300L); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); curl_easy_setopt(curl, CURLOPT_ERRORBUFFER, errbuf); errbuf[0] = 0; diff --git a/src/LuaAPI.cpp b/src/LuaAPI.cpp index 9ee24620..bb16dda0 100644 --- a/src/LuaAPI.cpp +++ b/src/LuaAPI.cpp @@ -143,22 +143,23 @@ static inline std::pair InternalTriggerClientEvent(int Player return { true, "" }; } else { auto MaybeClient = GetClient(LuaAPI::MP::Engine->Server(), PlayerID); - if (!MaybeClient || MaybeClient.value().expired()) { - beammp_lua_errorf("TriggerClientEvent invalid Player ID '{}'", PlayerID); - return { false, "Invalid Player ID" }; - } - auto c = MaybeClient.value().lock(); + if (MaybeClient) { + if (auto c = MaybeClient.value().lock()) { + if (!c->IsSyncing() && !c->IsSynced()) { + return { false, "Player hasn't joined yet" }; + } - if (!c->IsSyncing() && !c->IsSynced()) { - return { false, "Player hasn't joined yet" }; + if (!LuaAPI::MP::Engine->Network().Respond(*c, StringToVector(Packet), true)) { + beammp_lua_errorf("Respond failed, dropping client {}", PlayerID); + LuaAPI::MP::Engine->Network().ClientKick(*c, "Disconnected after failing to receive packets"); + return { false, "Respond failed, dropping client" }; + } + return { true, "" }; + } } - if (!LuaAPI::MP::Engine->Network().Respond(*c, StringToVector(Packet), true)) { - beammp_lua_errorf("Respond failed, dropping client {}", PlayerID); - LuaAPI::MP::Engine->Network().ClientKick(*c, "Disconnected after failing to receive packets"); - return { false, "Respond failed, dropping client" }; - } - return { true, "" }; + beammp_lua_errorf("TriggerClientEvent invalid Player ID '{}'", PlayerID); + return { false, "Invalid Player ID" }; } } @@ -169,13 +170,15 @@ std::pair LuaAPI::MP::TriggerClientEvent(int PlayerID, const std::pair LuaAPI::MP::DropPlayer(int ID, std::optional MaybeReason) { auto MaybeClient = GetClient(Engine->Server(), ID); - if (!MaybeClient || MaybeClient.value().expired()) { - beammp_lua_errorf("Tried to drop client with id {}, who doesn't exist", ID); - return { false, "Player does not exist" }; + if (MaybeClient) { + if (auto c = MaybeClient.value().lock()) { + LuaAPI::MP::Engine->Network().ClientKick(*c, MaybeReason.value_or("No reason")); + return { true, "" }; + } } - auto c = MaybeClient.value().lock(); - LuaAPI::MP::Engine->Network().ClientKick(*c, MaybeReason.value_or("No reason")); - return { true, "" }; + + beammp_lua_errorf("Tried to drop client with id {}, who doesn't exist", ID); + return { false, "Player does not exist" }; } std::pair LuaAPI::MP::SendChatMessage(int ID, const std::string& Message, const bool& LogChat) { @@ -189,21 +192,26 @@ std::pair LuaAPI::MP::SendChatMessage(int ID, const std::stri Result.first = true; } else { auto MaybeClient = GetClient(Engine->Server(), ID); - if (MaybeClient && !MaybeClient.value().expired()) { - auto c = MaybeClient.value().lock(); - if (!c->IsSynced()) { + if (MaybeClient) { + if (auto c = MaybeClient.value().lock()) { + if (!c->IsSynced()) { + Result.first = false; + Result.second = "Player still syncing data"; + return Result; + } + if (LogChat) { + LogChatMessage(" (to \"" + c->GetName() + "\")", -1, Message); + } + if (!Engine->Network().Respond(*c, StringToVector(Packet), true)) { + beammp_errorf("Failed to send chat message back to sender (id {}) - did the sender disconnect?", ID); + // TODO: should we return an error here? + } + Result.first = true; + } else { + beammp_lua_error("SendChatMessage invalid argument [1] invalid ID"); Result.first = false; - Result.second = "Player still syncing data"; - return Result; - } - if (LogChat) { - LogChatMessage(" (to \"" + c->GetName() + "\")", -1, Message); + Result.second = "Invalid Player ID"; } - if (!Engine->Network().Respond(*c, StringToVector(Packet), true)) { - beammp_errorf("Failed to send chat message back to sender (id {}) - did the sender disconnect?", ID); - // TODO: should we return an error here? - } - Result.first = true; } else { beammp_lua_error("SendChatMessage invalid argument [1] invalid ID"); Result.first = false; @@ -223,18 +231,23 @@ std::pair LuaAPI::MP::SendNotification(int ID, const std::str } else { auto MaybeClient = GetClient(Engine->Server(), ID); if (MaybeClient) { - auto c = MaybeClient.value().lock(); - if (!c->IsSynced()) { - Result.first = false; - Result.second = "Player is not synced yet"; - return Result; - } - if (!Engine->Network().Respond(*c, StringToVector(Packet), true)) { - beammp_errorf("Failed to send notification to player (id {}) - did the player disconnect?", ID); + if (auto c = MaybeClient.value().lock()) { + if (!c->IsSynced()) { + Result.first = false; + Result.second = "Player is not synced yet"; + return Result; + } + if (!Engine->Network().Respond(*c, StringToVector(Packet), true)) { + beammp_errorf("Failed to send notification to player (id {}) - did the player disconnect?", ID); + Result.first = false; + Result.second = "Failed to send packet"; + } + Result.first = true; + } else { + beammp_lua_error("SendNotification invalid argument [1] invalid ID"); Result.first = false; - Result.second = "Failed to send packet"; + Result.second = "Invalid Player ID"; } - Result.first = true; } else { beammp_lua_error("SendNotification invalid argument [1] invalid ID"); Result.first = false; @@ -265,18 +278,23 @@ std::pair LuaAPI::MP::ConfirmationDialog(int ID, const std::s } else { auto MaybeClient = GetClient(Engine->Server(), ID); if (MaybeClient) { - auto c = MaybeClient.value().lock(); - if (!c->IsSynced()) { - Result.first = false; - Result.second = "Player is not synced yet"; - return Result; - } - if (!Engine->Network().Respond(*c, StringToVector(Packet), true)) { - beammp_errorf("Failed to send confirmation dialog to player (id {}) - did the player disconnect?", ID); + if (auto c = MaybeClient.value().lock()) { + if (!c->IsSynced()) { + Result.first = false; + Result.second = "Player is not synced yet"; + return Result; + } + if (!Engine->Network().Respond(*c, StringToVector(Packet), true)) { + beammp_errorf("Failed to send confirmation dialog to player (id {}) - did the player disconnect?", ID); + Result.first = false; + Result.second = "Failed to send packet"; + } + Result.first = true; + } else { + beammp_lua_error("ConfirmationDialog invalid argument [1] invalid ID"); Result.first = false; - Result.second = "Failed to send packet"; + Result.second = "Invalid Player ID"; } - Result.first = true; } else { beammp_lua_error("ConfirmationDialog invalid argument [1] invalid ID"); Result.first = false; @@ -290,22 +308,27 @@ std::pair LuaAPI::MP::ConfirmationDialog(int ID, const std::s std::pair LuaAPI::MP::RemoveVehicle(int PID, int VID) { std::pair Result; auto MaybeClient = GetClient(Engine->Server(), PID); - if (!MaybeClient || MaybeClient.value().expired()) { + if (MaybeClient) { + if (auto c = MaybeClient.value().lock()) { + if (c->GetCarData(VID) != nlohmann::detail::value_t::null) { + std::string Destroy = "Od:" + std::to_string(PID) + "-" + std::to_string(VID); + LuaAPI::MP::Engine->ReportErrors(LuaAPI::MP::Engine->TriggerEvent("onVehicleDeleted", "", PID, VID)); + Engine->Network().SendToAll(nullptr, StringToVector(Destroy), true, true); + c->DeleteCar(VID); + Result.first = true; + } else { + Result.first = false; + Result.second = "Vehicle does not exist"; + } + } else { + beammp_lua_error("RemoveVehicle invalid Player ID"); + Result.first = false; + Result.second = "Invalid Player ID"; + } + } else { beammp_lua_error("RemoveVehicle invalid Player ID"); Result.first = false; Result.second = "Invalid Player ID"; - return Result; - } - auto c = MaybeClient.value().lock(); - if (c->GetCarData(VID) != nlohmann::detail::value_t::null) { - std::string Destroy = "Od:" + std::to_string(PID) + "-" + std::to_string(VID); - LuaAPI::MP::Engine->ReportErrors(LuaAPI::MP::Engine->TriggerEvent("onVehicleDeleted", "", PID, VID)); - Engine->Network().SendToAll(nullptr, StringToVector(Destroy), true, true); - c->DeleteCar(VID); - Result.first = true; - } else { - Result.first = false; - Result.second = "Vehicle does not exist"; } return Result; } @@ -412,20 +435,22 @@ void LuaAPI::MP::Sleep(size_t Ms) { bool LuaAPI::MP::IsPlayerConnected(int ID) { auto MaybeClient = GetClient(Engine->Server(), ID); - if (MaybeClient && !MaybeClient.value().expired()) { - return MaybeClient.value().lock()->IsUDPConnected(); - } else { - return false; + if (MaybeClient) { + if (auto c = MaybeClient.value().lock()) { + return c->IsUDPConnected(); + } } + return false; } bool LuaAPI::MP::IsPlayerGuest(int ID) { auto MaybeClient = GetClient(Engine->Server(), ID); - if (MaybeClient && !MaybeClient.value().expired()) { - return MaybeClient.value().lock()->IsGuest(); - } else { - return false; + if (MaybeClient) { + if (auto c = MaybeClient.value().lock()) { + return c->IsGuest(); + } } + return false; } void LuaAPI::MP::PrintRaw(sol::variadic_args Args) { @@ -440,7 +465,16 @@ void LuaAPI::MP::PrintRaw(sol::variadic_args Args) { } int LuaAPI::PanicHandler(lua_State* State) { - beammp_lua_error("PANIC: " + sol::stack::get(State, 1)); + // panic path: use raw lua c api only; sol2 conversions can raise lua_error + // and re-enter panic recursively. + const int ErrorType = lua_type(State, -1); + if (ErrorType == LUA_TSTRING) { + const char* Message = lua_tostring(State, -1); + beammp_lua_error(std::string("PANIC: ") + (Message ? Message : "(null)")); + } else { + const char* ErrorTypeName = lua_typename(State, ErrorType); + beammp_lua_error(std::string("PANIC: non-string error object (") + (ErrorTypeName ? ErrorTypeName : "unknown") + ")"); + } return 0; } diff --git a/src/TConnectionLimiter.cpp b/src/TConnectionLimiter.cpp new file mode 100644 index 00000000..97d3fabb --- /dev/null +++ b/src/TConnectionLimiter.cpp @@ -0,0 +1,96 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2026 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +#include "TConnectionLimiter.h" +#include +#include +#include "Common.h" + +TConnectionLimiter::TConnectionLimiter(size_t maxPerIp, size_t maxGlobal) + : mMaxPerIp(maxPerIp) + , mMaxGlobal(maxGlobal) { + if (maxPerIp == 0) { + beammp_errorf("Max connections count per IP is set to zero; the server would reject ALL connections"); + throw std::runtime_error("Invalid maximum connections per IP setting"); + } + if (maxGlobal == 0) { + beammp_errorf("Max connection count is set to zero; the server would reject ALL connections"); + throw std::runtime_error("Invalid maximum connections setting"); + } +} + +std::optional TConnectionLimiter::TryAcquire(const std::string& ip) { + std::unique_lock Lock { mMutex }; + if (mGlobal >= mMaxGlobal) return std::nullopt; + // `It` is the inserted element (so if insertion worked, its 0), or its the element that + // was already there, in which case we must check the ip connection limit + auto [It, _] = mPerIp.try_emplace(ip, 0); + if (It->second >= mMaxPerIp) { + return std::nullopt; + } + // now increment the counter finally + ++It->second; + ++mGlobal; + // RAII guard will drop the count once destructed + return TGuard(this, ip); +} + +TConnectionLimiter::TStats TConnectionLimiter::GetStats() { + std::unique_lock Lock { mMutex }; + TStats Stats; + Stats.CurrentGlobal = mGlobal; + Stats.MaxGlobal = mMaxGlobal; + Stats.ActiveIpBuckets = mPerIp.size(); + Stats.MaxPerIp = mMaxPerIp; + for (const auto& [_, Count] : mPerIp) { + if (Count > Stats.CurrentMaxPerIp) { + Stats.CurrentMaxPerIp = Count; + } + if (Count >= mMaxPerIp) { + ++Stats.SaturatedIpBuckets; + } + } + return Stats; +} + +TConnectionLimiter::TGuard::TGuard(TConnectionLimiter* owner, std::string ip) + : mOwner(owner) + , mIp(std::move(ip)) { + beammp_debugf("Acquired connection guard for {}", mIp); +} + +TConnectionLimiter::TGuard& TConnectionLimiter::TGuard::operator=(TGuard&& other) noexcept { + // identity check + if (this != &other) { + Release(); + mOwner = other.mOwner; + mIp = std::move(other.mIp); + other.mOwner = nullptr; + // setting this so its obvious when this happens, instead of being UB or empty string + other.mIp = ""; + } + return *this; +} + +void TConnectionLimiter::TGuard::Release() { + if (mOwner) { + mOwner->Release(mIp); + mOwner = nullptr; + beammp_debugf("Released connection guard for {}", mIp); + } +} diff --git a/src/TConsole.cpp b/src/TConsole.cpp index fb1dc8b0..2aca3a0d 100644 --- a/src/TConsole.cpp +++ b/src/TConsole.cpp @@ -22,9 +22,9 @@ #include "Client.h" #include "CustomAssert.h" +#include "Http.h" #include "LuaAPI.h" #include "TLuaEngine.h" -#include "Http.h" #include #include @@ -296,8 +296,7 @@ void TConsole::Command_NetTest(const std::string& cmd, const std::vector& return StringStartsWith(Name1, Name2) || StringStartsWith(Name2, Name1); }; mLuaEngine->Server().ForEachClient([&](std::weak_ptr Client) -> bool { - if (!Client.expired()) { - auto locked = Client.lock(); - if (NameCompare(locked->GetName(), Name)) { - mLuaEngine->Network().ClientKick(*locked, Reason); + if (auto Locked = Client.lock()) { + if (NameCompare(Locked->GetName(), Name)) { + mLuaEngine->Network().ClientKick(*Locked, Reason); Kicked = true; return false; } @@ -356,7 +354,7 @@ std::tuple> TConsole::ParseCommand(const s // It correctly splits arguments, including respecting single and double quotes, as well as backticks auto End_i = CommandWithArgs.find_first_of(' '); std::string Command = CommandWithArgs.substr(0, End_i); - std::string ArgsStr {}; + std::string ArgsStr { }; if (End_i != std::string::npos) { ArgsStr = CommandWithArgs.substr(End_i); } @@ -566,11 +564,10 @@ void TConsole::Command_List(const std::string&, const std::vector& std::stringstream ss; ss << std::left << std::setw(25) << "Name" << std::setw(6) << "ID" << std::setw(6) << "Cars" << std::endl; mLuaEngine->Server().ForEachClient([&](std::weak_ptr Client) -> bool { - if (!Client.expired()) { - auto locked = Client.lock(); - ss << std::left << std::setw(25) << locked->GetName() - << std::setw(6) << locked->GetID() - << std::setw(6) << locked->GetCarCount() << "\n"; + if (auto Locked = Client.lock()) { + ss << std::left << std::setw(25) << Locked->GetName() + << std::setw(6) << Locked->GetID() + << std::setw(6) << Locked->GetCarCount() << "\n"; } return true; }); @@ -593,8 +590,7 @@ void TConsole::Command_Status(const std::string&, const std::vector size_t MissedPacketQueueSum = 0; int LargestSecondsSinceLastPing = 0; mLuaEngine->Server().ForEachClient([&](std::weak_ptr Client) -> bool { - if (!Client.expired()) { - auto Locked = Client.lock(); + if (auto Locked = Client.lock()) { CarCount += Locked->GetCarCount(); ConnectedCount += Locked->IsUDPConnected() ? 1 : 0; GuestCount += Locked->IsGuest() ? 1 : 0; @@ -613,11 +609,11 @@ void TConsole::Command_Status(const std::string&, const std::vector size_t SystemsBad = 0; size_t SystemsShuttingDown = 0; size_t SystemsShutdown = 0; - std::string SystemsBadList {}; - std::string SystemsGoodList {}; - std::string SystemsStartingList {}; - std::string SystemsShuttingDownList {}; - std::string SystemsShutdownList {}; + std::string SystemsBadList { }; + std::string SystemsGoodList { }; + std::string SystemsStartingList { }; + std::string SystemsShuttingDownList { }; + std::string SystemsShutdownList { }; auto Statuses = Application::GetSubsystemStatuses(); for (const auto& NameStatusPair : Statuses) { switch (NameStatusPair.second) { @@ -653,6 +649,7 @@ void TConsole::Command_Status(const std::string&, const std::vector SystemsShutdownList = SystemsShutdownList.substr(0, SystemsShutdownList.size() - 2); auto ElapsedTime = mLuaEngine->Server().UptimeTimer.GetElapsedTime(); + auto ConnectionLimiterStats = mLuaEngine->Network().GetConnectionLimiterStats(); Status << "BeamMP-Server Status:\n" << "\tTotal Players: " << mLuaEngine->Server().ClientCount() << "\n" @@ -667,6 +664,11 @@ void TConsole::Command_Status(const std::string&, const std::vector << "\t\tStates: " << mLuaEngine->GetLuaStateCount() << "\n" << "\t\tEvent timers: " << mLuaEngine->GetTimedEventsCount() << "\n" << "\t\tEvent handlers: " << mLuaEngine->GetRegisteredEventHandlerCount() << "\n" + << "\tConnection limiter:\n" + << "\t\tActive/Max global: " << ConnectionLimiterStats.CurrentGlobal << "/" << ConnectionLimiterStats.MaxGlobal << "\n" + << "\t\tActive IP buckets: " << ConnectionLimiterStats.ActiveIpBuckets << "\n" + << "\t\tHighest single IP load: " << ConnectionLimiterStats.CurrentMaxPerIp << "/" << ConnectionLimiterStats.MaxPerIp << "\n" + << "\t\tSaturated IP buckets: " << ConnectionLimiterStats.SaturatedIpBuckets << "\n" << "\tSubsystems:\n" << "\t\tGood/Starting/Bad: " << SystemsGood << "/" << SystemsStarting << "/" << SystemsBad << "\n" << "\t\tShutting down/Shut down: " << SystemsShuttingDown << "/" << SystemsShutdown << "\n" @@ -683,9 +685,13 @@ void TConsole::Command_Status(const std::string&, const std::vector void TConsole::RunAsCommand(const std::string& cmd, bool IgnoreNotACommand) { auto FutureIsNonNil = [](const std::shared_ptr& Future) { - if (!Future->Error && Future->Result.valid()) { - auto Type = Future->Result.get_type(); - return Type != sol::type::lua_nil && Type != sol::type::none; + if (!Future->IsError()) { + auto Snapshot = Future->GetDetachedSnapshot(); + if (Snapshot.Result.V.valueless_by_exception() || std::get_if(&Snapshot.Result.V) != nullptr) { + // no value contained + return false; + } + return true; } return false; }; @@ -695,7 +701,7 @@ void TConsole::RunAsCommand(const std::string& cmd, bool IgnoreNotACommand) { TLuaEngine::WaitForAll(Futures, std::chrono::seconds(5)); size_t Count = 0; for (auto& Future : Futures) { - if (!Future->Error) { + if (!Future->IsError()) { ++Count; } } @@ -713,14 +719,16 @@ void TConsole::RunAsCommand(const std::string& cmd, bool IgnoreNotACommand) { std::stringstream Reply; if (NonNilFutures.size() > 1) { for (size_t i = 0; i < NonNilFutures.size(); ++i) { - Reply << NonNilFutures[i]->StateId << ": \n" - << LuaAPI::LuaToString(NonNilFutures[i]->Result); + auto Snapshot = NonNilFutures[i]->GetDetachedSnapshot(); + Reply << Snapshot.StateId << ": \n" + << Snapshot.Result; if (i < NonNilFutures.size() - 1) { Reply << "\n"; } } } else { - Reply << LuaAPI::LuaToString(NonNilFutures[0]->Result); + auto Snapshot = NonNilFutures[0]->GetDetachedSnapshot(); + Reply << Snapshot.Result; } Application::Console().WriteRaw(Reply.str()); } @@ -801,8 +809,9 @@ void TConsole::InitializeCommandline() { } else { auto Future = mLuaEngine->EnqueueScript(mStateId, { std::make_shared(TrimmedCmd), "", "" }); Future->WaitUntilReady(); - if (Future->Error) { - beammp_lua_error("error in " + mStateId + ": " + Future->ErrorMessage); + if (Future->IsError()) { + auto Snapshot = Future->GetDetachedSnapshot(); + beammp_lua_error("error in " + mStateId + ": " + Snapshot.ErrorMessage); } } } else { @@ -834,7 +843,7 @@ void TConsole::InitializeCommandline() { if (!mLuaEngine) { beammp_info("Lua not started yet, please try again in a second"); } else { - std::string prefix {}; // stores non-table part of input + std::string prefix { }; // stores non-table part of input for (size_t i = stub.length(); i > 0; i--) { // separate table from input if (!std::isalnum(stub[i - 1]) && stub[i - 1] != '_' && stub[i - 1] != '.') { prefix = stub.substr(0, i); diff --git a/src/THeartbeatThread.cpp b/src/THeartbeatThread.cpp index 36cdde1c..347f59a3 100644 --- a/src/THeartbeatThread.cpp +++ b/src/THeartbeatThread.cpp @@ -181,8 +181,8 @@ std::string THeartbeatThread::GetPlayers() { std::string Return; mServer.ForEachClient([&](const std::weak_ptr& ClientPtr) -> bool { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - Return += ClientPtr.lock()->GetName() + ";"; + if (auto Client = ClientPtr.lock()) { + Return += Client->GetName() + ";"; } return true; }); diff --git a/src/TIoPollThread.cpp b/src/TIoPollThread.cpp new file mode 100644 index 00000000..dabf8c89 --- /dev/null +++ b/src/TIoPollThread.cpp @@ -0,0 +1,37 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2024 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +#include "TIoPollThread.h" + +TIoPollThread::TIoPollThread() + : mWorkGuard(boost::asio::make_work_guard(mIoCtx)) + , mThread([this](std::stop_token StopToken) { + while (!StopToken.stop_requested()) { + try { + mIoCtx.run(); + break; + } catch (...) { + mIoCtx.restart(); + } + } + }) { } + +TIoPollThread::~TIoPollThread() { + mWorkGuard.reset(); + mIoCtx.stop(); +} diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index d7bee3e6..c0e3a28e 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -20,20 +20,26 @@ #include "Client.h" #include "Common.h" #include "CustomAssert.h" +#include "Env.h" #include "Http.h" #include "LuaAPI.h" -#include "Env.h" #include "Profiling.h" #include "TLuaPlugin.h" +#include "TLuaResult.h" #include "sol/object.hpp" #include #include #include +#include #include #include +#include #include #include +#include +#include +#include TLuaEngine* LuaAPI::MP::Engine; @@ -57,12 +63,6 @@ TLuaEngine::TLuaEngine() IThreaded::Start(); } -TEST_CASE("TLuaEngine ctor & dtor") { - Application::Settings.set(Settings::Key::General_ResourceFolder, "beammp_server_test_resources"); - TLuaEngine engine; - Application::GracefullyShutdown(); -} - void TLuaEngine::operator()() { RegisterThread("LuaEngine"); // lua engine main thread @@ -74,8 +74,9 @@ void TLuaEngine::operator()() { auto Futures = TriggerEvent("onInit", ""); WaitForAll(Futures, std::chrono::seconds(5)); for (const auto& Future : Futures) { - if (Future->Error && Future->ErrorMessage != BeamMPFnNotFoundError) { - beammp_lua_error("Calling \"onInit\" on \"" + Future->StateId + "\" failed: " + Future->ErrorMessage); + auto Snapshot = Future->GetDetachedSnapshot(); + if (Snapshot.Error && Snapshot.ErrorMessage != BeamMPFnNotFoundError) { + beammp_lua_error("Calling \"onInit\" on \"" + Snapshot.StateId + "\" failed: " + Snapshot.ErrorMessage); } } @@ -85,10 +86,11 @@ void TLuaEngine::operator()() { std::unique_lock Lock(mResultsToCheckMutex); if (!mResultsToCheck.empty()) { mResultsToCheck.remove_if([](const std::shared_ptr& Ptr) -> bool { - if (Ptr->Ready) { - if (Ptr->Error) { - if (Ptr->ErrorMessage != BeamMPFnNotFoundError) { - beammp_lua_error(Ptr->Function + ": " + Ptr->ErrorMessage); + if (Ptr->IsReady()) { + auto Snapshot = Ptr->GetDetachedSnapshot(); + if (Snapshot.Error) { + if (Snapshot.ErrorMessage != BeamMPFnNotFoundError) { + beammp_lua_error(Snapshot.Function + ": " + Snapshot.ErrorMessage); } } return true; @@ -113,7 +115,7 @@ void TLuaEngine::operator()() { std::unique_lock StateLock(mLuaStatesMutex); std::unique_lock Lock2(mResultsToCheckMutex); for (auto& Handler : Handlers) { - auto Res = mLuaStates[Timer.StateId]->EnqueueFunctionCallFromCustomEvent(Handler, {}, Timer.EventName, Timer.Strategy); + auto Res = mLuaStates[Timer.StateId]->EnqueueFunctionCallFromCustomEvent(Handler, { }, Timer.EventName, Timer.Strategy); if (Res) { mResultsToCheck.push_back(Res); mResultsToCheckCond.notify_one(); @@ -128,7 +130,13 @@ void TLuaEngine::operator()() { } } } - if (mLuaStates.empty()) { + bool StatesEmpty = false; + { + std::unique_lock Lock(mLuaStatesMutex); + StatesEmpty = mLuaStates.empty(); + } + + if (StatesEmpty) { beammp_trace("No Lua states, event loop running extremely sparsely"); Application::SleepSafeSeconds(10); } else { @@ -215,16 +223,16 @@ std::vector TLuaEngine::Debug_GetStateFunctionQueueF return Result; } -std::vector TLuaEngine::Debug_GetResultsToCheckForState(TLuaStateId StateId) { +std::vector TLuaEngine::Debug_GetResultsToCheckForState(TLuaStateId StateId) { std::unique_lock Lock(mResultsToCheckMutex); auto ResultsToCheckCopy = mResultsToCheck; Lock.unlock(); - std::vector Result; + std::vector Result; while (!ResultsToCheckCopy.empty()) { auto ResultToCheck = std::move(ResultsToCheckCopy.front()); ResultsToCheckCopy.pop_front(); - if (ResultToCheck->StateId == StateId) { - Result.push_back(*ResultToCheck); + if (ResultToCheck->OwnerState() == StateId) { + Result.push_back(ResultToCheck->GetDetachedSnapshot()); } } return Result; @@ -255,7 +263,7 @@ std::vector TLuaEngine::StateThreadData::GetStateTableKeys(const st auto globals = mStateView.globals(); sol::table current = globals; - std::vector Result {}; + std::vector Result { }; for (const auto& [key, value] : current) { std::string s = key.as(); @@ -311,27 +319,30 @@ void TLuaEngine::WaitForAll(std::vector>& Results, c size_t ms = 0; std::set WarnedResults; - while (!Result->Ready && !Cancelled) { + while (!Result->IsReady() && !Cancelled) { std::this_thread::sleep_for(std::chrono::milliseconds(10)); ms += 10; if (Max.has_value() && std::chrono::milliseconds(ms) > Max.value()) { - beammp_trace("'" + Result->Function + "' in '" + Result->StateId + "' did not finish executing in time (took: " + std::to_string(ms) + "ms)."); + auto Snapshot = Result->GetDetachedSnapshot(); + beammp_trace("'" + Snapshot.Function + "' in '" + Snapshot.StateId + "' did not finish executing in time (took: " + std::to_string(ms) + "ms)."); Cancelled = true; } else if (ms > 1000 * 60) { - auto ResultId = Result->StateId + "_" + Result->Function; + auto Snapshot = Result->GetDetachedSnapshot(); + auto ResultId = Snapshot.StateId + "_" + Snapshot.Function; if (WarnedResults.count(ResultId) == 0) { WarnedResults.insert(ResultId); - beammp_lua_warn("'" + Result->Function + "' in '" + Result->StateId + "' is taking very long. The event it's handling is too important to discard the result of this handler, but may block this event and possibly the whole lua state."); + beammp_lua_warn("'" + Snapshot.Function + "' in '" + Snapshot.StateId + "' is taking very long. The event it's handling is too important to discard the result of this handler, but may block this event and possibly the whole lua state."); } } } + auto Snapshot = Result->GetDetachedSnapshot(); if (Cancelled) { - beammp_lua_warn("'" + Result->Function + "' in '" + Result->StateId + "' failed to execute in time and was not waited for. It may still finish executing at a later time."); + beammp_lua_warn("'" + Snapshot.Function + "' in '" + Snapshot.StateId + "' failed to execute in time and was not waited for. It may still finish executing at a later time."); LuaAPI::MP::Engine->ReportErrors({ Result }); - } else if (Result->Error) { - if (Result->ErrorMessage != BeamMPFnNotFoundError) { - beammp_lua_error(Result->Function + ": " + Result->ErrorMessage); + } else if (Snapshot.Error) { + if (Snapshot.ErrorMessage != BeamMPFnNotFoundError) { + beammp_lua_error(Snapshot.Function + ": " + Snapshot.ErrorMessage); } } } @@ -431,10 +442,11 @@ void TLuaEngine::EnsureStateExists(TLuaStateId StateId, const std::string& Name, mLuaStates[StateId] = std::move(DataPtr); RegisterEvent("onInit", StateId, "onInit"); if (!DontCallOnInit) { - auto Res = EnqueueFunctionCall(StateId, "onInit", {}, "onInit"); + auto Res = EnqueueFunctionCall(StateId, "onInit", { }, "onInit"); Res->WaitUntilReady(); - if (Res->Error && Res->ErrorMessage != TLuaEngine::BeamMPFnNotFoundError) { - beammp_lua_error("Calling \"onInit\" on \"" + StateId + "\" failed: " + Res->ErrorMessage); + auto Snapshot = Res->GetSnapshot(StateId); + if (Snapshot.Error && Snapshot.ErrorMessage != TLuaEngine::BeamMPFnNotFoundError) { + beammp_lua_error("Calling \"onInit\" on \"" + StateId + "\" failed: " + Snapshot.ErrorMessage); } } } @@ -446,6 +458,7 @@ void TLuaEngine::RegisterEvent(const std::string& EventName, TLuaStateId StateId } std::set TLuaEngine::GetEventHandlersForState(const std::string& EventName, TLuaStateId StateId) { + std::unique_lock Lock(mLuaEventsMutex); return mLuaEvents[EventName][StateId]; } @@ -453,7 +466,7 @@ std::vector TLuaEngine::StateThreadData::JsonStringToArray(JsonStri auto LocalTable = Lua_JsonDecode(Str.value).as>(); for (auto& value : LocalTable) { if (value.is() && value.as() == BEAMMP_INTERNAL_NIL) { - value = sol::object {}; + value = sol::object { }; } } return LocalTable; @@ -495,15 +508,16 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerGlobalEvent(const std::string auto Fn = mStateView[Handler]; if (Fn.valid()) { auto LuaResult = Fn(LocalArgs); - auto Result = std::make_shared(); + auto Result = std::make_shared(mStateId, Handler); if (LuaResult.valid()) { - Result->Error = false; - Result->Result = LuaResult; + try { + Result->MarkReadySuccess(LuaResult); + } catch (const std::exception& e) { + Result->MarkReadyError(fmt::format("Call was successful, but result could not be serialized")); + } } else { - Result->Error = true; - Result->ErrorMessage = "Function result in TriggerGlobalEvent was invalid"; + Result->MarkReadyError("Function result in TriggerGlobalEvent was invalid"); } - Result->MarkAsReady(); Return.push_back(Result); } } @@ -511,26 +525,55 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerGlobalEvent(const std::string sol::table AsyncEventReturn = StateView.create_table(); AsyncEventReturn["ReturnValueImpl"] = Return; AsyncEventReturn.set_function("IsDone", - [&](const sol::table& Self) -> bool { + [](const sol::table& Self) -> bool { auto Vector = Self.get>>("ReturnValueImpl"); for (const auto& Value : Vector) { - if (!Value->Ready) { + if (!Value->IsReady()) { return false; } } return true; }); + TLuaStateId StateId = mStateId; AsyncEventReturn.set_function("GetResults", - [&](const sol::table& Self) -> sol::table { - sol::state_view StateView(mState); + [StateId](const sol::table& Self, sol::this_state State) -> sol::table { + sol::state_view StateView(State); sol::table Result = StateView.create_table(); auto Vector = Self.get>>("ReturnValueImpl"); int i = 1; for (const auto& Value : Vector) { - if (!Value->Ready) { + if (!Value->IsReady()) { return sol::lua_nil; } - Result.set(i, Value->Result); + // event results from this state are valid unserialized + if (Value->OwnerState() == StateId) { + auto Snapshot = Value->GetSnapshot(StateId); + Result.set(i, Snapshot.Result); + } else { + // event result from another state, goes through serialization boundary + auto Snapshot = Value->GetDetachedSnapshot(); + std::visit([i, &Result](auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v) + Result.set(i, arg); + else if constexpr (std::is_same_v) + Result.set(i, arg); + else if constexpr (std::is_same_v) + Result.set(i, arg); + else if constexpr (std::is_same_v) + Result.set(i, arg); + else if constexpr (std::is_same_v) + Result.set(i, arg); + else if constexpr (std::is_same_v) + Result.set(i, arg); + else if constexpr (std::is_same_v) + // monostate means no result value + Result.set(i, sol::lua_nil_t()); + else + static_assert(AlwaysFalseV, "non-exhaustive visitor!"); + }, Snapshot.Result.V); + } + ++i; } return Result; @@ -550,8 +593,13 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerLocalEvent(const std::string& Result.set(i, FnRet); ++i; } else { - sol::error Err = FnRet; - beammp_lua_error(std::string("TriggerLocalEvent: ") + Err.what()); + std::string ErrStr; + if (FnRet.get_type() == sol::type::string) { + ErrStr = FnRet.get().what(); + } else { + ErrStr = "(unknown error; error object is not inspectable)"; + } + beammp_lua_errorf("TriggerLocalEvent: {}", ErrStr); } } } @@ -560,37 +608,37 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerLocalEvent(const std::string& sol::table TLuaEngine::StateThreadData::Lua_GetPlayerIdentifiers(int ID) { auto MaybeClient = GetClient(mEngine->Server(), ID); - if (MaybeClient && !MaybeClient.value().expired()) { - auto IDs = MaybeClient.value().lock()->GetIdentifiers(); - if (IDs.empty()) { - return sol::lua_nil; - } - sol::table Result = mStateView.create_table(); - for (const auto& Pair : IDs) { - Result[Pair.first] = Pair.second; + if (MaybeClient) { + if (std::shared_ptr Locked = MaybeClient.value().lock()) { + auto IDs = Locked->GetIdentifiers(); + if (IDs.empty()) { + return sol::lua_nil; + } + sol::table Result = mStateView.create_table(); + for (const auto& Pair : IDs) { + Result.set(Pair.first, Pair.second); + } + return Result; } - return Result; - } else { - return sol::lua_nil; } + return sol::lua_nil; } std::variant TLuaEngine::StateThreadData::Lua_GetPlayerRole(int ID) { auto MaybeClient = GetClient(mEngine->Server(), ID); if (MaybeClient) { - return MaybeClient.value().lock()->GetRoles(); - } else { - return sol::nil; + if (auto Locked = MaybeClient.value().lock()) { + return Locked->GetRoles(); + } } + return sol::nil; } - sol::table TLuaEngine::StateThreadData::Lua_GetPlayers() { sol::table Result = mStateView.create_table(); mEngine->Server().ForEachClient([&](std::weak_ptr Client) -> bool { - if (!Client.expired()) { - auto locked = Client.lock(); - Result[locked->GetID()] = locked->GetName(); + if (auto Locked = Client.lock()) { + Result[Locked->GetID()] = Locked->GetName(); } return true; }); @@ -600,10 +648,9 @@ sol::table TLuaEngine::StateThreadData::Lua_GetPlayers() { int TLuaEngine::StateThreadData::Lua_GetPlayerIDByName(const std::string& Name) { int Id = -1; mEngine->mServer->ForEachClient([&Id, &Name](std::weak_ptr Client) -> bool { - if (!Client.expired()) { - auto locked = Client.lock(); - if (locked->GetName() == Name) { - Id = locked->GetID(); + if (auto Locked = Client.lock()) { + if (Locked->GetName() == Name) { + Id = Locked->GetID(); return false; } } @@ -640,60 +687,59 @@ sol::table TLuaEngine::StateThreadData::Lua_FS_ListDirectories(const std::string std::string TLuaEngine::StateThreadData::Lua_GetPlayerName(int ID) { auto MaybeClient = GetClient(mEngine->Server(), ID); - if (MaybeClient && !MaybeClient.value().expired()) { - return MaybeClient.value().lock()->GetName(); - } else { - return ""; + if (MaybeClient) { + if (auto Locked = MaybeClient.value().lock()) { + return Locked->GetName(); + } } + return ""; } sol::table TLuaEngine::StateThreadData::Lua_GetPlayerVehicles(int ID) { auto MaybeClient = GetClient(mEngine->Server(), ID); - if (MaybeClient && !MaybeClient.value().expired()) { - auto Client = MaybeClient.value().lock(); - TClient::TSetOfVehicleData VehicleData; - { // Vehicle Data Lock Scope - auto LockedData = Client->GetAllCars(); - VehicleData = *LockedData.VehicleData; - } // End Vehicle Data Lock Scope - if (VehicleData.empty()) { - return sol::lua_nil; - } - sol::state_view StateView(mState); - sol::table Result = StateView.create_table(); - for (const auto& v : VehicleData) { - Result[v.ID()] = v.DataAsPacket(Client->GetRoles(), Client->GetName(), Client->GetID()).substr(3); + if (MaybeClient) { + if (auto Client = MaybeClient.value().lock()) { + TClient::TSetOfVehicleData VehicleData; + { // Vehicle Data Lock Scope + auto LockedData = Client->GetAllCars(); + VehicleData = *LockedData.VehicleData; + } // End Vehicle Data Lock Scope + if (VehicleData.empty()) { + return sol::lua_nil; + } + sol::state_view StateView(mState); + sol::table Result = StateView.create_table(); + for (const auto& v : VehicleData) { + Result[v.ID()] = v.DataAsPacket(Client->GetRoles(), Client->GetName(), Client->GetID()).substr(3); + } + return Result; } - return Result; - } else - return sol::lua_nil; + } + return sol::lua_nil; } std::pair TLuaEngine::StateThreadData::Lua_GetPositionRaw(int PID, int VID) { std::pair Result; auto MaybeClient = GetClient(mEngine->Server(), PID); - if (MaybeClient && !MaybeClient.value().expired()) { - auto Client = MaybeClient.value().lock(); - std::string VehiclePos = Client->GetCarPositionRaw(VID); + if (MaybeClient) { + if (auto Client = MaybeClient.value().lock()) { + std::string VehiclePos = Client->GetCarPositionRaw(VID); - if (VehiclePos.empty()) { - // return std::make_tuple(sol::lua_nil, sol::make_object(StateView, "Vehicle not found")); - Result.second = "Vehicle not found"; - return Result; - } + if (VehiclePos.empty()) { + Result.second = "Vehicle not found"; + return Result; + } - sol::table t = Lua_JsonDecode(VehiclePos); - if (t == sol::lua_nil) { - Result.second = "Packet decode failed"; + sol::table t = Lua_JsonDecode(VehiclePos); + if (t == sol::lua_nil) { + Result.second = "Packet decode failed"; + } + Result.first = t; + return Result; } - // return std::make_tuple(Result, sol::make_object(StateView, sol::lua_nil)); - Result.first = t; - return Result; - } else { - // return std::make_tuple(sol::lua_nil, sol::make_object(StateView, "Client expired")); - Result.second = "No such player"; - return Result; } + Result.second = "No such player"; + return Result; } sol::table TLuaEngine::StateThreadData::Lua_HttpCreateConnection(const std::string& host, uint16_t port) { @@ -734,22 +780,31 @@ static bool mDisableMPSet = [] { static auto GetSettingName = [](int id) -> const char* { switch (id) { - case 0: return "Debug"; - case 1: return "Private"; - case 2: return "MaxCars"; - case 3: return "MaxPlayers"; - case 4: return "Map"; - case 5: return "Name"; - case 6: return "Description"; - case 7: return "InformationPacket"; - default: return "Unknown"; + case 0: + return "Debug"; + case 1: + return "Private"; + case 2: + return "MaxCars"; + case 3: + return "MaxPlayers"; + case 4: + return "Map"; + case 5: + return "Name"; + case 6: + return "Description"; + case 7: + return "InformationPacket"; + default: + return "Unknown"; } }; static void JsonDecodeRecursive(sol::state_view& StateView, sol::table& table, const std::string& left, const nlohmann::json& right) { switch (right.type()) { case nlohmann::detail::value_t::null: - AddToTable(table, left, sol::lua_nil_t {}); + AddToTable(table, left, sol::lua_nil_t { }); return; case nlohmann::detail::value_t::object: { auto value = table.create(); @@ -897,12 +952,9 @@ TLuaEngine::StateThreadData::StateThreadData(const std::string& Name, TLuaStateI beammp_lua_error("SendNotification expects 2, 3 or 4 arguments."); } }); - MPTable.set_function("ConfirmationDialog", sol::overload( - &LuaAPI::MP::ConfirmationDialog, - [&](const int& ID, const std::string& Title, const std::string& Body, const sol::table& Buttons, const std::string& InteractionID) { - LuaAPI::MP::ConfirmationDialog(ID, Title, Body, Buttons, InteractionID); - } - )); + MPTable.set_function("ConfirmationDialog", sol::overload(&LuaAPI::MP::ConfirmationDialog, [&](const int& ID, const std::string& Title, const std::string& Body, const sol::table& Buttons, const std::string& InteractionID) { + LuaAPI::MP::ConfirmationDialog(ID, Title, Body, Buttons, InteractionID); + })); MPTable.set_function("GetPlayers", [&]() -> sol::table { return Lua_GetPlayers(); }); @@ -1075,13 +1127,15 @@ TLuaEngine::StateThreadData::StateThreadData(const std::string& Name, TLuaStateI std::shared_ptr TLuaEngine::StateThreadData::EnqueueScript(const TLuaChunk& Script) { std::unique_lock Lock(mStateExecuteQueueMutex); - auto Result = std::make_shared(); + // explicitly passing empty string as there's no single function being called here + auto Result = std::make_shared(mStateId, std::string()); mStateExecuteQueue.push({ Script, Result }); return Result; } std::shared_ptr TLuaEngine::StateThreadData::EnqueueFunctionCallFromCustomEvent(const std::string& FunctionName, const std::vector& Args, const std::string& EventName, CallStrategy Strategy) { // TODO: Document all this + std::unique_lock Lock(mStateFunctionQueueMutex); decltype(mStateFunctionQueue)::iterator Iter = mStateFunctionQueue.end(); if (Strategy == CallStrategy::BestEffort) { Iter = std::find_if(mStateFunctionQueue.begin(), mStateFunctionQueue.end(), @@ -1090,10 +1144,7 @@ std::shared_ptr TLuaEngine::StateThreadData::EnqueueFunctionCallFrom }); } if (Iter == mStateFunctionQueue.end()) { - auto Result = std::make_shared(); - Result->StateId = mStateId; - Result->Function = FunctionName; - std::unique_lock Lock(mStateFunctionQueueMutex); + auto Result = std::make_shared(mStateId, FunctionName); mStateFunctionQueue.push_back({ FunctionName, Result, Args, EventName }); mStateFunctionQueueCond.notify_all(); return Result; @@ -1103,9 +1154,7 @@ std::shared_ptr TLuaEngine::StateThreadData::EnqueueFunctionCallFrom } std::shared_ptr TLuaEngine::StateThreadData::EnqueueFunctionCall(const std::string& FunctionName, const std::vector& Args, const std::string& EventName) { - auto Result = std::make_shared(); - Result->StateId = mStateId; - Result->Function = FunctionName; + auto Result = std::make_shared(mStateId, FunctionName); std::unique_lock Lock(mStateFunctionQueueMutex); mStateFunctionQueue.push_back({ FunctionName, Result, Args, EventName }); mStateFunctionQueueCond.notify_all(); @@ -1154,14 +1203,14 @@ void TLuaEngine::StateThreadData::operator()() { sol::state_view StateView(mState); auto Res = StateView.safe_script(*S.first.Content, sol::script_pass_on_error, S.first.FileName); if (Res.valid()) { - S.second->Error = false; - S.second->Result = std::move(Res); + try { + S.second->MarkReadySuccess(std::move(Res)); + } catch (const std::exception& e) { + S.second->MarkReadyError(fmt::format("Call was successful, but result could not be serialized")); + } } else { - S.second->Error = true; - sol::error Err = Res; - S.second->ErrorMessage = Err.what(); + S.second->MarkReadyError(std::move(Res)); } - S.second->MarkAsReady(); } } { // StateFunctionQueue Scope @@ -1178,7 +1227,7 @@ void TLuaEngine::StateThreadData::operator()() { auto& Result = TheQueuedFunction.Result; auto Args = TheQueuedFunction.Args; // TODO: Use TheQueuedFunction.EventName for errors, warnings, etc - Result->StateId = mStateId; + Result->SetOwnerState(mStateId); sol::state_view StateView(mState); auto Fn = StateView[FnName]; if (Fn.valid() && Fn.get_type() == sol::type::function) { @@ -1187,49 +1236,45 @@ void TLuaEngine::StateThreadData::operator()() { if (Arg.valueless_by_exception()) { continue; } - switch (Arg.index()) { - case TLuaType::String: - LuaArgs.push_back(sol::make_object(StateView, std::get(Arg))); - break; - case TLuaType::Int: - LuaArgs.push_back(sol::make_object(StateView, std::get(Arg))); - break; - case TLuaType::Json: { - auto LocalArgs = JsonStringToArray(std::get(Arg)); - LuaArgs.insert(LuaArgs.end(), LocalArgs.begin(), LocalArgs.end()); - break; - } - case TLuaType::Bool: - LuaArgs.push_back(sol::make_object(StateView, std::get(Arg))); - break; - case TLuaType::StringStringMap: { - auto Map = std::get>(Arg); - auto Table = StateView.create_table(); - for (const auto& [k, v] : Map) { - Table[k] = v; + std::visit([&LuaArgs, &StateView, this](const auto& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + LuaArgs.push_back(sol::make_object(StateView, arg)); + } else if constexpr (std::is_same_v) { + LuaArgs.push_back(sol::make_object(StateView, arg)); + } else if constexpr (std::is_same_v) { + LuaArgs.push_back(sol::make_object(StateView, arg)); + } else if constexpr (std::is_same_v) { + auto LocalArgs = JsonStringToArray(arg); + LuaArgs.insert(LuaArgs.end(), LocalArgs.begin(), LocalArgs.end()); + } else if constexpr (std::is_same_v) { + beammp_lua_error("Unknown argument type, passed as nil"); + LuaArgs.push_back(sol::lua_nil_t()); + } else if constexpr (std::is_same_v>) { + auto Table = StateView.create_table(); + for (const auto& [k, v] : arg) { + Table[k] = v; + } + LuaArgs.push_back(Table); + } else if constexpr (std::is_same_v) { + LuaArgs.push_back(sol::make_object(StateView, arg)); + } else { + static_assert(AlwaysFalseV, "unhandled variant"); } - LuaArgs.push_back(sol::make_object(StateView, Table)); - break; - } - default: - beammp_error("Unknown argument type, passed as nil"); - break; - } + }, Arg); } auto Res = Fn(sol::as_args(LuaArgs)); if (Res.valid()) { - Result->Error = false; - Result->Result = std::move(Res); + try { + Result->MarkReadySuccess(std::move(Res)); + } catch (const std::exception& e) { + Result->MarkReadyError(fmt::format("Call was successful, but result could not be serialized")); + } } else { - Result->Error = true; - sol::error Err = Res; - Result->ErrorMessage = Err.what(); + Result->MarkReadyError(std::move(Res)); } - Result->MarkAsReady(); } else { - Result->Error = true; - Result->ErrorMessage = BeamMPFnNotFoundError; // special error kind that we can ignore later - Result->MarkAsReady(); + Result->MarkReadyError(BeamMPFnNotFoundError); } auto ProfEnd = prof::now(); auto ProfDuration = prof::duration(ProfStart, ProfEnd); @@ -1282,21 +1327,6 @@ void TLuaEngine::StateThreadData::AddPath(const fs::path& Path) { mPaths.push(Path); } -void TLuaResult::MarkAsReady() { - { - std::lock_guard readyLock(*this->ReadyMutex); - this->Ready = true; - } - this->ReadyCondition->notify_all(); -} - -void TLuaResult::WaitUntilReady() { - std::unique_lock readyLock(*this->ReadyMutex); - // wait if not ready yet - if (!this->Ready) - this->ReadyCondition->wait(readyLock); -} - TLuaChunk::TLuaChunk(std::shared_ptr Content, std::string FileName, std::string PluginPath) : Content(Content) , FileName(FileName) @@ -1311,3 +1341,79 @@ bool TLuaEngine::TimedEvent::Expired() { void TLuaEngine::TimedEvent::Reset() { LastCompletion = std::chrono::high_resolution_clock::now(); } + +TEST_CASE("TLuaEngine ctor & dtor") { + Application::Settings.set(Settings::Key::General_ResourceFolder, "beammp_server_test_resources"); + TLuaEngine engine; + + const TLuaStateId StateId = "lua_event_contract_test"; + engine.EnsureStateExists(StateId, "LuaEventContractTest", true); + + // LLM generated test code + auto Script = std::make_shared(R"( +function onPlayerAuth(playerName, playerRole, isGuest, identifiers) + if type(playerName) ~= "string" then return "on:bad-playerName-type:" .. type(playerName) end + if type(playerRole) ~= "string" then return "on:bad-playerRole-type:" .. type(playerRole) end + if type(isGuest) ~= "boolean" then return "on:bad-isGuest-type:" .. type(isGuest) end + if type(identifiers) ~= "table" then return "on:bad-identifiers-type:" .. type(identifiers) end + return "on:" .. playerName .. ":" .. playerRole .. ":" .. tostring(isGuest) .. ":" .. tostring(identifiers.ip) .. ":" .. tostring(identifiers.beammp) +end + +function postPlayerAuth(isDenied, reason, playerName, playerRole, isGuest, identifiers) + if type(isDenied) ~= "boolean" then return "post:bad-isDenied-type:" .. type(isDenied) end + if type(reason) ~= "string" then return "post:bad-reason-type:" .. type(reason) end + if type(playerName) ~= "string" then return "post:bad-playerName-type:" .. type(playerName) end + if type(playerRole) ~= "string" then return "post:bad-playerRole-type:" .. type(playerRole) end + if type(isGuest) ~= "boolean" then return "post:bad-isGuest-type:" .. type(isGuest) end + if type(identifiers) ~= "table" then return "post:bad-identifiers-type:" .. type(identifiers) end + return "post:" .. tostring(isDenied) .. ":" .. reason .. ":" .. playerName .. ":" .. playerRole .. ":" .. tostring(isGuest) .. ":" .. tostring(identifiers.ip) .. ":" .. tostring(identifiers.beammp) +end + +MP.RegisterEvent("onPlayerAuth", "onPlayerAuth") +MP.RegisterEvent("postPlayerAuth", "postPlayerAuth") +)"); + + auto LoadResult = engine.EnqueueScript(StateId, TLuaChunk(Script, "event_contract.lua", "beammp_server_test_resources/Server/LuaEventContractTest")); + LoadResult->WaitUntilReady(); + auto LoadSnapshot = LoadResult->GetDetachedSnapshot(); + CHECK(!LoadSnapshot.Error); + + const std::unordered_map Identifiers { + {"ip", "410.0.24.1"}, + {"beammp", "123456"}, + }; + + auto OnPlayerAuthResults = engine.TriggerEvent( + "onPlayerAuth", "", + std::string("guest8133569"), + std::string("USER"), + true, + Identifiers); + REQUIRE(OnPlayerAuthResults.size() == 1); + TLuaEngine::WaitForAll(OnPlayerAuthResults); + + auto OnPlayerAuthSnapshot = OnPlayerAuthResults.front()->GetDetachedSnapshot(); + CHECK(!OnPlayerAuthSnapshot.Error); + const auto* OnPlayerAuthValue = std::get_if(&OnPlayerAuthSnapshot.Result.V); + REQUIRE(OnPlayerAuthValue != nullptr); + CHECK(*OnPlayerAuthValue == "on:guest8133569:USER:true:410.0.24.1:123456"); + + auto PostPlayerAuthResults = engine.TriggerEvent( + "postPlayerAuth", "", + false, + std::string(""), + std::string("guest8133569"), + std::string("USER"), + true, + Identifiers); + REQUIRE(PostPlayerAuthResults.size() == 1); + TLuaEngine::WaitForAll(PostPlayerAuthResults); + + auto PostPlayerAuthSnapshot = PostPlayerAuthResults.front()->GetDetachedSnapshot(); + CHECK(!PostPlayerAuthSnapshot.Error); + const auto* PostPlayerAuthValue = std::get_if(&PostPlayerAuthSnapshot.Result.V); + REQUIRE(PostPlayerAuthValue != nullptr); + CHECK(*PostPlayerAuthValue == "post:false::guest8133569:USER:true:410.0.24.1:123456"); + + Application::GracefullyShutdown(); +} diff --git a/src/TLuaPlugin.cpp b/src/TLuaPlugin.cpp index 50620be4..0d558ac1 100644 --- a/src/TLuaPlugin.cpp +++ b/src/TLuaPlugin.cpp @@ -63,8 +63,9 @@ TLuaPlugin::TLuaPlugin(TLuaEngine& Engine, const TLuaPluginConfig& Config, const } for (auto& Result : ResultsToCheck) { Result.second->WaitUntilReady(); - if (Result.second->Error) { - beammp_lua_error("Failed: \"" + Result.first.string() + "\": " + Result.second->ErrorMessage); + auto Snapshot = Result.second->GetDetachedSnapshot(); + if (Snapshot.Error) { + beammp_lua_error("Failed: \"" + Result.first.string() + "\": " + Snapshot.ErrorMessage); } } } diff --git a/src/TLuaResult.cpp b/src/TLuaResult.cpp new file mode 100644 index 00000000..39f4af64 --- /dev/null +++ b/src/TLuaResult.cpp @@ -0,0 +1,272 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2026 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +#include "TLuaResult.h" +#include +#include +#include +#include +#include + + +std::ostream& operator<<(std::ostream& os, const TDetachedLuaValue& value) { + std::visit([&os](auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + size_t i = 0; + for (const auto& val : arg) { + if (i > 0) { + os << ", "; + } + os << val; + ++i; + } + } else if constexpr (std::is_same_v) { + size_t i = 0; + for (const auto& [key, val] : arg) { + if (i > 0) { + os << ", "; + } + os << key << "=" << val; + ++i; + } + } else if constexpr (std::is_same_v) + os << (arg ? "true" : "false"); + else if constexpr (std::is_same_v) + os << arg; + else if constexpr (std::is_same_v) + os << arg; + else if constexpr (std::is_same_v) + os << arg; + else if constexpr (std::is_same_v) + // monostate means no result value + os << ""; + else + static_assert(AlwaysFalseV, "non-exhaustive visitor!"); + }, + value.V); + + return os; +} + + +void TLuaResult::MarkReadySuccess(sol::object Res) { + std::unique_lock Lock(mMutex); + mError = false; + mResult = Res; + mDetachedResult = Freeze(Res); + + MarkAsReady(); +} + +void TLuaResult::MarkReadyError(sol::protected_function_result Res) { + std::unique_lock Lock(mMutex); + mError = true; + SetErrorMessageFromResult(Res); + + MarkAsReady(); +} + +void TLuaResult::MarkReadyError(std::string Res) { + std::unique_lock Lock(mMutex); + mError = true; + mErrorMessage = std::move(Res); + + MarkAsReady(); +} + +bool TLuaResult::IsReady() const { + std::unique_lock Lock(mMutex); + return mReady; +} +bool TLuaResult::IsError() const { + std::unique_lock Lock(mMutex); + return mError; +} + +TLuaStateId TLuaResult::OwnerState() const { + std::unique_lock Lock(mMutex); + // copy + return mStateId; +} +void TLuaResult::SetOwnerState(TLuaStateId StateId) { + std::unique_lock Lock(mMutex); + mStateId = std::move(StateId); +} + +TLuaResult::Snapshot TLuaResult::GetSnapshot(TLuaStateId CallingState) const { + std::unique_lock Lock(mMutex); + if (CallingState != mStateId) { + throw std::logic_error("Tried to get snapshot from non-owning state (use a detached snapshot instead)"); + } + Snapshot snapshot { + .Error = mError, + .Ready = mReady, + .ErrorMessage = mErrorMessage, + .Result = mResult, + .StateId = mStateId, + .Function = mFunction, + }; + return snapshot; +} +TLuaResult::DetachedSnapshot TLuaResult::GetDetachedSnapshot() const { + std::unique_lock Lock(mMutex); + DetachedSnapshot snapshot { + .Error = mError, + .Ready = mReady, + .ErrorMessage = mErrorMessage, + .Result = mDetachedResult, + .StateId = mStateId, + .Function = mFunction, + }; + return snapshot; +} + +TDetachedLuaValue TLuaResult::Freeze(const sol::object& o, int depth) { + if (depth > 64) + throw std::runtime_error("max depth (64) reached"); + switch (o.get_type()) { + case sol::type::lua_nil: + return { { std::monostate { } } }; + case sol::type::boolean: + return { { o.as() } }; + case sol::type::number: { + if (o.is()) { + return { { o.as() } }; + } else { + return { { o.as() } }; + } + } + case sol::type::string: + return { { o.as() } }; + case sol::type::table: { + TDetachedLuaValue::Object out; + for (auto&& [k, v] : o.as()) { + if (!k.is()) + continue; // no numeric-key handling, don't need it + out.emplace(k.as(), std::make_shared(std::move(Freeze(v, depth + 1)))); + } + return { { std::move(out) } }; + } + default: + throw std::runtime_error("unsupported Lua type for cross-thread snapshot"); + } +} +void TLuaResult::MarkAsReady() { + mReady = true; + mReadyCondition.notify_all(); +} + +void TLuaResult::WaitUntilReady() { + std::unique_lock Lock(mMutex); + while (!mReady) + mReadyCondition.wait_for(Lock, std::chrono::milliseconds(50), + [this] { + return mReady; + }); +} + +TEST_CASE("TLuaResult MarkReadyError(string) marks ready and wakes waiters") { + TLuaResult result("state_a", "fn_a"); + std::atomic waiterDone { false }; + + auto waiter = std::thread([&] { + result.WaitUntilReady(); + waiterDone.store(true, std::memory_order_release); + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + CHECK_FALSE(waiterDone.load(std::memory_order_acquire)); + + result.MarkReadyError(std::string("boom")); + waiter.join(); + + CHECK(result.IsReady()); + const auto snapshot = result.GetDetachedSnapshot(); + CHECK(snapshot.Ready); + CHECK(snapshot.Error); + CHECK(snapshot.ErrorMessage == "boom"); + CHECK(snapshot.StateId == "state_a"); + CHECK(snapshot.Function == "fn_a"); +} + +TEST_CASE("TLuaResult GetSnapshot enforces owner state id") { + sol::state lua; + TLuaResult result("owner_state", "fn_owner"); + lua.open_libraries(sol::lib::base); + result.MarkReadySuccess(sol::make_object(lua.lua_state(), std::string("ok"))); + + CHECK_NOTHROW(result.GetSnapshot("owner_state")); + CHECK_THROWS_AS(result.GetSnapshot("different_state"), std::logic_error); +} + +TEST_CASE("TLuaResult detached snapshot freezes nested string-keyed tables") { + sol::state lua; + TLuaResult result("state_table", "fn_table"); + lua.open_libraries(sol::lib::base); + + auto outer = lua.create_table(); + auto inner = lua.create_table(); + inner["k"] = std::string("v"); + + outer["flag"] = true; + outer["msg"] = std::string("hello"); + outer["inner"] = inner; + outer[1] = std::string("ignored_numeric_key"); + + result.MarkReadySuccess(sol::make_object(lua.lua_state(), outer)); + const auto detached = result.GetDetachedSnapshot(); + + CHECK(detached.Ready); + CHECK_FALSE(detached.Error); + const auto* object = std::get_if(&detached.Result.V); + REQUIRE(object != nullptr); + + CHECK(object->contains("flag")); + CHECK(object->contains("msg")); + CHECK(object->contains("inner")); + CHECK_FALSE(object->contains("1")); + + const auto* flag = std::get_if(&object->at("flag")->V); + REQUIRE(flag != nullptr); + CHECK(*flag); + + const auto* msg = std::get_if(&object->at("msg")->V); + REQUIRE(msg != nullptr); + CHECK(*msg == "hello"); + + const auto* innerObj = std::get_if(&object->at("inner")->V); + REQUIRE(innerObj != nullptr); + REQUIRE(innerObj->contains("k")); + const auto* innerValue = std::get_if(&innerObj->at("k")->V); + REQUIRE(innerValue != nullptr); + CHECK(*innerValue == "v"); +} + +TEST_CASE("TLuaResult MarkReadySuccess throws on unsupported Lua function value") { + sol::state lua; + TLuaResult result("state_fn", "fn_fn"); + lua.open_libraries(sol::lib::base); + lua["f"] = [] { return 1; }; + const sol::table globals = lua.globals(); + const sol::protected_function fn = globals.get("f"); + const sol::object fnObj = sol::make_object(lua.lua_state(), fn); + + CHECK_THROWS_AS(result.MarkReadySuccess(fnObj), std::runtime_error); + CHECK_FALSE(result.IsReady()); +} diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index d0393e89..c7aff281 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -20,6 +20,7 @@ #include "Client.h" #include "Common.h" #include "LuaAPI.h" +#include "TConnectionLimiter.h" #include "THeartbeatThread.h" #include "TLuaEngine.h" #include "TScopedTimer.h" @@ -27,20 +28,26 @@ #include #include #include +#include #include #include #include #include +#include +#include +#include #include +#include #include #include +#include #include typedef boost::asio::detail::socket_option::integer rcv_timeout_option; static constexpr uint8_t MAX_CONCURRENT_CONNECTIONS = 10; static constexpr uint8_t MAX_GLOBAL_CONNECTIONS = 128; -static constexpr uint8_t READ_TIMEOUT_S = 10; //seconds +static constexpr uint8_t READ_TIMEOUT_S = 10; // seconds std::vector StringToVector(const std::string& Str) { return std::vector(Str.data(), Str.data() + Str.size()); @@ -55,18 +62,47 @@ static void CompressProperly(std::vector& Data) { Data = CombinedData; } +// for unit-tests, otherwise unused +static bool OpenLoopbackSocketPair(io_context& IoCtx, ip::tcp::socket& ClientSocket, ip::tcp::socket& ServerSocket, boost::system::error_code& Ec) { + ip::tcp::acceptor Acceptor(IoCtx); + Acceptor.open(ip::tcp::v4(), Ec); + if (Ec) { + return true; + } + Acceptor.bind(ip::tcp::endpoint(ip::address_v4::loopback(), 0), Ec); + if (Ec) { + return true; + } + Acceptor.listen(socket_base::max_listen_connections, Ec); + if (Ec) { + return true; + } + + const auto Port = Acceptor.local_endpoint(Ec).port(); + if (Ec) { + return true; + } + ClientSocket.connect(ip::tcp::endpoint(ip::address_v4::loopback(), Port), Ec); + if (Ec) { + return true; + } + Acceptor.accept(ServerSocket, Ec); + return true; +} + TNetwork::TNetwork(TServer& Server, TPPSMonitor& PPSMonitor, TResourceManager& ResourceManager) : mServer(Server) , mPPSMonitor(PPSMonitor) , mUDPSock(Server.IoCtx()) - , mResourceManager(ResourceManager) { + , mResourceManager(ResourceManager) + , mConnectionLimiter(MAX_CONCURRENT_CONNECTIONS, MAX_GLOBAL_CONNECTIONS) { Application::SetSubsystemStatus("TCPNetwork", Application::Status::Starting); Application::SetSubsystemStatus("UDPNetwork", Application::Status::Starting); Application::RegisterShutdownHandler([&] { beammp_debug("Kicking all players due to shutdown"); Server.ForEachClient([&](std::weak_ptr client) -> bool { - if (!client.expired()) { - ClientKick(*client.lock(), "Server shutdown"); + if (auto Locked = client.lock()) { + ClientKick(*Locked, "Server shutdown"); } return true; }); @@ -125,13 +161,13 @@ void TNetwork::UDPServerMain() { + std::to_string(Application::Settings.getAsInt(Settings::Key::General_MaxPlayers)) + (" Clients")); while (!Application::IsShuttingDown()) { try { - boost::asio::ip::udp::endpoint remote_client_ep {}; + boost::asio::ip::udp::endpoint remote_client_ep { }; std::vector Data = UDPRcvFromClient(remote_client_ep); if (Data.empty()) { continue; } if (Data.size() == 1 && Data.at(0) == 'P') { - mUDPSock.send_to(boost::asio::const_buffer("P", 1), remote_client_ep, {}, ec); + mUDPSock.send_to(boost::asio::const_buffer("P", 1), remote_client_ep, { }, ec); // ignore errors (void)ec; continue; @@ -145,14 +181,14 @@ void TNetwork::UDPServerMain() { std::shared_ptr Client; { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - Client = ClientPtr.lock(); + if (auto Locked = ClientPtr.lock()) { + Client = std::move(Locked); } else return true; } if (Client->GetID() == ID) { - if (Client->GetUDPAddr() == boost::asio::ip::udp::endpoint {} && !Client->IsUDPConnected() && !Client->GetMagic().empty()) { + if (Client->GetUDPAddr() == boost::asio::ip::udp::endpoint { } && !Client->IsUDPConnected() && !Client->GetMagic().empty()) { if (Data.size() != 66) { beammp_debugf("Invalid size for UDP value. IP: {} ID: {}", remote_client_ep.address().to_string(), ID); return false; @@ -165,7 +201,7 @@ void TNetwork::UDPServerMain() { return false; } - Client->SetMagic({}); + Client->SetMagic({ }); Client->SetUDPAddr(remote_client_ep); Client->SetIsUDPConnected(true); return false; @@ -217,7 +253,7 @@ void TNetwork::TCPServerMain() { beammp_warnf("WARNING: On FreeBSD, for IPv4 to work, you must run `sysctl net.inet6.ip6.v6only=0`!"); beammp_debugf("This is due to an annoying detail in the *BSDs: In the name of security, unsetting the IPV6_V6ONLY option does not work by default (but does not fail???), as it allows IPv4 mapped IPv6 like ::ffff:127.0.0.1, which they deem a security issue. For more information, see RFC 2553, section 3.7."); #endif - socket_base::linger LingerOpt {}; + socket_base::linger LingerOpt { }; LingerOpt.enabled(false); Listener.set_option(LingerOpt, ec); if (ec) { @@ -243,28 +279,27 @@ void TNetwork::TCPServerMain() { beammp_debug("shutdown during TCP wait for accept loop"); break; } - boost::asio::ip::tcp::endpoint ClientEp; - boost::asio::ip::tcp::socket ClientSocket = Acceptor.accept(ClientEp, ec); - std::string ClientIP = ClientEp.address().to_string(); - if (!ec) { - mClientMapMutex.lock(); - if (mClientMap[ClientIP] >= MAX_CONCURRENT_CONNECTIONS) { - beammp_debugf("The connection was rejected for {}, as it had {} concurrent connections.", ClientIP, mClientMap[ClientIP]); - } - else if (mClientMap.size() >= MAX_GLOBAL_CONNECTIONS) { - beammp_debugf("The connection was rejected for {}, as there are {} global connections.", ClientIP, mClientMap.size()); - } - else { - TConnection Conn { std::move(ClientSocket), ClientEp }; - std::thread ID(&TNetwork::Identify, this, std::move(Conn)); - ID.detach(); // TODO: Add to a queue and attempt to join periodically - mClientMap[ClientIP]++; - } - mClientMapMutex.unlock(); - } - else { + boost::asio::ip::tcp::socket ClientSocket = Acceptor.accept(ec); + if (ec) { beammp_errorf("Failed to accept() new client: {}", ec.message()); + continue; + } + boost::asio::ip::tcp::endpoint ClientEp = ClientSocket.remote_endpoint(ec); + if (ec) { + beammp_errorf("Accepted socket but failed to query remote endpoint for IP address: {}", ec.message()); + continue; } + std::string ClientIP = ClientEp.address().to_string(); + auto MaybeGuard = mConnectionLimiter.TryAcquire(ClientIP); + if (!MaybeGuard.has_value()) { + beammp_debugf("Connection rejected for {} due to the global or concurrent connection limit", ClientIP); + continue; + } + // move-swap to avoid copy ctor (deleted) + auto Guard = std::move(MaybeGuard.value()); + TConnection Conn { std::move(ClientSocket), ClientEp }; + std::thread ID(&TNetwork::Identify, this, std::move(Conn), std::move(Guard)); + ID.detach(); // TODO: Add to a queue and attempt to join periodically } catch (const std::exception& e) { beammp_errorf("Exception in accept routine: {}", e.what()); } @@ -276,7 +311,7 @@ void TNetwork::TCPServerMain() { #include "Json.h" namespace json = rapidjson; -void TNetwork::Identify(TConnection&& RawConnection) { +void TNetwork::Identify(TConnection&& RawConnection, TConnectionLimiter::TGuard&& Guard) { RegisterThreadAuto(); char Code; @@ -285,17 +320,8 @@ void TNetwork::Identify(TConnection&& RawConnection) { // TODO: is this right?! beammp_debug("Error occured reading code"); RawConnection.Socket.shutdown(socket_base::shutdown_both, ec); - mClientMapMutex.lock(); - { - std::string ClientIP = RawConnection.SockAddr.address().to_string(); - if (mClientMap[ClientIP] > 0) { - mClientMap[ClientIP]--; - } - if (mClientMap[ClientIP] == 0) { - mClientMap.erase(ClientIP); - } - } - mClientMapMutex.unlock(); + // TODO: is this right too? + RawConnection.Socket.close(ec); return; } std::shared_ptr Client { nullptr }; @@ -326,17 +352,6 @@ void TNetwork::Identify(TConnection&& RawConnection) { beammp_errorf("Error during handling of code {} - client left in invalid state, closing socket: {}", Code, e.what()); boost::system::error_code ec; RawConnection.Socket.shutdown(boost::asio::socket_base::shutdown_both, ec); - mClientMapMutex.lock(); - { - std::string ClientIP = RawConnection.SockAddr.address().to_string(); - if (mClientMap[ClientIP] > 0) { - mClientMap[ClientIP]--; - } - if (mClientMap[ClientIP] == 0) { - mClientMap.erase(ClientIP); - } - } - mClientMapMutex.unlock(); if (ec) { beammp_debugf("Failed to shutdown client socket: {}", ec.message()); } @@ -347,8 +362,6 @@ void TNetwork::Identify(TConnection&& RawConnection) { } } - - std::string HashPassword(const std::string& str) { std::stringstream ret; unsigned char* hash = SHA256(reinterpret_cast(str.c_str()), str.length(), nullptr); @@ -376,7 +389,12 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { beammp_info("Identifying new ClientConnection..."); - auto Data = TCPRcv(*Client); + auto Data = TCPRcv(*Client, true); + if (Data.empty()) { + beammp_debug("Authentication failed: did not receive version packet"); + ClientKick(*Client, "Connection closed during version handshake"); + return nullptr; + } constexpr std::string_view VC = "VC"; if (Data.size() > 3 && std::equal(Data.begin(), Data.begin() + VC.size(), VC.begin(), VC.end())) { @@ -398,7 +416,12 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { // TODO: handle } - Data = TCPRcv(*Client); + Data = TCPRcv(*Client, true); + if (Data.empty()) { + beammp_debug("Authentication failed: did not receive auth key packet"); + ClientKick(*Client, "Connection closed during authentication"); + return nullptr; + } if (Data.size() > 50) { ClientKick(*Client, "Invalid Key (too long)!"); @@ -406,11 +429,15 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { } std::string Key(reinterpret_cast(Data.data()), Data.size()); + if (Key.empty()) { + ClientKick(*Client, "Invalid Key (empty)!"); + return nullptr; + } std::string AuthKey = Application::Settings.getAsString(Settings::Key::General_AuthKey); std::string ClientIp = Client->GetIdentifiers().at("ip"); - nlohmann::json AuthReq {}; - std::string AuthResStr {}; + nlohmann::json AuthReq { }; + std::string AuthResStr { }; try { AuthReq = nlohmann::json { { "key", Key }, @@ -462,8 +489,8 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { std::shared_ptr Cl; { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - Cl = ClientPtr.lock(); + if (auto Locked = ClientPtr.lock()) { + Cl = std::move(Locked); } else return true; } @@ -481,23 +508,31 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { bool BypassLimit = false; for (const auto& Result : Futures) { - if (!Result->Error && Result->Result.is()) { - auto Res = Result->Result.as(); - - if (Res == 1) { - NotAllowed = true; - break; - } else if (Res == 2) { - BypassLimit = true; + auto Snapshot = Result->GetDetachedSnapshot(); + if (!Snapshot.Error) { + const int* MaybeInt = std::get_if(&Snapshot.Result.V); + if (MaybeInt != nullptr) { + auto Res = *MaybeInt; + + if (Res == 1) { + NotAllowed = true; + break; + } else if (Res == 2) { + BypassLimit = true; + } } } } std::string Reason; bool NotAllowedWithReason = std::any_of(Futures.begin(), Futures.end(), [&Reason](const std::shared_ptr& Result) -> bool { - if (!Result->Error && Result->Result.is()) { - Reason = Result->Result.as(); - return true; + auto Snapshot = Result->GetDetachedSnapshot(); + if (!Snapshot.Error) { + const std::string* MaybeStr = std::get_if(&Snapshot.Result.V); + if (MaybeStr != nullptr) { + Reason = *MaybeStr; + return true; + } } return false; }); @@ -508,8 +543,8 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { } if (!NotAllowed && !NotAllowedWithReason && mServer.ClientCount() >= size_t(Application::Settings.getAsInt(Settings::Key::General_MaxPlayers)) && !BypassLimit) { - NotAllowedWithReason = true; - Reason = "Server full!"; + NotAllowedWithReason = true; + Reason = "Server full!"; } if (NotAllowedWithReason) { @@ -574,29 +609,33 @@ bool TNetwork::TCPSend(TClient& c, const std::vector& Data, bool IsSync return true; } -std::vector TNetwork::TCPRcv(TClient& c) { +std::vector TNetwork::TCPRcv(TClient& c, bool WithTimeout) { if (c.IsDisconnected()) { beammp_error("Client disconnected, cancelling TCPRcv"); - return {}; + return { }; } - int32_t Header {}; + int32_t Header { }; auto& Sock = c.GetTCPSock(); boost::system::error_code ec; std::array HeaderData; - boost::asio::read(Sock, boost::asio::buffer(HeaderData), ec); + if (WithTimeout) { + ec = ReadWithTimeout(Sock, HeaderData.data(), HeaderData.size(), std::chrono::seconds(READ_TIMEOUT_S)); + } else { + boost::asio::read(Sock, boost::asio::buffer(HeaderData), ec); + } if (ec) { // TODO: handle this case (read failed) beammp_debugf("TCPRcv: Reading header failed: {}", ec.message()); - return {}; + return { }; } Header = *reinterpret_cast(HeaderData.data()); if (Header < 0) { ClientKick(c, "Invalid packet - header negative"); beammp_errorf("Client {} send negative TCP header, ignoring packet", c.GetID()); - return {}; + return { }; } std::vector Data; @@ -608,13 +647,23 @@ std::vector TNetwork::TCPRcv(TClient& c) { } else { ClientKick(c, "Header size limit exceeded"); beammp_warn("Client " + c.GetName() + " (" + std::to_string(c.GetID()) + ") sent header larger than expected - assuming malicious intent and disconnecting the client."); - return {}; + return { }; + } + std::size_t N = 0; + if (WithTimeout) { + if (!Data.empty()) { + ec = ReadWithTimeout(Sock, Data.data(), Data.size(), std::chrono::seconds(READ_TIMEOUT_S)); + if (!ec) { + N = Data.size(); + } + } + } else { + N = boost::asio::read(Sock, boost::asio::buffer(Data), ec); } - auto N = boost::asio::read(Sock, boost::asio::buffer(Data), ec); if (ec) { // TODO: handle this case properly beammp_debugf("TCPRcv: Reading data failed: {}", ec.message()); - return {}; + return { }; } if (N != Header) { @@ -626,7 +675,7 @@ std::vector TNetwork::TCPRcv(TClient& c) { Data.erase(Data.begin(), Data.begin() + ABG.size()); try { return DeComp(Data); - } catch (const InvalidDataError& ) { + } catch (const InvalidDataError&) { beammp_errorf("Failed to decompress packet from a client. The receive failed and the client may be disconnected as a result"); // return empty -> error return std::vector(); @@ -648,35 +697,26 @@ void TNetwork::ClientKick(TClient& c, const std::string& R) { DisconnectClient(c, "Kicked"); } -void TNetwork::DisconnectClient(const std::weak_ptr &c, const std::string &R) -{ +void TNetwork::DisconnectClient(const std::weak_ptr& c, const std::string& R) { if (auto locked = c.lock()) { DisconnectClient(*locked, R); - } - else { + } else { beammp_debugf("Tried to disconnect a non existant client with reason: {}", R); } } -void TNetwork::DisconnectClient(TClient &c, const std::string &R) -{ - if (c.IsDisconnected()) return; - std::string ClientIP = c.GetTCPSock().remote_endpoint().address().to_string(); - mClientMapMutex.lock(); - if (mClientMap[ClientIP] > 0) { - mClientMap[ClientIP]--; - } - if (mClientMap[ClientIP] == 0) { - mClientMap.erase(ClientIP); - } - mClientMapMutex.unlock(); - c.Disconnect(R); +void TNetwork::DisconnectClient(TClient& c, const std::string& R) { + // Keep this unconditional; TClient::Disconnect() is the single-winner guard. + (void)c.Disconnect(R); } void TNetwork::Looper(const std::weak_ptr& c) { RegisterThreadAuto(); - while (!c.expired()) { + while (true) { auto Client = c.lock(); + if (!Client) { + break; + } if (Client->IsDisconnected()) { beammp_debug("client is disconnected, breaking client loop"); break; @@ -684,7 +724,7 @@ void TNetwork::Looper(const std::weak_ptr& c) { if (!Client->IsSyncing() && Client->IsSynced() && Client->MissedPacketQueueSize() != 0) { // debug("sending " + std::to_string(Client->MissedPacketQueueSize()) + " queued packets"); while (Client->MissedPacketQueueSize() > 0) { - std::vector QData {}; + std::vector QData { }; { // locked context std::unique_lock lock(Client->MissedPacketQueueMutex()); if (Client->MissedPacketQueueSize() <= 0) { @@ -710,20 +750,29 @@ void TNetwork::Looper(const std::weak_ptr& c) { } void TNetwork::TCPClient(const std::weak_ptr& c) { - // TODO: the c.expired() might cause issues here, remove if you end up here with your debugger - if (c.expired() || !c.lock()->GetTCPSock().is_open()) { + if (auto Client = c.lock()) { + if (Client->IsDisconnected()) { + mServer.RemoveClient(c); + return; + } + } else { mServer.RemoveClient(c); return; } OnConnect(c); - RegisterThread("(" + std::to_string(c.lock()->GetID()) + ") \"" + c.lock()->GetName() + "\""); + if (auto Client = c.lock()) { + RegisterThread("(" + std::to_string(Client->GetID()) + ") \"" + Client->GetName() + "\""); + } else { + return; + } - std::thread QueueSync(&TNetwork::Looper, this, c); + std::jthread QueueSync(&TNetwork::Looper, this, c); while (true) { - if (c.expired()) - break; auto Client = c.lock(); + if (!Client) { + break; + } if (Client->IsDisconnected()) { beammp_debug("client status < 0, breaking client loop"); break; @@ -744,11 +793,7 @@ void TNetwork::TCPClient(const std::weak_ptr& c) { } } - if (QueueSync.joinable()) - QueueSync.join(); - - if (!c.expired()) { - auto Client = c.lock(); + if (auto Client = c.lock()) { OnDisconnect(c); } else { beammp_warn("client expired in TCPClient, should never happen"); @@ -759,8 +804,7 @@ void TNetwork::UpdatePlayer(TClient& Client) { std::string Packet = ("Ss") + std::to_string(mServer.ClientCount()) + "/" + std::to_string(Application::Settings.getAsInt(Settings::Key::General_MaxPlayers)) + ":"; mServer.ForEachClient([&](const std::weak_ptr& ClientPtr) -> bool { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - auto c = ClientPtr.lock(); + if (auto c = ClientPtr.lock()) { Packet += c->GetName() + ","; } return true; @@ -770,43 +814,144 @@ void TNetwork::UpdatePlayer(TClient& Client) { //(void)Respond(Client, Packet, true); } -boost::system::error_code TNetwork::ReadWithTimeout(TConnection& Connection, void *Buf, size_t Len, std::chrono::steady_clock::duration Timeout) -{ - io_context TimerIO; - steady_timer Timer(TimerIO); - Timer.expires_after(Timeout); +static boost::system::error_code ReadSocketWithTimeout( + boost::asio::ip::tcp::socket& socket, + void* buffer, + std::size_t length, + std::chrono::steady_clock::duration timeout); - std::atomic TimedOut = false; +boost::system::error_code TNetwork::ReadWithTimeout(TConnection& Connection, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout) { + return ReadWithTimeout(Connection.Socket, Buf, Len, Timeout); +} - Timer.async_wait([&](const boost::system::error_code& ec) { - if (!ec) { - TimedOut = true; - Connection.Socket.cancel(); - } - }); - std::thread TimerThread([&]() { TimerIO.run(); }); +boost::system::error_code TNetwork::ReadWithTimeout(boost::asio::ip::tcp::socket& Socket, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout) { + return ReadSocketWithTimeout(Socket, Buf, Len, Timeout); +} - boost::system::error_code ReadEc; - boost::asio::read(Connection.Socket, boost::asio::buffer(Buf, Len), ReadEc); +static boost::system::error_code ReadSocketWithTimeout( + boost::asio::ip::tcp::socket& Socket, + void* Buffer, + std::size_t Length, + std::chrono::steady_clock::duration Timeout) { + namespace asio = boost::asio; + using boost::system::error_code; + + struct TTimeoutState { + explicit TTimeoutState(const boost::asio::any_io_executor& Executor) + : Timer(Executor) { } + + asio::steady_timer Timer; + std::promise> Promise; + std::atomic_bool Completed { false }; + }; + + auto State = std::make_shared(Socket.get_executor()); + auto Future = State->Promise.get_future(); + + asio::async_read( + Socket, + asio::buffer(Buffer, Length), + [State](error_code ec, std::size_t n) { + if (!State->Completed.exchange(true)) { + State->Timer.cancel(); + State->Promise.set_value({ ec, n }); + } + }); - TimerIO.stop(); - TimerThread.join(); + State->Timer.expires_after(Timeout); + State->Timer.async_wait( + [State, &Socket](error_code ec) { + if (ec == asio::error::operation_aborted) + return; - if (TimedOut.load()) { - return error::timed_out; // synthesize a clean timeout error - } - return ReadEc; //Succes! + if (!State->Completed.exchange(true)) { + error_code IgnoredEc; + Socket.cancel(IgnoredEc); + State->Promise.set_value({ asio::error::timed_out, 0 }); + } + }); + + auto [ec, NRead] = Future.get(); + return ec; +} + +TEST_CASE("ReadSocketWithTimeout returns timed_out when peer sends no data") { + TIoPollThread TimerThread; + boost::system::error_code Ec; + ip::tcp::socket ClientSocket(TimerThread.IoCtx()); + ip::tcp::socket ServerSocket(TimerThread.IoCtx()); + OpenLoopbackSocketPair(TimerThread.IoCtx(), ClientSocket, ServerSocket, Ec); + REQUIRE(!Ec); + + uint8_t ReadByte = 0; + const auto ReadEc = ReadSocketWithTimeout(ServerSocket, &ReadByte, 1, std::chrono::milliseconds(50)); + + CHECK(ReadEc == error::timed_out); +} + +TEST_CASE("ReadSocketWithTimeout reads small payload") { + TIoPollThread TimerThread; + boost::system::error_code Ec; + ip::tcp::socket ClientSocket(TimerThread.IoCtx()); + ip::tcp::socket ServerSocket(TimerThread.IoCtx()); + OpenLoopbackSocketPair(TimerThread.IoCtx(), ClientSocket, ServerSocket, Ec); + REQUIRE(!Ec); + + const std::array Sent { 'O', 'K' }; + boost::asio::write(ClientSocket, boost::asio::buffer(Sent), Ec); + REQUIRE(!Ec); + std::array Received { }; + const auto ReadEc = ReadSocketWithTimeout(ServerSocket, Received.data(), Received.size(), std::chrono::milliseconds(200)); + + CHECK(!ReadEc); + CHECK(Received == Sent); +} + +TEST_CASE("ReadSocketWithTimeout reads large payload") { + TIoPollThread TimerThread; + boost::system::error_code Ec; + ip::tcp::socket ClientSocket(TimerThread.IoCtx()); + ip::tcp::socket ServerSocket(TimerThread.IoCtx()); + OpenLoopbackSocketPair(TimerThread.IoCtx(), ClientSocket, ServerSocket, Ec); + REQUIRE(!Ec); + + constexpr size_t PacketSize = 2 * 1024 * 1024; + std::vector Sent(PacketSize, uint8_t(0x7A)); + boost::asio::write(ClientSocket, boost::asio::buffer(Sent), Ec); + REQUIRE(!Ec); + std::vector Received(PacketSize); + const auto ReadEc = ReadSocketWithTimeout(ServerSocket, Received.data(), Received.size(), std::chrono::seconds(2)); + + CHECK(!ReadEc); + CHECK(Received == Sent); +} + +TEST_CASE("ReadSocketWithTimeout can timeout then retry successfully") { + TIoPollThread TimerThread; + boost::system::error_code Ec; + ip::tcp::socket ClientSocket(TimerThread.IoCtx()); + ip::tcp::socket ServerSocket(TimerThread.IoCtx()); + OpenLoopbackSocketPair(TimerThread.IoCtx(), ClientSocket, ServerSocket, Ec); + REQUIRE(!Ec); + + uint8_t Received = 0; + CHECK(ReadSocketWithTimeout(ServerSocket, &Received, 1, std::chrono::milliseconds(20)) == error::timed_out); + + const uint8_t Sent = 0x42; + boost::asio::write(ClientSocket, boost::asio::buffer(&Sent, 1), Ec); + REQUIRE(!Ec); + const auto ReadEc = ReadSocketWithTimeout(ServerSocket, &Received, 1, std::chrono::milliseconds(200)); + + CHECK(!ReadEc); + CHECK(Received == Sent); } void TNetwork::OnDisconnect(const std::weak_ptr& ClientPtr) { - std::shared_ptr LockedClientPtr { nullptr }; - try { - LockedClientPtr = ClientPtr.lock(); - } catch (const std::exception&) { + auto LockedClientPtr = ClientPtr.lock(); + if (!LockedClientPtr) { beammp_warn("Client expired in OnDisconnect, this is unexpected"); return; } - beammp_assert(LockedClientPtr != nullptr); TClient& c = *LockedClientPtr; beammp_info(c.GetName() + (" Connection Terminated")); std::string Packet; @@ -837,8 +982,7 @@ int TNetwork::OpenID() { found = true; mServer.ForEachClient([&](const std::weak_ptr& ClientPtr) -> bool { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - auto c = ClientPtr.lock(); + if (auto c = ClientPtr.lock()) { if (c->GetID() == ID) { found = false; ID++; @@ -851,9 +995,11 @@ int TNetwork::OpenID() { } void TNetwork::OnConnect(const std::weak_ptr& c) { - beammp_assert(!c.expired()); - beammp_info("Client connected"); auto LockedClient = c.lock(); + if (!LockedClient) { + return; + } + beammp_info("Client connected"); LockedClient->SetID(OpenID()); beammp_info("Assigned ID " + std::to_string(LockedClient->GetID()) + " to " + LockedClient->GetName()); LuaAPI::MP::Engine->ReportErrors(LuaAPI::MP::Engine->TriggerEvent("onPlayerConnecting", "", LockedClient->GetID())); @@ -885,6 +1031,7 @@ void TNetwork::SyncResources(TClient& c) { while (!c.IsDisconnected()) { Data = TCPRcv(c); if (Data.empty()) { + DisconnectClient(c, "TCPRcv failed during resource sync"); break; } constexpr std::string_view Done = "Done"; @@ -960,9 +1107,9 @@ void TNetwork::SendFile(TClient& c, const std::string& UnsafeName) { #if defined(BEAMMP_LINUX) #include #include +#include #include #include -#include #endif void TNetwork::SendFileToClient(TClient& c, size_t Size, const std::string& Name) { TScopedTimer timer(fmt::format("Download of '{}' for client {}", Name, c.GetID())); @@ -977,17 +1124,40 @@ void TNetwork::SendFileToClient(TClient& c, size_t Size, const std::string& Name // native handle, needed in order to make native syscalls with it int socket = c.GetTCPSock().native_handle(); - ssize_t ret = 0; - auto ToSendTotal = Size; - auto Start = 0; - while (ret < ssize_t(ToSendTotal)) { - auto SysOffset = off_t(Start + size_t(ret)); - ret = sendfile(socket, fd, &SysOffset, ToSendTotal - size_t(ret)); - if (ret < 0) { - beammp_errorf("Failed to send mod '{}' to client {}: {}", Name, c.GetID(), std::strerror(errno)); + const auto ToSendTotal = Size; + size_t TotalSent = 0; + while (TotalSent < ToSendTotal) { + off_t SysOffset = off_t(TotalSent); + const ssize_t SentNow = sendfile(socket, fd, &SysOffset, ToSendTotal - TotalSent); + if (SentNow > 0) { + TotalSent += size_t(SentNow); + continue; + } + if (SentNow == 0) { + beammp_errorf("Failed to send mod '{}' to client {}: sendfile returned 0 before all bytes were sent", Name, c.GetID()); + ::close(fd); + DisconnectClient(c, "sendfile returned 0 during mod download"); return; } + + if (errno == EINTR) { + continue; + } + if (errno == EAGAIN +#if EWOULDBLOCK != EAGAIN + || errno == EWOULDBLOCK +#endif + ) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + + beammp_errorf("Failed to send mod '{}' to client {}: {}", Name, c.GetID(), std::strerror(errno)); + ::close(fd); + DisconnectClient(c, "sendfile failed during mod download"); + return; } + ::close(fd); #else std::ifstream f(Name.c_str(), std::ios::binary); @@ -1056,10 +1226,10 @@ bool TNetwork::Respond(TClient& c, const std::vector& MSG, bool Rel, bo } bool TNetwork::SyncClient(const std::weak_ptr& c) { - if (c.expired()) { + auto LockedClient = c.lock(); + if (!LockedClient) { return false; } - auto LockedClient = c.lock(); if (LockedClient->IsSynced()) return true; // Syncing, later set isSynced @@ -1078,8 +1248,8 @@ bool TNetwork::SyncClient(const std::weak_ptr& c) { std::shared_ptr client; { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - client = ClientPtr.lock(); + if (auto Locked = ClientPtr.lock()) { + client = std::move(Locked); } else return true; } @@ -1116,14 +1286,14 @@ void TNetwork::SendToAll(TClient* c, const std::vector& Data, bool Self char C = Data.at(0); bool ret = true; mServer.ForEachClient([&](std::weak_ptr ClientPtr) -> bool { - std::shared_ptr Client; - try { + std::shared_ptr Client { nullptr }; + { ReadLock Lock(mServer.GetClientMutex()); - Client = ClientPtr.lock(); - } catch (const std::exception&) { - // continue - beammp_warn("Client expired, shouldn't happen - if a client disconnected recently, you can ignore this"); - return true; + if (auto Locked = ClientPtr.lock()) { + Client = std::move(Locked); + } else { + return true; + } } if (Self || Client.get() != c) { if (Client->IsSynced() || Client->IsSyncing()) { @@ -1178,12 +1348,12 @@ bool TNetwork::UDPSend(TClient& Client, std::vector Data) { } std::vector TNetwork::UDPRcvFromClient(boost::asio::ip::udp::endpoint& ClientEndpoint) { - std::array Ret {}; + std::array Ret { }; boost::system::error_code ec; const auto Rcv = mUDPSock.receive_from(boost::asio::mutable_buffer(Ret.data(), Ret.size()), ClientEndpoint, 0, ec); if (ec) { beammp_errorf("UDP recvfrom() failed: {}", ec.message()); - return {}; + return { }; } beammp_assert(Rcv <= Ret.size()); return std::vector(Ret.begin(), Ret.begin() + Rcv); diff --git a/src/TPPSMonitor.cpp b/src/TPPSMonitor.cpp index ccababf7..859871e7 100644 --- a/src/TPPSMonitor.cpp +++ b/src/TPPSMonitor.cpp @@ -55,8 +55,8 @@ void TPPSMonitor::operator()() { std::shared_ptr c; { ReadLock Lock(mServer.GetClientMutex()); - if (!ClientPtr.expired()) { - c = ClientPtr.lock(); + if (auto Locked = ClientPtr.lock()) { + c = std::move(Locked); } else return true; } @@ -76,7 +76,7 @@ void TPPSMonitor::operator()() { return true; }); for (auto& ClientToKick : TimedOutClients) { - ClientToKick->Disconnect("Timeout"); + Network().DisconnectClient(*ClientToKick, "Timeout"); } TimedOutClients.clear(); if (C == 0 || mInternalPPS == 0) { diff --git a/src/TPluginMonitor.cpp b/src/TPluginMonitor.cpp index 049fd200..d766a70c 100644 --- a/src/TPluginMonitor.cpp +++ b/src/TPluginMonitor.cpp @@ -69,8 +69,9 @@ void TPluginMonitor::operator()() { auto StateID = mEngine->GetStateIDForPlugin(fs::path(Pair.first).parent_path()); auto Res = mEngine->EnqueueScript(StateID, Chunk); Res->WaitUntilReady(); - if (Res->Error) { - beammp_lua_errorf("Error while hot-reloading \"{}\": {}", Pair.first, Res->ErrorMessage); + if (Res->IsError()) { + auto Snapshot = Res->GetDetachedSnapshot(); + beammp_lua_errorf("Error while hot-reloading \"{}\": {}", Pair.first, Snapshot.ErrorMessage); } else { mEngine->ReportErrors(mEngine->TriggerLocalEvent(StateID, "onInit")); mEngine->ReportErrors(mEngine->TriggerEvent("onFileChanged", "", Pair.first)); diff --git a/src/TServer.cpp b/src/TServer.cpp index 587cef67..f4b0a452 100644 --- a/src/TServer.cpp +++ b/src/TServer.cpp @@ -127,19 +127,15 @@ TServer::TServer(const std::vector& Arguments) { } void TServer::RemoveClient(const std::weak_ptr& WeakClientPtr) { - std::shared_ptr LockedClientPtr { nullptr }; - try { - LockedClientPtr = WeakClientPtr.lock(); - } catch (const std::exception&) { - // silently fail, as there's nothing to do + auto LockedClientPtr = WeakClientPtr.lock(); + if (!LockedClientPtr) { return; } - beammp_assert(LockedClientPtr != nullptr); TClient& Client = *LockedClientPtr; beammp_debug("removing client " + Client.GetName() + " (" + std::to_string(ClientCount()) + ")"); Client.ClearCars(); WriteLock Lock(mClientsMutex); - mClients.erase(WeakClientPtr.lock()); + mClients.erase(LockedClientPtr); } void TServer::ForEachClient(const std::function)>& Fn) { @@ -167,14 +163,16 @@ void TServer::GlobalParser(const std::weak_ptr& Client, std::vectorGetID()); - Network.ClientKick(*LockedClient, "Sent invalid compressed packet (this is likely a bug on your end)"); + if (auto LockedClient = Client.lock()) { + beammp_errorf("Failed to decompress packet from client {}. The client sent invalid data and will now be disconnected.", LockedClient->GetID()); + Network.ClientKick(*LockedClient, "Sent invalid compressed packet (this is likely a bug on your end)"); + } return; } catch (const std::runtime_error& e) { - auto LockedClient = Client.lock(); - beammp_errorf("Failed to decompress packet from client {}: {}. The server might be out of RAM! The client will now be disconnected.", LockedClient->GetID(), e.what()); - Network.ClientKick(*LockedClient, "Decompression failed (likely a server-side problem)"); + if (auto LockedClient = Client.lock()) { + beammp_errorf("Failed to decompress packet from client {}: {}. The server might be out of RAM! The client will now be disconnected.", LockedClient->GetID(), e.what()); + Network.ClientKick(*LockedClient, "Decompression failed (likely a server-side problem)"); + } return; } } @@ -182,10 +180,10 @@ void TServer::GlobalParser(const std::weak_ptr& Client, std::vector& Client, std::vectorDisconnect("Failed to send ping"); + Network.DisconnectClient(*LockedClient, "Failed to send ping"); } else { Network.UpdatePlayer(*LockedClient); } @@ -265,9 +263,15 @@ void TServer::GlobalParser(const std::weak_ptr& Client, std::vectorGetName(), LockedClient->GetID(), PacketAsString.substr(PacketAsString.find(':', 3) + 1)); bool Rejected = std::any_of(Futures.begin(), Futures.end(), [](const std::shared_ptr& Elem) { - return !Elem->Error - && Elem->Result.is() - && bool(Elem->Result.as()); + auto Snapshot = Elem->GetDetachedSnapshot(); + if (Snapshot.Error) { + return false; + } + const int* MaybeInt = std::get_if(&Snapshot.Result.V); + if (MaybeInt == nullptr) { + return false; + } + return bool(*MaybeInt); }); if (!Rejected) { std::string SanitizedPacket = fmt::format("C:{}: {}", LockedClient->GetName(), Message); @@ -380,7 +384,15 @@ void TServer::ParseVehicle(TClient& c, const std::string& Pckt, TNetwork& Networ TLuaEngine::WaitForAll(Futures); bool ShouldntSpawn = std::any_of(Futures.begin(), Futures.end(), [](const std::shared_ptr& Result) { - return !Result->Error && Result->Result.is() && Result->Result.as() != 0; + auto Snapshot = Result->GetDetachedSnapshot(); + if (Snapshot.Error) { + return false; + } + const int* MaybeInt = std::get_if(&Snapshot.Result.V); + if (MaybeInt == nullptr) { + return false; + } + return *MaybeInt != 0; }); bool SpawnConfirmed = false; @@ -417,7 +429,15 @@ void TServer::ParseVehicle(TClient& c, const std::string& Pckt, TNetwork& Networ TLuaEngine::WaitForAll(Futures); bool ShouldntAllow = std::any_of(Futures.begin(), Futures.end(), [](const std::shared_ptr& Result) { - return !Result->Error && Result->Result.is() && Result->Result.as() != 0; + auto Snapshot = Result->GetDetachedSnapshot(); + if (Snapshot.Error) { + return false; + } + const int* MaybeInt = std::get_if(&Snapshot.Result.V); + if (MaybeInt == nullptr) { + return false; + } + return *MaybeInt != 0; }); auto FoundPos = Packet.find('{');