diff --git a/src/core/tcp_server.py b/src/core/tcp_server.py index 3d354eb..dddcf67 100644 --- a/src/core/tcp_server.py +++ b/src/core/tcp_server.py @@ -23,6 +23,8 @@ class TCPServer: self.port = port self.run = False self.rl = RateLimiter(50, 10, 300) + console.add_command("rl", self.rl.parse_console, None, "RateLimiter menu", + {"rl": {"info": None, "unban": None, "ban": None, "help": None}}) async def auth_client(self, reader, writer): client = self.Core.create_client(reader, writer) diff --git a/src/modules/RateLimiter/__init__.py b/src/modules/RateLimiter/__init__.py index 8ebc7b3..ac94717 100644 --- a/src/modules/RateLimiter/__init__.py +++ b/src/modules/RateLimiter/__init__.py @@ -1,4 +1,5 @@ import asyncio +import textwrap from collections import defaultdict, deque from datetime import datetime, timedelta @@ -7,7 +8,7 @@ from core import utils class RateLimiter: def __init__(self, max_calls: int, period: float, ban_time: float): - self.log = utils.get_logger("DOSProtect") + self.log = utils.get_logger("RateLimiter") self.max_calls = max_calls self.period = timedelta(seconds=period) self.ban_time = timedelta(seconds=ban_time) @@ -15,6 +16,52 @@ class RateLimiter: self._banned_until = defaultdict(lambda: datetime.min) self._notified = {} + def parse_console(self, x): + help_msg = textwrap.dedent("""\ + + RateLimiter menu: + info - list banned ip's + ban - put ip in banlist + unban - force remove ip from banlist + help - print that message""") + _banned_ips = [i for i in self._banned_until if self.is_banned(i, False)] + if len(x) > 0: + match x[0]: + case "info": + self.log.info(f"Trigger {self.max_calls}req/{self.period}. IP will be banned for {self.ban_time}.") + if len(_banned_ips) == 0: + return "No one ip in banlist." + else: + _msg = f"Banned ip{'' if len(_banned_ips) == 1 else 's'}: " + for ip in _banned_ips: + _msg += f"{ip}; " + return _msg + case "unban": + if len(x) == 2: + ip = x[1] + if ip in _banned_ips: + self._calls[ip].clear() + self._banned_until[ip] = datetime.now() + return f"{ip} removed from banlist." + return f"{ip} not banned." + else: + return 'rl unban ' + case "ban": + if len(x) == 3: + ip = x[1] + sec = x[2] + if not sec.isdigit(): + return f"{sec!r} is not digit." + self._calls[ip].clear() + self._banned_until[ip] = datetime.now() + timedelta(seconds=int(sec)) + return f"{ip} banned until {self._banned_until[ip]}" + else: + return 'rl ban ' + case _: + return help_msg + else: + return help_msg + async def notify(self, ip, writer): if not self._notified[ip]: self._notified[ip] = True @@ -26,14 +73,13 @@ class RateLimiter: except Exception: pass - def is_banned(self, ip: str) -> bool: + def is_banned(self, ip: str, _add_call=True) -> bool: now = datetime.now() if now < self._banned_until[ip]: return True - now = datetime.now() - self._calls[ip].append(now) - + if _add_call: + self._calls[ip].append(now) while self._calls[ip] and self._calls[ip][0] + self.period < now: self._calls[ip].popleft() @@ -49,8 +95,7 @@ class RateLimiter: async def handle_request(ip: str, rate_limiter: RateLimiter): if rate_limiter.is_banned(ip): print(f"Request from {ip} is banned at {datetime.now()}") - print(f"{rate_limiter._banned_until[ip]}") - return + rate_limiter.parse_console(["info"]) async def server_simulation():