From 26383d534669baa3f3f61318b848bc38dc53cf5e Mon Sep 17 00:00:00 2001 From: Lion Kortlepel Date: Sun, 8 Nov 2020 02:50:17 +0100 Subject: [PATCH] Add lots of memory safety to client interface --- include/Client.hpp | 19 +++++++++++++------ src/Init/Heartbeat.cpp | 2 +- src/Lua/LuaSystem.cpp | 9 +++++---- src/Network/Auth.cpp | 9 +++++---- src/Network/GParser.cpp | 4 ++-- src/Network/InitClient.cpp | 14 +++++++------- src/Network/StatMonitor.cpp | 2 +- src/Network/VehicleData.cpp | 8 ++++---- 8 files changed, 38 insertions(+), 29 deletions(-) diff --git a/include/Client.hpp b/include/Client.hpp index 0cf733d..120059b 100644 --- a/include/Client.hpp +++ b/include/Client.hpp @@ -14,6 +14,7 @@ #include #include #include +#include struct VData{ int ID = -1; @@ -57,17 +58,23 @@ public: int GetID(); }; struct ClientInterface{ - std::set Clients; - void RemoveClient(Client *c){ + std::set> Clients; + void RemoveClient(Client*& c){ Assert(c); c->ClearCars(); - Clients.erase(c); - delete c; + auto Iter = std::find_if(Clients.begin(), Clients.end(), [&](auto& ptr) { + return c == ptr.get(); + }); + Assert(Iter != Clients.end()); + if (Iter == Clients.end()) { + return; + } + Clients.erase(Iter); c = nullptr; } - void AddClient(Client *c){ + void AddClient(Client*&& c){ Assert(c); - Clients.insert(c); + Clients.insert(std::move(std::unique_ptr(c))); } int Size(){ return int(Clients.size()); diff --git a/src/Init/Heartbeat.cpp b/src/Init/Heartbeat.cpp index 763dd47..46bca7d 100644 --- a/src/Init/Heartbeat.cpp +++ b/src/Init/Heartbeat.cpp @@ -14,7 +14,7 @@ void WebsocketInit(); std::string GetPlayers(){ std::string Return; - for(Client* c : CI->Clients){ + for(auto& c : CI->Clients){ if(c != nullptr){ Return += c->GetName() + ";"; } diff --git a/src/Lua/LuaSystem.cpp b/src/Lua/LuaSystem.cpp index d1b8fa9..cef2737 100644 --- a/src/Lua/LuaSystem.cpp +++ b/src/Lua/LuaSystem.cpp @@ -240,9 +240,10 @@ int lua_Sleep(lua_State* L) { return 1; } Client* GetClient(int ID) { - for (Client* c : CI->Clients) { - if (c != nullptr && c->GetID() == ID) - return c; + for (auto& c : CI->Clients) { + if (c != nullptr && c->GetID() == ID) { + return c.get(); + } } return nullptr; } @@ -295,7 +296,7 @@ int lua_GetDID(lua_State* L) { int lua_GetAllPlayers(lua_State* L) { lua_newtable(L); int i = 1; - for (Client* c : CI->Clients) { + for (auto& c : CI->Clients) { if (c == nullptr) continue; lua_pushinteger(L, c->GetID()); diff --git a/src/Network/Auth.cpp b/src/Network/Auth.cpp index cdcb7da..b87722f 100644 --- a/src/Network/Auth.cpp +++ b/src/Network/Auth.cpp @@ -81,7 +81,7 @@ void Check(Hold* S){ } int Max(){ int M = MaxPlayers; - for(Client*c : CI->Clients){ + for(auto& c : CI->Clients){ if(c != nullptr){ if(c->GetRole() == Sec("MDEV"))M++; } @@ -94,8 +94,9 @@ void CreateClient(SOCKET TCPSock,const std::string &Name, const std::string &DID c->SetName(Name); c->SetRole(Role); c->SetDID(DID); - CI->AddClient(c); - InitClient(c); + Client& Client = *c; + CI->AddClient(std::move(c)); + InitClient(&Client); } std::pair Parse(const std::string& msg){ std::stringstream ss(msg); @@ -175,7 +176,7 @@ void Identification(SOCKET TCPSock,Hold*S,RSA*Skey){ } DebugPrintTIDInternal(std::string("Client(") + Name + ")"); debug(Sec("Name -> ") + Name + Sec(", Role -> ") + Role + Sec(", ID -> ") + DID); - for(Client*c: CI->Clients){ + for(auto& c : CI->Clients){ if(c != nullptr){ if(c->GetDID() == DID){ error("died on " + std::string(__func__) + ":" + std::to_string(__LINE__)); diff --git a/src/Network/GParser.cpp b/src/Network/GParser.cpp index fe566fb..9191735 100644 --- a/src/Network/GParser.cpp +++ b/src/Network/GParser.cpp @@ -125,9 +125,9 @@ void SyncClient(Client*c){ Respond(c,Sec("Sn")+c->GetName(),true); SendToAll(c,Sec("JWelcome ")+c->GetName()+"!",false,true); TriggerLuaEvent(Sec("onPlayerJoin"),false,nullptr,std::unique_ptr(new LuaArg{{c->GetID()}}),false); - for (Client*client : CI->Clients) { + for (auto& client : CI->Clients) { if(client != nullptr){ - if (client != c) { + if (client.get() != c) { for (VData *v : client->GetAllCars()) { if(v != nullptr){ Respond(c, v->Data, true); diff --git a/src/Network/InitClient.cpp b/src/Network/InitClient.cpp index 8fda185..8e0bdbe 100644 --- a/src/Network/InitClient.cpp +++ b/src/Network/InitClient.cpp @@ -12,7 +12,7 @@ int OpenID(){ bool found; do { found = true; - for (Client *c : CI->Clients){ + for (auto& c : CI->Clients){ if(c != nullptr){ if(c->GetID() == ID){ found = false; @@ -35,15 +35,15 @@ void Respond(Client*c, const std::string& MSG, bool Rel){ void SendToAll(Client*c, const std::string& Data, bool Self, bool Rel){ if (!Self)Assert(c); char C = Data.at(0); - for(Client*client : CI->Clients){ + for(auto& client : CI->Clients){ if(client != nullptr) { - if (Self || client != c) { + if (Self || client.get() != c) { if (client->isSynced) { if (Rel || C == 'W' || C == 'Y' || C == 'V' || C == 'E') { if (C == 'O' || C == 'T' || - Data.length() > 1000)SendLarge(client, Data); - else TCPSend(client, Data); - } else UDPSend(client, Data); + Data.length() > 1000)SendLarge(client.get(), Data); + else TCPSend(client.get(), Data); + } else UDPSend(client.get(), Data); } } } @@ -51,7 +51,7 @@ void SendToAll(Client*c, const std::string& Data, bool Self, bool Rel){ } void UpdatePlayers(){ std::string Packet = Sec("Ss") + std::to_string(CI->Size())+"/"+std::to_string(MaxPlayers) + ":"; - for (Client*c : CI->Clients) { + for (auto& c : CI->Clients) { if(c != nullptr)Packet += c->GetName() + ","; } Packet = Packet.substr(0,Packet.length()-1); diff --git a/src/Network/StatMonitor.cpp b/src/Network/StatMonitor.cpp index 9a2db80..7efa8ed 100644 --- a/src/Network/StatMonitor.cpp +++ b/src/Network/StatMonitor.cpp @@ -14,7 +14,7 @@ void Monitor() { StatReport = "-"; return; } - for (Client *c : CI->Clients) { + for (auto& c : CI->Clients) { if (c != nullptr && c->GetCarCount() > 0) { C++; V += c->GetCarCount(); diff --git a/src/Network/VehicleData.cpp b/src/Network/VehicleData.cpp index 1662024..e5d5699 100644 --- a/src/Network/VehicleData.cpp +++ b/src/Network/VehicleData.cpp @@ -311,11 +311,11 @@ void LOOP() { ZeroMemory(clientIp, 256); ///Code to get IP we don't need that yet inet_ntop(AF_INET, &client.sin_addr, clientIp, 256);*/ uint8_t ID = Data.at(0) - 1; - for (Client* c : CI->Clients) { + for (auto& c : CI->Clients) { if (c != nullptr && c->GetID() == ID) { c->SetUDPAddr(client); c->isConnected = true; - UDPParser(c, Data.substr(2)); + UDPParser(c.get(), Data.substr(2)); } } } catch (const std::exception& e) { @@ -357,11 +357,11 @@ void LOOP() { ZeroMemory(clientIp, 256); ///Code to get IP we don't need that yet inet_ntop(AF_INET, &client.sin_addr, clientIp, 256);*/ uint8_t ID = uint8_t(Data.at(0)) - 1; - for (Client* c : CI->Clients) { + for (auto& c : CI->Clients) { if (c != nullptr && c->GetID() == ID) { c->SetUDPAddr(client); c->isConnected = true; - UDPParser(c, Data.substr(2)); + UDPParser(c.get(), Data.substr(2)); } } } catch (const std::exception& e) {