From 7d2e4d4581bc2c9fd39e56b1c7d81f4c7b2c81a4 Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Wed, 5 Oct 2022 15:44:32 +0200 Subject: [PATCH] replace tcp networking with boost::asio tcp networking --- include/Client.h | 34 ++- include/Common.h | 2 +- include/TNetwork.h | 34 +-- include/TServer.h | 19 +- src/Client.cpp | 30 +- src/LuaAPI.cpp | 12 +- src/THeartbeatThread.cpp | 2 +- src/TNetwork.cpp | 642 +++++++++++++++++---------------------- src/TServer.cpp | 66 ++-- 9 files changed, 388 insertions(+), 453 deletions(-) diff --git a/include/Client.h b/include/Client.h index fa70208..b8858d1 100644 --- a/include/Client.h +++ b/include/Client.h @@ -20,9 +20,8 @@ class TServer; #endif // WINDOWS struct TConnection final { - SOCKET Socket; - struct sockaddr SockAddr; - socklen_t SockAddrLen; + ip::tcp::socket Socket; + ip::tcp::endpoint SockAddr; }; class TClient final { @@ -34,8 +33,9 @@ public: std::unique_lock Lock; }; - explicit TClient(TServer& Server); + TClient(TServer& Server, ip::tcp::socket&& Socket); TClient(const TClient&) = delete; + ~TClient(); TClient& operator=(const TClient&) = delete; void AddNewCar(int Ident, const std::string& Data); @@ -48,16 +48,19 @@ public: std::string GetCarData(int Ident); std::string GetCarPositionRaw(int Ident); void SetUDPAddr(const ip::udp::endpoint& Addr) { mUDPAddress = Addr; } - void SetDownSock(SOCKET CSock) { mSocket[1] = CSock; } - void SetTCPSock(SOCKET CSock) { mSocket[0] = CSock; } - void SetStatus(int Status) { mStatus = Status; } + void SetDownSock(ip::tcp::socket&& CSock) { mDownSocket = std::move(CSock); } + void SetTCPSock(ip::tcp::socket&& CSock) { mSocket = std::move(CSock); } + void Disconnect(std::string_view Reason); + bool IsDisconnected() const { return !mSocket.is_open(); } // locks void DeleteCar(int Ident); [[nodiscard]] const std::unordered_map& GetIdentifiers() const { return mIdentifiers; } [[nodiscard]] const ip::udp::endpoint& GetUDPAddr() const { return mUDPAddress; } [[nodiscard]] ip::udp::endpoint& GetUDPAddr() { return mUDPAddress; } - [[nodiscard]] SOCKET GetDownSock() const { return mSocket[1]; } - [[nodiscard]] SOCKET GetTCPSock() const { return mSocket[0]; } + [[nodiscard]] ip::tcp::socket& GetDownSock() { return mDownSocket; } + [[nodiscard]] const ip::tcp::socket& GetDownSock() const { return mDownSocket; } + [[nodiscard]] ip::tcp::socket& GetTCPSock() { return mSocket; } + [[nodiscard]] const ip::tcp::socket& GetTCPSock() const { return mSocket; } [[nodiscard]] std::string GetRoles() const { return mRole; } [[nodiscard]] std::string GetName() const { return mName; } void SetUnicycleID(int ID) { mUnicycleID = ID; } @@ -65,7 +68,6 @@ public: [[nodiscard]] int GetOpenCarID() const; [[nodiscard]] int GetCarCount() const; void ClearCars(); - [[nodiscard]] int GetStatus() const { return mStatus; } [[nodiscard]] int GetID() const { return mID; } [[nodiscard]] int GetUnicycleID() const { return mUnicycleID; } [[nodiscard]] bool IsConnected() const { return mIsConnected; } @@ -75,9 +77,9 @@ public: void SetIsGuest(bool NewIsGuest) { mIsGuest = NewIsGuest; } void SetIsSynced(bool NewIsSynced) { mIsSynced = NewIsSynced; } void SetIsSyncing(bool NewIsSyncing) { mIsSyncing = NewIsSyncing; } - void EnqueuePacket(const std::string& Packet); - [[nodiscard]] std::queue& MissedPacketQueue() { return mPacketsSync; } - [[nodiscard]] const std::queue& MissedPacketQueue() const { return mPacketsSync; } + void EnqueuePacket(const std::vector& Packet); + [[nodiscard]] std::queue>& MissedPacketQueue() { return mPacketsSync; } + [[nodiscard]] const std::queue>& MissedPacketQueue() const { return mPacketsSync; } [[nodiscard]] size_t MissedPacketQueueSize() const { return mPacketsSync.size(); } [[nodiscard]] std::mutex& MissedPacketQueueMutex() const { return mMissedPacketsMutex; } void SetIsConnected(bool NewIsConnected) { mIsConnected = NewIsConnected; } @@ -93,7 +95,7 @@ private: bool mIsSynced = false; bool mIsSyncing = false; mutable std::mutex mMissedPacketsMutex; - std::queue mPacketsSync; + std::queue> mPacketsSync; std::unordered_map mIdentifiers; bool mIsGuest = false; mutable std::mutex mVehicleDataMutex; @@ -101,12 +103,12 @@ private: TSetOfVehicleData mVehicleData; SparseArray mVehiclePosition; std::string mName = "Unknown Client"; - SOCKET mSocket[2] { SOCKET(0), SOCKET(0) }; + ip::tcp::socket mSocket; + ip::tcp::socket mDownSocket; ip::udp::endpoint mUDPAddress {}; int mUnicycleID = -1; std::string mRole; std::string mDID; - int mStatus = 0; int mID = -1; std::chrono::time_point mLastPingTime; }; diff --git a/include/Common.h b/include/Common.h index 3b0f205..06b24c8 100644 --- a/include/Common.h +++ b/include/Common.h @@ -80,7 +80,7 @@ public: static TConsole& Console() { return *mConsole; } static std::string ServerVersionString(); static const Version& ServerVersion() { return mVersion; } - static std::string ClientVersionString() { return "2.0"; } + static uint8_t ClientMajorVersion() { return 2; } static std::string PPS() { return mPPS; } static void SetPPS(const std::string& NewPPS) { mPPS = NewPPS; } diff --git a/include/TNetwork.h b/include/TNetwork.h index f18f17f..3b4980e 100644 --- a/include/TNetwork.h +++ b/include/TNetwork.h @@ -13,19 +13,18 @@ class TNetwork { public: TNetwork(TServer& Server, TPPSMonitor& PPSMonitor, TResourceManager& ResourceManager); - [[nodiscard]] bool TCPSend(TClient& c, const std::string& Data, bool IsSync = false); - [[nodiscard]] bool SendLarge(TClient& c, std::string Data, bool isSync = false); - [[nodiscard]] bool Respond(TClient& c, const std::string& MSG, bool Rel, bool isSync = false); - std::shared_ptr CreateClient(SOCKET TCPSock); - std::string TCPRcv(TClient& c); + [[nodiscard]] bool TCPSend(TClient& c, const std::vector& Data, bool IsSync = false); + [[nodiscard]] bool SendLarge(TClient& c, std::vector Data, bool isSync = false); + [[nodiscard]] bool Respond(TClient& c, const std::vector& MSG, bool Rel, bool isSync = false); + std::shared_ptr CreateClient(ip::tcp::socket&& TCPSock); + std::vector TCPRcv(TClient& c); void ClientKick(TClient& c, const std::string& R); [[nodiscard]] bool SyncClient(const std::weak_ptr& c); - void Identify(const TConnection& client); - void Authentication(const TConnection& ClientConnection); - [[nodiscard]] bool CheckBytes(TClient& c, int32_t BytesRcv); + void Identify(TConnection&& client); + std::shared_ptr Authentication(TConnection&& ClientConnection); void SyncResources(TClient& c); - [[nodiscard]] bool UDPSend(TClient& Client, std::string Data); - void SendToAll(TClient* c, const std::string& Data, bool Self, bool Rel); + [[nodiscard]] bool UDPSend(TClient& Client, std::vector Data); + void SendToAll(TClient* c, const std::vector& Data, bool Self, bool Rel); void UpdatePlayer(TClient& Client); private: @@ -34,22 +33,23 @@ private: TServer& mServer; TPPSMonitor& mPPSMonitor; - io_context mIoCtx; ip::udp::socket mUDPSock; TResourceManager& mResourceManager; std::thread mUDPThread; std::thread mTCPThread; - std::string UDPRcvFromClient(ip::udp::endpoint& ClientEndpoint); - void HandleDownload(SOCKET TCPSock); + std::vector UDPRcvFromClient(ip::udp::endpoint& ClientEndpoint); + void HandleDownload(TConnection&& TCPSock); void OnConnect(const std::weak_ptr& c); void TCPClient(const std::weak_ptr& c); void Looper(const std::weak_ptr& c); int OpenID(); - void OnDisconnect(const std::weak_ptr& ClientPtr, bool kicked); - void Parse(TClient& c, const std::string& Packet); + void OnDisconnect(const std::weak_ptr& ClientPtr); + void Parse(TClient& c, const std::vector& Packet); void SendFile(TClient& c, const std::string& Name); - static bool TCPSendRaw(TClient& C, SOCKET socket, char* Data, int32_t Size); + static bool TCPSendRaw(TClient& C, ip::tcp::socket& socket, const uint8_t* Data, size_t Size); static void SplitLoad(TClient& c, size_t Sent, size_t Size, bool D, const std::string& Name); - static uint8_t* SendSplit(TClient& c, SOCKET Socket, uint8_t* DataPtr, size_t Size); + static const uint8_t* SendSplit(TClient& c, ip::tcp::socket& Socket, const uint8_t* DataPtr, size_t Size); }; + +std::vector StringToVector(const std::string& Str); diff --git a/include/TServer.h b/include/TServer.h index cf7891a..0967381 100644 --- a/include/TServer.h +++ b/include/TServer.h @@ -8,6 +8,8 @@ #include #include +#include "BoostAliases.h" + class TClient; class TNetwork; class TPPSMonitor; @@ -19,19 +21,22 @@ public: TServer(const std::vector& Arguments); void InsertClient(const std::shared_ptr& Ptr); - std::weak_ptr InsertNewClient(); void RemoveClient(const std::weak_ptr&); // in Fn, return true to continue, return false to break void ForEachClient(const std::function)>& Fn); size_t ClientCount() const; - static void GlobalParser(const std::weak_ptr& Client, std::string Packet, TPPSMonitor& PPSMonitor, TNetwork& Network); + static void GlobalParser(const std::weak_ptr& Client, std::vector&& Packet, TPPSMonitor& PPSMonitor, TNetwork& Network); static void HandleEvent(TClient& c, const std::string& Data); RWMutex& GetClientMutex() const { return mClientsMutex; } - const TScopedTimer UptimeTimer; + + // asio io context + io_context& IoCtx() { return mIoCtx; } + private: + io_context mIoCtx {}; TClientSet mClients; mutable RWMutex mClientsMutex; static void ParseVehicle(TClient& c, const std::string& Pckt, TNetwork& Network); @@ -40,3 +45,11 @@ private: static void Apply(TClient& c, int VID, const std::string& pckt); static void HandlePosition(TClient& c, const std::string& Packet); }; + +struct BufferView { + uint8_t* Data { nullptr }; + size_t Size { 0 }; + const uint8_t* data() const { return Data; } + uint8_t* data() { return Data; } + size_t size() const { return Size; } +}; diff --git a/src/Client.cpp b/src/Client.cpp index f93f825..991f0b7 100644 --- a/src/Client.cpp +++ b/src/Client.cpp @@ -2,6 +2,7 @@ #include "CustomAssert.h" #include "TServer.h" +#include #include #include @@ -49,16 +50,27 @@ TClient::TVehicleDataLockPair TClient::GetAllCars() { std::string TClient::GetCarPositionRaw(int Ident) { std::unique_lock lock(mVehiclePositionMutex); - try - { + try { return mVehiclePosition.at(Ident); - } - catch (const std::out_of_range& oor) { + } catch (const std::out_of_range& oor) { return ""; } return ""; } +void TClient::Disconnect(std::string_view Reason) { + beammp_debugf("Disconnecting client {} for reason: {}", GetID(), Reason); + boost::system::error_code ec; + mSocket.shutdown(socket_base::shutdown_both, ec); + if (ec) { + beammp_warnf("Failed to shutdown client socket: {}", ec.what()); + } + mSocket.close(ec); + if (ec) { + beammp_warnf("Failed to close client socket: {}", ec.what()); + } +} + void TClient::SetCarPosition(int Ident, const std::string& Data) { std::unique_lock lock(mVehiclePositionMutex); mVehiclePosition[Ident] = Data; @@ -98,16 +110,22 @@ TServer& TClient::Server() const { return mServer; } -void TClient::EnqueuePacket(const std::string& Packet) { +void TClient::EnqueuePacket(const std::vector& Packet) { std::unique_lock Lock(mMissedPacketsMutex); mPacketsSync.push(Packet); } -TClient::TClient(TServer& Server) +TClient::TClient(TServer& Server, ip::tcp::socket&& Socket) : mServer(Server) + , mSocket(std::move(Socket)) + , mDownSocket(ip::tcp::socket(Server.IoCtx())) , mLastPingTime(std::chrono::high_resolution_clock::now()) { } +TClient::~TClient() { + beammp_debugf("client destroyed: {} ('{}')", this->GetID(), this->GetName()); +} + void TClient::UpdatePingTime() { mLastPingTime = std::chrono::high_resolution_clock::now(); } diff --git a/src/LuaAPI.cpp b/src/LuaAPI.cpp index 30a0b86..08dec7e 100644 --- a/src/LuaAPI.cpp +++ b/src/LuaAPI.cpp @@ -116,7 +116,7 @@ TEST_CASE("LuaAPI::MP::GetServerVersion") { static inline std::pair InternalTriggerClientEvent(int PlayerID, const std::string& EventName, const std::string& Data) { std::string Packet = "E:" + EventName + ":" + Data; if (PlayerID == -1) { - LuaAPI::MP::Engine->Network().SendToAll(nullptr, Packet, true, true); + LuaAPI::MP::Engine->Network().SendToAll(nullptr, StringToVector(Packet), true, true); return { true, "" }; } else { auto MaybeClient = GetClient(LuaAPI::MP::Engine->Server(), PlayerID); @@ -125,7 +125,7 @@ static inline std::pair InternalTriggerClientEvent(int Player return { false, "Invalid Player ID" }; } auto c = MaybeClient.value().lock(); - if (!LuaAPI::MP::Engine->Network().Respond(*c, Packet, true)) { + if (!LuaAPI::MP::Engine->Network().Respond(*c, StringToVector(Packet), true)) { beammp_lua_errorf("Respond failed, dropping client {}", PlayerID); LuaAPI::MP::Engine->Network().ClientKick(*c, "Disconnected after failing to receive packets"); return { false, "Respond failed, dropping client" }; @@ -155,7 +155,7 @@ std::pair LuaAPI::MP::SendChatMessage(int ID, const std::stri std::string Packet = "C:Server: " + Message; if (ID == -1) { LogChatMessage(" (to everyone) ", -1, Message); - Engine->Network().SendToAll(nullptr, Packet, true, true); + Engine->Network().SendToAll(nullptr, StringToVector(Packet), true, true); Result.first = true; } else { auto MaybeClient = GetClient(Engine->Server(), ID); @@ -167,7 +167,7 @@ std::pair LuaAPI::MP::SendChatMessage(int ID, const std::stri return Result; } LogChatMessage(" (to \"" + c->GetName() + "\")", -1, Message); - if (!Engine->Network().Respond(*c, Packet, true)) { + if (!Engine->Network().Respond(*c, StringToVector(Packet), true)) { beammp_errorf("Failed to send chat message back to sender (id {}) - did the sender disconnect?", ID); // TODO: should we return an error here? } @@ -194,7 +194,7 @@ std::pair LuaAPI::MP::RemoveVehicle(int PID, int VID) { auto c = MaybeClient.value().lock(); if (!c->GetCarData(VID).empty()) { std::string Destroy = "Od:" + std::to_string(PID) + "-" + std::to_string(VID); - Engine->Network().SendToAll(nullptr, Destroy, true, true); + Engine->Network().SendToAll(nullptr, StringToVector(Destroy), true, true); c->DeleteCar(VID); Result.first = true; } else { @@ -526,7 +526,7 @@ static void JsonEncodeRecursive(nlohmann::json& json, const sol::object& left, c beammp_lua_error("json serialize will not go deeper than 100 nested tables, internal references assumed, aborted this path"); return; } - std::string key{}; + std::string key {}; switch (left.get_type()) { case sol::type::lua_nil: case sol::type::none: diff --git a/src/THeartbeatThread.cpp b/src/THeartbeatThread.cpp index edd0ed2..96504a8 100644 --- a/src/THeartbeatThread.cpp +++ b/src/THeartbeatThread.cpp @@ -148,7 +148,7 @@ std::string THeartbeatThread::GenerateCall() { << "&map=" << Application::Settings.MapName << "&private=" << (Application::Settings.Private ? "true" : "false") << "&version=" << Application::ServerVersionString() - << "&clientversion=" << Application::ClientVersionString() + << "&clientversion=" << Application::ClientMajorVersion() << "&name=" << Application::Settings.ServerName << "&modlist=" << mResourceManager.TrimmedList() << "&modstotalsize=" << mResourceManager.MaxModSize() diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index 5bed070..35ace21 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -1,7 +1,9 @@ #include "TNetwork.h" #include "Client.h" +#include "Common.h" #include "LuaAPI.h" #include "TLuaEngine.h" +#include "nlohmann/json.hpp" #include #include #include @@ -10,11 +12,23 @@ #include #include +std::vector StringToVector(const std::string& Str) { + return std::vector(Str.data(), Str.data() + Str.size()); +} + +static void CompressProperly(std::vector& Data) { + constexpr std::string_view ABG = "ABG:"; + auto CombinedData = std::vector(ABG.begin(), ABG.end()); + auto CompData = Comp(Data); + CombinedData.resize(ABG.size() + CompData.size()); + std::copy(CompData.begin(), CompData.end(), CombinedData.begin() + ABG.size()); + Data = CombinedData; +} + TNetwork::TNetwork(TServer& Server, TPPSMonitor& PPSMonitor, TResourceManager& ResourceManager) : mServer(Server) , mPPSMonitor(PPSMonitor) - , mIoCtx {} - , mUDPSock(mIoCtx) + , mUDPSock(Server.IoCtx()) , mResourceManager(ResourceManager) { Application::SetSubsystemStatus("TCPNetwork", Application::Status::Starting); Application::SetSubsystemStatus("UDPNetwork", Application::Status::Starting); @@ -67,9 +81,9 @@ void TNetwork::UDPServerMain() { while (!Application::IsShuttingDown()) { try { ip::udp::endpoint client {}; - std::string Data = UDPRcvFromClient(client); // Receives any data from Socket - size_t Pos = Data.find(':'); - if (Data.empty() || Pos > 2) + std::vector Data = UDPRcvFromClient(client); // Receives any data from Socket + auto Pos = std::find(Data.begin(), Data.end(), ':'); + if (Data.empty() || Pos > Data.begin() + 2) continue; uint8_t ID = uint8_t(Data.at(0)) - 1; mServer.ForEachClient([&](std::weak_ptr ClientPtr) -> bool { @@ -85,7 +99,8 @@ void TNetwork::UDPServerMain() { if (Client->GetID() == ID) { Client->SetUDPAddr(client); Client->SetIsConnected(true); - TServer::GlobalParser(ClientPtr, Data.substr(2), mPPSMonitor, *this); + Data.erase(Data.begin(), Data.begin() + 2); + TServer::GlobalParser(ClientPtr, std::move(Data), mPPSMonitor, *this); } return true; @@ -98,47 +113,30 @@ void TNetwork::UDPServerMain() { void TNetwork::TCPServerMain() { RegisterThread("TCPServer"); -#if defined(BEAMMP_WINDOWS) - WSADATA wsaData; - if (WSAStartup(514, &wsaData)) { - beammp_error("Can't start Winsock! Shutting down"); - Application::GracefullyShutdown(); + + ip::tcp::endpoint ListenEp(ip::address::from_string("0.0.0.0"), Application::Settings.Port); + ip::tcp::socket Listener(mServer.IoCtx()); + boost::system::error_code ec; + Listener.open(ListenEp.protocol(), ec); + if (ec) { + beammp_errorf("Failed to open socket: {}", ec.what()); + return; } -#endif // WINDOWS - TConnection client {}; - SOCKET Listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - if (Listener == BEAMMP_INVALID_SOCKET) { - beammp_error("Failed to create socket: " + GetPlatformAgnosticErrorString() - + ". This is a fatal error, as a socket is needed for the server to operate. Shutting down."); - Application::GracefullyShutdown(); + socket_base::linger LingerOpt {}; + LingerOpt.enabled(false); + Listener.set_option(LingerOpt, ec); + if (ec) { + beammp_errorf("Failed to set up listening socket to not linger / reuse address. " + "This may cause the socket to refuse to bind(). Error: {}", + ec.what()); } -#if defined(BEAMMP_WINDOWS) - const char optval = 0; - int ret = ::setsockopt(Listener, SOL_SOCKET, SO_DONTLINGER, &optval, sizeof(optval)); -#elif defined(BEAMMP_LINUX) || defined(BEAMMP_APPLE) - int optval = true; - int ret = ::setsockopt(Listener, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&optval), sizeof(optval)); -#endif - // not a fatal error - if (ret < 0) { - beammp_error("Failed to set up listening socket to not linger / reuse address. " - "This may cause the socket to refuse to bind(). Error: " - + GetPlatformAgnosticErrorString()); - } - sockaddr_in addr {}; - addr.sin_addr.s_addr = INADDR_ANY; - addr.sin_family = AF_INET; - addr.sin_port = htons(uint16_t(Application::Settings.Port)); - if (bind(Listener, reinterpret_cast(&addr), sizeof(addr)) < 0) { - beammp_error("bind() failed, the server cannot operate and will shut down now. " - "Error: " - + GetPlatformAgnosticErrorString()); - Application::GracefullyShutdown(); - } - if (listen(Listener, SOMAXCONN) < 0) { - beammp_error("listen() failed, which is needed for the server to operate. " - "Shutting down. Error: " - + GetPlatformAgnosticErrorString()); + + ip::tcp::acceptor Acceptor(mServer.IoCtx(), ListenEp); + Acceptor.listen(socket_base::max_listen_connections, ec); + if (ec) { + beammp_errorf("listen() failed, which is needed for the server to operate. " + "Shutting down. Error: {}", + ec.what()); Application::GracefullyShutdown(); } Application::SetSubsystemStatus("TCPNetwork", Application::Status::Good); @@ -149,39 +147,22 @@ void TNetwork::TCPServerMain() { beammp_debug("shutdown during TCP wait for accept loop"); break; } - client.SockAddrLen = sizeof(client.SockAddr); - client.Socket = accept(Listener, &client.SockAddr, &client.SockAddrLen); - if (client.Socket == -1) { - beammp_warn(("Got an invalid client socket on connect! Skipping...")); - continue; + ip::tcp::endpoint ClientEp; + ip::tcp::socket ClientSocket = Acceptor.accept(ClientEp, ec); + if (ec) { + beammp_errorf("failed to accept: {}", ec.what()); } - // set timeout - size_t SendTimeoutMS = 30 * 1000; -#if defined(BEAMMP_WINDOWS) - int ret = ::setsockopt(client.Socket, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&SendTimeoutMS), sizeof(SendTimeoutMS)); -#else // POSIX - struct timeval optval; - optval.tv_sec = int(SendTimeoutMS / 1000); - optval.tv_usec = (SendTimeoutMS % 1000) * 1000; - ret = ::setsockopt(client.Socket, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&optval), sizeof(optval)); -#endif - if (ret < 0) { - throw std::runtime_error("setsockopt recv timeout: " + GetPlatformAgnosticErrorString()); + ClientSocket.set_option(boost::asio::detail::socket_option::integer { 30 * 1000 }, ec); + if (!ec) { + beammp_errorf("failed to set send timeout on client socket: {}", ec.what()); } - std::thread ID(&TNetwork::Identify, this, client); + TConnection Conn { std::move(ClientSocket), ClientEp }; + std::thread ID(&TNetwork::Identify, this, std::move(Conn)); ID.detach(); // TODO: Add to a queue and attempt to join periodically } catch (const std::exception& e) { beammp_error("fatal: " + std::string(e.what())); } - } while (client.Socket != BEAMMP_INVALID_SOCKET); - - beammp_debug("all ok, arrived at " + std::string(__func__) + ":" + std::to_string(__LINE__)); - - CloseSocketProper(client.Socket); -#ifdef BEAMMP_WINDOWS - CloseSocketProper(client.Socket); - WSACleanup(); -#endif // WINDOWS + } while (!Application::IsShuttingDown()); } #undef GetObject // Fixes Windows @@ -189,34 +170,38 @@ void TNetwork::TCPServerMain() { #include "Json.h" namespace json = rapidjson; -void TNetwork::Identify(const TConnection& client) { +void TNetwork::Identify(TConnection&& RawConnection) { RegisterThreadAuto(); char Code; - if (recv(client.Socket, &Code, 1, 0) != 1) { - CloseSocketProper(client.Socket); + + boost::system::error_code ec; + read(RawConnection.Socket, buffer(&Code, 1), ec); + if (ec) { + // TODO: is this right?! + RawConnection.Socket.shutdown(socket_base::shutdown_both); return; } + std::shared_ptr Client { nullptr }; if (Code == 'C') { - Authentication(client); + Client = Authentication(std::move(RawConnection)); } else if (Code == 'D') { - HandleDownload(client.Socket); + HandleDownload(std::move(RawConnection)); } else if (Code == 'P') { -#if defined(BEAMMP_LINUX) || defined(BEAMMP_APPLE) - send(client.Socket, "P", 1, MSG_NOSIGNAL); -#else - send(client.Socket, "P", 1, 0); -#endif - CloseSocketProper(client.Socket); + boost::system::error_code ec; + write(RawConnection.Socket, buffer("P"), ec); return; } else { - CloseSocketProper(client.Socket); + beammp_errorf("Invalid code got in Identify: '{}'", Code); } } -void TNetwork::HandleDownload(SOCKET TCPSock) { +void TNetwork::HandleDownload(TConnection&& Conn) { char D; - if (recv(TCPSock, &D, 1, 0) != 1) { - CloseSocketProper(TCPSock); + boost::system::error_code ec; + read(Conn.Socket, buffer(&D, 1), ec); + if (ec) { + Conn.Socket.shutdown(socket_base::shutdown_both, ec); + // ignore ec return; } auto ID = uint8_t(D); @@ -225,110 +210,78 @@ void TNetwork::HandleDownload(SOCKET TCPSock) { if (!ClientPtr.expired()) { auto c = ClientPtr.lock(); if (c->GetID() == ID) { - c->SetDownSock(TCPSock); + c->SetDownSock(std::move(Conn.Socket)); } } return true; }); } -static int get_ip_str(const struct sockaddr* sa, char* strBuf, socklen_t strBufSize) { - switch (sa->sa_family) { - case AF_INET: - inet_ntop(AF_INET, &reinterpret_cast(sa)->sin_addr, strBuf, strBufSize); - break; - case AF_INET6: - inet_ntop(AF_INET6, &reinterpret_cast(sa)->sin6_addr, strBuf, strBufSize); - break; - default: - return 1; - } - return 0; -} +std::shared_ptr TNetwork::Authentication(TConnection&& RawConnection) { + auto Client = CreateClient(std::move(RawConnection.Socket)); + Client->SetIdentifier("ip", RawConnection.SockAddr.address().to_string()); + beammp_tracef("This thread is ip {}", RawConnection.SockAddr.address().to_string()); -void TNetwork::Authentication(const TConnection& ClientConnection) { - auto Client = CreateClient(ClientConnection.Socket); - char AddrBuf[INET6_ADDRSTRLEN]; - get_ip_str(&ClientConnection.SockAddr, AddrBuf, sizeof(AddrBuf)); - beammp_trace("This thread is ip " + std::string(AddrBuf)); - Client->SetIdentifier("ip", AddrBuf); - - std::string Rc; // TODO: figure out why this is not default constructed beammp_info("Identifying new ClientConnection..."); - Rc = TCPRcv(*Client); + auto Data = TCPRcv(*Client); - if (Rc.size() > 3 && Rc.substr(0, 2) == "VC") { - Rc = Rc.substr(2); - if (Rc.length() > 4 || Rc != Application::ClientVersionString()) { + constexpr std::string_view VC = "VC"; + if (Data.size() > 3 && std::equal(Data.begin(), Data.begin() + VC.size(), VC.begin(), VC.end())) { + std::string ClientVersionStr(reinterpret_cast(Data.data() + 2), Data.size() - 2); + Version ClientVersion = Application::VersionStrToInts(ClientVersionStr); + if (ClientVersion.major != Application::ClientMajorVersion()) { + beammp_errorf("Client tried to connect with version '{}', but only versions '{}.x.x' is allowed", + ClientVersion.AsString(), Application::ClientMajorVersion()); ClientKick(*Client, "Outdated Version!"); - return; + return nullptr; } } else { - ClientKick(*Client, "Invalid version header!"); - return; + ClientKick(*Client, fmt::format("Invalid version header: '{}' ({})", std::string(reinterpret_cast(Data.data()), Data.size()), Data.size())); + return nullptr; } - if (!TCPSend(*Client, "S")) { + if (!TCPSend(*Client, StringToVector("S"))) { // TODO: handle } - Rc = TCPRcv(*Client); + Data = TCPRcv(*Client); - if (Rc.size() > 50) { - ClientKick(*Client, "Invalid Key!"); - return; + if (Data.size() > 50) { + ClientKick(*Client, "Invalid Key (too long)!"); + return nullptr; } - auto RequestString = R"({"key":")" + Rc + "\"}"; - + nlohmann::json AuthReq { + { "key", std::string(reinterpret_cast(Data.data()), Data.size()) } + }; auto Target = "/pkToUser"; unsigned int ResponseCode = 0; - if (!Rc.empty()) { - Rc = Http::POST(Application::GetBackendUrlForAuth(), 443, Target, RequestString, "application/json", &ResponseCode); - } + const auto AuthResStr = Http::POST(Application::GetBackendUrlForAuth(), 443, Target, AuthReq.dump(), "application/json", &ResponseCode); - json::Document AuthResponse; - AuthResponse.Parse(Rc.c_str()); - if (Rc == Http::ErrorString || AuthResponse.HasParseError()) { + try { + nlohmann::json AuthRes = nlohmann::json::parse(AuthResStr); + + if (AuthRes["username"].is_string() && AuthRes["roles"].is_string() + && AuthRes["guest"].is_boolean() && AuthRes["identifiers"].is_array()) { + + Client->SetName(AuthRes["username"]); + Client->SetRoles(AuthRes["roles"]); + Client->SetIsGuest(AuthRes["guest"]); + for (const auto& ID : AuthRes["identifier"]) { + auto Raw = std::string(ID); + auto SepIndex = Raw.find(':'); + Client->SetIdentifier(Raw.substr(0, SepIndex), Raw.substr(SepIndex + 1)); + } + } else { + beammp_error("Invalid authentication data received from authentication backend"); + ClientKick(*Client, "Invalid authentication data!"); + return nullptr; + } + } catch (const std::exception& e) { + beammp_errorf("Client sent invalid key: {}", e.what()); + // TODO: we should really clarify that this was a backend response or parsing error ClientKick(*Client, "Invalid key! Please restart your game."); - return; - } - - if (!AuthResponse.IsObject()) { - if (Rc == "0") { - auto Lock = Sentry.CreateExclusiveContext(); - Sentry.SetContext("auth", - { { "response-body", Rc }, - { "key", RequestString } }); - Sentry.SetTransaction(Application::GetBackendUrlForAuth() + Target); - Sentry.Log(SentryLevel::Info, "default", "backend returned 0 instead of json (" + std::to_string(ResponseCode) + ")"); - } else { // Rc != "0" - ClientKick(*Client, "Backend returned invalid auth response format."); - beammp_error("Backend returned invalid auth response format. This should never happen."); - auto Lock = Sentry.CreateExclusiveContext(); - Sentry.SetContext("auth", - { { "response-body", Rc }, - { "key", RequestString } }); - Sentry.SetTransaction(Application::GetBackendUrlForAuth() + Target); - Sentry.Log(SentryLevel::Error, "default", "unexpected backend response (" + std::to_string(ResponseCode) + ")"); - } - return; - } - - if (AuthResponse["username"].IsString() && AuthResponse["roles"].IsString() - && AuthResponse["guest"].IsBool() && AuthResponse["identifiers"].IsArray()) { - - Client->SetName(AuthResponse["username"].GetString()); - Client->SetRoles(AuthResponse["roles"].GetString()); - Client->SetIsGuest(AuthResponse["guest"].GetBool()); - for (const auto& ID : AuthResponse["identifiers"].GetArray()) { - auto Raw = std::string(ID.GetString()); - auto SepIndex = Raw.find(':'); - Client->SetIdentifier(Raw.substr(0, SepIndex), Raw.substr(SepIndex + 1)); - } - } else { - ClientKick(*Client, "Invalid authentication data!"); - return; + return nullptr; } beammp_debug("Name -> " + Client->GetName() + ", Guest -> " + std::to_string(Client->IsGuest()) + ", Roles -> " + Client->GetRoles()); @@ -342,8 +295,7 @@ void TNetwork::Authentication(const TConnection& ClientConnection) { return true; } if (Cl->GetName() == Client->GetName() && Cl->IsGuest() == Client->IsGuest()) { - CloseSocketProper(Cl->GetTCPSock()); - Cl->SetStatus(-2); + Cl->Disconnect("Stale Client (not a real player)"); return false; } @@ -368,27 +320,28 @@ void TNetwork::Authentication(const TConnection& ClientConnection) { if (NotAllowed) { ClientKick(*Client, "you are not allowed on the server!"); - return; + return {}; } else if (NotAllowedWithReason) { ClientKick(*Client, Reason); - return; + return {}; } if (mServer.ClientCount() < size_t(Application::Settings.MaxPlayers)) { beammp_info("Identification success"); mServer.InsertClient(Client); TCPClient(Client); - } else + } else { ClientKick(*Client, "Server full!"); + } + return Client; } -std::shared_ptr TNetwork::CreateClient(SOCKET TCPSock) { - auto c = std::make_shared(mServer); - c->SetTCPSock(TCPSock); +std::shared_ptr TNetwork::CreateClient(ip::tcp::socket&& TCPSock) { + auto c = std::make_shared(mServer, std::move(TCPSock)); return c; } -bool TNetwork::TCPSend(TClient& c, const std::string& Data, bool IsSync) { +bool TNetwork::TCPSend(TClient& c, const std::vector& Data, bool IsSync) { if (!IsSync) { if (c.IsSyncing()) { if (!Data.empty()) { @@ -400,120 +353,101 @@ bool TNetwork::TCPSend(TClient& c, const std::string& Data, bool IsSync) { } } - int32_t Size, Sent; - std::string Send(4, 0); - Size = int32_t(Data.size()); - memcpy(&Send[0], &Size, sizeof(Size)); - Send += Data; - Sent = 0; - Size += 4; - do { -#if defined(BEAMMP_WINDOWS) - int32_t Temp = send(c.GetTCPSock(), &Send[Sent], Size - Sent, 0); -#elif defined(BEAMMP_LINUX) || defined(BEAMMP_APPLE) - int32_t Temp = send(c.GetTCPSock(), &Send[Sent], Size - Sent, MSG_NOSIGNAL); -#endif - if (Temp == 0) { - beammp_debug("send() == 0: " + GetPlatformAgnosticErrorString()); - if (c.GetStatus() > -1) - c.SetStatus(-1); - return false; - } else if (Temp < 0) { - beammp_debug("send() < 0: " + GetPlatformAgnosticErrorString()); // TODO fix it was spamming yet everyone stayed on the server - if (c.GetStatus() > -1) - c.SetStatus(-1); - CloseSocketProper(c.GetTCPSock()); - return false; - } - Sent += Temp; - c.UpdatePingTime(); - } while (Sent < Size); + auto& Sock = c.GetTCPSock(); + + /* + * our TCP protocol sends a header of 4 bytes, followed by the data. + * + * [][][][][][]...[] + * ^------^^---...-^ + * size data + */ + + const auto Size = int32_t(Data.size()); + std::vector ToSend; + ToSend.resize(Data.size() + sizeof(Size)); + std::memcpy(ToSend.data(), &Size, sizeof(Size)); + std::memcpy(ToSend.data() + sizeof(Size), Data.data(), Data.size()); + boost::system::error_code ec; + write(Sock, buffer(ToSend), ec); + if (ec) { + beammp_debugf("write(): {}", ec.what()); + c.Disconnect("write() failed"); + return false; + } + c.UpdatePingTime(); return true; } -bool TNetwork::CheckBytes(TClient& c, int32_t BytesRcv) { - if (BytesRcv == 0) { - beammp_trace("(TCP) Connection closing..."); - if (c.GetStatus() > -1) - c.SetStatus(-1); - return false; - } else if (BytesRcv < 0) { - beammp_debug("(TCP) recv() failed: " + GetPlatformAgnosticErrorString()); - if (c.GetStatus() > -1) - c.SetStatus(-1); - CloseSocketProper(c.GetTCPSock()); - return false; +std::vector TNetwork::TCPRcv(TClient& c) { + if (c.IsDisconnected()) { + beammp_error("Client disconnected, cancelling TCPRcv"); + return {}; } - return true; -} -std::string TNetwork::TCPRcv(TClient& c) { - int32_t Header {}, BytesRcv = 0, Temp {}; - if (c.GetStatus() < 0) - return ""; + int32_t Header {}; + auto& Sock = c.GetTCPSock(); - std::vector Data(sizeof(Header)); - do { - Temp = recv(c.GetTCPSock(), &Data[BytesRcv], 4 - BytesRcv, 0); - if (!CheckBytes(c, Temp)) { - return ""; - } - BytesRcv += Temp; - } while (size_t(BytesRcv) < sizeof(Header)); - memcpy(&Header, &Data[0], sizeof(Header)); - - if (!CheckBytes(c, BytesRcv)) { - return ""; + boost::system::error_code ec; + std::array HeaderData; + read(Sock, buffer(HeaderData), ec); + if (ec) { + // TODO: handle this case (read failed) + beammp_debugf("TCPRcv: Reading header failed: {}", ec.what()); + return {}; } + Header = *reinterpret_cast(HeaderData.data()); + beammp_tracef("Expecting to read {} bytes", Header); + + std::vector Data; + // TODO: This is arbitrary, this needs to be handled another way if (Header < int32_t(100 * MB)) { Data.resize(Header); } else { ClientKick(c, "Header size limit exceeded"); beammp_warn("Client " + c.GetName() + " (" + std::to_string(c.GetID()) + ") sent header of >100MB - assuming malicious intent and disconnecting the client."); - return ""; + return {}; + } + auto N = read(Sock, buffer(Data), ec); + if (ec) { + // TODO: handle this case properly + beammp_debugf("TCPRcv: Reading data failed: {}", ec.what()); + return {}; } - BytesRcv = 0; - do { - Temp = recv(c.GetTCPSock(), &Data[BytesRcv], Header - BytesRcv, 0); - if (!CheckBytes(c, Temp)) { - return ""; - } - BytesRcv += Temp; - } while (BytesRcv < Header); - std::string Ret(Data.data(), Header); - if (Ret.substr(0, 4) == "ABG:") { - Ret = DeComp(Ret.substr(4)); + if (N != Header) { + beammp_errorf("Expected to read {} bytes, instead got {}", Header, N); + } + + constexpr std::string_view ABG = "ABG:"; + if (Data.size() >= ABG.size() && std::equal(Data.begin(), Data.begin() + ABG.size(), ABG.begin(), ABG.end())) { + Data.erase(Data.begin(), Data.begin() + ABG.size()); + return DeComp(Data); + } else { + return Data; } - return Ret; } void TNetwork::ClientKick(TClient& c, const std::string& R) { beammp_info("Client kicked: " + R); - if (!TCPSend(c, "K" + R)) { + if (!TCPSend(c, StringToVector("K" + R))) { beammp_warn("tried to kick player '" + c.GetName() + "' (id " + std::to_string(c.GetID()) + "), but was already disconnected"); } - c.SetStatus(-2); - - if (c.GetTCPSock()) - CloseSocketProper(c.GetTCPSock()); - - if (c.GetDownSock()) - CloseSocketProper(c.GetDownSock()); + c.Disconnect("Kicked"); } void TNetwork::Looper(const std::weak_ptr& c) { RegisterThreadAuto(); while (!c.expired()) { auto Client = c.lock(); - if (Client->GetStatus() < 0) { - beammp_debug("client status < 0, breaking client loop"); + if (Client->IsDisconnected()) { + beammp_debug("client is disconnected, breaking client loop"); break; } if (!Client->IsSyncing() && Client->IsSynced() && Client->MissedPacketQueueSize() != 0) { // debug("sending " + std::to_string(Client->MissedPacketQueueSize()) + " queued packets"); while (Client->MissedPacketQueueSize() > 0) { - std::string QData {}; + std::vector QData {}; { // locked context std::unique_lock lock(Client->MissedPacketQueueMutex()); if (Client->MissedPacketQueueSize() <= 0) { @@ -524,15 +458,15 @@ void TNetwork::Looper(const std::weak_ptr& c) { } // end locked context // beammp_debug("sending a missed packet: " + QData); if (!TCPSend(*Client, QData, true)) { - if (Client->GetStatus() > -1) - Client->SetStatus(-1); + if (!Client->IsDisconnected()) + Client->Disconnect("Failed to TCPSend while clearing the missed packet queue"); { std::unique_lock lock(Client->MissedPacketQueueMutex()); while (!Client->MissedPacketQueue().empty()) { Client->MissedPacketQueue().pop(); } } - CloseSocketProper(Client->GetTCPSock()); + Client->Disconnect("WHY THE FUCK NOT"); break; } } @@ -544,7 +478,7 @@ void TNetwork::Looper(const std::weak_ptr& c) { void TNetwork::TCPClient(const std::weak_ptr& c) { // TODO: the c.expired() might cause issues here, remove if you end up here with your debugger - if (c.expired() || c.lock()->GetTCPSock() == -1) { + if (c.expired() || !c.lock()->GetTCPSock().is_open()) { mServer.RemoveClient(c); return; } @@ -557,24 +491,23 @@ void TNetwork::TCPClient(const std::weak_ptr& c) { if (c.expired()) break; auto Client = c.lock(); - if (Client->GetStatus() < 0) { + if (Client->IsDisconnected()) { beammp_debug("client status < 0, breaking client loop"); break; } auto res = TCPRcv(*Client); - if (res == "") { - beammp_debug("TCPRcv error, break client loop"); - break; + if (res.empty()) { + beammp_debug("TCPRcv empty, ignoring"); } - TServer::GlobalParser(c, res, mPPSMonitor, *this); + TServer::GlobalParser(c, std::move(res), mPPSMonitor, *this); } if (QueueSync.joinable()) QueueSync.join(); if (!c.expired()) { auto Client = c.lock(); - OnDisconnect(c, Client->GetStatus() == -2); + OnDisconnect(c); } else { beammp_warn("client expired in TCPClient, should never happen"); } @@ -591,11 +524,11 @@ void TNetwork::UpdatePlayer(TClient& Client) { return true; }); Packet = Packet.substr(0, Packet.length() - 1); - Client.EnqueuePacket(Packet); + Client.EnqueuePacket(StringToVector(Packet)); //(void)Respond(Client, Packet, true); } -void TNetwork::OnDisconnect(const std::weak_ptr& ClientPtr, bool kicked) { +void TNetwork::OnDisconnect(const std::weak_ptr& ClientPtr) { beammp_assert(!ClientPtr.expired()); auto LockedClientPtr = ClientPtr.lock(); TClient& c = *LockedClientPtr; @@ -608,20 +541,14 @@ void TNetwork::OnDisconnect(const std::weak_ptr& ClientPtr, bool kicked } // End Vehicle Data Lock Scope for (auto& v : VehicleData) { Packet = "Od:" + std::to_string(c.GetID()) + "-" + std::to_string(v.ID()); - SendToAll(&c, Packet, false, true); + SendToAll(&c, StringToVector(Packet), false, true); } - if (kicked) - Packet = ("L") + c.GetName() + (" was kicked!"); - else - Packet = ("L") + c.GetName() + (" left the server!"); - SendToAll(&c, Packet, false, true); + Packet = ("L") + c.GetName() + (" left the server!"); + SendToAll(&c, StringToVector(Packet), false, true); Packet.clear(); auto Futures = LuaAPI::MP::Engine->TriggerEvent("onPlayerDisconnect", "", c.GetID()); LuaAPI::MP::Engine->ReportErrors(Futures); - if (c.GetTCPSock()) - CloseSocketProper(c.GetTCPSock()); - if (c.GetDownSock()) - CloseSocketProper(c.GetDownSock()); + c.Disconnect("Already Disconnected (OnDisconnect)"); mServer.RemoveClient(ClientPtr); } @@ -653,44 +580,39 @@ void TNetwork::OnConnect(const std::weak_ptr& c) { beammp_info("Assigned ID " + std::to_string(LockedClient->GetID()) + " to " + LockedClient->GetName()); LuaAPI::MP::Engine->ReportErrors(LuaAPI::MP::Engine->TriggerEvent("onPlayerConnecting", "", LockedClient->GetID())); SyncResources(*LockedClient); - if (LockedClient->GetStatus() < 0) + if (LockedClient->IsDisconnected()) return; - (void)Respond(*LockedClient, "M" + Application::Settings.MapName, true); // Send the Map on connect + (void)Respond(*LockedClient, StringToVector("M" + Application::Settings.MapName), true); // Send the Map on connect beammp_info(LockedClient->GetName() + " : Connected"); LuaAPI::MP::Engine->ReportErrors(LuaAPI::MP::Engine->TriggerEvent("onPlayerJoining", "", LockedClient->GetID())); } void TNetwork::SyncResources(TClient& c) { -#ifndef DEBUG - try { -#endif - if (!TCPSend(c, "P" + std::to_string(c.GetID()))) { - // TODO handle - } - std::string Data; - while (c.GetStatus() > -1) { - Data = TCPRcv(c); - if (Data == "Done") - break; - Parse(c, Data); - } -#ifndef DEBUG - } catch (std::exception& e) { - beammp_error("Exception! : " + std::string(e.what())); - c.SetStatus(-1); + if (!TCPSend(c, StringToVector("P" + std::to_string(c.GetID())))) { + // TODO handle + } + std::vector Data; + while (!c.IsDisconnected()) { + Data = TCPRcv(c); + if (Data.empty()) { + break; + } + constexpr std::string_view Done = "Done"; + if (std::equal(Data.begin(), Data.end(), Done.begin(), Done.end())) + break; + Parse(c, Data); } -#endif } -void TNetwork::Parse(TClient& c, const std::string& Packet) { +void TNetwork::Parse(TClient& c, const std::vector& Packet) { if (Packet.empty()) return; char Code = Packet.at(0), SubCode = 0; - if (Packet.length() > 1) + if (Packet.size() > 1) SubCode = Packet.at(1); switch (Code) { case 'f': - SendFile(c, Packet.substr(1)); + SendFile(c, std::string(reinterpret_cast(Packet.data() + 1), Packet.size() - 1)); return; case 'S': if (SubCode == 'R') { @@ -698,7 +620,7 @@ void TNetwork::Parse(TClient& c, const std::string& Packet) { std::string ToSend = mResourceManager.FileList() + mResourceManager.FileSizes(); if (ToSend.empty()) ToSend = "-"; - if (!TCPSend(c, ToSend)) { + if (!TCPSend(c, StringToVector(ToSend))) { // TODO: error } } @@ -712,7 +634,7 @@ void TNetwork::SendFile(TClient& c, const std::string& UnsafeName) { beammp_info(c.GetName() + " requesting : " + UnsafeName.substr(UnsafeName.find_last_of('/'))); if (!fs::path(UnsafeName).has_filename()) { - if (!TCPSend(c, "CO")) { + if (!TCPSend(c, StringToVector("CO"))) { // TODO: handle } beammp_warn("File " + UnsafeName + " is not a file!"); @@ -722,28 +644,28 @@ void TNetwork::SendFile(TClient& c, const std::string& UnsafeName) { FileName = Application::Settings.Resource + "/Client/" + FileName; if (!std::filesystem::exists(FileName)) { - if (!TCPSend(c, "CO")) { + if (!TCPSend(c, StringToVector("CO"))) { // TODO: handle } beammp_warn("File " + UnsafeName + " could not be accessed!"); return; } - if (!TCPSend(c, "AG")) { + if (!TCPSend(c, StringToVector("AG"))) { // TODO: handle } /// Wait for connections int T = 0; - while (c.GetDownSock() < 1 && T < 50) { + while (!c.GetDownSock().is_open() && T < 50) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); T++; } - if (c.GetDownSock() < 1) { + if (!c.GetDownSock().is_open()) { beammp_error("Client doesn't have a download socket!"); - if (c.GetStatus() > -1) - c.SetStatus(-1); + if (!c.IsDisconnected()) + c.Disconnect("Missing download socket"); return; } @@ -807,8 +729,8 @@ TEST_CASE("SplitIntoChunks") { CHECK((Count * ChunkSize) + LastSize == FullSize); } -uint8_t* /* end ptr */ TNetwork::SendSplit(TClient& c, SOCKET Socket, uint8_t* DataPtr, size_t Size) { - if (TCPSendRaw(c, Socket, reinterpret_cast(DataPtr), Size)) { +const uint8_t* /* end ptr */ TNetwork::SendSplit(TClient& c, ip::tcp::socket& Socket, const uint8_t* DataPtr, size_t Size) { + if (TCPSendRaw(c, Socket, DataPtr, Size)) { return DataPtr + Size; } else { return nullptr; @@ -823,29 +745,28 @@ void TNetwork::SplitLoad(TClient& c, size_t Sent, size_t Size, bool D, const std Data.resize(Split); else Data.resize(Size); - SOCKET TCPSock; + ip::tcp::socket* TCPSock { nullptr }; if (D) - TCPSock = c.GetDownSock(); + TCPSock = &c.GetDownSock(); else - TCPSock = c.GetTCPSock(); - beammp_debug("Split load Socket " + std::to_string(TCPSock)); - while (c.GetStatus() > -1 && Sent < Size) { + TCPSock = &c.GetTCPSock(); + while (!c.IsDisconnected() && Sent < Size) { size_t Diff = Size - Sent; if (Diff > Split) { f.seekg(Sent, std::ios_base::beg); f.read(reinterpret_cast(Data.data()), Split); - if (!TCPSendRaw(c, TCPSock, reinterpret_cast(Data.data()), Split)) { - if (c.GetStatus() > -1) - c.SetStatus(-1); + if (!TCPSendRaw(c, *TCPSock, Data.data(), Split)) { + if (!c.IsDisconnected()) + c.Disconnect("TCPSendRaw failed in mod download (1)"); break; } Sent += Split; } else { f.seekg(Sent, std::ios_base::beg); f.read(reinterpret_cast(Data.data()), Diff); - if (!TCPSendRaw(c, TCPSock, reinterpret_cast(Data.data()), int32_t(Diff))) { - if (c.GetStatus() > -1) - c.SetStatus(-1); + if (!TCPSendRaw(c, *TCPSock, Data.data(), int32_t(Diff))) { + if (!c.IsDisconnected()) + c.Disconnect("TCPSendRaw failed in mod download (2)"); break; } Sent += Diff; @@ -853,37 +774,28 @@ void TNetwork::SplitLoad(TClient& c, size_t Sent, size_t Size, bool D, const std } } -bool TNetwork::TCPSendRaw(TClient& C, SOCKET socket, char* Data, int32_t Size) { - intmax_t Sent = 0; - do { -#if defined(BEAMMP_LINUX) || defined(BEAMMP_APPLE) - intmax_t Temp = send(socket, &Data[Sent], int(Size - Sent), MSG_NOSIGNAL); -#else - intmax_t Temp = send(socket, &Data[Sent], int(Size - Sent), 0); -#endif - if (Temp < 1) { - beammp_info("Socket Closed! " + std::to_string(socket)); - CloseSocketProper(socket); - return false; - } - Sent += Temp; - C.UpdatePingTime(); - } while (Sent < Size); +bool TNetwork::TCPSendRaw(TClient& C, ip::tcp::socket& socket, const uint8_t* Data, size_t Size) { + boost::system::error_code ec; + write(socket, buffer(Data, Size), ec); + if (ec) { + beammp_errorf("Failed to send raw data to client: {}", ec.what()); + return false; + } + C.UpdatePingTime(); return true; } -bool TNetwork::SendLarge(TClient& c, std::string Data, bool isSync) { - if (Data.length() > 400) { - std::string CMP(Comp(Data)); - Data = "ABG:" + CMP; +bool TNetwork::SendLarge(TClient& c, std::vector Data, bool isSync) { + if (Data.size() > 400) { + CompressProperly(Data); } return TCPSend(c, Data, isSync); } -bool TNetwork::Respond(TClient& c, const std::string& MSG, bool Rel, bool isSync) { +bool TNetwork::Respond(TClient& c, const std::vector& MSG, bool Rel, bool isSync) { char C = MSG.at(0); if (Rel || C == 'W' || C == 'Y' || C == 'V' || C == 'E') { - if (C == 'O' || C == 'T' || MSG.length() > 1000) { + if (C == 'O' || C == 'T' || MSG.size() > 1000) { return SendLarge(c, MSG, isSync); } else { return TCPSend(c, MSG, isSync); @@ -902,11 +814,11 @@ bool TNetwork::SyncClient(const std::weak_ptr& c) { return true; // Syncing, later set isSynced // after syncing is done, we apply all packets they missed - if (!Respond(*LockedClient, ("Sn") + LockedClient->GetName(), true)) { + if (!Respond(*LockedClient, StringToVector("Sn" + LockedClient->GetName()), true)) { return false; } // ignore error - (void)SendToAll(LockedClient.get(), ("JWelcome ") + LockedClient->GetName() + "!", false, true); + (void)SendToAll(LockedClient.get(), StringToVector("JWelcome " + LockedClient->GetName() + "!"), false, true); LuaAPI::MP::Engine->ReportErrors(LuaAPI::MP::Engine->TriggerEvent("onPlayerJoin", "", LockedClient->GetID())); LockedClient->SetIsSyncing(true); @@ -928,12 +840,12 @@ bool TNetwork::SyncClient(const std::weak_ptr& c) { } // End Vehicle Data Lock Scope if (client != LockedClient) { for (auto& v : VehicleData) { - if (LockedClient->GetStatus() < 0) { + if (LockedClient->IsDisconnected()) { Return = true; res = false; return false; } - res = Respond(*LockedClient, v.Data(), true, true); + res = Respond(*LockedClient, StringToVector(v.Data()), true, true); } } @@ -948,7 +860,7 @@ bool TNetwork::SyncClient(const std::weak_ptr& c) { return true; } -void TNetwork::SendToAll(TClient* c, const std::string& Data, bool Self, bool Rel) { +void TNetwork::SendToAll(TClient* c, const std::vector& Data, bool Self, bool Rel) { if (!Self) beammp_assert(c); char C = Data.at(0); @@ -965,10 +877,11 @@ void TNetwork::SendToAll(TClient* c, const std::string& Data, bool Self, bool Re if (Self || Client.get() != c) { if (Client->IsSynced() || Client->IsSyncing()) { if (Rel || C == 'W' || C == 'Y' || C == 'V' || C == 'E') { - if (C == 'O' || C == 'T' || Data.length() > 1000) { - if (Data.length() > 400) { - std::string CMP(Comp(Data)); - Client->EnqueuePacket("ABG:" + CMP); + if (C == 'O' || C == 'T' || Data.size() > 1000) { + if (Data.size() > 400) { + auto CompressedData = Data; + CompressProperly(CompressedData); + Client->EnqueuePacket(CompressedData); } else { Client->EnqueuePacket(Data); } @@ -990,8 +903,8 @@ void TNetwork::SendToAll(TClient* c, const std::string& Data, bool Self, bool Re return; } -bool TNetwork::UDPSend(TClient& Client, std::string Data) { - if (!Client.IsConnected() || Client.GetStatus() < 0) { +bool TNetwork::UDPSend(TClient& Client, std::vector Data) { + if (!Client.IsConnected() || Client.IsDisconnected()) { // this can happen if we try to send a packet to a client that is either // 1. not yet fully connected, or // 2. disconnected and not yet fully removed @@ -999,22 +912,21 @@ bool TNetwork::UDPSend(TClient& Client, std::string Data) { return true; } const auto Addr = Client.GetUDPAddr(); - if (Data.length() > 400) { - std::string CMP(Comp(Data)); - Data = "ABG:" + CMP; + if (Data.size() > 400) { + CompressProperly(Data); } boost::system::error_code ec; mUDPSock.send_to(buffer(Data), Addr, 0, ec); if (ec) { beammp_debugf("UDP sendto() failed: {}", ec.what()); - if (Client.GetStatus() > -1) - Client.SetStatus(-1); + if (!Client.IsDisconnected()) + Client.Disconnect("UDP send failed"); return false; } return true; } -std::string TNetwork::UDPRcvFromClient(ip::udp::endpoint& ClientEndpoint) { +std::vector TNetwork::UDPRcvFromClient(ip::udp::endpoint& ClientEndpoint) { std::array Ret {}; boost::system::error_code ec; beammp_debugf("receiving data from {}:{}", ClientEndpoint.address().to_string(), ClientEndpoint.port()); @@ -1022,8 +934,8 @@ std::string TNetwork::UDPRcvFromClient(ip::udp::endpoint& ClientEndpoint) { beammp_debugf("received {} bytes from {}:{}", Rcv, ClientEndpoint.address().to_string(), ClientEndpoint.port()); if (ec) { beammp_errorf("UDP recvfrom() failed: {}", ec.what()); - return ""; + return {}; } // FIXME: This breaks binary data due to \0. - return std::string(Ret.begin(), Ret.begin() + Rcv); + return std::vector(Ret.begin(), Ret.end()); } diff --git a/src/TServer.cpp b/src/TServer.cpp index 8cba141..690ec50 100644 --- a/src/TServer.cpp +++ b/src/TServer.cpp @@ -4,6 +4,7 @@ #include "TNetwork.h" #include "TPPSMonitor.h" #include +#include #include #include @@ -102,13 +103,6 @@ void TServer::RemoveClient(const std::weak_ptr& WeakClientPtr) { } } -std::weak_ptr TServer::InsertNewClient() { - beammp_debug("inserting new client (" + std::to_string(ClientCount()) + ")"); - WriteLock Lock(mClientsMutex); - auto [Iter, Replaced] = mClients.insert(std::make_shared(*this)); - return *Iter; -} - void TServer::ForEachClient(const std::function)>& Fn) { decltype(mClients) Clients; { @@ -127,12 +121,11 @@ size_t TServer::ClientCount() const { return mClients.size(); } -void TServer::GlobalParser(const std::weak_ptr& Client, std::string Packet, TPPSMonitor& PPSMonitor, TNetwork& Network) { - if (Packet.find("Zp") != std::string::npos && Packet.size() > 500) { - // abort(); - } - if (Packet.substr(0, 4) == "ABG:") { - Packet = DeComp(Packet.substr(4)); +void TServer::GlobalParser(const std::weak_ptr& Client, std::vector&& Packet, TPPSMonitor& PPSMonitor, TNetwork& Network) { + constexpr std::string_view ABG = "ABG:"; + if (Packet.size() >= ABG.size() && std::equal(Packet.begin(), Packet.begin() + ABG.size(), ABG.begin(), ABG.end())) { + Packet.erase(Packet.begin(), Packet.begin() + ABG.size()); + Packet = DeComp(Packet); } if (Packet.empty()) { return; @@ -146,6 +139,8 @@ void TServer::GlobalParser(const std::weak_ptr& Client, std::string Pac std::any Res; char Code = Packet.at(0); + std::string StringPacket(reinterpret_cast(Packet.data()), Packet.size()); + // V to Y if (Code <= 89 && Code >= 86) { PPSMonitor.IncrementInternalPPS(); @@ -154,38 +149,34 @@ void TServer::GlobalParser(const std::weak_ptr& Client, std::string Pac } switch (Code) { case 'H': // initial connection - beammp_trace(std::string("got 'H' packet: '") + Packet + "' (" + std::to_string(Packet.size()) + ")"); if (!Network.SyncClient(Client)) { // TODO handle } return; case 'p': - if (!Network.Respond(*LockedClient, ("p"), false)) { + if (!Network.Respond(*LockedClient, StringToVector("p"), false)) { // failed to send - if (LockedClient->GetStatus() > -1) { - LockedClient->SetStatus(-1); - } + LockedClient->Disconnect("Failed to send ping"); } else { Network.UpdatePlayer(*LockedClient); } return; case 'O': - if (Packet.length() > 1000) { - beammp_debug(("Received data from: ") + LockedClient->GetName() + (" Size: ") + std::to_string(Packet.length())); + if (Packet.size() > 1000) { + beammp_debug(("Received data from: ") + LockedClient->GetName() + (" Size: ") + std::to_string(Packet.size())); } - ParseVehicle(*LockedClient, Packet, Network); + ParseVehicle(*LockedClient, StringPacket, Network); return; case 'J': - beammp_trace(std::string(("got 'J' packet: '")) + Packet + ("' (") + std::to_string(Packet.size()) + (")")); Network.SendToAll(LockedClient.get(), Packet, false, true); return; case 'C': { - beammp_trace(std::string(("got 'C' packet: '")) + Packet + ("' (") + std::to_string(Packet.size()) + (")")); - if (Packet.length() < 4 || Packet.find(':', 3) == std::string::npos) + if (Packet.size() < 4 || std::find(Packet.begin() + 3, Packet.end(), ':') == Packet.end()) break; - auto Futures = LuaAPI::MP::Engine->TriggerEvent("onChatMessage", "", LockedClient->GetID(), LockedClient->GetName(), Packet.substr(Packet.find(':', 3) + 2)); + const auto PacketAsString = std::string(reinterpret_cast(Packet.data()), Packet.size()); + auto Futures = LuaAPI::MP::Engine->TriggerEvent("onChatMessage", "", LockedClient->GetID(), LockedClient->GetName(), PacketAsString.substr(PacketAsString.find(':', 3) + 2)); TLuaEngine::WaitForAll(Futures); - LogChatMessage(LockedClient->GetName(), LockedClient->GetID(), Packet.substr(Packet.find(':', 3) + 1)); + LogChatMessage(LockedClient->GetName(), LockedClient->GetID(), PacketAsString.substr(PacketAsString.find(':', 3) + 1)); if (std::any_of(Futures.begin(), Futures.end(), [](const std::shared_ptr& Elem) { return !Elem->Error @@ -198,8 +189,7 @@ void TServer::GlobalParser(const std::weak_ptr& Client, std::string Pac return; } case 'E': - beammp_trace(std::string(("got 'E' packet: '")) + Packet + ("' (") + std::to_string(Packet.size()) + (")")); - HandleEvent(*LockedClient, Packet); + HandleEvent(*LockedClient, StringPacket); return; case 'N': beammp_trace("got 'N' packet (" + std::to_string(Packet.size()) + ")"); @@ -209,7 +199,7 @@ void TServer::GlobalParser(const std::weak_ptr& Client, std::string Pac PPSMonitor.IncrementInternalPPS(); Network.SendToAll(LockedClient.get(), Packet, false, false); - HandlePosition(*LockedClient, Packet); + HandlePosition(*LockedClient, StringPacket); default: return; } @@ -275,13 +265,13 @@ void TServer::ParseVehicle(TClient& c, const std::string& Pckt, TNetwork& Networ if (ShouldSpawn(c, CarJson, CarID) && !ShouldntSpawn) { c.AddNewCar(CarID, Packet); - Network.SendToAll(nullptr, Packet, true, true); + Network.SendToAll(nullptr, StringToVector(Packet), true, true); } else { - if (!Network.Respond(c, Packet, true)) { + if (!Network.Respond(c, StringToVector(Packet), true)) { // TODO: handle } std::string Destroy = "Od:" + std::to_string(c.GetID()) + "-" + std::to_string(CarID); - if (!Network.Respond(c, Destroy, true)) { + if (!Network.Respond(c, StringToVector(Destroy), true)) { // TODO: handle } beammp_debugf("{} (force : car limit/lua) removed ID {}", c.GetName(), CarID); @@ -306,14 +296,14 @@ void TServer::ParseVehicle(TClient& c, const std::string& Pckt, TNetwork& Networ FoundPos = FoundPos == std::string::npos ? 0 : FoundPos; // attempt at sanitizing this if ((c.GetUnicycleID() != VID || IsUnicycle(c, Packet.substr(FoundPos))) && !ShouldntAllow) { - Network.SendToAll(&c, Packet, false, true); + Network.SendToAll(&c, StringToVector(Packet), false, true); Apply(c, VID, Packet); } else { if (c.GetUnicycleID() == VID) { c.SetUnicycleID(-1); } std::string Destroy = "Od:" + std::to_string(c.GetID()) + "-" + std::to_string(VID); - Network.SendToAll(nullptr, Destroy, true, true); + Network.SendToAll(nullptr, StringToVector(Destroy), true, true); c.DeleteCar(VID); } } @@ -329,7 +319,7 @@ void TServer::ParseVehicle(TClient& c, const std::string& Pckt, TNetwork& Networ if (c.GetUnicycleID() == VID) { c.SetUnicycleID(-1); } - Network.SendToAll(nullptr, Packet, true, true); + Network.SendToAll(nullptr, StringToVector(Packet), true, true); // TODO: should this trigger on all vehicle deletions? LuaAPI::MP::Engine->ReportErrors(LuaAPI::MP::Engine->TriggerEvent("onVehicleDeleted", "", c.GetID(), VID)); c.DeleteCar(VID); @@ -347,16 +337,16 @@ void TServer::ParseVehicle(TClient& c, const std::string& Pckt, TNetwork& Networ if (PID != -1 && VID != -1 && PID == c.GetID()) { Data = Data.substr(Data.find('{')); LuaAPI::MP::Engine->ReportErrors(LuaAPI::MP::Engine->TriggerEvent("onVehicleReset", "", c.GetID(), VID, Data)); - Network.SendToAll(&c, Packet, false, true); + Network.SendToAll(&c, StringToVector(Packet), false, true); } return; } case 't': beammp_trace(std::string(("got 'Ot' packet: '")) + Packet + ("' (") + std::to_string(Packet.size()) + (")")); - Network.SendToAll(&c, Packet, false, true); + Network.SendToAll(&c, StringToVector(Packet), false, true); return; case 'm': - Network.SendToAll(&c, Packet, true, true); + Network.SendToAll(&c, StringToVector(Packet), true, true); return; default: beammp_trace(std::string(("possibly not implemented: '") + Packet + ("' (") + std::to_string(Packet.size()) + (")")));