From 9ea0931e138e1141e577519aa47f5226effbc236 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Tue, 16 Jan 2024 00:58:14 +0100 Subject: [PATCH] implement udp connection --- cmake/CompilerWarnings.cmake | 1 - include/Network.h | 24 ++++- include/THeartbeatThread.h | 2 - include/TLuaEngine.h | 3 +- src/Http.cpp | 3 +- src/LuaAPI.cpp | 1 - src/Network.cpp | 201 ++++++++++++++++++++++++++++++++++- src/TConsole.cpp | 1 - src/THeartbeatThread.cpp | 1 - src/TLuaEngine.cpp | 1 - src/main.cpp | 10 +- 11 files changed, 224 insertions(+), 24 deletions(-) diff --git a/cmake/CompilerWarnings.cmake b/cmake/CompilerWarnings.cmake index 0e38b5e..c654ec0 100644 --- a/cmake/CompilerWarnings.cmake +++ b/cmake/CompilerWarnings.cmake @@ -71,7 +71,6 @@ function(set_project_warnings project_name) -Werror=write-strings -Werror=strict-aliasing -fstrict-aliasing -Werror=missing-declarations - -Werror=missing-field-initializers -Werror=ctor-dtor-privacy -Wswitch-default -Werror=unused-result diff --git a/include/Network.h b/include/Network.h index f308d9b..c8548aa 100644 --- a/include/Network.h +++ b/include/Network.h @@ -29,6 +29,11 @@ struct Client { ClientID id; bmp::State state { bmp::State::None }; + Sync name; + Sync role; + Sync is_guest; + Sync> identifiers; + /// Reads a single packet from the TCP stream. Blocks all other reads (not writes). Packet tcp_read(); /// Writes the packet to the TCP stream. Blocks all other writes. @@ -37,16 +42,18 @@ struct Client { /// conjunction with something else. Blocks other writes. void tcp_write_file_raw(const std::filesystem::path& path); - Client(ClientID id, class Network& network, ip::tcp::socket&& tcp_socket); + Client(ClientID id, class Network& network, ip::tcp::socket&& tcp_sockem_udp_endpointst); ~Client(); ip::tcp::socket& tcp_socket() { return m_tcp_socket; } - [[nodiscard]] const ip::udp::endpoint& udp_endpoint() const { return m_udp_ep; } - void set_udp_endpoint(const ip::udp::endpoint& ep) { m_udp_ep = ep; } - void start_tcp(); + /// Used to associate the udp socket with this client. + /// This isn't very secure and still allows spoofing of the UDP connection (technically), + /// but better than simply using the ID like the old protocol. + const uint64_t udp_magic; + private: void tcp_main(); @@ -55,7 +62,6 @@ private: std::mutex m_tcp_write_mtx; std::mutex m_udp_read_mtx; - ip::udp::endpoint m_udp_ep; ip::tcp::socket m_tcp_socket; boost::scoped_thread<> m_tcp_thread; @@ -89,6 +95,8 @@ private: Sync> m_clients {}; Sync> m_vehicles {}; + Sync> m_client_magics {}; + Sync> m_udp_endpoints {}; ClientID new_client_id() { static Sync s_id { 0 }; @@ -105,5 +113,11 @@ private: thread_pool m_threadpool {}; Sync m_shutdown { false }; ip::udp::socket m_udp_socket { m_io }; + void handle_identification(ClientID id, const Packet& packet, std::shared_ptr& client); + + void handle_authentication(ClientID id, const Packet& packet, std::shared_ptr& client); + + /// On failure, throws an exception with the error for the client. + static void authenticate_user(const std::string& public_key, std::shared_ptr& client); }; diff --git a/include/THeartbeatThread.h b/include/THeartbeatThread.h index 1063be6..d41c96c 100644 --- a/include/THeartbeatThread.h +++ b/include/THeartbeatThread.h @@ -2,8 +2,6 @@ #include "Common.h" #include "IThreaded.h" -#include "TResourceManager.h" -#include "TServer.h" class THeartbeatThread : public IThreaded { public: diff --git a/include/TLuaEngine.h b/include/TLuaEngine.h index 7e28984..dbd4740 100644 --- a/include/TLuaEngine.h +++ b/include/TLuaEngine.h @@ -1,7 +1,6 @@ #pragma once -#include "TNetwork.h" -#include "TServer.h" +#include "Network.h" #include #include #include diff --git a/src/Http.cpp b/src/Http.cpp index 9f87ea4..99e01ae 100644 --- a/src/Http.cpp +++ b/src/Http.cpp @@ -1,6 +1,5 @@ #include "Http.h" -#include "Client.h" #include "Common.h" #include "CustomAssert.h" #include "LuaAPI.h" @@ -181,7 +180,7 @@ void Http::Server::THttpServerInstance::operator()() try { } } res.set_content( - json { + nlohmann::json { { "ok", SystemsBad == 0 }, } .dump(), diff --git a/src/LuaAPI.cpp b/src/LuaAPI.cpp index 3a92cb0..4328c93 100644 --- a/src/LuaAPI.cpp +++ b/src/LuaAPI.cpp @@ -1,5 +1,4 @@ #include "LuaAPI.h" -#include "Client.h" #include "Common.h" #include "CustomAssert.h" #include "TLuaEngine.h" diff --git a/src/Network.cpp b/src/Network.cpp index d6ce276..db92672 100644 --- a/src/Network.cpp +++ b/src/Network.cpp @@ -2,8 +2,15 @@ #include "ClientInfo.h" #include "Common.h" #include "Environment.h" +#include "Http.h" +#include "LuaAPI.h" #include "ProtocolVersion.h" #include "ServerInfo.h" +#include "TLuaEngine.h" +#include "Util.h" +#include +#include +#include #if defined(BEAMMP_LINUX) #include @@ -30,6 +37,7 @@ Packet Client::tcp_read() { } void Client::tcp_write(const Packet& packet) { + beammp_tracef("Sending 0x{:x} to {}", int(packet.purpose), id); std::unique_lock lock(m_tcp_write_mtx); auto header = packet.header(); std::vector header_data(bmp::Header::SERIALIZED_SIZE); @@ -73,6 +81,7 @@ Client::~Client() { Client::Client(ClientID id, Network& network, ip::tcp::socket&& tcp_socket) : id(id) + , udp_magic(id ^ uint64_t(std::rand()) ^ uint64_t(this)) , m_tcp_socket(std::forward(tcp_socket)) , m_network(network) { beammp_debugf("Client {} created", id); @@ -116,6 +125,8 @@ Packet Network::udp_read(ip::udp::endpoint& out_ep) { beammp_errorf("Flags are not implemented"); return {}; } + packet.data.resize(header.size); + std::copy(s_buffer.begin() + offset, s_buffer.begin() + offset + header.size, packet.data.begin()); return packet; } @@ -183,13 +194,76 @@ void Network::tcp_listen_main() { } void Network::udp_read_main() { + m_udp_socket = ip::udp::socket(m_io, ip::udp::endpoint(ip::udp::v4(), Application::Settings.Port)); while (!*m_shutdown) { + try { + ip::udp::endpoint ep; + auto packet = udp_read(ep); + // special case for new udp connections, only happens once + if (packet.purpose == bmp::Purpose::StartUDP) [[unlikely]] { + auto all = boost::synchronize(m_clients, m_udp_endpoints, m_client_magics); + auto& clients = std::get<0>(all); + auto& endpoints = std::get<1>(all); + auto& magics = std::get<2>(all); + ClientID id = 0xffffffff; + uint64_t recv_magic; + bmp::deserialize(recv_magic, packet.data); + if (magics->contains(recv_magic)) { + id = magics->at(recv_magic); + magics->erase(recv_magic); + } else { + beammp_debugf("Invalid magic received on UDP from [{}]:{}, ignoring.", ep.address().to_string(), ep.port()); + continue; + } + if (clients->contains(id)) { + auto client = clients->at(id); + // check if endpoint already exists for this client! + auto iter = std::find_if(endpoints->begin(), endpoints->end(), [&](const auto& item) { + return item.second == id; + }); + if (iter != endpoints->end()) { + // already exists, malicious attempt! + beammp_debugf("[{}]:{} tried to replace {}'s UDP endpoint, ignoring.", ep.address().to_string(), ep.port(), id); + continue; + } + // not yet set! nice! set! + endpoints->emplace(ep, id); + // now transfer them to the next state + beammp_debugf("Client {} successfully connected via UDP", client->id); + Packet state_change { + .purpose = bmp::Purpose::StateChangeModDownload, + }; + client->tcp_write(state_change); + client->state = bmp::State::ModDownload; + } else { + beammp_warnf("Received magic for client who doesn't exist anymore: {}. Ignoring.", id); + } + } + } catch (const std::exception& e) { + beammp_errorf("Failed to UDP read: {}", e.what()); + } } } void Network::disconnect(ClientID id, const std::string& msg) { beammp_infof("Disconnecting client {}: {}", id, msg); - m_clients->erase(id); + // deadlock-free algorithm to acquire a lock on all these + // this is a little ugly but saves a headache here in the future + auto all = boost::synchronize(m_clients, m_udp_endpoints, m_client_magics); + auto& clients = std::get<0>(all); + auto& endpoints = std::get<1>(all); + auto& magics = std::get<2>(all); + + if (clients->contains(id)) { + auto client = clients->at(id); + beammp_debugf("Removing client udp magic {}", client->udp_magic); + magics->erase(client->udp_magic); + } + std::erase_if(*endpoints, [&](const auto& item) { + const auto& [key, value] = item; + return value == id; + }); + clients->erase(id); } void Network::handle_packet(ClientID id, const Packet& packet) { std::shared_ptr client; @@ -211,6 +285,7 @@ void Network::handle_packet(ClientID id, const Packet& packet) { handle_identification(id, packet, client); break; case bmp::State::Authentication: + handle_authentication(id, packet, client); break; case bmp::State::ModDownload: break; @@ -283,11 +358,131 @@ void Network::handle_identification(ClientID id, const Packet& packet, std::shar .purpose = bmp::StateChangeAuthentication, }; client->tcp_write(auth_state); + client->state = bmp::State::Authentication; break; } default: - beammp_errorf("Got 0x{:x} in state {}. This is not allowed disconnecting the client", uint16_t(packet.purpose), int(client->state)); + beammp_errorf("Got 0x{:x} in state {}. This is not allowed. Disconnecting the client", uint16_t(packet.purpose), int(client->state)); + disconnect(id, "invalid purpose in current state"); + } +} + +void Network::authenticate_user(const std::string& public_key, std::shared_ptr& client) { + nlohmann::json AuthReq {}; + std::string auth_res_str {}; + try { + AuthReq = nlohmann::json { + { "key", public_key } + }; + + auto Target = "/pkToUser"; + unsigned int ResponseCode = 0; + auth_res_str = Http::POST(Application::GetBackendUrlForAuth(), 443, Target, AuthReq.dump(), "application/json", &ResponseCode); + } catch (const std::exception& e) { + beammp_debugf("Invalid key sent by client {}: {}", client->id, e.what()); + throw std::runtime_error("Public key was of an invalid format"); + } + + try { + nlohmann::json auth_response = nlohmann::json::parse(auth_res_str); + + if (auth_response["username"].is_string() && auth_response["roles"].is_string() + && auth_response["guest"].is_boolean() && auth_response["identifiers"].is_array()) { + + *client->name = auth_response["username"]; + *client->role = auth_response["roles"]; + *client->is_guest = auth_response["guest"]; + for (const auto& identifier : auth_response["identifiers"]) { + auto identifier_str = std::string(identifier); + auto identifier_sep_idx = identifier_str.find(':'); + client->identifiers->emplace(identifier_str.substr(0, identifier_sep_idx), identifier_str.substr(identifier_sep_idx + 1)); + } + } else { + beammp_errorf("Invalid authentication data received from authentication backend for client {}", client->id); + throw std::runtime_error("Backend failed to authenticate the client"); + } + } catch (const std::exception& e) { + beammp_errorf("Client {} sent invalid key. Error was: {}", client->id, e.what()); + throw std::runtime_error("Invalid public key"); + } +} + +void Network::handle_authentication(ClientID id, const Packet& packet, std::shared_ptr& client) { + switch (packet.purpose) { + case bmp::Purpose::PlayerPublicKey: { + auto public_key = std::string(packet.data.begin(), packet.data.end()); + try { + authenticate_user(public_key, client); + } catch (const std::exception& e) { + // propragate to client and disconnect + auto err = std::string(e.what()); + beammp_errorf("Client {} failed to authenticate: {}", id, err); + Packet auth_fail_packet { + .purpose = bmp::Purpose::AuthFailed, + .data = std::vector(err.begin(), err.end()), + }; + client->tcp_write(auth_fail_packet); + disconnect(id, err); + return; + } + auto Futures = LuaAPI::MP::Engine->TriggerEvent("onPlayerAuth", "", client->name.get(), client->role.get(), client->is_guest.get(), client->identifiers.get()); + TLuaEngine::WaitForAll(Futures); + bool NotAllowed = std::any_of(Futures.begin(), Futures.end(), + [](const std::shared_ptr& Result) { + return !Result->Error && Result->Result.is() && bool(Result->Result.as()); + }); + std::string Reason; + bool NotAllowedWithReason = std::any_of(Futures.begin(), Futures.end(), + [&Reason](const std::shared_ptr& Result) -> bool { + if (!Result->Error && Result->Result.is()) { + Reason = Result->Result.as(); + return true; + } + return false; + }); + + if (NotAllowed) { + Packet auth_fail_packet { + .purpose = bmp::Purpose::PlayerRejected + }; + client->tcp_write(auth_fail_packet); + disconnect(id, "Rejected by a plugin"); + return; + } else if (NotAllowedWithReason) { + Packet auth_fail_packet { + .purpose = bmp::Purpose::PlayerRejected, + .data = std::vector(Reason.begin(), Reason.end()), + }; + client->tcp_write(auth_fail_packet); + disconnect(id, fmt::format("Rejected by a plugin for reason: {}", Reason)); + return; + } + beammp_debugf("Client {} successfully authenticated as {} '{}'", id, client->role.get(), client->name.get()); + // send auth ok since auth succeeded + Packet auth_ok { + .purpose = bmp::Purpose::AuthOk, + .data = std::vector(4), + }; + // with the player id + bmp::serialize(client->id, auth_ok.data); + client->tcp_write(auth_ok); + + // save the udp magic + m_udp_magics.emplace(client->udp_magic, client->id); + + // send the udp start packet, which should get the client to start udp with + // this packet as the first message + Packet udp_start { + .purpose = bmp::Purpose::StartUDP, + .data = std::vector(8), + }; + bmp::serialize(client->udp_magic, udp_start.data); + client->tcp_write(udp_start); + // player must start udp to advance now, so no state change + break; + } + default: + beammp_errorf("Got 0x{:x} in state {}. This is not allowed. Disconnecting the client", uint16_t(packet.purpose), int(client->state)); disconnect(id, "invalid purpose in current state"); } - } diff --git a/src/TConsole.cpp b/src/TConsole.cpp index 556c5f6..6a67582 100644 --- a/src/TConsole.cpp +++ b/src/TConsole.cpp @@ -2,7 +2,6 @@ #include "Common.h" #include "Compat.h" -#include "Client.h" #include "CustomAssert.h" #include "LuaAPI.h" #include "TLuaEngine.h" diff --git a/src/THeartbeatThread.cpp b/src/THeartbeatThread.cpp index 5a1a71d..955f90d 100644 --- a/src/THeartbeatThread.cpp +++ b/src/THeartbeatThread.cpp @@ -1,6 +1,5 @@ #include "THeartbeatThread.h" -#include "Client.h" #include "Http.h" //#include "SocketIO.h" #include diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index 4c5f16e..eab9007 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -1,5 +1,4 @@ #include "TLuaEngine.h" -#include "Client.h" #include "CustomAssert.h" #include "Http.h" #include "LuaAPI.h" diff --git a/src/main.cpp b/src/main.cpp index d73c6c1..ccf07ec 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -6,12 +6,9 @@ #include "TConfig.h" #include "THeartbeatThread.h" #include "TLuaEngine.h" -#include "TNetwork.h" -#include "TPPSMonitor.h" #include "TPluginMonitor.h" -#include "TResourceManager.h" -#include "TServer.h" +#include #include #include @@ -87,6 +84,9 @@ int BeamMPServerMain(MainArguments Arguments) { return 0; } + // badly seed C's rng - this is only because rand() is used here and there for unimportant stuff + std::srand(std::time(0)); + std::string ConfigPath = "ServerConfig.toml"; if (Parser.FoundArgument({ "config" })) { auto MaybeConfigPath = Parser.GetValueOfArgument({ "config" }); @@ -106,7 +106,7 @@ int BeamMPServerMain(MainArguments Arguments) { } } } - + TConfig Config(ConfigPath); if (Config.Failed()) {