diff --git a/include/TNetwork.h b/include/TNetwork.h index 1640be6..7a646a2 100644 --- a/include/TNetwork.h +++ b/include/TNetwork.h @@ -27,6 +27,14 @@ struct TConnection; +class WatchingConnecting{ +public: + bool IsConnectionAllowed(const std::string& clientAddress); +private: + void BlockIP(const std::string& clientAddress); + bool IsIPBlocked(const std::string& clientAddress); +}; + class TNetwork { public: TNetwork(TServer& Server, TPPSMonitor& PPSMonitor, TResourceManager& ResourceManager); @@ -48,7 +56,6 @@ public: private: void UDPServerMain(); void TCPServerMain(); - TServer& mServer; TPPSMonitor& mPPSMonitor; ip::udp::socket mUDPSock; diff --git a/src/TNetwork.cpp b/src/TNetwork.cpp index e88ac53..b2cdfea 100644 --- a/src/TNetwork.cpp +++ b/src/TNetwork.cpp @@ -31,6 +31,63 @@ typedef boost::asio::detail::socket_option::integer rcv_timeout_option; +#include +#include +#include +#include +#include + +std::unordered_map>> connectionAttempts; +std::mutex connectionAttemptsMutex; + +bool WatchingConnecting::IsConnectionAllowed(const std::string& clientAddress) { + // we check if there is an IP in the blocked list + if (WatchingConnecting::IsIPBlocked(clientAddress)) { + return false; + } + std::lock_guard lock(connectionAttemptsMutex); + auto currentTime = std::chrono::high_resolution_clock::now(); + auto& violations = connectionAttempts[clientAddress]; + + // Deleting old violations (older than 5 seconds) + violations.erase(std::remove_if(violations.begin(), violations.end(), + [&](const auto& timestamp) { + return std::chrono::duration_cast(currentTime - timestamp).count() > 5; + }), violations.end()); + + // Adding the current violation + violations.push_back(currentTime); + + // We check the number of violations + if (violations.size() >= 4) { + WatchingConnecting::BlockIP(clientAddress); + beammp_errorf("[DOS] Blocked IP: {}", clientAddress); + return false; + } + + return true; // We allow the connection + } + +void WatchingConnecting::BlockIP(const std::string& clientAddress) { + std::ofstream blockFile("blocked_ips.txt", std::ios::app); + if (blockFile.is_open()) { + blockFile << clientAddress << std::endl; + } + } + +bool WatchingConnecting::IsIPBlocked(const std::string& clientAddress) { + std::ifstream blockFile("blocked_ips.txt"); + std::unordered_set blockedIPs; + + if (blockFile.is_open()) { + std::string line; + while (std::getline(blockFile, line)) { + blockedIPs.insert(line); + } + } + return blockedIPs.find(clientAddress) != blockedIPs.end(); +}; + std::vector StringToVector(const std::string& Str) { return std::vector(Str.data(), Str.data() + Str.size()); } @@ -196,10 +253,16 @@ void TNetwork::Identify(TConnection&& RawConnection) { RawConnection.Socket.shutdown(socket_base::shutdown_both, ec); return; } + std::string clientAddress = RawConnection.SockAddr.address().to_string(); std::shared_ptr Client { nullptr }; + WatchingConnecting connectionManager; try { - if (Code == 'C') { - Client = Authentication(std::move(RawConnection)); + if (Code == 'C') { + if (connectionManager.IsConnectionAllowed(clientAddress)) { + Client = Authentication(std::move(RawConnection)); + } else { + RawConnection.Socket.shutdown(socket_base::shutdown_both, ec); + } } else if (Code == 'D') { HandleDownload(std::move(RawConnection)); } else if (Code == 'P') {