diff --git a/CMakeLists.txt b/CMakeLists.txt index 0527453..be7c9d4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,8 @@ set(PRJ_HEADERS include/VehicleData.h include/Env.h include/Settings.h + include/Profiling.h + include/ChronoWrapper.h ) # add all source files (.cpp) to this, except the one with main() set(PRJ_SOURCES @@ -74,6 +76,8 @@ set(PRJ_SOURCES src/VehicleData.cpp src/Env.cpp src/Settings.cpp + src/Profiling.cpp + src/ChronoWrapper.cpp ) find_package(Lua REQUIRED) @@ -173,6 +177,11 @@ add_library(commandline_static deps/commandline/src/backends/BufferedBackend.cpp deps/commandline/src/backends/BufferedBackend.h ) + +# Ensure the commandline library uses C++11 +set_target_properties(commandline_static PROPERTIES CXX_STANDARD 11 CXX_STANDARD_REQUIRED YES) + + if (WIN32) target_compile_definitions(commandline_static PRIVATE -DPLATFORM_WINDOWS=1) else () @@ -216,4 +225,3 @@ if(${PROJECT_NAME}_ENABLE_UNIT_TESTING) target_link_options(${PROJECT_NAME}-tests PRIVATE "/SUBSYSTEM:CONSOLE") endif(MSVC) endif() - diff --git a/include/ChronoWrapper.h b/include/ChronoWrapper.h new file mode 100644 index 0000000..009e911 --- /dev/null +++ b/include/ChronoWrapper.h @@ -0,0 +1,8 @@ +#pragma once + +#include +#include + +namespace ChronoWrapper { + std::chrono::high_resolution_clock::duration TimeFromStringWithLiteral(const std::string& time_str); +} diff --git a/include/Client.h b/include/Client.h index 63f0316..898a0a9 100644 --- a/include/Client.h +++ b/include/Client.h @@ -128,7 +128,7 @@ private: std::string mRole; std::string mDID; int mID = -1; - std::chrono::time_point mLastPingTime; + std::chrono::time_point mLastPingTime = std::chrono::high_resolution_clock::now(); }; std::optional> GetClient(class TServer& Server, int ID); diff --git a/include/Common.h b/include/Common.h index 259cb1d..64d7194 100644 --- a/include/Common.h +++ b/include/Common.h @@ -18,6 +18,7 @@ #pragma once +#include #include #include #include @@ -61,6 +62,7 @@ public: // types + using TShutdownHandler = std::function; // methods @@ -126,7 +128,7 @@ private: static inline std::mutex mShutdownHandlersMutex {}; static inline std::deque mShutdownHandlers {}; - static inline Version mVersion { 3, 3, 0 }; + static inline Version mVersion { 3, 5, 0 }; }; void SplitString(std::string const& str, const char delim, std::vector& out); diff --git a/include/Profiling.h b/include/Profiling.h new file mode 100644 index 0000000..c0b23c8 --- /dev/null +++ b/include/Profiling.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace prof { + +using Duration = std::chrono::duration; +using TimePoint = std::chrono::high_resolution_clock::time_point; + +/// Returns the current time. +TimePoint now(); + +/// Returns a sub-millisecond resolution duration between start and end. +Duration duration(const TimePoint& start, const TimePoint& end); + +struct Stats { + double mean; + double stdev; + double min; + double max; + size_t n; +}; + +/// Calculates and stores the moving average over K samples of execution time data +/// for some single unit of code. Threadsafe. +struct UnitExecutionTime { + UnitExecutionTime(); + + /// Adds a sample to the collection, overriding the oldest sample if needed. + void add_sample(const Duration& dur); + + /// Calculates the mean duration over the `measurement_count()` measurements, + /// as well as the standard deviation. + Stats stats() const; + + /// Returns the number of elements the moving average is calculated over. + size_t measurement_count() const; + +private: + mutable std::mutex m_mtx {}; + size_t m_total_calls {}; + double m_sum {}; + // sum of measurements squared (for running stdev) + double m_measurement_sqr_sum {}; + double m_min { std::numeric_limits::max() }; + double m_max { std::numeric_limits::min() }; +}; + +/// Holds profiles for multiple units by name. Threadsafe. +struct UnitProfileCollection { + /// Adds a sample to the collection, overriding the oldest sample if needed. + void add_sample(const std::string& unit, const Duration& duration); + + /// Calculates the mean duration over the `measurement_count()` measurements, + /// as well as the standard deviation. + Stats stats(const std::string& unit); + + /// Returns the number of elements the moving average is calculated over. + size_t measurement_count(const std::string& unit); + + /// Returns the stats for all stored units. + std::unordered_map all_stats(); + +private: + boost::synchronized_value> m_map; +}; + +} diff --git a/include/TLuaEngine.h b/include/TLuaEngine.h index b4c14b0..3d74483 100644 --- a/include/TLuaEngine.h +++ b/include/TLuaEngine.h @@ -18,9 +18,11 @@ #pragma once +#include "Profiling.h" #include "TNetwork.h" #include "TServer.h" #include +#include #include #include #include @@ -28,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -36,19 +39,34 @@ #include #define SOL_ALL_SAFETIES_ON 1 +#define SOL_USER_C_ASSERT SOL_ON +#define SOL_C_ASSERT(...) \ + beammp_lua_errorf("SOL2 assertion failure: Assertion `{}` failed in {}:{}. This *should* be a fatal error, but BeamMP Server overrides it to not be fatal. This may cause the Lua Engine to crash, or cause other issues.", #__VA_ARGS__, __FILE__, __LINE__) #include +struct JsonString { + std::string value; +}; + +// value used to keep nils in a table or array, across serialization boundaries like +// JsonEncode, so that the nil stays at the same index and isn't treated like a special +// value (e.g. one that can be ignored or discarded). +const inline std::string BEAMMP_INTERNAL_NIL = "BEAMMP_SERVER_INTERNAL_NIL_VALUE"; + using TLuaStateId = std::string; namespace fs = std::filesystem; /** * std::variant means, that TLuaArgTypes may be one of the Types listed as template args */ -using TLuaArgTypes = std::variant>; -static constexpr size_t TLuaArgTypes_String = 0; -static constexpr size_t TLuaArgTypes_Int = 1; -static constexpr size_t TLuaArgTypes_VariadicArgs = 2; -static constexpr size_t TLuaArgTypes_Bool = 3; -static constexpr size_t TLuaArgTypes_StringStringMap = 4; +using TLuaValue = std::variant, float>; +enum TLuaType { + String = 0, + Int = 1, + Json = 2, + Bool = 3, + StringStringMap = 4, + Float = 5, +}; class TLuaPlugin; @@ -96,7 +114,7 @@ public: struct QueuedFunction { std::string FunctionName; std::shared_ptr Result; - std::vector Args; + std::vector Args; std::string EventName; // optional, may be empty }; @@ -149,7 +167,7 @@ public: void ReportErrors(const std::vector>& Results); bool HasState(TLuaStateId StateId); [[nodiscard]] std::shared_ptr EnqueueScript(TLuaStateId StateID, const TLuaChunk& Script); - [[nodiscard]] std::shared_ptr EnqueueFunctionCall(TLuaStateId StateID, const std::string& FunctionName, const std::vector& Args); + [[nodiscard]] std::shared_ptr EnqueueFunctionCall(TLuaStateId StateID, const std::string& FunctionName, const std::vector& Args); void EnsureStateExists(TLuaStateId StateId, const std::string& Name, bool DontCallOnInit = false); void RegisterEvent(const std::string& EventName, TLuaStateId StateId, const std::string& FunctionName); /** @@ -169,7 +187,7 @@ public: } std::vector> Results; - std::vector Arguments { TLuaArgTypes { std::forward(Args) }... }; + std::vector Arguments { TLuaValue { std::forward(Args) }... }; for (const auto& Event : mLuaEvents.at(EventName)) { for (const auto& Function : Event.second) { @@ -188,7 +206,7 @@ public: return {}; } std::vector> Results; - std::vector Arguments { TLuaArgTypes { std::forward(Args) }... }; + std::vector Arguments { TLuaValue { std::forward(Args) }... }; const auto Handlers = GetEventHandlersForState(EventName, StateId); for (const auto& Handler : Handlers) { Results.push_back(EnqueueFunctionCall(StateId, Handler, Arguments)); @@ -225,8 +243,8 @@ private: StateThreadData(const StateThreadData&) = delete; virtual ~StateThreadData() noexcept { beammp_debug("\"" + mStateId + "\" destroyed"); } [[nodiscard]] std::shared_ptr EnqueueScript(const TLuaChunk& Script); - [[nodiscard]] std::shared_ptr EnqueueFunctionCall(const std::string& FunctionName, const std::vector& Args); - [[nodiscard]] std::shared_ptr EnqueueFunctionCallFromCustomEvent(const std::string& FunctionName, const std::vector& Args, const std::string& EventName, CallStrategy Strategy); + [[nodiscard]] std::shared_ptr EnqueueFunctionCall(const std::string& FunctionName, const std::vector& Args); + [[nodiscard]] std::shared_ptr EnqueueFunctionCallFromCustomEvent(const std::string& FunctionName, const std::vector& Args, const std::string& EventName, CallStrategy Strategy); void RegisterEvent(const std::string& EventName, const std::string& FunctionName); void AddPath(const fs::path& Path); // to be added to path and cpath void operator()() override; @@ -253,6 +271,9 @@ private: sol::table Lua_FS_ListFiles(const std::string& Path); sol::table Lua_FS_ListDirectories(const std::string& Path); + prof::UnitProfileCollection mProfile {}; + std::unordered_map mProfileStarts; + std::string mName; TLuaStateId mStateId; lua_State* mState; @@ -268,6 +289,7 @@ private: std::recursive_mutex mPathsMutex; std::mt19937 mMersenneTwister; std::uniform_real_distribution mUniformRealDistribution01; + std::vector JsonStringToArray(JsonString Str); }; struct TimedEvent { diff --git a/src/ChronoWrapper.cpp b/src/ChronoWrapper.cpp new file mode 100644 index 0000000..8b2b917 --- /dev/null +++ b/src/ChronoWrapper.cpp @@ -0,0 +1,27 @@ +#include "ChronoWrapper.h" +#include "Common.h" +#include + +std::chrono::high_resolution_clock::duration ChronoWrapper::TimeFromStringWithLiteral(const std::string& time_str) +{ + // const std::regex time_regex(R"((\d+\.{0,1}\d*)(min|ms|us|ns|[dhs]))"); //i.e one of: "25ns, 6us, 256ms, 2s, 13min, 69h, 356d" will get matched (only available in newer C++ versions) + const std::regex time_regex(R"((\d+\.{0,1}\d*)(min|[dhs]))"); //i.e one of: "2.01s, 13min, 69h, 356.69d" will get matched + std::smatch match; + float time_value; + if (!std::regex_search(time_str, match, time_regex)) return std::chrono::nanoseconds(0); + time_value = stof(match.str(1)); + beammp_debugf("Parsed time was: {}{}", time_value, match.str(2)); + if (match.str(2) == "d") { + return std::chrono::seconds((uint64_t)(time_value * 86400)); //86400 seconds in a day + } + else if (match.str(2) == "h") { + return std::chrono::seconds((uint64_t)(time_value * 3600)); //3600 seconds in an hour + } + else if (match.str(2) == "min") { + return std::chrono::seconds((uint64_t)(time_value * 60)); + } + else if (match.str(2) == "s") { + return std::chrono::seconds((uint64_t)time_value); + } + return std::chrono::nanoseconds(0); +} diff --git a/src/Common.cpp b/src/Common.cpp index f4a4319..e412a6f 100644 --- a/src/Common.cpp +++ b/src/Common.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include "Compat.h" #include "CustomAssert.h" @@ -382,3 +383,4 @@ void SplitString(const std::string& str, const char delim, std::vector #include -// TODO: Add sentry error handling back - using json = nlohmann::json; +struct Connection { + std::string host{}; + int port{}; + Connection() = default; + Connection(std::string host, int port) + : host(host) + , port(port) {}; +}; +constexpr uint8_t CONNECTION_AMOUNT = 10; +static thread_local uint8_t write_index = 0; +static thread_local std::array connections; +static thread_local std::array, CONNECTION_AMOUNT> clients; + +[[nodiscard]] static std::shared_ptr getClient(Connection connectionInfo) { + for (uint8_t i = 0; i < CONNECTION_AMOUNT; i++) { + if (connectionInfo.host == connections[i].host + && connectionInfo.port == connections[i].port) { + beammp_tracef("Old client reconnected, with ip {} and port {}", connectionInfo.host, connectionInfo.port); + return clients[i]; + } + } + uint8_t i = write_index; + write_index++; + write_index %= CONNECTION_AMOUNT; + clients[i] = std::make_shared(connectionInfo.host, connectionInfo.port); + connections[i] = {connectionInfo.host, connectionInfo.port}; + beammp_tracef("New client connected, with ip {} and port {}", connectionInfo.host, connectionInfo.port); + return clients[i]; +} std::string Http::GET(const std::string& host, int port, const std::string& target, unsigned int* status) { - httplib::SSLClient client(host, port); - client.enable_server_certificate_verification(false); - client.set_address_family(AF_INET); - auto res = client.Get(target.c_str()); + std::shared_ptr client = getClient({host, port}); + client->enable_server_certificate_verification(false); + client->set_address_family(AF_INET); + auto res = client->Get(target.c_str()); if (res) { if (status) { *status = res->status; @@ -48,12 +75,12 @@ std::string Http::GET(const std::string& host, int port, const std::string& targ } std::string Http::POST(const std::string& host, int port, const std::string& target, const std::string& body, const std::string& ContentType, unsigned int* status, const httplib::Headers& headers) { - httplib::SSLClient client(host, port); - client.set_read_timeout(std::chrono::seconds(10)); - beammp_assert(client.is_valid()); - client.enable_server_certificate_verification(false); - client.set_address_family(AF_INET); - auto res = client.Post(target.c_str(), headers, body.c_str(), body.size(), ContentType.c_str()); + std::shared_ptr client = getClient({host, port}); + client->set_read_timeout(std::chrono::seconds(10)); + beammp_assert(client->is_valid()); + client->enable_server_certificate_verification(false); + client->set_address_family(AF_INET); + auto res = client->Post(target.c_str(), headers, body.c_str(), body.size(), ContentType.c_str()); if (res) { if (status) { *status = res->status; diff --git a/src/LuaAPI.cpp b/src/LuaAPI.cpp index 2d1906f..8ec0ec8 100644 --- a/src/LuaAPI.cpp +++ b/src/LuaAPI.cpp @@ -61,7 +61,11 @@ std::string LuaAPI::LuaToString(const sol::object Value, size_t Indent, bool Quo } case sol::type::number: { std::stringstream ss; - ss << Value.as(); + if (Value.is()) { + ss << Value.as(); + } else { + ss << Value.as(); + } return ss.str(); } case sol::type::lua_nil: @@ -562,7 +566,11 @@ static void JsonEncodeRecursive(nlohmann::json& json, const sol::object& left, c key = left.as(); break; case sol::type::number: - key = std::to_string(left.as()); + if (left.is()) { + key = std::to_string(left.as()); + } else { + key = std::to_string(left.as()); + } break; default: beammp_assert_not_reachable(); @@ -590,21 +598,30 @@ static void JsonEncodeRecursive(nlohmann::json& json, const sol::object& left, c case sol::type::string: value = right.as(); break; - case sol::type::number: - value = right.as(); + case sol::type::number: { + if (right.is()) { + value = right.as(); + } else { + value = right.as(); + } break; + } case sol::type::function: beammp_lua_warn("unsure what to do with function in JsonEncode, ignoring"); return; case sol::type::table: { - bool local_is_array = true; - for (const auto& pair : right.as()) { - if (pair.first.get_type() != sol::type::number) { - local_is_array = false; + if (right.as().empty()) { + value = nlohmann::json::object(); + } else { + bool local_is_array = true; + for (const auto& pair : right.as()) { + if (pair.first.get_type() != sol::type::number) { + local_is_array = false; + } + } + for (const auto& pair : right.as()) { + JsonEncodeRecursive(value, pair.first, pair.second, local_is_array, depth + 1); } - } - for (const auto& pair : right.as()) { - JsonEncodeRecursive(value, pair.first, pair.second, local_is_array, depth + 1); } break; } @@ -621,14 +638,18 @@ static void JsonEncodeRecursive(nlohmann::json& json, const sol::object& left, c std::string LuaAPI::MP::JsonEncode(const sol::table& object) { nlohmann::json json; // table - bool is_array = true; - for (const auto& pair : object.as()) { - if (pair.first.get_type() != sol::type::number) { - is_array = false; + if (object.as().empty()) { + json = nlohmann::json::object(); + } else { + bool is_array = true; + for (const auto& pair : object.as()) { + if (pair.first.get_type() != sol::type::number) { + is_array = false; + } + } + for (const auto& entry : object) { + JsonEncodeRecursive(json, entry.first, entry.second, is_array); } - } - for (const auto& entry : object) { - JsonEncodeRecursive(json, entry.first, entry.second, is_array); } return json.dump(); } diff --git a/src/Profiling.cpp b/src/Profiling.cpp new file mode 100644 index 0000000..f6a41d8 --- /dev/null +++ b/src/Profiling.cpp @@ -0,0 +1,60 @@ +#include "Profiling.h" +#include + +prof::Duration prof::duration(const TimePoint& start, const TimePoint& end) { + return end - start; +} +prof::TimePoint prof::now() { + return std::chrono::high_resolution_clock::now(); +} +prof::Stats prof::UnitProfileCollection::stats(const std::string& unit) { + return m_map->operator[](unit).stats(); +} + +size_t prof::UnitProfileCollection::measurement_count(const std::string& unit) { + return m_map->operator[](unit).measurement_count(); +} + +void prof::UnitProfileCollection::add_sample(const std::string& unit, const Duration& duration) { + m_map->operator[](unit).add_sample(duration); +} + +prof::Stats prof::UnitExecutionTime::stats() const { + std::unique_lock lock(m_mtx); + Stats result {}; + // calculate sum + result.n = m_total_calls; + result.max = m_min; + result.min = m_max; + // calculate mean: mean = sum_x / n + result.mean = m_sum / double(m_total_calls); + // calculate stdev: stdev = sqrt((sum_x2 / n) - (mean * mean)) + result.stdev = std::sqrt((m_measurement_sqr_sum / double(result.n)) - (result.mean * result.mean)); + return result; +} + +void prof::UnitExecutionTime::add_sample(const Duration& dur) { + std::unique_lock lock(m_mtx); + m_sum += dur.count(); + m_measurement_sqr_sum += dur.count() * dur.count(); + m_min = std::min(dur.count(), m_min); + m_max = std::max(dur.count(), m_max); + ++m_total_calls; +} + +prof::UnitExecutionTime::UnitExecutionTime() { +} + +std::unordered_map prof::UnitProfileCollection::all_stats() { + auto map = m_map.synchronize(); + std::unordered_map result {}; + for (const auto& [name, time] : *map) { + result[name] = time.stats(); + } + return result; +} +size_t prof::UnitExecutionTime::measurement_count() const { + std::unique_lock lock(m_mtx); + return m_total_calls; +} + diff --git a/src/TConfig.cpp b/src/TConfig.cpp index 66bcd2b..d144989 100644 --- a/src/TConfig.cpp +++ b/src/TConfig.cpp @@ -54,12 +54,15 @@ static constexpr std::string_view StrAuthKey = "AuthKey"; static constexpr std::string_view EnvStrAuthKey = "BEAMMP_AUTH_KEY"; static constexpr std::string_view StrLogChat = "LogChat"; static constexpr std::string_view EnvStrLogChat = "BEAMMP_LOG_CHAT"; +static constexpr std::string_view StrAllowGuests = "AllowGuests"; +static constexpr std::string_view EnvStrAllowGuests = "BEAMMP_ALLOW_GUESTS"; static constexpr std::string_view StrPassword = "Password"; // Misc static constexpr std::string_view StrSendErrors = "SendErrors"; static constexpr std::string_view StrSendErrorsMessageEnabled = "SendErrorsShowMessage"; static constexpr std::string_view StrHideUpdateMessages = "ImScaredOfUpdates"; +static constexpr std::string_view StrUpdateReminderTime = "UpdateReminderTime"; TEST_CASE("TConfig::TConfig") { const std::string CfgFile = "beammp_server_testconfig.toml"; @@ -129,6 +132,8 @@ void TConfig::FlushToFile() { SetComment(data["General"][StrLogChat.data()].comments(), " Whether to log chat messages in the console / log"); data["General"][StrDebug.data()] = Application::Settings.getAsBool(Settings::Key::General_Debug); data["General"][StrPrivate.data()] = Application::Settings.getAsBool(Settings::Key::General_Private); + data["General"][StrAllowGuests.data()] = Application::Settings.AllowGuests; + SetComment(data["General"][StrAllowGuests.data()].comments(), " Whether to allow guests"); data["General"][StrPort.data()] = Application::Settings.getAsInt(Settings::Key::General_Port); data["General"][StrName.data()] = Application::Settings.getAsString(Settings::Key::General_Name); SetComment(data["General"][StrTags.data()].comments(), " Add custom identifying tags to your server to make it easier to find. Format should be TagA,TagB,TagC. Note the comma seperation."); @@ -144,6 +149,8 @@ void TConfig::FlushToFile() { data["Misc"][StrHideUpdateMessages.data()] = Application::Settings.getAsBool(Settings::Key::Misc_ImScaredOfUpdates); SetComment(data["Misc"][StrHideUpdateMessages.data()].comments(), " Hides the periodic update message which notifies you of a new server version. You should really keep this on and always update as soon as possible. For more information visit https://wiki.beammp.com/en/home/server-maintenance#updating-the-server. An update message will always appear at startup regardless."); data["Misc"][StrSendErrors.data()] = Application::Settings.getAsBool(Settings::Key::Misc_SendErrors); + data["Misc"][StrUpdateReminderTime.data()] = Application::Settings.UpdateReminderTime; + SetComment(data["Misc"][StrUpdateReminderTime.data()].comments(), " Specifies the time between update reminders. You can use any of \"s, min, h, d\" at the end to specify the units seconds, minutes, hours or days. So 30d or 0.5min will print the update message every 30 days or half a minute."); SetComment(data["Misc"][StrSendErrors.data()].comments(), " If SendErrors is `true`, the server will send helpful info about crashes and other issues back to the BeamMP developers. This info may include your config, who is on your server at the time of the error, and similar general information. This kind of data is vital in helping us diagnose and fix issues faster. This has no impact on server performance. You can opt-out of this system by setting this to `false`"); data["Misc"][StrSendErrorsMessageEnabled.data()] = Application::Settings.getAsBool(Settings::Key::Misc_SendErrorsShowMessage); SetComment(data["Misc"][StrSendErrorsMessageEnabled.data()].comments(), " You can turn on/off the SendErrors message you get on startup here"); @@ -260,10 +267,12 @@ void TConfig::ParseFromFile(std::string_view name) { TryReadValue(data, "General", StrResourceFolder, EnvStrResourceFolder, Settings::Key::General_ResourceFolder); TryReadValue(data, "General", StrAuthKey, EnvStrAuthKey, Settings::Key::General_AuthKey); TryReadValue(data, "General", StrLogChat, EnvStrLogChat, Settings::Key::General_LogChat); + TryReadValue(data, "General", StrAllowGuests, EnvStrAllowGuests, Application::Settings.AllowGuests); // Misc TryReadValue(data, "Misc", StrSendErrors, "", Settings::Key::Misc_SendErrors); TryReadValue(data, "Misc", StrHideUpdateMessages, "", Settings::Misc_ImScaredOfUpdates); TryReadValue(data, "Misc", StrSendErrorsMessageEnabled, "", Settings::Misc_SendErrorsShowMessage); + TryReadValue(data, "Misc", StrUpdateReminderTime, "", Application::Settings.UpdateReminderTime); } catch (const std::exception& err) { beammp_error("Error parsing config file value: " + std::string(err.what())); @@ -308,6 +317,7 @@ void TConfig::PrintDebug() { beammp_debug(std::string(StrTags) + ": " + TagsAsPrettyArray()); beammp_debug(std::string(StrLogChat) + ": \"" + (Application::Settings.getAsBool(Settings::Key::General_LogChat) ? "true" : "false") + "\""); beammp_debug(std::string(StrResourceFolder) + ": \"" + Application::Settings.getAsString(Settings::Key::General_ResourceFolder) + "\""); + beammp_debug(std::string(StrAllowGuests) + ": \"" + (Application::Settings.AllowGuests ? "true" : "false") + "\""); // special! beammp_debug("Key Length: " + std::to_string(Application::Settings.getAsString(Settings::Key::General_AuthKey).length()) + ""); } diff --git a/src/TConsole.cpp b/src/TConsole.cpp index 68abaf7..356f3a4 100644 --- a/src/TConsole.cpp +++ b/src/TConsole.cpp @@ -26,7 +26,9 @@ #include "TLuaEngine.h" #include +#include #include +#include #include #include #include @@ -240,7 +242,25 @@ void TConsole::Command_Version(const std::string& cmd, const std::vector& args) { diff --git a/src/THeartbeatThread.cpp b/src/THeartbeatThread.cpp index 4148439..fb40669 100644 --- a/src/THeartbeatThread.cpp +++ b/src/THeartbeatThread.cpp @@ -20,6 +20,7 @@ #include "Client.h" #include "Http.h" +#include "ChronoWrapper.h" //#include "SocketIO.h" #include #include @@ -36,15 +37,17 @@ void THeartbeatThread::operator()() { static std::string Last; static std::chrono::high_resolution_clock::time_point LastNormalUpdateTime = std::chrono::high_resolution_clock::now(); + static std::chrono::high_resolution_clock::time_point LastUpdateReminderTime = std::chrono::high_resolution_clock::now(); bool isAuth = false; - size_t UpdateReminderCounter = 0; + std::chrono::high_resolution_clock::duration UpdateReminderTimePassed; + auto UpdateReminderTimeout = ChronoWrapper::TimeFromStringWithLiteral(Application::Settings.UpdateReminderTime); while (!Application::IsShuttingDown()) { - ++UpdateReminderCounter; Body = GenerateCall(); // a hot-change occurs when a setting has changed, to update the backend of that change. auto Now = std::chrono::high_resolution_clock::now(); bool Unchanged = Last == Body; auto TimePassed = (Now - LastNormalUpdateTime); + UpdateReminderTimePassed = (Now - LastUpdateReminderTime); auto Threshold = Unchanged ? 30 : 5; if (TimePassed < std::chrono::seconds(Threshold)) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); @@ -126,7 +129,8 @@ void THeartbeatThread::operator()() { if (isAuth || Application::Settings.getAsBool(Settings::Key::General_Private)) { Application::SetSubsystemStatus("Heartbeat", Application::Status::Good); } - if (!Application::Settings.getAsBool(Settings::Key::Misc_ImScaredOfUpdates) && UpdateReminderCounter % 5) { + if (!Application::Settings.HideUpdateMessages && UpdateReminderTimePassed.count() > UpdateReminderTimeout.count()) { + LastUpdateReminderTime = std::chrono::high_resolution_clock::now(); Application::CheckForUpdates(); } } @@ -145,6 +149,7 @@ std::string THeartbeatThread::GenerateCall() { << "&clientversion=" << std::to_string(Application::ClientMajorVersion()) + ".0" // FIXME: Wtf. << "&name=" << Application::Settings.getAsString(Settings::Key::General_Name) << "&tags=" << Application::Settings.getAsString(Settings::Key::General_Tags) + << "&guests=" << (Application::Settings.AllowGuests ? "true" : "false") << "&modlist=" << mResourceManager.TrimmedList() << "&modstotalsize=" << mResourceManager.MaxModSize() << "&modstotal=" << mResourceManager.ModsLoaded() diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index 121044b..32ad0e5 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -22,11 +22,13 @@ #include "CustomAssert.h" #include "Http.h" #include "LuaAPI.h" +#include "Profiling.h" #include "TLuaPlugin.h" #include "sol/object.hpp" #include #include +#include #include #include #include @@ -64,6 +66,7 @@ void TLuaEngine::operator()() { RegisterThread("LuaEngine"); Application::SetSubsystemStatus("LuaEngine", Application::Status::Good); // lua engine main thread + beammp_infof("Lua v{}.{}.{}", LUA_VERSION_MAJOR, LUA_VERSION_MINOR, LUA_VERSION_RELEASE); CollectAndInitPlugins(); // now call all onInit's auto Futures = TriggerEvent("onInit", ""); @@ -266,7 +269,7 @@ std::vector TLuaEngine::StateThreadData::GetStateTableKeys(const st for (size_t i = 0; i < keys.size(); ++i) { auto obj = current.get(keys.at(i)); - if (obj.get_type() == sol::type::nil) { + if (obj.get_type() == sol::type::lua_nil) { // error break; } else if (i == keys.size() - 1) { @@ -351,7 +354,7 @@ std::shared_ptr TLuaEngine::EnqueueScript(TLuaStateId StateID, const return mLuaStates.at(StateID)->EnqueueScript(Script); } -std::shared_ptr TLuaEngine::EnqueueFunctionCall(TLuaStateId StateID, const std::string& FunctionName, const std::vector& Args) { +std::shared_ptr TLuaEngine::EnqueueFunctionCall(TLuaStateId StateID, const std::string& FunctionName, const std::vector& Args) { std::unique_lock Lock(mLuaStatesMutex); return mLuaStates.at(StateID)->EnqueueFunctionCall(FunctionName, Args); } @@ -360,17 +363,30 @@ void TLuaEngine::CollectAndInitPlugins() { if (!fs::exists(mResourceServerPath)) { fs::create_directories(mResourceServerPath); } - for (const auto& Dir : fs::directory_iterator(mResourceServerPath)) { - auto Path = Dir.path(); - Path = fs::relative(Path); - if (!Dir.is_directory()) { - beammp_error("\"" + Dir.path().string() + "\" is not a directory, skipping"); + + std::vector PluginsEntries; + for (const auto& Entry : fs::directory_iterator(mResourceServerPath)) { + if (Entry.is_directory()) { + PluginsEntries.push_back(Entry); } else { - TLuaPluginConfig Config { Path.stem().string() }; - FindAndParseConfig(Path, Config); - InitializePlugin(Path, Config); + beammp_error("\"" + Entry.path().string() + "\" is not a directory, skipping"); } } + + std::sort(PluginsEntries.begin(), PluginsEntries.end(), [](const fs::path& first, const fs::path& second) { + auto firstStr = first.string(); + auto secondStr = second.string(); + std::transform(firstStr.begin(), firstStr.end(), firstStr.begin(), ::tolower); + std::transform(secondStr.begin(), secondStr.end(), secondStr.begin(), ::tolower); + return firstStr < secondStr; + }); + + for (const auto& Dir : PluginsEntries) { + auto Path = fs::relative(Dir); + TLuaPluginConfig Config { Path.stem().string() }; + FindAndParseConfig(Path, Config); + InitializePlugin(Path, Config); + } } void TLuaEngine::InitializePlugin(const fs::path& Folder, const TLuaPluginConfig& Config) { @@ -431,13 +447,52 @@ std::set TLuaEngine::GetEventHandlersForState(const std::string& Ev return mLuaEvents[EventName][StateId]; } +std::vector TLuaEngine::StateThreadData::JsonStringToArray(JsonString Str) { + auto LocalTable = Lua_JsonDecode(Str.value).as>(); + for (auto& value : LocalTable) { + if (value.is() && value.as() == BEAMMP_INTERNAL_NIL) { + value = sol::object {}; + } + } + return LocalTable; +} + sol::table TLuaEngine::StateThreadData::Lua_TriggerGlobalEvent(const std::string& EventName, sol::variadic_args EventArgs) { - auto Return = mEngine->TriggerEvent(EventName, mStateId, EventArgs); + auto Table = mStateView.create_table(); + int i = 1; + for (auto Arg : EventArgs) { + switch (Arg.get_type()) { + case sol::type::none: + case sol::type::userdata: + case sol::type::lightuserdata: + case sol::type::thread: + case sol::type::function: + case sol::type::poly: + Table.set(i, BEAMMP_INTERNAL_NIL); + beammp_warnf("Passed a value of type '{}' to TriggerGlobalEvent(\"{}\", ...). This type can not be serialized, and cannot be passed between states. It will arrive as in handlers.", sol::type_name(EventArgs.lua_state(), Arg.get_type()), EventName); + break; + case sol::type::lua_nil: + Table.set(i, BEAMMP_INTERNAL_NIL); + break; + case sol::type::string: + case sol::type::number: + case sol::type::boolean: + case sol::type::table: + Table.set(i, Arg); + break; + } + ++i; + } + JsonString Str { LuaAPI::MP::JsonEncode(Table) }; + beammp_debugf("json: {}", Str.value); + auto Return = mEngine->TriggerEvent(EventName, mStateId, Str); auto MyHandlers = mEngine->GetEventHandlersForState(EventName, mStateId); + + sol::variadic_results LocalArgs = JsonStringToArray(Str); for (const auto& Handler : MyHandlers) { auto Fn = mStateView[Handler]; if (Fn.valid()) { - auto LuaResult = Fn(EventArgs); + auto LuaResult = Fn(LocalArgs); auto Result = std::make_shared(); if (LuaResult.valid()) { Result->Error = false; @@ -468,11 +523,13 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerGlobalEvent(const std::string sol::state_view StateView(mState); sol::table Result = StateView.create_table(); auto Vector = Self.get>>("ReturnValueImpl"); + int i = 1; for (const auto& Value : Vector) { if (!Value->Ready) { return sol::lua_nil; } - Result.add(Value->Result); + Result.set(i, Value->Result); + ++i; } return Result; }); @@ -482,12 +539,14 @@ sol::table TLuaEngine::StateThreadData::Lua_TriggerGlobalEvent(const std::string sol::table TLuaEngine::StateThreadData::Lua_TriggerLocalEvent(const std::string& EventName, sol::variadic_args EventArgs) { // TODO: make asynchronous? sol::table Result = mStateView.create_table(); + int i = 1; for (const auto& Handler : mEngine->GetEventHandlersForState(EventName, mStateId)) { auto Fn = mStateView[Handler]; if (Fn.valid() && Fn.get_type() == sol::type::function) { auto FnRet = Fn(EventArgs); if (FnRet.valid()) { - Result.add(FnRet); + Result.set(i, FnRet); + ++i; } else { sol::error Err = FnRet; beammp_lua_error(std::string("TriggerLocalEvent: ") + Err.what()); @@ -659,6 +718,7 @@ static void AddToTable(sol::table& table, const std::string& left, const T& valu static void JsonDecodeRecursive(sol::state_view& StateView, sol::table& table, const std::string& left, const nlohmann::json& right) { switch (right.type()) { case nlohmann::detail::value_t::null: + AddToTable(table, left, sol::lua_nil_t {}); return; case nlohmann::detail::value_t::object: { auto value = table.create(); @@ -882,6 +942,30 @@ TLuaEngine::StateThreadData::StateThreadData(const std::string& Name, TLuaStateI UtilTable.set_function("RandomIntRange", [this](int64_t min, int64_t max) -> int64_t { return std::uniform_int_distribution(min, max)(mMersenneTwister); }); + UtilTable.set_function("DebugExecutionTime", [this]() -> sol::table { + sol::state_view StateView(mState); + sol::table Result = StateView.create_table(); + auto stats = mProfile.all_stats(); + for (const auto& [name, stat] : stats) { + Result[name] = StateView.create_table(); + Result[name]["mean"] = stat.mean; + Result[name]["stdev"] = stat.stdev; + Result[name]["min"] = stat.min; + Result[name]["max"] = stat.max; + Result[name]["n"] = stat.n; + } + return Result; + }); + UtilTable.set_function("DebugStartProfile", [this](const std::string& name) { + mProfileStarts[name] = prof::now(); + }); + UtilTable.set_function("DebugStopProfile", [this](const std::string& name) { + if (!mProfileStarts.contains(name)) { + beammp_lua_errorf("DebugStopProfile('{}') failed, because a profile for '{}' wasn't started", name, name); + return; + } + mProfile.add_sample(name, prof::duration(mProfileStarts.at(name), prof::now())); + }); auto HttpTable = StateView.create_named_table("Http"); HttpTable.set_function("CreateConnection", [this](const std::string& host, uint16_t port) { @@ -929,7 +1013,7 @@ std::shared_ptr TLuaEngine::StateThreadData::EnqueueScript(const TLu return Result; } -std::shared_ptr TLuaEngine::StateThreadData::EnqueueFunctionCallFromCustomEvent(const std::string& FunctionName, const std::vector& Args, const std::string& EventName, CallStrategy Strategy) { +std::shared_ptr TLuaEngine::StateThreadData::EnqueueFunctionCallFromCustomEvent(const std::string& FunctionName, const std::vector& Args, const std::string& EventName, CallStrategy Strategy) { // TODO: Document all this decltype(mStateFunctionQueue)::iterator Iter = mStateFunctionQueue.end(); if (Strategy == CallStrategy::BestEffort) { @@ -951,7 +1035,7 @@ std::shared_ptr TLuaEngine::StateThreadData::EnqueueFunctionCallFrom } } -std::shared_ptr TLuaEngine::StateThreadData::EnqueueFunctionCall(const std::string& FunctionName, const std::vector& Args) { +std::shared_ptr TLuaEngine::StateThreadData::EnqueueFunctionCall(const std::string& FunctionName, const std::vector& Args) { auto Result = std::make_shared(); Result->StateId = mStateId; Result->Function = FunctionName; @@ -1019,6 +1103,7 @@ void TLuaEngine::StateThreadData::operator()() { std::chrono::milliseconds(500), [&]() -> bool { return !mStateFunctionQueue.empty(); }); if (NotExpired) { + auto ProfStart = prof::now(); auto TheQueuedFunction = std::move(mStateFunctionQueue.front()); mStateFunctionQueue.erase(mStateFunctionQueue.begin()); Lock.unlock(); @@ -1036,19 +1121,21 @@ void TLuaEngine::StateThreadData::operator()() { continue; } switch (Arg.index()) { - case TLuaArgTypes_String: + case TLuaType::String: LuaArgs.push_back(sol::make_object(StateView, std::get(Arg))); break; - case TLuaArgTypes_Int: + case TLuaType::Int: LuaArgs.push_back(sol::make_object(StateView, std::get(Arg))); break; - case TLuaArgTypes_VariadicArgs: - LuaArgs.push_back(sol::make_object(StateView, std::get(Arg))); + case TLuaType::Json: { + auto LocalArgs = JsonStringToArray(std::get(Arg)); + LuaArgs.insert(LuaArgs.end(), LocalArgs.begin(), LocalArgs.end()); break; - case TLuaArgTypes_Bool: + } + case TLuaType::Bool: LuaArgs.push_back(sol::make_object(StateView, std::get(Arg))); break; - case TLuaArgTypes_StringStringMap: { + case TLuaType::StringStringMap: { auto Map = std::get>(Arg); auto Table = StateView.create_table(); for (const auto& [k, v] : Map) { @@ -1077,6 +1164,9 @@ void TLuaEngine::StateThreadData::operator()() { Result->ErrorMessage = BeamMPFnNotFoundError; // special error kind that we can ignore later Result->MarkAsReady(); } + auto ProfEnd = prof::now(); + auto ProfDuration = prof::duration(ProfStart, ProfEnd); + mProfile.add_sample(FnName, ProfDuration); } } } diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index 114d59c..a142f75 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -209,7 +209,7 @@ void TNetwork::Identify(TConnection&& RawConnection) { } else { beammp_errorf("Invalid code got in Identify: '{}'", Code); } - } catch(const std::exception& e) { + } catch (const std::exception& e) { beammp_errorf("Error during handling of code {} - client left in invalid state, closing socket", Code); boost::system::error_code ec; RawConnection.Socket.shutdown(socket_base::shutdown_both, ec); @@ -278,7 +278,7 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { return nullptr; } - if (!TCPSend(*Client, StringToVector("A"))) { //changed to A for Accepted version + if (!TCPSend(*Client, StringToVector("A"))) { // changed to A for Accepted version // TODO: handle } @@ -289,16 +289,21 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { return nullptr; } - std::string key(reinterpret_cast(Data.data()), Data.size()); - - nlohmann::json AuthReq{}; - std::string AuthResStr{}; + std::string Key(reinterpret_cast(Data.data()), Data.size()); + std::string AuthKey = Application::Settings.Key; + std::string ClientIp = Client->GetIdentifiers().at("ip"); + + nlohmann::json AuthReq {}; + std::string AuthResStr {}; try { AuthReq = nlohmann::json { - { "key", key } + { "key", Key }, + { "auth_key", AuthKey }, + { "client_ip", ClientIp } }; auto Target = "/pkToUser"; + unsigned int ResponseCode = 0; AuthResStr = Http::POST(Application::GetBackendUrlForAuth(), 443, Target, AuthReq.dump(), "application/json", &ResponseCode); @@ -368,6 +373,11 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { return false; }); + if (!NotAllowedWithReason && !Application::Settings.AllowGuests && Client->IsGuest()) { //!NotAllowedWithReason because this message has the lowest priority + NotAllowedWithReason = true; + Reason = "No guests are allowed on this server! To join, sign up at: forum.beammp.com."; + } + if (NotAllowed) { ClientKick(*Client, "you are not allowed on the server!"); return {}; diff --git a/src/TPPSMonitor.cpp b/src/TPPSMonitor.cpp index b891a22..461dfbf 100644 --- a/src/TPPSMonitor.cpp +++ b/src/TPPSMonitor.cpp @@ -76,7 +76,7 @@ void TPPSMonitor::operator()() { return true; }); for (auto& ClientToKick : TimedOutClients) { - Network().ClientKick(*ClientToKick, "Timeout (no ping for way too long)"); + ClientToKick->Disconnect("Timeout"); } TimedOutClients.clear(); if (C == 0 || mInternalPPS == 0) { diff --git a/test/Server/JsonTests/main.lua b/test/Server/JsonTests/main.lua new file mode 100644 index 0000000..fef09b9 --- /dev/null +++ b/test/Server/JsonTests/main.lua @@ -0,0 +1,66 @@ +local function assert_eq(x, y, explain) + if x ~= y then + print("assertion '"..explain.."' failed:\n\tgot:\t", x, "\n\texpected:", y) + end +end + +---@param o1 any|table First object to compare +---@param o2 any|table Second object to compare +---@param ignore_mt boolean True to ignore metatables (a recursive function to tests tables inside tables) +function equals(o1, o2, ignore_mt) + if o1 == o2 then return true end + local o1Type = type(o1) + local o2Type = type(o2) + if o1Type ~= o2Type then return false end + if o1Type ~= 'table' then return false end + + if not ignore_mt then + local mt1 = getmetatable(o1) + if mt1 and mt1.__eq then + --compare using built in method + return o1 == o2 + end + end + + local keySet = {} + + for key1, value1 in pairs(o1) do + local value2 = o2[key1] + if value2 == nil or equals(value1, value2, ignore_mt) == false then + return false + end + keySet[key1] = true + end + + for key2, _ in pairs(o2) do + if not keySet[key2] then return false end + end + return true +end + + +local function assert_table_eq(x, y, explain) + if not equals(x, y, true) then + print("assertion '"..explain.."' failed:\n\tgot:\t", x, "\n\texpected:", y) + end +end + +assert_eq(Util.JsonEncode({1, 2, 3, 4, 5}), "[1,2,3,4,5]", "table to array") +assert_eq(Util.JsonEncode({"a", 1, 2, 3, 4, 5}), '["a",1,2,3,4,5]', "table to array") +assert_eq(Util.JsonEncode({"a", 1, 2.0, 3, 4, 5}), '["a",1,2.0,3,4,5]', "table to array") +assert_eq(Util.JsonEncode({hello="world", john={doe = 1, jane = 2.5, mike = {2, 3, 4}}, dave={}}), '{"dave":{},"hello":"world","john":{"doe":1,"jane":2.5,"mike":[2,3,4]}}', "table to obj") +assert_eq(Util.JsonEncode({a = nil}), "{}", "null obj member") +assert_eq(Util.JsonEncode({1, nil, 3}), "[1,3]", "null array member") +assert_eq(Util.JsonEncode({}), "{}", "empty array/table") +assert_eq(Util.JsonEncode({1234}), "[1234]", "int") +assert_eq(Util.JsonEncode({1234.0}), "[1234.0]", "double") + +assert_table_eq(Util.JsonDecode("[1,2,3,4,5]"), {1, 2, 3, 4, 5}, "decode table to array") +assert_table_eq(Util.JsonDecode('["a",1,2,3,4,5]'), {"a", 1, 2, 3, 4, 5}, "decode table to array") +assert_table_eq(Util.JsonDecode('["a",1,2.0,3,4,5]'), {"a", 1, 2.0, 3, 4, 5}, "decode table to array") +assert_table_eq(Util.JsonDecode('{"dave":{},"hello":"world","john":{"doe":1,"jane":2.5,"mike":[2,3,4]}}'), {hello="world", john={doe = 1, jane = 2.5, mike = {2, 3, 4}}, dave={}}, "decode table to obj") +assert_table_eq(Util.JsonDecode("{}"), {a = nil}, "decode null obj member") +assert_table_eq(Util.JsonDecode("[1,3]"), {1, 3}, "decode null array member") +assert_table_eq(Util.JsonDecode("{}"), {}, "decode empty array/table") +assert_table_eq(Util.JsonDecode("[1234]"), {1234}, "decode int") +assert_table_eq(Util.JsonDecode("[1234.0]"), {1234.0}, "decode double") diff --git a/vcpkg b/vcpkg index 326d8b4..6978381 160000 --- a/vcpkg +++ b/vcpkg @@ -1 +1 @@ -Subproject commit 326d8b43e365352ba3ccadf388d989082fe0f2a6 +Subproject commit 6978381401d33a5ad6a3385895d12e383083712a