diff --git a/CMakeLists.txt b/CMakeLists.txt index 67e758d6..62496c64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,6 +31,7 @@ set(PRJ_HEADERS include/Defer.h include/Environment.h include/Http.h + include/HttpAsync.h include/IThreaded.h include/Json.h include/LuaAPI.h @@ -60,6 +61,7 @@ set(PRJ_SOURCES src/Common.cpp src/Compat.cpp src/Http.cpp + src/HttpAsync.cpp src/LuaAPI.cpp src/SignalHandling.cpp src/TConfig.cpp diff --git a/include/HttpAsync.h b/include/HttpAsync.h new file mode 100644 index 00000000..72932325 --- /dev/null +++ b/include/HttpAsync.h @@ -0,0 +1,155 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2024 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Forward declaration for WebSocket client to keep header lean +namespace httplib { + namespace ws { + class WebSocketClient; + } +} + +namespace HttpAsync { + + struct HttpResult { + enum class Type { COMPLETE, PROGRESS } type; + + uint64_t requestId; + + // Progress data + long long current = 0; + long long total = 0; + + // Response data + int status = 0; + std::string body; + std::map> headers; + }; + + class AsyncHttpProxy : public std::enable_shared_from_this { + public: + AsyncHttpProxy(std::string baseUrl, sol::table defaultHeaders); + ~AsyncHttpProxy() = default; + + // Configuration + void SetConnectTimeout(int seconds); + void SetReadTimeout(int seconds); + void VerifySSL(bool verify); + void SetDefaultHeaders(sol::table headers); + + // HTTP Methods + sol::table Get(std::string endpoint, sol::object headers, sol::function cb, sol::object prog); + sol::table Post(std::string endpoint, sol::object data, sol::object headers, sol::function cb); + sol::table Put(std::string endpoint, sol::object data, sol::object headers, sol::function cb); + sol::table Patch(std::string endpoint, sol::object data, sol::object headers, sol::function cb); + sol::table Delete(std::string endpoint, sol::object headers, sol::function cb); + sol::table Head(std::string endpoint, sol::object headers, sol::function cb); + + // File Operations + sol::table Download(std::string endpoint, std::string savePath, sol::function cb, sol::object prog); + sol::table PostFile(std::string endpoint, std::string fieldName, std::string filePath, sol::object headers, sol::function cb); + + private: + std::map PrepareHeaders(sol::object overrides); + void PreparePayload(sol::object data, sol::object overrides, std::string& outBody, std::map& outHeaders); + + std::string mBaseUrl; + std::map mDefaultHeaders; + int mConnectTimeoutSeconds = 5; + int mReadTimeoutSeconds = 30; + bool mVerifySSL = true; + }; + + enum class WSEventType { OPEN, MESSAGE, CLOSE, ERROR_EVENT }; + + struct WSEvent { + WSEventType type; + std::string payload; + int closeCode; + }; + + class AsyncWebSocket : public std::enable_shared_from_this { + public: + static sol::object Create(sol::this_state s, std::string url, sol::object headers); + + AsyncWebSocket(std::string url, sol::table headers, lua_State* state); + ~AsyncWebSocket(); + + void Connect(); + void Send(const std::string& data); + void Close(); + void VerifySSL(bool verify); + + // Lua Callback Registration + void OnOpen(sol::object cb); + void OnMessage(sol::object cb); + void OnClose(sol::object cb); + void OnError(sol::object cb); + + void ProcessEvents(); + void Abandon(); + + [[nodiscard]] lua_State* GetLuaState() const { return L; } + + private: + void PushEvent(WSEvent ev); + + std::string mUrl; + lua_State* L; + std::map mHeaders; + bool mVerifySSL = true; + + std::thread mThread; + std::atomic mIsRunning{false}; + std::atomic mAbandoned{false}; + + // Internal httplib pointer and sync + httplib::ws::WebSocketClient* mClient = nullptr; + std::mutex mClientMutex; + + // Event Queue + std::queue mEvents; + std::mutex mMutex; + + // Lua Registry References + int mOnOpenRef = LUA_REFNIL; + int mOnMessageRef = LUA_REFNIL; + int mOnCloseRef = LUA_REFNIL; + int mOnErrorRef = LUA_REFNIL; + }; + + // Module Lifecycle + void Init(); + void Shutdown(); + void Update(sol::state_view& lua); + void RegisterBindings(sol::state_view& lua); + void CleanupState(lua_State* L); + +} // namespace HttpAsync \ No newline at end of file diff --git a/include/TLuaEngine.h b/include/TLuaEngine.h index 47a5c760..d4a2a2ae 100644 --- a/include/TLuaEngine.h +++ b/include/TLuaEngine.h @@ -241,7 +241,7 @@ class TLuaEngine : public std::enable_shared_from_this, IThreaded { public: StateThreadData(const std::string& Name, TLuaStateId StateId, TLuaEngine& Engine); StateThreadData(const StateThreadData&) = delete; - virtual ~StateThreadData() noexcept { beammp_debug("\"" + mStateId + "\" destroyed"); } + virtual ~StateThreadData() noexcept; [[nodiscard]] std::shared_ptr EnqueueScript(const TLuaChunk& Script); [[nodiscard]] std::shared_ptr EnqueueFunctionCall(const std::string& FunctionName, const std::vector& Args, const std::string& EventName); [[nodiscard]] std::shared_ptr EnqueueFunctionCallFromCustomEvent(const std::string& FunctionName, const std::vector& Args, const std::string& EventName, CallStrategy Strategy); diff --git a/src/HttpAsync.cpp b/src/HttpAsync.cpp new file mode 100644 index 00000000..2f0471b9 --- /dev/null +++ b/src/HttpAsync.cpp @@ -0,0 +1,932 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2024 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +#include "HttpAsync.h" +#include "httplib.h" +#include "Common.h" +#include "LuaAPI.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +namespace HttpAsync { + +struct PendingRequest { + lua_State* L; + int callbackRef = LUA_REFNIL; + int progressRef = LUA_REFNIL; + std::atomic abandoned{false}; +}; + +struct PluginContext { + std::deque results; + std::mutex resultsMutex; + + std::vector> webSockets; + std::mutex wsMutex; + + std::chrono::steady_clock::time_point lastLimitWarning{}; +}; + +// --- Centralized Global Context --- +static struct GlobalContext { + std::atomic shuttingDown{false}; + std::unique_ptr threadPool; + std::atomic nextRequestId{1}; + + std::map> pendingRequests; + std::mutex pendingRequestsMutex; + + std::map stateRequestCount; + std::mutex limitMutex; + + std::map stateWsCount; + + std::map> pluginContexts; + std::mutex pluginContextsMutex; + + int actualPoolSize = 16; + int maxRequestsPerPlugin = 8; + int maxWsPerPlugin = 4; + int maxWsGlobal = 32; + int currentWsGlobal = 0; +} ctx; + +static const char* DEFAULT_USER_AGENT = "BeamMP-Server/1.0"; + +// --- Utilities --- + +static bool ShouldWarn(std::shared_ptr pCtx) { + auto now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(now - pCtx->lastLimitWarning).count() >= 10) { + pCtx->lastLimitWarning = now; + return true; + } + return false; +} + +static std::shared_ptr GetPluginContext(lua_State* L) { + std::lock_guard lock(ctx.pluginContextsMutex); + auto it = ctx.pluginContexts.find(L); + if (it == ctx.pluginContexts.end()) { + auto newCtx = std::make_shared(); + ctx.pluginContexts[L] = newCtx; + return newCtx; + } + return it->second; +} + + +static void ToLowerInPlace(std::string& s) { + for (char& c : s) { + if (c >= 'A' && c <= 'Z') c += 32; + } +} + +static std::string ToLower(std::string s) { + ToLowerInPlace(s); + return s; +} + +static void UnrefCallback(lua_State* L, int& ref) { + if (ref != LUA_REFNIL) { + luaL_unref(L, LUA_REGISTRYINDEX, ref); + ref = LUA_REFNIL; + } +} + +static void ReleasePendingRequest(const std::shared_ptr& info) { + if (!info) return; + UnrefCallback(info->L, info->callbackRef); + UnrefCallback(info->L, info->progressRef); +} + +static int MakeRef(sol::object obj) { + if (!obj.is()) return LUA_REFNIL; + lua_State* L = obj.lua_state(); + obj.push(); + return luaL_ref(L, LUA_REGISTRYINDEX); +} + +template +static void InvokeLuaCallback(lua_State* L, int ref, const char* errorContext, Args&&... args) { + if (ref == LUA_REFNIL) return; + lua_rawgeti(L, LUA_REGISTRYINDEX, ref); + sol::protected_function cb = sol::stack::pop(L); + if (cb.valid()) { + auto r = cb(std::forward(args)...); + if (!r.valid()) beammp_lua_errorf("%s: %s", errorContext, sol::error(r).what()); + } +} + +static void PushResult(lua_State* L, HttpResult res) { + if (ctx.shuttingDown.load()) return; + auto pCtx = GetPluginContext(L); + + std::lock_guard lock(pCtx->resultsMutex); + pCtx->results.push_back(std::move(res)); +} + +static void ExtractHeaders(const httplib::Headers& source, std::map>& dest) { + for (const auto& [k, v] : source) dest[k].push_back(v); +} + +static bool ParseUrl(const std::string& url, std::string& base, std::string& path) { + if (url.rfind("http://", 0) != 0 && url.rfind("https://", 0) != 0) return false; + + auto pos = url.find("://"); + if (pos == std::string::npos) return false; + + auto pathPos = url.find_first_of("/?#", pos + 3); + if (pathPos == std::string::npos) { + base = url; + path = "/"; + } else { + base = url.substr(0, pathPos); + path = url.substr(pathPos); + } + + if (base.length() <= pos + 3) return false; + + return true; +} + +static bool IsValidWsUrl(const std::string& url) { + return url.rfind("ws://", 0) == 0 || url.rfind("wss://", 0) == 0; +} + +// --- HTTP Implementation --- + +static sol::table CreateHandle(lua_State* L, std::shared_ptr info) { + sol::state_view lua(L); + sol::table handle = lua.create_table(); + + if (info) { + handle["Cancel"] = [info]() { + info->abandoned.store(true); + ReleasePendingRequest(info); + }; + handle["IsActive"] = [info]() { + return !info->abandoned.load(); + }; + handle["OnProgress"] = [info](sol::object func) { + if (func.is()) { + UnrefCallback(info->L, info->progressRef); + info->progressRef = MakeRef(func); + } + }; + } else { + handle["Error"] = "Rate limited or Shutdown"; + } + return handle; +} + +static bool SetupClient(const std::string& url, int connectTimeout, int readTimeout, bool verifySSL, std::unique_ptr& outClient, std::string& outPath) { + std::string base; + if (!ParseUrl(url, base, outPath)) return false; + + outClient = std::make_unique(base); + outClient->set_connection_timeout(connectTimeout, 0); + outClient->set_read_timeout(readTimeout, 0); + outClient->set_write_timeout(readTimeout, 0); + outClient->set_follow_location(true); + outClient->enable_server_certificate_verification(verifySSL); + return true; +} + +static std::shared_ptr EnqueueTask(lua_State* L, int cbRef, int progRef, std::function)> task) { + if (ctx.shuttingDown.load() || !ctx.threadPool) { + UnrefCallback(L, cbRef); + UnrefCallback(L, progRef); + return nullptr; + } + + { + std::lock_guard lock(ctx.limitMutex); + if (ctx.stateRequestCount[L] >= ctx.maxRequestsPerPlugin) { + auto pCtx = GetPluginContext(L); + if (ShouldWarn(pCtx)) { + beammp_lua_warnf("Plugin reached HTTP request limit ({}). Further requests silenced for 10s.", ctx.maxRequestsPerPlugin); + } + UnrefCallback(L, cbRef); + UnrefCallback(L, progRef); + return nullptr; + } + ctx.stateRequestCount[L]++; + } + + uint64_t reqId = ctx.nextRequestId++; + auto info = std::make_shared(); + info->L = L; + info->callbackRef = cbRef; + info->progressRef = progRef; + + { + std::lock_guard lock(ctx.pendingRequestsMutex); + ctx.pendingRequests[reqId] = info; + } + + ctx.threadPool->enqueue([reqId, info, task = std::move(task)]() { + task(reqId, info); + }); + + return info; +} + +static sol::table Dispatch(std::string method, std::string url, std::map headers, + std::string body, int connectTimeout, int readTimeout, bool verifySSL, lua_State* L, int cbRef, int progRef) { + + auto info = EnqueueTask(L, cbRef, progRef,[=, b = std::move(body), hMap = std::move(headers)] + (uint64_t reqId, std::shared_ptr pReq) { + std::string path; + std::unique_ptr cli; + + if (!SetupClient(url, connectTimeout, readTimeout, verifySSL, cli, path)) { + HttpResult res{HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Invalid URL", {}}; + PushResult(pReq->L, std::move(res)); + return; + } + + httplib::Headers h; + bool hasUA = false; + std::string cType = "application/json"; + + for (const auto&[key, val] : hMap) { + std::string kLower = ToLower(key); + if (kLower == "user-agent") hasUA = true; + if (kLower == "content-type") { + cType = val; + if (method == "POST" || method == "PUT" || method == "PATCH") continue; + } + h.emplace(key, val); + } + if (!hasUA) h.emplace("User-Agent", DEFAULT_USER_AGENT); + + auto lastProg = std::chrono::steady_clock::now(); + auto prog_func = [&](uint64_t len, uint64_t total) { + if (ctx.shuttingDown.load() || pReq->abandoned.load()) return false; + + if (pReq->progressRef != LUA_REFNIL) { + auto now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(now - lastProg).count() > 100 || len == total) { + HttpResult res{HttpResult::Type::PROGRESS, reqId, static_cast(len), static_cast(total), 0, "", {}}; + PushResult(pReq->L, std::move(res)); + lastProg = now; + } + } + return true; + }; + + httplib::Result response; + if (method == "POST") response = cli->Post(path.c_str(), h, b, cType.c_str(), prog_func); + else if (method == "PUT") response = cli->Put(path.c_str(), h, b, cType.c_str(), prog_func); + else if (method == "PATCH") response = cli->Patch(path.c_str(), h, b, cType.c_str(), prog_func); + else if (method == "DELETE") response = cli->Delete(path.c_str(), h, prog_func); + else if (method == "HEAD") response = cli->Head(path.c_str(), h); + else response = cli->Get(path.c_str(), h, prog_func); + + if (pReq->abandoned.load()) return; + + HttpResult res{HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "", {}}; + if (response) { + res.status = response->status; + res.body = std::move(response->body); + ExtractHeaders(response->headers, res.headers); + } else { + res.body = "Network Error: " + httplib::to_string(response.error()); + } + PushResult(pReq->L, std::move(res)); + }); + + return CreateHandle(L, info); +} + +// --- AsyncHttpProxy --- + +AsyncHttpProxy::AsyncHttpProxy(std::string baseUrl, sol::table defaultHeaders) : mBaseUrl(std::move(baseUrl)) { + SetDefaultHeaders(defaultHeaders); +} + +void AsyncHttpProxy::SetConnectTimeout(int seconds) { mConnectTimeoutSeconds = seconds; } +void AsyncHttpProxy::SetReadTimeout(int seconds) { mReadTimeoutSeconds = seconds; } +void AsyncHttpProxy::VerifySSL(bool verify) { mVerifySSL = verify; } + +void AsyncHttpProxy::SetDefaultHeaders(sol::table headers) { + mDefaultHeaders.clear(); + if (headers != sol::lua_nil && headers.valid()) { + for (auto const& pair : headers) { + if (pair.first.is() && pair.second.is()) + mDefaultHeaders[pair.first.as()] = pair.second.as(); + } + } +} + +std::map AsyncHttpProxy::PrepareHeaders(sol::object overrides) { + auto finalHeaders = mDefaultHeaders; + if (overrides.is()) { + for (auto const& pair : overrides.as()) { + if (pair.first.is() && pair.second.is()) + finalHeaders[pair.first.as()] = pair.second.as(); + } + } + return finalHeaders; +} + +void AsyncHttpProxy::PreparePayload(sol::object data, sol::object overrides, std::string& outBody, std::map& outHeaders) { + outHeaders = PrepareHeaders(overrides); + + bool hasCT = false; + for (const auto& [k, v] : outHeaders) { + if (ToLower(k) == "content-type") hasCT = true; + } + + if (data.is()) { + outBody = LuaAPI::MP::JsonEncode(data.as()); + if (!hasCT) outHeaders["Content-Type"] = "application/json"; + } else { + outBody = data.is() ? data.as() : ""; + } +} + +sol::table AsyncHttpProxy::Get(std::string ep, sol::object h, sol::function cb, sol::object prog) { + return Dispatch("GET", mBaseUrl + ep, PrepareHeaders(h), "", mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), MakeRef(prog)); +} + +sol::table AsyncHttpProxy::Post(std::string ep, sol::object data, sol::object h, sol::function cb) { + std::string body; std::map headers; + PreparePayload(data, h, body, headers); + return Dispatch("POST", mBaseUrl + ep, headers, std::move(body), mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); +} + +sol::table AsyncHttpProxy::Put(std::string ep, sol::object data, sol::object h, sol::function cb) { + std::string body; std::map headers; + PreparePayload(data, h, body, headers); + return Dispatch("PUT", mBaseUrl + ep, headers, std::move(body), mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); +} + +sol::table AsyncHttpProxy::Patch(std::string ep, sol::object data, sol::object h, sol::function cb) { + std::string body; std::map headers; + PreparePayload(data, h, body, headers); + return Dispatch("PATCH", mBaseUrl + ep, headers, std::move(body), mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); +} + +sol::table AsyncHttpProxy::Delete(std::string ep, sol::object h, sol::function cb) { + return Dispatch("DELETE", mBaseUrl + ep, PrepareHeaders(h), "", mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); +} + +sol::table AsyncHttpProxy::Head(std::string ep, sol::object h, sol::function cb) { + return Dispatch("HEAD", mBaseUrl + ep, PrepareHeaders(h), "", mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); +} + +sol::table AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std::string filePath, sol::object headers, sol::function cb) { + auto hMap = PrepareHeaders(headers); + auto info = EnqueueTask(cb.lua_state(), MakeRef(cb), LUA_REFNIL, + [this, ep, fieldName, filePath, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr pReq) mutable { + std::string path; + std::unique_ptr cli; + + if (!SetupClient(mBaseUrl + ep, mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cli, path)) { + PushResult(pReq->L, {HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Invalid URL", {}}); + return; + } + + if (!fs::exists(filePath)) { + PushResult(pReq->L, {HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "File not found", {}}); + return; + } + + auto file_stream = std::make_shared(filePath, std::ios::binary); + if (!file_stream || !file_stream->is_open()) { + PushResult(pReq->L, {HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Could not open file", {}}); + return; + } + + httplib::UploadFormDataItems regular_items; + httplib::FormDataProviderItems provider_items = { + { + fieldName, + [file_stream, pReq](size_t offset, httplib::DataSink &sink) { + if (pReq->abandoned.load()) return false; + + if (static_cast(file_stream->tellg()) != offset) { + file_stream->clear(); + file_stream->seekg(static_cast(offset), std::ios::beg); + } + + char buffer[8192]; + file_stream->read(buffer, sizeof(buffer)); + std::streamsize read_bytes = file_stream->gcount(); + if (read_bytes > 0) sink.write(buffer, static_cast(read_bytes)); + if (file_stream->eof()) sink.done(); + return true; + }, + fs::path(filePath).filename().string(), + "application/octet-stream" + } + }; + + httplib::Headers finalH; + bool hasUA = false; + for (const auto& [key, val] : hMap) { + std::string kLower = ToLower(key); + if (kLower == "user-agent") hasUA = true; + if (kLower == "content-type") continue; + finalH.emplace(key, val); + } + if (!hasUA) finalH.emplace("User-Agent", DEFAULT_USER_AGENT); + + auto response = cli->Post(path.c_str(), finalH, regular_items, provider_items); + HttpResult res{HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "", {}}; + + if (response) { + res.status = response->status; + res.body = std::move(response->body); + ExtractHeaders(response->headers, res.headers); + } else { + res.body = "Upload Failed: " + httplib::to_string(response.error()); + } + PushResult(pReq->L, std::move(res)); + }); + + return CreateHandle(cb.lua_state(), info); +} + +sol::table AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::function cb, sol::object prog) { + auto hMap = PrepareHeaders(sol::lua_nil); + auto info = EnqueueTask(cb.lua_state(), MakeRef(cb), MakeRef(prog), + [this, ep, savePath, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr pReq) mutable { + std::string path; + std::unique_ptr cli; + + if (!SetupClient(mBaseUrl + ep, mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cli, path)) { + PushResult(pReq->L, {HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Invalid URL", {}}); + return; + } + + std::ofstream ofs(savePath, std::ios::binary); + if (!ofs) { + PushResult(pReq->L, {HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Could not open file for writing", {}}); + return; + } + + httplib::Headers finalH; + bool hasUA = false; + for (const auto&[key, val] : hMap) { + if (ToLower(key) == "user-agent") hasUA = true; + finalH.emplace(key, val); + } + if (!hasUA) finalH.emplace("User-Agent", DEFAULT_USER_AGENT); + + int status_code = 0; + std::map> resHeaders; + auto lastProg = std::chrono::steady_clock::now(); + + auto res = cli->Get(path.c_str(), finalH, + [&](const httplib::Response &r) { + status_code = r.status; + ExtractHeaders(r.headers, resHeaders); + return !ctx.shuttingDown.load() && !pReq->abandoned.load(); + }, + [&](const char *b, size_t l) { + if (ctx.shuttingDown.load() || pReq->abandoned.load()) return false; + ofs.write(b, static_cast(l)); + return true; + }, + [&](uint64_t len, uint64_t total) { + if (pReq->progressRef != LUA_REFNIL && !pReq->abandoned.load()) { + auto now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(now - lastProg).count() > 100 || len == total) { + PushResult(pReq->L, {HttpResult::Type::PROGRESS, reqId, static_cast(len), static_cast(total), 0, "", {}}); + lastProg = now; + } + } + return true; + } + ); + ofs.close(); + + HttpResult fres{HttpResult::Type::COMPLETE, reqId, 0, 0, status_code, "", std::move(resHeaders)}; + fres.body = res ? "Success" : "Download Failed: " + httplib::to_string(res.error()); + PushResult(pReq->L, std::move(fres)); + }); + + return CreateHandle(cb.lua_state(), info); +} + + +// --- AsyncWebSocket --- + +AsyncWebSocket::AsyncWebSocket(std::string url, sol::table headers, lua_State* state) + : mUrl(std::move(url)), L(state) { + if (headers.valid()) { + for (auto const& pair : headers) { + if (pair.first.is() && pair.second.is()) { + mHeaders.emplace(pair.first.as(), pair.second.as()); + } + } + } +} + +AsyncWebSocket::~AsyncWebSocket() { + Abandon(); + if (mThread.joinable()) mThread.detach(); +} + +sol::object AsyncWebSocket::Create(sol::this_state s, std::string url, sol::object headers) { + lua_State* L = s.lua_state(); + + { + std::lock_guard lock(ctx.limitMutex); + if (ctx.currentWsGlobal >= ctx.maxWsGlobal || ctx.stateWsCount[L] >= ctx.maxWsPerPlugin) { + auto pCtx = GetPluginContext(L); + if (ShouldWarn(pCtx)) { + beammp_lua_warnf("WebSocket limit reached (Global: {}/{}, Plugin: {}/{}). Silencing for 10s.", + ctx.currentWsGlobal, ctx.maxWsGlobal, ctx.stateWsCount[L], ctx.maxWsPerPlugin); + } + return sol::make_object(s, sol::lua_nil); + } + ctx.stateWsCount[L]++; + ctx.currentWsGlobal++; + } + + if (!IsValidWsUrl(url)) { + beammp_lua_warnf("Invalid WebSocket URL: {}. Use 'ws://' or 'wss://'.", url); + std::lock_guard lock(ctx.limitMutex); + ctx.stateWsCount[L]--; + ctx.currentWsGlobal--; + return sol::make_object(s, sol::lua_nil); + } + + auto ws = std::make_shared(url, headers.is() ? headers.as() : sol::table(), L); + + auto pCtx = GetPluginContext(L); + std::lock_guard lock(pCtx->wsMutex); + pCtx->webSockets.push_back(ws); + return sol::make_object(s, ws); +} + +void AsyncWebSocket::VerifySSL(bool verify) { mVerifySSL = verify; } + +void AsyncWebSocket::Connect() { + if (mIsRunning.exchange(true)) return; + + std::weak_ptr weakSelf = shared_from_this(); + std::string url = mUrl; + std::map headers = mHeaders; + bool verifySSL = mVerifySSL; + + mThread = std::thread([weakSelf, url = std::move(url), headers = std::move(headers), verifySSL]() { + httplib::Headers h; + bool hasUA = false; + + for (const auto&[k, v] : headers) { + if (ToLower(k) == "user-agent") hasUA = true; + h.emplace(k, v); + } + if (!hasUA) h.emplace("User-Agent", DEFAULT_USER_AGENT); + + httplib::ws::WebSocketClient client(url, h); + client.enable_server_certificate_verification(verifySSL); + + if (!client.connect()) { + if (auto self = weakSelf.lock()) { + if (self->mIsRunning.exchange(false)) { + self->PushEvent({WSEventType::ERROR_EVENT, "Failed to connect", 0}); + } + } + return; + } + + if (auto self = weakSelf.lock()) { + std::lock_guard lock(self->mClientMutex); + if (self->mAbandoned) return; + + if (!self->mIsRunning) { + return; + } + + self->mClient = &client; + self->PushEvent({WSEventType::OPEN, "", 0}); + } else { + return; + } + + std::string msg; + while (true) { + if (auto self = weakSelf.lock()) { + if (!self->mIsRunning || self->mAbandoned) break; + } else break; + + if (client.read(msg) == httplib::ws::ReadResult::Fail) break; + + if (auto self = weakSelf.lock()) { + if (!self->mIsRunning || self->mAbandoned) break; + self->PushEvent({WSEventType::MESSAGE, msg, 0}); + } else break; + + msg.clear(); + } + + if (auto self = weakSelf.lock()) { + if (self->mIsRunning.exchange(false)) { + self->PushEvent({WSEventType::CLOSE, "Connection closed by peer", 1000}); + } + std::lock_guard lock(self->mClientMutex); + self->mClient = nullptr; + } + }); +} + +void AsyncWebSocket::Send(const std::string& data) { + std::lock_guard lock(mClientMutex); + if (mClient && mClient->is_open()) mClient->send(data); +} + +void AsyncWebSocket::Close() { + if (mIsRunning.exchange(false)) { + PushEvent({WSEventType::CLOSE, "Connection closed locally", 1000}); + } +} + +void AsyncWebSocket::PushEvent(WSEvent ev) { + if (mAbandoned) return; + std::lock_guard lock(mMutex); + mEvents.push(std::move(ev)); +} + +void AsyncWebSocket::OnOpen(sol::object cb) { UnrefCallback(L, mOnOpenRef); mOnOpenRef = MakeRef(cb); } +void AsyncWebSocket::OnMessage(sol::object cb) { UnrefCallback(L, mOnMessageRef); mOnMessageRef = MakeRef(cb); } +void AsyncWebSocket::OnClose(sol::object cb) { UnrefCallback(L, mOnCloseRef); mOnCloseRef = MakeRef(cb); } +void AsyncWebSocket::OnError(sol::object cb) { UnrefCallback(L, mOnErrorRef); mOnErrorRef = MakeRef(cb); } + +void AsyncWebSocket::ProcessEvents() { + if (mAbandoned) return; + + std::queue events; + { + std::lock_guard lock(mMutex); + std::swap(events, mEvents); + } + + while (!events.empty() && !mAbandoned) { + auto ev = events.front(); + events.pop(); + + try { + switch (ev.type) { + case WSEventType::OPEN: + InvokeLuaCallback(L, mOnOpenRef, "WS OnOpen Error"); + break; + case WSEventType::MESSAGE: + InvokeLuaCallback(L, mOnMessageRef, "WS OnMessage Error", ev.payload); + break; + case WSEventType::CLOSE: + InvokeLuaCallback(L, mOnCloseRef, "WS OnClose Error", ev.closeCode, ev.payload); + break; + case WSEventType::ERROR_EVENT: + InvokeLuaCallback(L, mOnErrorRef, "WS OnError Error", ev.payload); + break; + default: + break; + } + } catch (const std::exception& e) { + beammp_lua_errorf("WebSocket Exception: {}", e.what()); + } + } +} + +void AsyncWebSocket::Abandon() { + if (mAbandoned.exchange(true)) return; + + { + std::lock_guard lock(ctx.limitMutex); + ctx.stateWsCount[L]--; + if (ctx.stateWsCount[L] <= 0) ctx.stateWsCount.erase(L); + ctx.currentWsGlobal--; + } + + Close(); + + UnrefCallback(L, mOnOpenRef); + UnrefCallback(L, mOnMessageRef); + UnrefCallback(L, mOnCloseRef); + UnrefCallback(L, mOnErrorRef); +} + +// --- Lifecycle & Lua Bindings --- + +void RegisterBindings(sol::state_view& lua) { + lua.new_usertype("AsyncHttp", sol::no_constructor, + "SetConnectTimeout", &AsyncHttpProxy::SetConnectTimeout, + "SetReadTimeout", &AsyncHttpProxy::SetReadTimeout, + "VerifySSL", &AsyncHttpProxy::VerifySSL, + "SetDefaultHeaders", &AsyncHttpProxy::SetDefaultHeaders, + "Get", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object h, sol::function cb) { return self.Get(ep, h, cb, sol::nil); }, &AsyncHttpProxy::Get), + "Post", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { return self.Post(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Post), + "PostFile", sol::overload([](AsyncHttpProxy& self, std::string ep, std::string fn, std::string fp, sol::function cb) { return self.PostFile(ep, fn, fp, sol::nil, cb); }, &AsyncHttpProxy::PostFile), + "Put", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { return self.Put(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Put), + "Patch", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { return self.Patch(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Patch), + "Delete", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::function cb) { return self.Delete(ep, sol::nil, cb); }, &AsyncHttpProxy::Delete), + "Head", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::function cb) { return self.Head(ep, sol::nil, cb); }, &AsyncHttpProxy::Head), + "Download", sol::overload([](AsyncHttpProxy& self, std::string ep, std::string p, sol::function cb) { return self.Download(ep, p, cb, sol::nil); }, &AsyncHttpProxy::Download) + ); + + lua["AsyncHttp"]["new"] = sol::overload([](std::string url) { return std::make_shared(url, sol::table(sol::lua_nil)); },[](std::string url, sol::table headers) { return std::make_shared(url, headers); } + ); + + lua.new_usertype("AsyncWebSocket", sol::no_constructor, + "Connect", &AsyncWebSocket::Connect, + "Send", &AsyncWebSocket::Send, + "Close", &AsyncWebSocket::Close, + "VerifySSL", &AsyncWebSocket::VerifySSL, + "OnOpen", &AsyncWebSocket::OnOpen, + "OnMessage", &AsyncWebSocket::OnMessage, + "OnClose", &AsyncWebSocket::OnClose, + "OnError", &AsyncWebSocket::OnError + ); + + lua["AsyncWebSocket"]["new"] = &AsyncWebSocket::Create; +} + +void Update(sol::state_view& lua) { + lua_State* L = lua.lua_state(); + auto pCtx = GetPluginContext(L); + + std::deque toProcess; + + // 1. Gather relevant results safely + { + std::lock_guard lock(pCtx->resultsMutex); + toProcess.swap(pCtx->results); + } + + // 2. Dispatch Lua callbacks + for (const auto& res : toProcess) { + std::shared_ptr info; + { + std::lock_guard lock(ctx.pendingRequestsMutex); + auto it = ctx.pendingRequests.find(res.requestId); + if (it != ctx.pendingRequests.end()) info = it->second; + } + + if (!info || info->abandoned.load()) { + if (info) { + ReleasePendingRequest(info); + + std::lock_guard lock(ctx.limitMutex); + ctx.stateRequestCount[L]--; + } + std::lock_guard lock(ctx.pendingRequestsMutex); + ctx.pendingRequests.erase(res.requestId); + continue; + } + + if (res.type == HttpResult::Type::PROGRESS) { + InvokeLuaCallback(L, info->progressRef, "AsyncHttp Progress Error", res.current, res.total); + } else { + if (info->callbackRef != LUA_REFNIL) { + sol::table luaHeaders = lua.create_table(); + for (auto const& [name, values] : res.headers) { + if (values.empty()) continue; + std::string key = ToLower(name); + if (values.size() > 1 || key == "set-cookie") luaHeaders[key] = sol::as_table(values); + else luaHeaders[key] = values[0]; + } + InvokeLuaCallback(L, info->callbackRef, "AsyncHttp Callback Error", res.status, res.body, luaHeaders); + } + + ReleasePendingRequest(info); + + { + std::lock_guard lock(ctx.limitMutex); + ctx.stateRequestCount[L]--; + } + + std::lock_guard lock(ctx.pendingRequestsMutex); + ctx.pendingRequests.erase(res.requestId); + } + } + + // 3. Update WebSockets cleanly without cross-plugin locking + std::vector> websocketsToUpdate; + { + std::lock_guard wsLock(pCtx->wsMutex); + auto wsIt = pCtx->webSockets.begin(); + while (wsIt != pCtx->webSockets.end()) { + if (auto ws = wsIt->lock()) { + websocketsToUpdate.push_back(ws); + ++wsIt; + } else { + wsIt = pCtx->webSockets.erase(wsIt); + } + } + } + + for (auto& ws : websocketsToUpdate) { + ws->ProcessEvents(); + } +} + +void CleanupState(lua_State* L) { + auto pCtx = GetPluginContext(L); + + // Purge pending HTTP requests + { + std::lock_guard lock(ctx.pendingRequestsMutex); + for (auto it = ctx.pendingRequests.begin(); it != ctx.pendingRequests.end(); ) { + if (it->second->L == L) { + it->second->abandoned.store(true); + ReleasePendingRequest(it->second); + it = ctx.pendingRequests.erase(it); + } else { + ++it; + } + } + } + + // Clear out un-polled results + { + std::lock_guard lock(pCtx->resultsMutex); + pCtx->results.clear(); + } + + // Purge attached WebSockets + { + std::lock_guard wsLock(pCtx->wsMutex); + for (auto& weak_ws : pCtx->webSockets) { + if (auto ws = weak_ws.lock()) ws->Abandon(); + } + pCtx->webSockets.clear(); + } + + std::lock_guard lock(ctx.pluginContextsMutex); + ctx.pluginContexts.erase(L); +} + +void Init() { + ctx.shuttingDown.store(false); + int cores = static_cast(std::thread::hardware_concurrency()); + if (cores <= 0) cores = 4; + + ctx.actualPoolSize = std::clamp(cores * 4, 16, 128); + ctx.maxRequestsPerPlugin = std::max(ctx.actualPoolSize / 2, 5); + ctx.threadPool = std::make_unique(ctx.actualPoolSize); + + ctx.maxWsGlobal = std::max(cores * 8, 32); + ctx.maxWsPerPlugin = std::max(ctx.maxWsGlobal / 4, 4); + + beammp_infof("AsyncHttp initialized. HTTP Pool: {} ({} per plugin). WS Quota: {} ({} per plugin).", + ctx.actualPoolSize, ctx.maxRequestsPerPlugin, ctx.maxWsGlobal, ctx.maxWsPerPlugin); +} + +void Shutdown() { + ctx.shuttingDown.store(true); + + if (ctx.threadPool) { + ctx.threadPool->shutdown(); + ctx.threadPool.reset(); + } + + { + std::lock_guard lock(ctx.pendingRequestsMutex); + ctx.pendingRequests.clear(); + } + + { + std::lock_guard lock(ctx.pluginContextsMutex); + for (auto& [L, pCtx] : ctx.pluginContexts) { + std::lock_guard wsLock(pCtx->wsMutex); + for (auto& weak_ws : pCtx->webSockets) { + if (auto ws = weak_ws.lock()) ws->Abandon(); + } + pCtx->webSockets.clear(); + } + ctx.pluginContexts.clear(); + } +} + +} // namespace HttpAsync \ No newline at end of file diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index d7bee3e6..7b998685 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -21,6 +21,7 @@ #include "Common.h" #include "CustomAssert.h" #include "Http.h" +#include "HttpAsync.h" #include "LuaAPI.h" #include "Env.h" #include "Profiling.h" @@ -41,6 +42,7 @@ TLuaEngine::TLuaEngine() : mResourceServerPath(fs::path(Application::Settings.getAsString(Settings::Key::General_ResourceFolder)) / "Server") { Application::SetSubsystemStatus("LuaEngine", Application::Status::Starting); LuaAPI::MP::Engine = this; + HttpAsync::Init(); if (!fs::exists(Application::Settings.getAsString(Settings::Key::General_ResourceFolder))) { fs::create_directory(Application::Settings.getAsString(Settings::Key::General_ResourceFolder)); } @@ -49,6 +51,7 @@ TLuaEngine::TLuaEngine() } Application::RegisterShutdownHandler([&] { Application::SetSubsystemStatus("LuaEngine", Application::Status::ShuttingDown); + HttpAsync::Shutdown(); if (mThread.joinable()) { mThread.join(); } @@ -1070,9 +1073,16 @@ TLuaEngine::StateThreadData::StateThreadData(const std::string& Name, TLuaStateI FSTable.set_function("ListDirectories", [this](const std::string& Path) { return Lua_FS_ListDirectories(Path); }); + HttpAsync::RegisterBindings(mStateView); Start(); } +TLuaEngine::StateThreadData::~StateThreadData() noexcept { + HttpAsync::CleanupState(mState); + + beammp_debug("\"" + mStateId + "\" destroyed"); +} + std::shared_ptr TLuaEngine::StateThreadData::EnqueueScript(const TLuaChunk& Script) { std::unique_lock Lock(mStateExecuteQueueMutex); auto Result = std::make_shared(); @@ -1119,6 +1129,7 @@ void TLuaEngine::StateThreadData::RegisterEvent(const std::string& EventName, co void TLuaEngine::StateThreadData::operator()() { RegisterThread("Lua:" + mStateId); while (!Application::IsShuttingDown()) { + HttpAsync::Update(mStateView); { // StateExecuteQueue Scope std::unique_lock Lock(mStateExecuteQueueMutex); if (!mStateExecuteQueue.empty()) { diff --git a/vcpkg b/vcpkg index 5bf0c552..c3867e71 160000 --- a/vcpkg +++ b/vcpkg @@ -1 +1 @@ -Subproject commit 5bf0c55239da398b8c6f450818c9e28d36bf9966 +Subproject commit c3867e714dd3a51c272826eea77267876517ed99 diff --git a/vcpkg.json b/vcpkg.json index 1814b289..b4d95e35 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -27,5 +27,5 @@ "version": "5.3.5#6" } ], - "builtin-baseline": "5bf0c55239da398b8c6f450818c9e28d36bf9966" + "builtin-baseline": "c3867e714dd3a51c272826eea77267876517ed99" }