From 5c2a0db7efe608f7f5d49e0e88634b94162439ea Mon Sep 17 00:00:00 2001 From: SaltySnail Date: Tue, 31 Mar 2026 22:32:57 +0200 Subject: [PATCH] Added timeout on read --- include/TNetwork.h | 17 +++-- src/TNetwork.cpp | 155 ++++++++++++++++++++++++++++++++------------- 2 files changed, 123 insertions(+), 49 deletions(-) diff --git a/include/TNetwork.h b/include/TNetwork.h index 81931d9..d5e5473 100644 --- a/include/TNetwork.h +++ b/include/TNetwork.h @@ -34,9 +34,11 @@ public: [[nodiscard]] bool TCPSend(TClient& c, const std::vector& Data, bool IsSync = false); [[nodiscard]] bool SendLarge(TClient& c, std::vector Data, bool isSync = false); [[nodiscard]] bool Respond(TClient& c, const std::vector& MSG, bool Rel, bool isSync = false); - std::shared_ptr CreateClient(ip::tcp::socket&& TCPSock); + std::shared_ptr CreateClient(boost::asio::ip::tcp::socket&& TCPSock); std::vector TCPRcv(TClient& c); void ClientKick(TClient& c, const std::string& R); + void DisconnectClient(const std::weak_ptr& c, const std::string& R); + void DisconnectClient(TClient& c, const std::string& R); [[nodiscard]] bool SyncClient(const std::weak_ptr& c); void Identify(TConnection&& client); std::shared_ptr Authentication(TConnection&& ClientConnection); @@ -44,6 +46,7 @@ public: [[nodiscard]] bool UDPSend(TClient& Client, std::vector Data); void SendToAll(TClient* c, const std::vector& Data, bool Self, bool Rel); void UpdatePlayer(TClient& Client); + boost::system::error_code ReadWithTimeout(TConnection& Connection, void* Buf, size_t Len, std::chrono::steady_clock::duration Timeout); TResourceManager& ResourceManager() const { return mResourceManager; } @@ -53,13 +56,15 @@ private: TServer& mServer; TPPSMonitor& mPPSMonitor; - ip::udp::socket mUDPSock; + boost::asio::ip::udp::socket mUDPSock; TResourceManager& mResourceManager; std::thread mUDPThread; std::thread mTCPThread; std::mutex mOpenIDMutex; + std::map mClientMap; + std::mutex mClientMapMutex; - std::vector UDPRcvFromClient(ip::udp::endpoint& ClientEndpoint); + std::vector UDPRcvFromClient(boost::asio::ip::udp::endpoint& ClientEndpoint); void OnConnect(const std::weak_ptr& c); void TCPClient(const std::weak_ptr& c); void Looper(const std::weak_ptr& c); @@ -67,9 +72,9 @@ private: void OnDisconnect(const std::weak_ptr& ClientPtr); void Parse(TClient& c, const std::vector& Packet); void SendFile(TClient& c, const std::string& Name); - static bool TCPSendRaw(TClient& C, ip::tcp::socket& socket, const uint8_t* Data, size_t Size); - static void SendFileToClient(TClient& c, size_t Size, const std::string& Name); - static const uint8_t* SendSplit(TClient& c, ip::tcp::socket& Socket, const uint8_t* DataPtr, size_t Size); + static bool TCPSendRaw(TClient& C, boost::asio::ip::tcp::socket& socket, const uint8_t* Data, size_t Size); + void SendFileToClient(TClient& c, size_t Size, const std::string& Name); + static const uint8_t* SendSplit(TClient& c, boost::asio::ip::tcp::socket& Socket, const uint8_t* DataPtr, size_t Size); }; std::string HashPassword(const std::string& str); diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index 9c88dea..8f006c6 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -38,6 +38,9 @@ typedef boost::asio::detail::socket_option::integer rcv_timeout_option; +static constexpr uint8_t MAX_CONCURRENT_CONNECTIONS = 10; +static constexpr uint8_t MAX_GLOBAL_CONNECTIONS = 128; + std::vector StringToVector(const std::string& Str) { return std::vector(Str.data(), Str.data() + Str.size()); } @@ -89,14 +92,14 @@ void TNetwork::UDPServerMain() { RegisterThread("UDPServer"); boost::system::error_code ec; - auto address = ip::make_address(Application::Settings.getAsString(Settings::Key::General_IP), ec); + auto address = boost::asio::ip::make_address(Application::Settings.getAsString(Settings::Key::General_IP), ec); if (ec) { beammp_errorf("Failed to parse IP: {}", ec.message()); Application::GracefullyShutdown(); } - ip::udp::endpoint UdpListenEndpoint(address, Application::Settings.getAsInt(Settings::Key::General_Port)); + boost::asio::ip::udp::endpoint UdpListenEndpoint(address, Application::Settings.getAsInt(Settings::Key::General_Port)); mUDPSock.open(UdpListenEndpoint.protocol(), ec); if (ec) { @@ -121,13 +124,13 @@ void TNetwork::UDPServerMain() { + std::to_string(Application::Settings.getAsInt(Settings::Key::General_MaxPlayers)) + (" Clients")); while (!Application::IsShuttingDown()) { try { - ip::udp::endpoint remote_client_ep {}; + boost::asio::ip::udp::endpoint remote_client_ep {}; std::vector Data = UDPRcvFromClient(remote_client_ep); if (Data.empty()) { continue; } if (Data.size() == 1 && Data.at(0) == 'P') { - mUDPSock.send_to(const_buffer("P", 1), remote_client_ep, {}, ec); + mUDPSock.send_to(boost::asio::const_buffer("P", 1), remote_client_ep, {}, ec); // ignore errors (void)ec; continue; @@ -148,7 +151,7 @@ void TNetwork::UDPServerMain() { } if (Client->GetID() == ID) { - if (Client->GetUDPAddr() == ip::udp::endpoint {} && !Client->IsUDPConnected() && !Client->GetMagic().empty()) { + if (Client->GetUDPAddr() == boost::asio::ip::udp::endpoint {} && !Client->IsUDPConnected() && !Client->GetMagic().empty()) { if (Data.size() != 66) { beammp_debugf("Invalid size for UDP value. IP: {} ID: {}", remote_client_ep.address().to_string(), ID); return false; @@ -188,16 +191,16 @@ void TNetwork::TCPServerMain() { RegisterThread("TCPServer"); boost::system::error_code ec; - auto address = ip::make_address(Application::Settings.getAsString(Settings::Key::General_IP), ec); + auto address = boost::asio::ip::make_address(Application::Settings.getAsString(Settings::Key::General_IP), ec); if (ec) { beammp_errorf("Failed to parse IP: {}", ec.message()); return; } - ip::tcp::endpoint ListenEp(address, + boost::asio::ip::tcp::endpoint ListenEp(address, uint16_t(Application::Settings.getAsInt(Settings::Key::General_Port))); - ip::tcp::socket Listener(mServer.IoCtx()); + boost::asio::ip::tcp::socket Listener(mServer.IoCtx()); Listener.open(ListenEp.protocol(), ec); if (ec) { beammp_errorf("Failed to open socket: {}", ec.message()); @@ -222,7 +225,7 @@ void TNetwork::TCPServerMain() { ec.message()); } - ip::tcp::acceptor Acceptor(mServer.IoCtx(), ListenEp); + boost::asio::ip::tcp::acceptor Acceptor(mServer.IoCtx(), ListenEp); Acceptor.listen(socket_base::max_listen_connections, ec); if (ec) { beammp_errorf("listen() failed, which is needed for the server to operate. " @@ -239,12 +242,24 @@ void TNetwork::TCPServerMain() { beammp_debug("shutdown during TCP wait for accept loop"); break; } - ip::tcp::endpoint ClientEp; - ip::tcp::socket ClientSocket = Acceptor.accept(ClientEp, ec); + boost::asio::ip::tcp::endpoint ClientEp; + boost::asio::ip::tcp::socket ClientSocket = Acceptor.accept(ClientEp, ec); + std::string ClientIP = ClientEp.address().to_string(); if (!ec) { - TConnection Conn { std::move(ClientSocket), ClientEp }; - std::thread ID(&TNetwork::Identify, this, std::move(Conn)); - ID.detach(); // TODO: Add to a queue and attempt to join periodically + mClientMapMutex.lock(); + if (mClientMap[ClientIP] >= MAX_CONCURRENT_CONNECTIONS) { + beammp_debugf("The connection was rejected for {}, as it had {} concurrent connections.", ClientIP, mClientMap[ClientIP]); + } + else if (mClientMap.size() >= MAX_GLOBAL_CONNECTIONS) { + beammp_debugf("The connection was rejected for {}, as there are {} global connections.", ClientIP, mClientMap.size()); + } + else { + TConnection Conn { std::move(ClientSocket), ClientEp }; + std::thread ID(&TNetwork::Identify, this, std::move(Conn)); + ID.detach(); // TODO: Add to a queue and attempt to join periodically + mClientMap[ClientIP]++; + } + mClientMapMutex.unlock(); } else { beammp_errorf("Failed to accept() new client: {}", ec.message()); @@ -265,7 +280,7 @@ void TNetwork::Identify(TConnection&& RawConnection) { char Code; boost::system::error_code ec; - read(RawConnection.Socket, buffer(&Code, 1), ec); + ReadWithTimeout(RawConnection, &Code, 1, std::chrono::seconds(10)); if (ec) { // TODO: is this right?! RawConnection.Socket.shutdown(socket_base::shutdown_both, ec); @@ -279,8 +294,7 @@ void TNetwork::Identify(TConnection&& RawConnection) { beammp_errorf("Old download packet detected - the client is wildly out of date, this will be ignored"); return; } else if (Code == 'P') { - boost::system::error_code ec; - write(RawConnection.Socket, buffer("P"), ec); + boost::asio::write(RawConnection.Socket, boost::asio::buffer("P"), ec); return; } else if (Code == 'I') { const std::string Data = Application::Settings.getAsBool(Settings::Key::General_InformationPacket) ? THeartbeatThread::lastCall : ""; @@ -292,14 +306,14 @@ void TNetwork::Identify(TConnection&& RawConnection) { std::memcpy(ToSend.data() + sizeof(Size), Data.data(), Data.size()); boost::system::error_code ec; - write(RawConnection.Socket, buffer(ToSend), ec); + boost::asio::write(RawConnection.Socket, boost::asio::buffer(ToSend), ec); } else { beammp_errorf("Invalid code got in Identify: '{}'", Code); } } catch (const std::exception& e) { beammp_errorf("Error during handling of code {} - client left in invalid state, closing socket: {}", Code, e.what()); boost::system::error_code ec; - RawConnection.Socket.shutdown(socket_base::shutdown_both, ec); + RawConnection.Socket.shutdown(boost::asio::socket_base::shutdown_both, ec); if (ec) { beammp_debugf("Failed to shutdown client socket: {}", ec.message()); } @@ -397,7 +411,7 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { try { nlohmann::json AuthRes = nlohmann::json::parse(AuthResStr); - if (AuthRes["username"].is_string() && AuthRes["roles"].is_string() + if (AuthRes["username"].is_string() && AuthRes["username"].size() > 0 && AuthRes["roles"].is_string() && AuthRes["guest"].is_boolean() && AuthRes["identifiers"].is_array()) { Client->SetName(AuthRes["username"]); @@ -431,7 +445,7 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { return true; } if (Cl->GetName() == Client->GetName() && Cl->IsGuest() == Client->IsGuest()) { - Cl->Disconnect("Stale Client (not a real player)"); + DisconnectClient(Cl, "Stale Client (not a real player)"); return false; } @@ -494,7 +508,7 @@ std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { return Client; } -std::shared_ptr TNetwork::CreateClient(ip::tcp::socket&& TCPSock) { +std::shared_ptr TNetwork::CreateClient(boost::asio::ip::tcp::socket&& TCPSock) { auto c = std::make_shared(mServer, std::move(TCPSock)); return c; } @@ -527,10 +541,10 @@ bool TNetwork::TCPSend(TClient& c, const std::vector& Data, bool IsSync std::memcpy(ToSend.data(), &Size, sizeof(Size)); std::memcpy(ToSend.data() + sizeof(Size), Data.data(), Data.size()); boost::system::error_code ec; - write(Sock, buffer(ToSend), ec); + boost::asio::write(Sock, boost::asio::buffer(ToSend), ec); if (ec) { beammp_debugf("write(): {}", ec.message()); - c.Disconnect("write() failed"); + DisconnectClient(c, "write() failed"); return false; } c.UpdatePingTime(); @@ -548,7 +562,7 @@ std::vector TNetwork::TCPRcv(TClient& c) { boost::system::error_code ec; std::array HeaderData; - read(Sock, buffer(HeaderData), ec); + boost::asio::read(Sock, boost::asio::buffer(HeaderData), ec); if (ec) { // TODO: handle this case (read failed) beammp_debugf("TCPRcv: Reading header failed: {}", ec.message()); @@ -564,14 +578,16 @@ std::vector TNetwork::TCPRcv(TClient& c) { std::vector Data; // TODO: This is arbitrary, this needs to be handled another way - if (Header < int32_t(100 * MB)) { + bool isUnauthenticated = c.GetName().empty(); + int32_t maxHeaderSize = isUnauthenticated ? 4096 : int32_t(100 * MB); + if (Header < maxHeaderSize) { Data.resize(Header); } else { ClientKick(c, "Header size limit exceeded"); - beammp_warn("Client " + c.GetName() + " (" + std::to_string(c.GetID()) + ") sent header of >100MB - assuming malicious intent and disconnecting the client."); + beammp_warn("Client " + c.GetName() + " (" + std::to_string(c.GetID()) + ") sent header larger than expected - assuming malicious intent and disconnecting the client."); return {}; } - auto N = read(Sock, buffer(Data), ec); + auto N = boost::asio::read(Sock, boost::asio::buffer(Data), ec); if (ec) { // TODO: handle this case properly beammp_debugf("TCPRcv: Reading data failed: {}", ec.message()); @@ -606,7 +622,32 @@ void TNetwork::ClientKick(TClient& c, const std::string& R) { if (!TCPSend(c, StringToVector("K" + R))) { beammp_debugf("tried to kick player '{}' (id {}), but was already disconnected", c.GetName(), c.GetID()); } - c.Disconnect("Kicked"); + DisconnectClient(c, "Kicked"); +} + +void TNetwork::DisconnectClient(const std::weak_ptr &c, const std::string &R) +{ + if (auto locked = c.lock()) { + DisconnectClient(*locked, R); + } + else { + beammp_debugf("Tried to disconnect a non existant client with reason: {}", R); + } +} + +void TNetwork::DisconnectClient(TClient &c, const std::string &R) +{ + if (c.IsDisconnected()) return; + std::string ClientIP = c.GetTCPSock().remote_endpoint().address().to_string(); + mClientMapMutex.lock(); + if (mClientMap[ClientIP] > 0) { + mClientMap[ClientIP]--; + } + if (mClientMap[ClientIP] == 0) { + mClientMap.erase(ClientIP); + } + mClientMapMutex.unlock(); + c.Disconnect(R); } void TNetwork::Looper(const std::weak_ptr& c) { @@ -631,7 +672,7 @@ void TNetwork::Looper(const std::weak_ptr& c) { } // end locked context // beammp_debug("sending a missed packet: " + QData); if (!TCPSend(*Client, QData, true)) { - Client->Disconnect("Failed to TCPSend while clearing the missed packet queue"); + DisconnectClient(Client, "Failed to TCPSend while clearing the missed packet queue"); std::unique_lock lock(Client->MissedPacketQueueMutex()); while (!Client->MissedPacketQueue().empty()) { Client->MissedPacketQueue().pop(); @@ -668,14 +709,14 @@ void TNetwork::TCPClient(const std::weak_ptr& c) { auto res = TCPRcv(*Client); if (res.empty()) { beammp_debug("TCPRcv empty"); - Client->Disconnect("TCPRcv failed"); + DisconnectClient(Client, "TCPRcv failed"); break; } try { mServer.GlobalParser(c, std::move(res), mPPSMonitor, *this, false); } catch (const std::exception& e) { beammp_warnf("Failed to receive/parse packet via TCP from client {}: {}", Client->GetID(), e.what()); - Client->Disconnect("Failed to parse packet"); + DisconnectClient(Client, "Failed to parse packet"); break; } } @@ -706,6 +747,34 @@ void TNetwork::UpdatePlayer(TClient& Client) { //(void)Respond(Client, Packet, true); } +boost::system::error_code TNetwork::ReadWithTimeout(TConnection& Connection, void *Buf, size_t Len, std::chrono::steady_clock::duration Timeout) +{ + io_context TimerIO; + steady_timer Timer(TimerIO); + Timer.expires_after(Timeout); + + std::atomic TimedOut = false; + + Timer.async_wait([&](const boost::system::error_code& ec) { + if (!ec) { + TimedOut = true; + Connection.Socket.cancel(); + } + }); + std::thread TimerThread([&]() { TimerIO.run(); }); + + boost::system::error_code ReadEc; + boost::asio::read(Connection.Socket, boost::asio::buffer(Buf, Len), ReadEc); + + TimerIO.stop(); + TimerThread.join(); + + if (TimedOut.load()) { + return error::timed_out; // synthesize a clean timeout error + } + return ReadEc; //Succes! +} + void TNetwork::OnDisconnect(const std::weak_ptr& ClientPtr) { std::shared_ptr LockedClientPtr { nullptr }; try { @@ -733,7 +802,7 @@ void TNetwork::OnDisconnect(const std::weak_ptr& ClientPtr) { Packet.clear(); auto Futures = LuaAPI::MP::Engine->TriggerEvent("onPlayerDisconnect", "", c.GetID()); LuaAPI::MP::Engine->WaitForAll(Futures); - c.Disconnect("Already Disconnected (OnDisconnect)"); + DisconnectClient(c, "Already Disconnected (OnDisconnect)"); mServer.RemoveClient(ClientPtr); } @@ -841,7 +910,7 @@ void TNetwork::SendFile(TClient& c, const std::string& UnsafeName) { for (auto mod : mResourceManager.GetMods()) { if (mod["file_name"].get() == FileName && mod["protected"] == true) { beammp_warn("Client tried to access protected file " + UnsafeName); - c.Disconnect("Mod is protected thus cannot be downloaded"); + DisconnectClient(c, "Mod is protected thus cannot be downloaded"); return; } } @@ -905,7 +974,7 @@ void TNetwork::SendFileToClient(TClient& c, size_t Size, const std::string& Name Data.resize(Split); else Data.resize(Size); - ip::tcp::socket* TCPSock = &c.GetTCPSock(); + boost::asio::ip::tcp::socket* TCPSock = &c.GetTCPSock(); std::streamsize Sent = 0; while (!c.IsDisconnected() && Sent < Size) { size_t Diff = Size - Sent; @@ -914,7 +983,7 @@ void TNetwork::SendFileToClient(TClient& c, size_t Size, const std::string& Name f.read(reinterpret_cast(Data.data()), Split); if (!TCPSendRaw(c, *TCPSock, Data.data(), Split)) { if (!c.IsDisconnected()) - c.Disconnect("TCPSendRaw failed in mod download (1)"); + DisconnectClient(c, "TCPSendRaw failed in mod download (1)"); break; } Sent += Split; @@ -923,7 +992,7 @@ void TNetwork::SendFileToClient(TClient& c, size_t Size, const std::string& Name f.read(reinterpret_cast(Data.data()), Diff); if (!TCPSendRaw(c, *TCPSock, Data.data(), int32_t(Diff))) { if (!c.IsDisconnected()) - c.Disconnect("TCPSendRaw failed in mod download (2)"); + DisconnectClient(c, "TCPSendRaw failed in mod download (2)"); break; } Sent += Diff; @@ -932,9 +1001,9 @@ void TNetwork::SendFileToClient(TClient& c, size_t Size, const std::string& Name #endif } -bool TNetwork::TCPSendRaw(TClient& C, ip::tcp::socket& socket, const uint8_t* Data, size_t Size) { +bool TNetwork::TCPSendRaw(TClient& C, boost::asio::ip::tcp::socket& socket, const uint8_t* Data, size_t Size) { boost::system::error_code ec; - write(socket, buffer(Data, Size), ec); + boost::asio::write(socket, boost::asio::buffer(Data, Size), ec); if (ec) { beammp_errorf("Failed to send raw data to client: {}", ec.message()); return false; @@ -1075,20 +1144,20 @@ bool TNetwork::UDPSend(TClient& Client, std::vector Data) { CompressProperly(Data); } boost::system::error_code ec; - mUDPSock.send_to(buffer(Data), Addr, 0, ec); + mUDPSock.send_to(boost::asio::buffer(Data), Addr, 0, ec); if (ec) { beammp_debugf("UDP sendto() failed: {}", ec.message()); if (!Client.IsDisconnected()) - Client.Disconnect("UDP send failed"); + DisconnectClient(Client, "UDP send failed"); return false; } return true; } -std::vector TNetwork::UDPRcvFromClient(ip::udp::endpoint& ClientEndpoint) { +std::vector TNetwork::UDPRcvFromClient(boost::asio::ip::udp::endpoint& ClientEndpoint) { std::array Ret {}; boost::system::error_code ec; - const auto Rcv = mUDPSock.receive_from(mutable_buffer(Ret.data(), Ret.size()), ClientEndpoint, 0, ec); + const auto Rcv = mUDPSock.receive_from(boost::asio::mutable_buffer(Ret.data(), Ret.size()), ClientEndpoint, 0, ec); if (ec) { beammp_errorf("UDP recvfrom() failed: {}", ec.message()); return {};