diff --git a/owrx/client.py b/owrx/client.py index 96b1567f..8e1ef813 100644 --- a/owrx/client.py +++ b/owrx/client.py @@ -1,6 +1,7 @@ from owrx.config import Config from owrx.color import ColorCache from datetime import datetime, timedelta +from ipaddress import ip_address import threading import re @@ -44,7 +45,7 @@ class ClientRegistry(object): def addClient(self, client): pm = Config.get() - if self.isIpBanned(client.conn.getIp()): + if self.isBanned(client.conn.handler): raise BannedClientException() elif len(self.clients) >= pm["max_clients"]: raise TooManyClientsException() @@ -94,13 +95,23 @@ class ClientRegistry(object): for c in self.clients: c.write_chat_message(name, text, color) + # Get client IP address from the handler. + def getIp(self, handler): + ip = handler.client_address[0] + # If address private and there is X-Forwarded-For header... + if ip_address(ip).is_private and hasattr(handler, "headers"): + if "x-forwarded-for" in handler.headers: + ip = handler.headers['x-forwarded-for'].split(',')[0] + # Done + return ip + # List all active and banned clients. def listAll(self): result = [] for c in self.clients: result.append({ - "ts" : c.conn.getStartTime(), - "ip" : c.conn.getIp(), + "ts" : c.conn.startTime, + "ip" : self.getIp(c.conn.handler), "sdr" : c.sdr.getName(), "band" : c.sdr.getProfileName(), "ban" : False @@ -120,7 +131,7 @@ class ClientRegistry(object): self.bans[ip] = datetime.now() + timedelta(minutes=minutes) banned = [] for c in self.clients: - if ip == c.conn.getIp(): + if ip == self.getIp(c.conn.handler): banned.append(c) for c in banned: try: @@ -134,7 +145,8 @@ class ClientRegistry(object): del self.bans[ip] # Check if given IP is banned at the moment. - def isIpBanned(self, ip: str): + def isBanned(self, handler): + ip = self.getIp(handler) return ip in self.bans and datetime.now() < self.bans[ip] # Delete all expired bans. diff --git a/owrx/websocket.py b/owrx/websocket.py index 1e7e7796..d9ffbd76 100644 --- a/owrx/websocket.py +++ b/owrx/websocket.py @@ -298,9 +298,3 @@ class WebSocketConnection(object): def sendPong(self): header = self.get_header(0, OPCODE_PONG) self._sendBytes(header) - - def getIp(self): - return self.handler.client_address[0] - - def getStartTime(self): - return self.startTime