427 lines
14 KiB
Python
427 lines
14 KiB
Python
"""
|
|
Base connection class for cryptocurrency exchange WebSocket clients.
|
|
|
|
This module provides a base class for WebSocket connections to various
|
|
cryptocurrency exchanges. It handles connection management, heartbeat
|
|
processing, subscription management, and data routing to Redis.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Optional, Set, Union
|
|
from urllib.parse import urlparse
|
|
|
|
import aioredis
|
|
import websockets
|
|
|
|
from modules.config import get_redis_url
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseConnection(ABC):
|
|
"""
|
|
Abstract base class for cryptocurrency exchange WebSocket connections.
|
|
|
|
This class provides a common interface for WebSocket connections to
|
|
different exchanges with support for:
|
|
- Connection management
|
|
- Heartbeat/ping processing
|
|
- Topic subscription
|
|
- Data routing to Redis
|
|
"""
|
|
|
|
def __init__(self, exchange_name: str, ws_url: str, redis_url: Optional[str] = None):
|
|
"""
|
|
Initialize the base connection.
|
|
|
|
Args:
|
|
exchange_name: Name of the exchange (e.g., 'binance', 'okx')
|
|
ws_url: WebSocket URL for the exchange
|
|
redis_url: Redis connection URL (uses default if not provided)
|
|
"""
|
|
self.exchange_name = exchange_name
|
|
self.ws_url = ws_url
|
|
self.redis_url = redis_url or get_redis_url()
|
|
self.redis_client: Optional[aioredis.Redis] = None
|
|
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
|
|
self.is_connected = False
|
|
self.is_connecting = False
|
|
self.subscribed_topics: Set[str] = set()
|
|
self._heartbeat_interval = 30 # seconds
|
|
self._last_heartbeat = 0
|
|
self._connection_attempts = 0
|
|
self._max_reconnect_attempts = 5
|
|
|
|
@abstractmethod
|
|
async def _parse_message(self, message: str) -> Dict[str, Any]:
|
|
"""
|
|
Parse incoming WebSocket message and extract relevant data.
|
|
|
|
This method should be implemented by subclasses to handle
|
|
exchange-specific message formats.
|
|
|
|
Args:
|
|
message: Raw WebSocket message
|
|
|
|
Returns:
|
|
Parsed message data as dictionary
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _get_subscribe_message(self, topics: List[str]) -> str:
|
|
"""
|
|
Generate subscription message for the exchange.
|
|
|
|
Args:
|
|
topics: List of topics to subscribe to
|
|
|
|
Returns:
|
|
Subscription message as string
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _get_orderbook_topic(self, symbol: str) -> str:
|
|
"""
|
|
Generate orderbook topic name for the exchange.
|
|
|
|
Args:
|
|
symbol: Trading symbol (e.g., 'BTCUSDT')
|
|
|
|
Returns:
|
|
Orderbook topic name
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _get_trade_topic(self, symbol: str) -> str:
|
|
"""
|
|
Generate trade topic name for the exchange.
|
|
|
|
Args:
|
|
symbol: Trading symbol (e.g., 'BTCUSDT')
|
|
|
|
Returns:
|
|
Trade topic name
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _get_symbol_from_topic(self, topic: str) -> str:
|
|
"""
|
|
Extract symbol from topic name.
|
|
|
|
Args:
|
|
topic: Topic name
|
|
|
|
Returns:
|
|
Symbol extracted from topic
|
|
"""
|
|
pass
|
|
|
|
async def connect(self) -> bool:
|
|
"""
|
|
Establish WebSocket connection to the exchange.
|
|
|
|
Returns:
|
|
True if connection successful, False otherwise
|
|
"""
|
|
if self.is_connected or self.is_connecting:
|
|
return True
|
|
|
|
self.is_connecting = True
|
|
try:
|
|
logger.info(f"Connecting to {self.exchange_name} WebSocket: {self.ws_url}")
|
|
|
|
# Connect to Redis
|
|
redis_url = self.redis_url
|
|
if not self.redis_client:
|
|
self.redis_client = await aioredis.from_url(redis_url)
|
|
logger.info(f"Connected to Redis: {redis_url}")
|
|
|
|
# Connect to WebSocket
|
|
self.websocket = await websockets.connect(
|
|
self.ws_url,
|
|
ping_interval=self._heartbeat_interval,
|
|
ping_timeout=self._heartbeat_interval
|
|
)
|
|
|
|
self.is_connected = True
|
|
self.is_connecting = False
|
|
self._connection_attempts = 0
|
|
|
|
logger.info(f"Connected to {self.exchange_name} WebSocket")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to {self.exchange_name}: {e}")
|
|
self.is_connecting = False
|
|
self._connection_attempts += 1
|
|
return False
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Close WebSocket connection and Redis connection."""
|
|
if self.websocket:
|
|
await self.websocket.close()
|
|
self.websocket = None
|
|
if self.redis_client:
|
|
await self.redis_client.close()
|
|
self.redis_client = None
|
|
self.is_connected = False
|
|
|
|
async def _handle_heartbeat(self, message: str) -> bool:
|
|
"""
|
|
Process heartbeat/ping messages.
|
|
|
|
Args:
|
|
message: Incoming message
|
|
|
|
Returns:
|
|
True if message was a heartbeat, False otherwise
|
|
"""
|
|
try:
|
|
data = json.loads(message)
|
|
|
|
# Common heartbeat patterns across exchanges
|
|
if (
|
|
"ping" in data or
|
|
"event" in data and data["event"] == "ping" or
|
|
(isinstance(data.get("op"), str) and data["op"] == "ping")
|
|
):
|
|
# Send pong response
|
|
if "pong" in message:
|
|
# Already a pong, just update heartbeat time
|
|
self._last_heartbeat = time.time()
|
|
else:
|
|
# Send pong response
|
|
pong_msg = message.replace("ping", "pong")
|
|
await self._send_message(pong_msg)
|
|
self._last_heartbeat = time.time()
|
|
|
|
logger.debug(f"Heartbeat received from {self.exchange_name}")
|
|
return True
|
|
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
return False
|
|
|
|
async def subscribe(self, topics: List[str]) -> bool:
|
|
"""
|
|
Subscribe to topics on the exchange.
|
|
|
|
Args:
|
|
topics: List of topics to subscribe to
|
|
|
|
Returns:
|
|
True if subscription successful, False otherwise
|
|
"""
|
|
if not self.is_connected:
|
|
if not await self.connect():
|
|
return False
|
|
|
|
if not self.websocket:
|
|
logger.error("WebSocket not connected for subscription")
|
|
return False
|
|
|
|
try:
|
|
# Generate subscription message
|
|
subscribe_msg = self._get_subscribe_message(topics)
|
|
|
|
# Send subscription message
|
|
await self._send_message(subscribe_msg)
|
|
|
|
# Update subscribed topics
|
|
self.subscribed_topics.update(topics)
|
|
|
|
logger.info(f"Subscribed to topics on {self.exchange_name}: {topics}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Subscription failed on {self.exchange_name}: {e}")
|
|
return False
|
|
|
|
async def _send_message(self, message: str) -> None:
|
|
"""
|
|
Send message through WebSocket connection.
|
|
|
|
Args:
|
|
message: Message to send
|
|
"""
|
|
if self.websocket and self.is_connected:
|
|
await self.websocket.send(message)
|
|
else:
|
|
logger.warning("WebSocket not connected, cannot send message")
|
|
|
|
async def _process_message(self, message: str) -> None:
|
|
"""
|
|
Process incoming WebSocket message and route to Redis.
|
|
|
|
Args:
|
|
message: Raw WebSocket message
|
|
"""
|
|
# Check if it's a heartbeat message
|
|
if await self._handle_heartbeat(message):
|
|
return
|
|
|
|
try:
|
|
# Parse the message
|
|
data = await self._parse_message(message)
|
|
if not data:
|
|
return
|
|
|
|
# Determine the appropriate output channel based on message type
|
|
channel = self._determine_output_channel(data)
|
|
if not channel:
|
|
return
|
|
|
|
# Add metadata
|
|
data["exchange"] = self.exchange_name
|
|
data["timestamp"] = time.time()
|
|
|
|
# Send to Redis
|
|
await self._send_to_redis(channel, data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing message from {self.exchange_name}: {e}")
|
|
|
|
def _determine_output_channel(self, data: Dict[str, Any]) -> Optional[str]:
|
|
"""
|
|
Determine the appropriate Redis channel based on data type.
|
|
|
|
Args:
|
|
data: Parsed message data
|
|
|
|
Returns:
|
|
Redis channel name or None if no channel should be used
|
|
"""
|
|
# Default to None if no channel is determined
|
|
channel = None
|
|
|
|
# Try to determine channel from message structure
|
|
if "e" in data: # Binance-style events
|
|
event_type = data.get("e")
|
|
if event_type == "depthUpdate":
|
|
channel = "orderbook"
|
|
elif event_type == "trade":
|
|
channel = "trade"
|
|
elif "event" in data: # OKX-style events
|
|
event_type = data.get("event")
|
|
if event_type == "depth":
|
|
channel = "orderbook"
|
|
elif event_type == "trade":
|
|
channel = "trade"
|
|
elif "topic" in data: # Generic topic-based
|
|
topic = data.get("topic")
|
|
if "depth" in topic:
|
|
channel = "orderbook"
|
|
elif "trade" in topic:
|
|
channel = "trade"
|
|
|
|
# If we have a symbol, we might want to add it to the channel
|
|
symbol = data.get("s") or data.get("symbol")
|
|
if symbol and channel:
|
|
return f"{channel}:{symbol}"
|
|
elif channel:
|
|
return channel
|
|
|
|
return None
|
|
|
|
async def _send_to_redis(self, channel: str, data: Dict[str, Any]) -> None:
|
|
"""
|
|
Send data to Redis channel.
|
|
|
|
Args:
|
|
channel: Redis channel to send data to
|
|
data: Data to send
|
|
"""
|
|
if not self.redis_client:
|
|
logger.warning("Redis client not initialized")
|
|
return
|
|
|
|
try:
|
|
# Convert data to JSON string
|
|
message = json.dumps(data)
|
|
|
|
# Send to Redis
|
|
await self.redis_client.publish(channel, message)
|
|
logger.debug(f"Published to Redis channel {channel}: {data}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to send to Redis channel {channel}: {e}")
|
|
|
|
async def start_listening(self) -> None:
|
|
"""
|
|
Start listening for incoming messages from WebSocket.
|
|
|
|
This method should be run in an asyncio event loop.
|
|
"""
|
|
if not self.is_connected:
|
|
if not await self.connect():
|
|
logger.error(f"Cannot start listening - failed to connect to {self.exchange_name}")
|
|
return
|
|
|
|
try:
|
|
while self.is_connected and self.websocket:
|
|
try:
|
|
message = await asyncio.wait_for(
|
|
self.websocket.recv(),
|
|
timeout=60.0 # 60 second timeout
|
|
)
|
|
await self._process_message(message)
|
|
except asyncio.TimeoutError:
|
|
# Handle timeout - send ping to keep connection alive
|
|
await self._send_message('{"ping":1}')
|
|
except websockets.exceptions.ConnectionClosed:
|
|
logger.warning(f"Connection closed by {self.exchange_name}")
|
|
await self.reconnect()
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in message loop for {self.exchange_name}: {e}")
|
|
await self.reconnect()
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in listening loop for {self.exchange_name}: {e}")
|
|
|
|
async def reconnect(self) -> bool:
|
|
"""
|
|
Attempt to reconnect to the exchange.
|
|
|
|
Returns:
|
|
True if reconnection successful, False otherwise
|
|
"""
|
|
logger.info(f"Attempting to reconnect to {self.exchange_name}")
|
|
|
|
# Disconnect first
|
|
await self.disconnect()
|
|
|
|
# If we've exceeded max attempts, stop
|
|
if self._connection_attempts >= self._max_reconnect_attempts:
|
|
logger.error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded for {self.exchange_name}")
|
|
return False
|
|
|
|
try:
|
|
# Wait before reconnecting
|
|
await asyncio.sleep(2 ** self._connection_attempts) # Exponential backoff
|
|
|
|
# Try to reconnect
|
|
success = await self.connect()
|
|
if success:
|
|
# Resubscribe to topics
|
|
if self.subscribed_topics:
|
|
await self.subscribe(list(self.subscribed_topics))
|
|
else:
|
|
logger.error(f"Reconnection failed for {self.exchange_name}")
|
|
|
|
return success
|
|
|
|
except Exception as e:
|
|
logger.error(f"Exception during reconnection to {self.exchange_name}: {e}")
|
|
self._connection_attempts += 1
|
|
return False |