From 2f5ea5bac4c78e261552c816d7429f1f9937de84 Mon Sep 17 00:00:00 2001 From: Bogdan Kolesnik Date: Wed, 25 Mar 2026 08:04:34 +0500 Subject: [PATCH] feat: WebSocket connection --- modules/binance_connection.py | 105 +++++++++ modules/config.py | 68 ++++++ modules/connections.py | 427 ++++++++++++++++++++++++++++++++++ tests/test_connections.py | 83 +++++++ 4 files changed, 683 insertions(+) create mode 100644 modules/binance_connection.py create mode 100644 modules/config.py create mode 100644 modules/connections.py create mode 100644 tests/test_connections.py diff --git a/modules/binance_connection.py b/modules/binance_connection.py new file mode 100644 index 0000000..fd745ba --- /dev/null +++ b/modules/binance_connection.py @@ -0,0 +1,105 @@ +""" +Binance-specific WebSocket connection implementation. +""" + +import json +import logging +from typing import Any, Dict, List + +from modules.connections import BaseConnection + +logger = logging.getLogger(__name__) + + +class BinanceConnection(BaseConnection): + """ + Binance WebSocket connection implementation. + + This class implements the BaseConnection interface for Binance exchange + with specific handling for Binance WebSocket API format. + """ + + def __init__(self, ws_url: str = "wss://stream.binance.com:9443"): + """ + Initialize Binance connection. + + Args: + ws_url: Binance WebSocket URL + """ + super().__init__("binance", ws_url) + + async def _parse_message(self, message: str) -> Dict[str, Any]: + """ + Parse Binance WebSocket message. + + Args: + message: Raw WebSocket message + + Returns: + Parsed message data + """ + try: + data = json.loads(message) + return data + except json.JSONDecodeError: + logger.error(f"Failed to parse Binance message: {message}") + return {} + + def _get_subscribe_message(self, topics: List[str]) -> str: + """ + Generate Binance subscription message. + + Args: + topics: List of topics to subscribe to + + Returns: + Subscription message as string + """ + # Binance uses a different format for subscription + subscribe_msg = { + "method": "SUBSCRIBE", + "params": topics, + "id": 1 + } + return json.dumps(subscribe_msg) + + def _get_orderbook_topic(self, symbol: str) -> str: + """ + Generate Binance orderbook topic name. + + Args: + symbol: Trading symbol + + Returns: + Orderbook topic name + """ + return f"{symbol.lower()}@depth20@100ms" + + def _get_trade_topic(self, symbol: str) -> str: + """ + Generate Binance trade topic name. + + Args: + symbol: Trading symbol + + Returns: + Trade topic name + """ + return f"{symbol.lower()}@aggTrade" + + def _get_symbol_from_topic(self, topic: str) -> str: + """ + Extract symbol from Binance topic name. + + Args: + topic: Topic name + + Returns: + Symbol extracted from topic + """ + # Example Binance topics: "btcusdt@depth20@100ms", "btcusdt@aggTrade" + # Extract the symbol part + if "@" in topic: + symbol = topic.split("@")[0].upper() + return symbol + return topic.upper() \ No newline at end of file diff --git a/modules/config.py b/modules/config.py new file mode 100644 index 0000000..cfe9292 --- /dev/null +++ b/modules/config.py @@ -0,0 +1,68 @@ +""" +Configuration management for the cryptocurrency trading platform. +""" + +import os +from typing import Optional + +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + + +def get_redis_url() -> str: + """ + Get Redis connection URL from environment variables. + + Returns: + Redis connection URL + """ + return os.getenv("REDIS_URL", "redis://localhost:6379/0") + + +def get_timescaledb_url() -> str: + """ + Get TimescaleDB connection URL from environment variables. + + Returns: + TimescaleDB connection URL + """ + return os.getenv("TIMESCALEDB_URL", "postgresql://localhost:5432/trading") + + +def get_supported_exchanges() -> list: + """ + Get list of supported exchanges from environment variables. + + Returns: + List of supported exchange names + """ + exchanges = os.getenv("SUPPORTED_EXCHANGES", "binance") + return [ex.strip() for ex in exchanges.split(",") if ex.strip()] + + +def get_exchange_config(exchange_name: str) -> dict: + """ + Get configuration for a specific exchange. + + Args: + exchange_name: Name of the exchange + + Returns: + Exchange configuration dictionary + """ + config = { + "binance": { + "ws_url": "wss://stream.binance.com:9443", + "api_url": "https://api.binance.com", + "max_concurrent_connections": 1000, + }, + "okx": { + "ws_url": "wss://ws.okx.com:8443", + "api_url": "https://www.okx.com", + "max_concurrent_connections": 1000, + } + } + + return config.get(exchange_name.lower(), {}) \ No newline at end of file diff --git a/modules/connections.py b/modules/connections.py new file mode 100644 index 0000000..3b6e667 --- /dev/null +++ b/modules/connections.py @@ -0,0 +1,427 @@ +""" +Base connection class for cryptocurrency exchange WebSocket clients. + +This module provides a base class for WebSocket connections to various +cryptocurrency exchanges. It handles connection management, heartbeat +processing, subscription management, and data routing to Redis. +""" + +import asyncio +import json +import logging +import time +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Set, Union +from urllib.parse import urlparse + +import aioredis +import websockets + +from modules.config import get_redis_url + +logger = logging.getLogger(__name__) + + +class BaseConnection(ABC): + """ + Abstract base class for cryptocurrency exchange WebSocket connections. + + This class provides a common interface for WebSocket connections to + different exchanges with support for: + - Connection management + - Heartbeat/ping processing + - Topic subscription + - Data routing to Redis + """ + + def __init__(self, exchange_name: str, ws_url: str, redis_url: Optional[str] = None): + """ + Initialize the base connection. + + Args: + exchange_name: Name of the exchange (e.g., 'binance', 'okx') + ws_url: WebSocket URL for the exchange + redis_url: Redis connection URL (uses default if not provided) + """ + self.exchange_name = exchange_name + self.ws_url = ws_url + self.redis_url = redis_url or get_redis_url() + self.redis_client: Optional[aioredis.Redis] = None + self.websocket: Optional[websockets.WebSocketClientProtocol] = None + self.is_connected = False + self.is_connecting = False + self.subscribed_topics: Set[str] = set() + self._heartbeat_interval = 30 # seconds + self._last_heartbeat = 0 + self._connection_attempts = 0 + self._max_reconnect_attempts = 5 + + @abstractmethod + async def _parse_message(self, message: str) -> Dict[str, Any]: + """ + Parse incoming WebSocket message and extract relevant data. + + This method should be implemented by subclasses to handle + exchange-specific message formats. + + Args: + message: Raw WebSocket message + + Returns: + Parsed message data as dictionary + """ + pass + + @abstractmethod + def _get_subscribe_message(self, topics: List[str]) -> str: + """ + Generate subscription message for the exchange. + + Args: + topics: List of topics to subscribe to + + Returns: + Subscription message as string + """ + pass + + @abstractmethod + def _get_orderbook_topic(self, symbol: str) -> str: + """ + Generate orderbook topic name for the exchange. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + + Returns: + Orderbook topic name + """ + pass + + @abstractmethod + def _get_trade_topic(self, symbol: str) -> str: + """ + Generate trade topic name for the exchange. + + Args: + symbol: Trading symbol (e.g., 'BTCUSDT') + + Returns: + Trade topic name + """ + pass + + @abstractmethod + def _get_symbol_from_topic(self, topic: str) -> str: + """ + Extract symbol from topic name. + + Args: + topic: Topic name + + Returns: + Symbol extracted from topic + """ + pass + + async def connect(self) -> bool: + """ + Establish WebSocket connection to the exchange. + + Returns: + True if connection successful, False otherwise + """ + if self.is_connected or self.is_connecting: + return True + + self.is_connecting = True + try: + logger.info(f"Connecting to {self.exchange_name} WebSocket: {self.ws_url}") + + # Connect to Redis + redis_url = self.redis_url + if not self.redis_client: + self.redis_client = await aioredis.from_url(redis_url) + logger.info(f"Connected to Redis: {redis_url}") + + # Connect to WebSocket + self.websocket = await websockets.connect( + self.ws_url, + ping_interval=self._heartbeat_interval, + ping_timeout=self._heartbeat_interval + ) + + self.is_connected = True + self.is_connecting = False + self._connection_attempts = 0 + + logger.info(f"Connected to {self.exchange_name} WebSocket") + return True + + except Exception as e: + logger.error(f"Failed to connect to {self.exchange_name}: {e}") + self.is_connecting = False + self._connection_attempts += 1 + return False + + async def disconnect(self) -> None: + """Close WebSocket connection and Redis connection.""" + if self.websocket: + await self.websocket.close() + self.websocket = None + if self.redis_client: + await self.redis_client.close() + self.redis_client = None + self.is_connected = False + + async def _handle_heartbeat(self, message: str) -> bool: + """ + Process heartbeat/ping messages. + + Args: + message: Incoming message + + Returns: + True if message was a heartbeat, False otherwise + """ + try: + data = json.loads(message) + + # Common heartbeat patterns across exchanges + if ( + "ping" in data or + "event" in data and data["event"] == "ping" or + (isinstance(data.get("op"), str) and data["op"] == "ping") + ): + # Send pong response + if "pong" in message: + # Already a pong, just update heartbeat time + self._last_heartbeat = time.time() + else: + # Send pong response + pong_msg = message.replace("ping", "pong") + await self._send_message(pong_msg) + self._last_heartbeat = time.time() + + logger.debug(f"Heartbeat received from {self.exchange_name}") + return True + + except json.JSONDecodeError: + pass + + return False + + async def subscribe(self, topics: List[str]) -> bool: + """ + Subscribe to topics on the exchange. + + Args: + topics: List of topics to subscribe to + + Returns: + True if subscription successful, False otherwise + """ + if not self.is_connected: + if not await self.connect(): + return False + + if not self.websocket: + logger.error("WebSocket not connected for subscription") + return False + + try: + # Generate subscription message + subscribe_msg = self._get_subscribe_message(topics) + + # Send subscription message + await self._send_message(subscribe_msg) + + # Update subscribed topics + self.subscribed_topics.update(topics) + + logger.info(f"Subscribed to topics on {self.exchange_name}: {topics}") + return True + + except Exception as e: + logger.error(f"Subscription failed on {self.exchange_name}: {e}") + return False + + async def _send_message(self, message: str) -> None: + """ + Send message through WebSocket connection. + + Args: + message: Message to send + """ + if self.websocket and self.is_connected: + await self.websocket.send(message) + else: + logger.warning("WebSocket not connected, cannot send message") + + async def _process_message(self, message: str) -> None: + """ + Process incoming WebSocket message and route to Redis. + + Args: + message: Raw WebSocket message + """ + # Check if it's a heartbeat message + if await self._handle_heartbeat(message): + return + + try: + # Parse the message + data = await self._parse_message(message) + if not data: + return + + # Determine the appropriate output channel based on message type + channel = self._determine_output_channel(data) + if not channel: + return + + # Add metadata + data["exchange"] = self.exchange_name + data["timestamp"] = time.time() + + # Send to Redis + await self._send_to_redis(channel, data) + + except Exception as e: + logger.error(f"Error processing message from {self.exchange_name}: {e}") + + def _determine_output_channel(self, data: Dict[str, Any]) -> Optional[str]: + """ + Determine the appropriate Redis channel based on data type. + + Args: + data: Parsed message data + + Returns: + Redis channel name or None if no channel should be used + """ + # Default to None if no channel is determined + channel = None + + # Try to determine channel from message structure + if "e" in data: # Binance-style events + event_type = data.get("e") + if event_type == "depthUpdate": + channel = "orderbook" + elif event_type == "trade": + channel = "trade" + elif "event" in data: # OKX-style events + event_type = data.get("event") + if event_type == "depth": + channel = "orderbook" + elif event_type == "trade": + channel = "trade" + elif "topic" in data: # Generic topic-based + topic = data.get("topic") + if "depth" in topic: + channel = "orderbook" + elif "trade" in topic: + channel = "trade" + + # If we have a symbol, we might want to add it to the channel + symbol = data.get("s") or data.get("symbol") + if symbol and channel: + return f"{channel}:{symbol}" + elif channel: + return channel + + return None + + async def _send_to_redis(self, channel: str, data: Dict[str, Any]) -> None: + """ + Send data to Redis channel. + + Args: + channel: Redis channel to send data to + data: Data to send + """ + if not self.redis_client: + logger.warning("Redis client not initialized") + return + + try: + # Convert data to JSON string + message = json.dumps(data) + + # Send to Redis + await self.redis_client.publish(channel, message) + logger.debug(f"Published to Redis channel {channel}: {data}") + + except Exception as e: + logger.error(f"Failed to send to Redis channel {channel}: {e}") + + async def start_listening(self) -> None: + """ + Start listening for incoming messages from WebSocket. + + This method should be run in an asyncio event loop. + """ + if not self.is_connected: + if not await self.connect(): + logger.error(f"Cannot start listening - failed to connect to {self.exchange_name}") + return + + try: + while self.is_connected and self.websocket: + try: + message = await asyncio.wait_for( + self.websocket.recv(), + timeout=60.0 # 60 second timeout + ) + await self._process_message(message) + except asyncio.TimeoutError: + # Handle timeout - send ping to keep connection alive + await self._send_message('{"ping":1}') + except websockets.exceptions.ConnectionClosed: + logger.warning(f"Connection closed by {self.exchange_name}") + await self.reconnect() + break + except Exception as e: + logger.error(f"Error in message loop for {self.exchange_name}: {e}") + await self.reconnect() + break + + except Exception as e: + logger.error(f"Error in listening loop for {self.exchange_name}: {e}") + + async def reconnect(self) -> bool: + """ + Attempt to reconnect to the exchange. + + Returns: + True if reconnection successful, False otherwise + """ + logger.info(f"Attempting to reconnect to {self.exchange_name}") + + # Disconnect first + await self.disconnect() + + # If we've exceeded max attempts, stop + if self._connection_attempts >= self._max_reconnect_attempts: + logger.error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded for {self.exchange_name}") + return False + + try: + # Wait before reconnecting + await asyncio.sleep(2 ** self._connection_attempts) # Exponential backoff + + # Try to reconnect + success = await self.connect() + if success: + # Resubscribe to topics + if self.subscribed_topics: + await self.subscribe(list(self.subscribed_topics)) + else: + logger.error(f"Reconnection failed for {self.exchange_name}") + + return success + + except Exception as e: + logger.error(f"Exception during reconnection to {self.exchange_name}: {e}") + self._connection_attempts += 1 + return False \ No newline at end of file diff --git a/tests/test_connections.py b/tests/test_connections.py new file mode 100644 index 0000000..cd9385d --- /dev/null +++ b/tests/test_connections.py @@ -0,0 +1,83 @@ +""" +Test module for WebSocket connection classes. +""" + +import asyncio +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from modules.binance_connection import BinanceConnection +from modules.connections import BaseConnection + + +class TestBaseConnection(unittest.TestCase): + """Test the BaseConnection class functionality.""" + + def test_base_connection_initialization(self): + """Test that BaseConnection initializes correctly.""" + connection = BinanceConnection() + self.assertEqual(connection.exchange_name, "binance") + self.assertEqual(connection.ws_url, "wss://stream.binance.com:9443") + self.assertIsNone(connection.websocket) + self.assertIsNone(connection.redis_client) + self.assertFalse(connection.is_connected) + self.assertFalse(connection.is_connecting) + + @patch('modules.connections.aioredis.from_url') + @patch('modules.connections.websockets.connect') + async def test_connect_success(self, mock_websocket_connect, mock_redis_connect): + """Test successful connection.""" + # Mock the websocket connection + mock_ws = AsyncMock() + mock_websocket_connect.return_value = mock_ws + + # Mock the redis connection + mock_redis = AsyncMock() + mock_redis_connect.return_value = mock_redis + + connection = BinanceConnection() + + # Test connection + result = await connection.connect() + self.assertTrue(result) + self.assertTrue(connection.is_connected) + + # Verify connections were made + mock_websocket_connect.assert_called_once() + mock_redis_connect.assert_called_once() + + +class TestBinanceConnection(unittest.TestCase): + """Test the BinanceConnection class functionality.""" + + def test_binance_connection_initialization(self): + """Test that BinanceConnection initializes correctly.""" + connection = BinanceConnection() + self.assertEqual(connection.exchange_name, "binance") + self.assertEqual(connection.ws_url, "wss://stream.binance.com:9443") + + def test_get_subscribe_message(self): + """Test subscription message generation.""" + connection = BinanceConnection() + topics = ["btcusdt@depth20@100ms", "ethusdt@aggTrade"] + message = connection._get_subscribe_message(topics) + data = eval(message) # Parse the dict string + self.assertEqual(data["method"], "SUBSCRIBE") + self.assertEqual(data["params"], topics) + self.assertEqual(data["id"], 1) + + def test_get_orderbook_topic(self): + """Test orderbook topic generation.""" + connection = BinanceConnection() + topic = connection._get_orderbook_topic("BTCUSDT") + self.assertEqual(topic, "btcusdt@depth20@100ms") + + def test_get_trade_topic(self): + """Test trade topic generation.""" + connection = BinanceConnection() + topic = connection._get_trade_topic("BTCUSDT") + self.assertEqual(topic, "btcusdt@aggTrade") + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file