feat: WebSocket connection
This commit is contained in:
105
modules/binance_connection.py
Normal file
105
modules/binance_connection.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Binance-specific WebSocket connection implementation.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from modules.connections import BaseConnection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BinanceConnection(BaseConnection):
|
||||
"""
|
||||
Binance WebSocket connection implementation.
|
||||
|
||||
This class implements the BaseConnection interface for Binance exchange
|
||||
with specific handling for Binance WebSocket API format.
|
||||
"""
|
||||
|
||||
def __init__(self, ws_url: str = "wss://stream.binance.com:9443"):
|
||||
"""
|
||||
Initialize Binance connection.
|
||||
|
||||
Args:
|
||||
ws_url: Binance WebSocket URL
|
||||
"""
|
||||
super().__init__("binance", ws_url)
|
||||
|
||||
async def _parse_message(self, message: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse Binance WebSocket message.
|
||||
|
||||
Args:
|
||||
message: Raw WebSocket message
|
||||
|
||||
Returns:
|
||||
Parsed message data
|
||||
"""
|
||||
try:
|
||||
data = json.loads(message)
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse Binance message: {message}")
|
||||
return {}
|
||||
|
||||
def _get_subscribe_message(self, topics: List[str]) -> str:
|
||||
"""
|
||||
Generate Binance subscription message.
|
||||
|
||||
Args:
|
||||
topics: List of topics to subscribe to
|
||||
|
||||
Returns:
|
||||
Subscription message as string
|
||||
"""
|
||||
# Binance uses a different format for subscription
|
||||
subscribe_msg = {
|
||||
"method": "SUBSCRIBE",
|
||||
"params": topics,
|
||||
"id": 1
|
||||
}
|
||||
return json.dumps(subscribe_msg)
|
||||
|
||||
def _get_orderbook_topic(self, symbol: str) -> str:
|
||||
"""
|
||||
Generate Binance orderbook topic name.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Orderbook topic name
|
||||
"""
|
||||
return f"{symbol.lower()}@depth20@100ms"
|
||||
|
||||
def _get_trade_topic(self, symbol: str) -> str:
|
||||
"""
|
||||
Generate Binance trade topic name.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Trade topic name
|
||||
"""
|
||||
return f"{symbol.lower()}@aggTrade"
|
||||
|
||||
def _get_symbol_from_topic(self, topic: str) -> str:
|
||||
"""
|
||||
Extract symbol from Binance topic name.
|
||||
|
||||
Args:
|
||||
topic: Topic name
|
||||
|
||||
Returns:
|
||||
Symbol extracted from topic
|
||||
"""
|
||||
# Example Binance topics: "btcusdt@depth20@100ms", "btcusdt@aggTrade"
|
||||
# Extract the symbol part
|
||||
if "@" in topic:
|
||||
symbol = topic.split("@")[0].upper()
|
||||
return symbol
|
||||
return topic.upper()
|
||||
68
modules/config.py
Normal file
68
modules/config.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Configuration management for the cryptocurrency trading platform.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def get_redis_url() -> str:
|
||||
"""
|
||||
Get Redis connection URL from environment variables.
|
||||
|
||||
Returns:
|
||||
Redis connection URL
|
||||
"""
|
||||
return os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||
|
||||
|
||||
def get_timescaledb_url() -> str:
|
||||
"""
|
||||
Get TimescaleDB connection URL from environment variables.
|
||||
|
||||
Returns:
|
||||
TimescaleDB connection URL
|
||||
"""
|
||||
return os.getenv("TIMESCALEDB_URL", "postgresql://localhost:5432/trading")
|
||||
|
||||
|
||||
def get_supported_exchanges() -> list:
|
||||
"""
|
||||
Get list of supported exchanges from environment variables.
|
||||
|
||||
Returns:
|
||||
List of supported exchange names
|
||||
"""
|
||||
exchanges = os.getenv("SUPPORTED_EXCHANGES", "binance")
|
||||
return [ex.strip() for ex in exchanges.split(",") if ex.strip()]
|
||||
|
||||
|
||||
def get_exchange_config(exchange_name: str) -> dict:
|
||||
"""
|
||||
Get configuration for a specific exchange.
|
||||
|
||||
Args:
|
||||
exchange_name: Name of the exchange
|
||||
|
||||
Returns:
|
||||
Exchange configuration dictionary
|
||||
"""
|
||||
config = {
|
||||
"binance": {
|
||||
"ws_url": "wss://stream.binance.com:9443",
|
||||
"api_url": "https://api.binance.com",
|
||||
"max_concurrent_connections": 1000,
|
||||
},
|
||||
"okx": {
|
||||
"ws_url": "wss://ws.okx.com:8443",
|
||||
"api_url": "https://www.okx.com",
|
||||
"max_concurrent_connections": 1000,
|
||||
}
|
||||
}
|
||||
|
||||
return config.get(exchange_name.lower(), {})
|
||||
427
modules/connections.py
Normal file
427
modules/connections.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""
|
||||
Base connection class for cryptocurrency exchange WebSocket clients.
|
||||
|
||||
This module provides a base class for WebSocket connections to various
|
||||
cryptocurrency exchanges. It handles connection management, heartbeat
|
||||
processing, subscription management, and data routing to Redis.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aioredis
|
||||
import websockets
|
||||
|
||||
from modules.config import get_redis_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseConnection(ABC):
|
||||
"""
|
||||
Abstract base class for cryptocurrency exchange WebSocket connections.
|
||||
|
||||
This class provides a common interface for WebSocket connections to
|
||||
different exchanges with support for:
|
||||
- Connection management
|
||||
- Heartbeat/ping processing
|
||||
- Topic subscription
|
||||
- Data routing to Redis
|
||||
"""
|
||||
|
||||
def __init__(self, exchange_name: str, ws_url: str, redis_url: Optional[str] = None):
|
||||
"""
|
||||
Initialize the base connection.
|
||||
|
||||
Args:
|
||||
exchange_name: Name of the exchange (e.g., 'binance', 'okx')
|
||||
ws_url: WebSocket URL for the exchange
|
||||
redis_url: Redis connection URL (uses default if not provided)
|
||||
"""
|
||||
self.exchange_name = exchange_name
|
||||
self.ws_url = ws_url
|
||||
self.redis_url = redis_url or get_redis_url()
|
||||
self.redis_client: Optional[aioredis.Redis] = None
|
||||
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
|
||||
self.is_connected = False
|
||||
self.is_connecting = False
|
||||
self.subscribed_topics: Set[str] = set()
|
||||
self._heartbeat_interval = 30 # seconds
|
||||
self._last_heartbeat = 0
|
||||
self._connection_attempts = 0
|
||||
self._max_reconnect_attempts = 5
|
||||
|
||||
@abstractmethod
|
||||
async def _parse_message(self, message: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse incoming WebSocket message and extract relevant data.
|
||||
|
||||
This method should be implemented by subclasses to handle
|
||||
exchange-specific message formats.
|
||||
|
||||
Args:
|
||||
message: Raw WebSocket message
|
||||
|
||||
Returns:
|
||||
Parsed message data as dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_subscribe_message(self, topics: List[str]) -> str:
|
||||
"""
|
||||
Generate subscription message for the exchange.
|
||||
|
||||
Args:
|
||||
topics: List of topics to subscribe to
|
||||
|
||||
Returns:
|
||||
Subscription message as string
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_orderbook_topic(self, symbol: str) -> str:
|
||||
"""
|
||||
Generate orderbook topic name for the exchange.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTCUSDT')
|
||||
|
||||
Returns:
|
||||
Orderbook topic name
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_trade_topic(self, symbol: str) -> str:
|
||||
"""
|
||||
Generate trade topic name for the exchange.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTCUSDT')
|
||||
|
||||
Returns:
|
||||
Trade topic name
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_symbol_from_topic(self, topic: str) -> str:
|
||||
"""
|
||||
Extract symbol from topic name.
|
||||
|
||||
Args:
|
||||
topic: Topic name
|
||||
|
||||
Returns:
|
||||
Symbol extracted from topic
|
||||
"""
|
||||
pass
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
Establish WebSocket connection to the exchange.
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
if self.is_connected or self.is_connecting:
|
||||
return True
|
||||
|
||||
self.is_connecting = True
|
||||
try:
|
||||
logger.info(f"Connecting to {self.exchange_name} WebSocket: {self.ws_url}")
|
||||
|
||||
# Connect to Redis
|
||||
redis_url = self.redis_url
|
||||
if not self.redis_client:
|
||||
self.redis_client = await aioredis.from_url(redis_url)
|
||||
logger.info(f"Connected to Redis: {redis_url}")
|
||||
|
||||
# Connect to WebSocket
|
||||
self.websocket = await websockets.connect(
|
||||
self.ws_url,
|
||||
ping_interval=self._heartbeat_interval,
|
||||
ping_timeout=self._heartbeat_interval
|
||||
)
|
||||
|
||||
self.is_connected = True
|
||||
self.is_connecting = False
|
||||
self._connection_attempts = 0
|
||||
|
||||
logger.info(f"Connected to {self.exchange_name} WebSocket")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to {self.exchange_name}: {e}")
|
||||
self.is_connecting = False
|
||||
self._connection_attempts += 1
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close WebSocket connection and Redis connection."""
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
self.websocket = None
|
||||
if self.redis_client:
|
||||
await self.redis_client.close()
|
||||
self.redis_client = None
|
||||
self.is_connected = False
|
||||
|
||||
async def _handle_heartbeat(self, message: str) -> bool:
|
||||
"""
|
||||
Process heartbeat/ping messages.
|
||||
|
||||
Args:
|
||||
message: Incoming message
|
||||
|
||||
Returns:
|
||||
True if message was a heartbeat, False otherwise
|
||||
"""
|
||||
try:
|
||||
data = json.loads(message)
|
||||
|
||||
# Common heartbeat patterns across exchanges
|
||||
if (
|
||||
"ping" in data or
|
||||
"event" in data and data["event"] == "ping" or
|
||||
(isinstance(data.get("op"), str) and data["op"] == "ping")
|
||||
):
|
||||
# Send pong response
|
||||
if "pong" in message:
|
||||
# Already a pong, just update heartbeat time
|
||||
self._last_heartbeat = time.time()
|
||||
else:
|
||||
# Send pong response
|
||||
pong_msg = message.replace("ping", "pong")
|
||||
await self._send_message(pong_msg)
|
||||
self._last_heartbeat = time.time()
|
||||
|
||||
logger.debug(f"Heartbeat received from {self.exchange_name}")
|
||||
return True
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
async def subscribe(self, topics: List[str]) -> bool:
|
||||
"""
|
||||
Subscribe to topics on the exchange.
|
||||
|
||||
Args:
|
||||
topics: List of topics to subscribe to
|
||||
|
||||
Returns:
|
||||
True if subscription successful, False otherwise
|
||||
"""
|
||||
if not self.is_connected:
|
||||
if not await self.connect():
|
||||
return False
|
||||
|
||||
if not self.websocket:
|
||||
logger.error("WebSocket not connected for subscription")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Generate subscription message
|
||||
subscribe_msg = self._get_subscribe_message(topics)
|
||||
|
||||
# Send subscription message
|
||||
await self._send_message(subscribe_msg)
|
||||
|
||||
# Update subscribed topics
|
||||
self.subscribed_topics.update(topics)
|
||||
|
||||
logger.info(f"Subscribed to topics on {self.exchange_name}: {topics}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Subscription failed on {self.exchange_name}: {e}")
|
||||
return False
|
||||
|
||||
async def _send_message(self, message: str) -> None:
|
||||
"""
|
||||
Send message through WebSocket connection.
|
||||
|
||||
Args:
|
||||
message: Message to send
|
||||
"""
|
||||
if self.websocket and self.is_connected:
|
||||
await self.websocket.send(message)
|
||||
else:
|
||||
logger.warning("WebSocket not connected, cannot send message")
|
||||
|
||||
async def _process_message(self, message: str) -> None:
|
||||
"""
|
||||
Process incoming WebSocket message and route to Redis.
|
||||
|
||||
Args:
|
||||
message: Raw WebSocket message
|
||||
"""
|
||||
# Check if it's a heartbeat message
|
||||
if await self._handle_heartbeat(message):
|
||||
return
|
||||
|
||||
try:
|
||||
# Parse the message
|
||||
data = await self._parse_message(message)
|
||||
if not data:
|
||||
return
|
||||
|
||||
# Determine the appropriate output channel based on message type
|
||||
channel = self._determine_output_channel(data)
|
||||
if not channel:
|
||||
return
|
||||
|
||||
# Add metadata
|
||||
data["exchange"] = self.exchange_name
|
||||
data["timestamp"] = time.time()
|
||||
|
||||
# Send to Redis
|
||||
await self._send_to_redis(channel, data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message from {self.exchange_name}: {e}")
|
||||
|
||||
def _determine_output_channel(self, data: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
Determine the appropriate Redis channel based on data type.
|
||||
|
||||
Args:
|
||||
data: Parsed message data
|
||||
|
||||
Returns:
|
||||
Redis channel name or None if no channel should be used
|
||||
"""
|
||||
# Default to None if no channel is determined
|
||||
channel = None
|
||||
|
||||
# Try to determine channel from message structure
|
||||
if "e" in data: # Binance-style events
|
||||
event_type = data.get("e")
|
||||
if event_type == "depthUpdate":
|
||||
channel = "orderbook"
|
||||
elif event_type == "trade":
|
||||
channel = "trade"
|
||||
elif "event" in data: # OKX-style events
|
||||
event_type = data.get("event")
|
||||
if event_type == "depth":
|
||||
channel = "orderbook"
|
||||
elif event_type == "trade":
|
||||
channel = "trade"
|
||||
elif "topic" in data: # Generic topic-based
|
||||
topic = data.get("topic")
|
||||
if "depth" in topic:
|
||||
channel = "orderbook"
|
||||
elif "trade" in topic:
|
||||
channel = "trade"
|
||||
|
||||
# If we have a symbol, we might want to add it to the channel
|
||||
symbol = data.get("s") or data.get("symbol")
|
||||
if symbol and channel:
|
||||
return f"{channel}:{symbol}"
|
||||
elif channel:
|
||||
return channel
|
||||
|
||||
return None
|
||||
|
||||
async def _send_to_redis(self, channel: str, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Send data to Redis channel.
|
||||
|
||||
Args:
|
||||
channel: Redis channel to send data to
|
||||
data: Data to send
|
||||
"""
|
||||
if not self.redis_client:
|
||||
logger.warning("Redis client not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
# Convert data to JSON string
|
||||
message = json.dumps(data)
|
||||
|
||||
# Send to Redis
|
||||
await self.redis_client.publish(channel, message)
|
||||
logger.debug(f"Published to Redis channel {channel}: {data}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send to Redis channel {channel}: {e}")
|
||||
|
||||
async def start_listening(self) -> None:
|
||||
"""
|
||||
Start listening for incoming messages from WebSocket.
|
||||
|
||||
This method should be run in an asyncio event loop.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
if not await self.connect():
|
||||
logger.error(f"Cannot start listening - failed to connect to {self.exchange_name}")
|
||||
return
|
||||
|
||||
try:
|
||||
while self.is_connected and self.websocket:
|
||||
try:
|
||||
message = await asyncio.wait_for(
|
||||
self.websocket.recv(),
|
||||
timeout=60.0 # 60 second timeout
|
||||
)
|
||||
await self._process_message(message)
|
||||
except asyncio.TimeoutError:
|
||||
# Handle timeout - send ping to keep connection alive
|
||||
await self._send_message('{"ping":1}')
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.warning(f"Connection closed by {self.exchange_name}")
|
||||
await self.reconnect()
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in message loop for {self.exchange_name}: {e}")
|
||||
await self.reconnect()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in listening loop for {self.exchange_name}: {e}")
|
||||
|
||||
async def reconnect(self) -> bool:
|
||||
"""
|
||||
Attempt to reconnect to the exchange.
|
||||
|
||||
Returns:
|
||||
True if reconnection successful, False otherwise
|
||||
"""
|
||||
logger.info(f"Attempting to reconnect to {self.exchange_name}")
|
||||
|
||||
# Disconnect first
|
||||
await self.disconnect()
|
||||
|
||||
# If we've exceeded max attempts, stop
|
||||
if self._connection_attempts >= self._max_reconnect_attempts:
|
||||
logger.error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded for {self.exchange_name}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Wait before reconnecting
|
||||
await asyncio.sleep(2 ** self._connection_attempts) # Exponential backoff
|
||||
|
||||
# Try to reconnect
|
||||
success = await self.connect()
|
||||
if success:
|
||||
# Resubscribe to topics
|
||||
if self.subscribed_topics:
|
||||
await self.subscribe(list(self.subscribed_topics))
|
||||
else:
|
||||
logger.error(f"Reconnection failed for {self.exchange_name}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during reconnection to {self.exchange_name}: {e}")
|
||||
self._connection_attempts += 1
|
||||
return False
|
||||
Reference in New Issue
Block a user