feat: WebSocket connection
This commit is contained in:
105
modules/binance_connection.py
Normal file
105
modules/binance_connection.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""
|
||||||
|
Binance-specific WebSocket connection implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from modules.connections import BaseConnection
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BinanceConnection(BaseConnection):
|
||||||
|
"""
|
||||||
|
Binance WebSocket connection implementation.
|
||||||
|
|
||||||
|
This class implements the BaseConnection interface for Binance exchange
|
||||||
|
with specific handling for Binance WebSocket API format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ws_url: str = "wss://stream.binance.com:9443"):
|
||||||
|
"""
|
||||||
|
Initialize Binance connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ws_url: Binance WebSocket URL
|
||||||
|
"""
|
||||||
|
super().__init__("binance", ws_url)
|
||||||
|
|
||||||
|
async def _parse_message(self, message: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Parse Binance WebSocket message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Raw WebSocket message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed message data
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = json.loads(message)
|
||||||
|
return data
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"Failed to parse Binance message: {message}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _get_subscribe_message(self, topics: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Generate Binance subscription message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topics: List of topics to subscribe to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Subscription message as string
|
||||||
|
"""
|
||||||
|
# Binance uses a different format for subscription
|
||||||
|
subscribe_msg = {
|
||||||
|
"method": "SUBSCRIBE",
|
||||||
|
"params": topics,
|
||||||
|
"id": 1
|
||||||
|
}
|
||||||
|
return json.dumps(subscribe_msg)
|
||||||
|
|
||||||
|
def _get_orderbook_topic(self, symbol: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate Binance orderbook topic name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Orderbook topic name
|
||||||
|
"""
|
||||||
|
return f"{symbol.lower()}@depth20@100ms"
|
||||||
|
|
||||||
|
def _get_trade_topic(self, symbol: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate Binance trade topic name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Trade topic name
|
||||||
|
"""
|
||||||
|
return f"{symbol.lower()}@aggTrade"
|
||||||
|
|
||||||
|
def _get_symbol_from_topic(self, topic: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract symbol from Binance topic name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: Topic name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Symbol extracted from topic
|
||||||
|
"""
|
||||||
|
# Example Binance topics: "btcusdt@depth20@100ms", "btcusdt@aggTrade"
|
||||||
|
# Extract the symbol part
|
||||||
|
if "@" in topic:
|
||||||
|
symbol = topic.split("@")[0].upper()
|
||||||
|
return symbol
|
||||||
|
return topic.upper()
|
||||||
68
modules/config.py
Normal file
68
modules/config.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""
|
||||||
|
Configuration management for the cryptocurrency trading platform.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables from .env file
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis_url() -> str:
|
||||||
|
"""
|
||||||
|
Get Redis connection URL from environment variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Redis connection URL
|
||||||
|
"""
|
||||||
|
return os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||||
|
|
||||||
|
|
||||||
|
def get_timescaledb_url() -> str:
|
||||||
|
"""
|
||||||
|
Get TimescaleDB connection URL from environment variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TimescaleDB connection URL
|
||||||
|
"""
|
||||||
|
return os.getenv("TIMESCALEDB_URL", "postgresql://localhost:5432/trading")
|
||||||
|
|
||||||
|
|
||||||
|
def get_supported_exchanges() -> list:
|
||||||
|
"""
|
||||||
|
Get list of supported exchanges from environment variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of supported exchange names
|
||||||
|
"""
|
||||||
|
exchanges = os.getenv("SUPPORTED_EXCHANGES", "binance")
|
||||||
|
return [ex.strip() for ex in exchanges.split(",") if ex.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def get_exchange_config(exchange_name: str) -> dict:
|
||||||
|
"""
|
||||||
|
Get configuration for a specific exchange.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange_name: Name of the exchange
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exchange configuration dictionary
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"binance": {
|
||||||
|
"ws_url": "wss://stream.binance.com:9443",
|
||||||
|
"api_url": "https://api.binance.com",
|
||||||
|
"max_concurrent_connections": 1000,
|
||||||
|
},
|
||||||
|
"okx": {
|
||||||
|
"ws_url": "wss://ws.okx.com:8443",
|
||||||
|
"api_url": "https://www.okx.com",
|
||||||
|
"max_concurrent_connections": 1000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return config.get(exchange_name.lower(), {})
|
||||||
427
modules/connections.py
Normal file
427
modules/connections.py
Normal file
@@ -0,0 +1,427 @@
|
|||||||
|
"""
|
||||||
|
Base connection class for cryptocurrency exchange WebSocket clients.
|
||||||
|
|
||||||
|
This module provides a base class for WebSocket connections to various
|
||||||
|
cryptocurrency exchanges. It handles connection management, heartbeat
|
||||||
|
processing, subscription management, and data routing to Redis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import aioredis
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from modules.config import get_redis_url
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConnection(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for cryptocurrency exchange WebSocket connections.
|
||||||
|
|
||||||
|
This class provides a common interface for WebSocket connections to
|
||||||
|
different exchanges with support for:
|
||||||
|
- Connection management
|
||||||
|
- Heartbeat/ping processing
|
||||||
|
- Topic subscription
|
||||||
|
- Data routing to Redis
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, exchange_name: str, ws_url: str, redis_url: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the base connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange_name: Name of the exchange (e.g., 'binance', 'okx')
|
||||||
|
ws_url: WebSocket URL for the exchange
|
||||||
|
redis_url: Redis connection URL (uses default if not provided)
|
||||||
|
"""
|
||||||
|
self.exchange_name = exchange_name
|
||||||
|
self.ws_url = ws_url
|
||||||
|
self.redis_url = redis_url or get_redis_url()
|
||||||
|
self.redis_client: Optional[aioredis.Redis] = None
|
||||||
|
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
|
||||||
|
self.is_connected = False
|
||||||
|
self.is_connecting = False
|
||||||
|
self.subscribed_topics: Set[str] = set()
|
||||||
|
self._heartbeat_interval = 30 # seconds
|
||||||
|
self._last_heartbeat = 0
|
||||||
|
self._connection_attempts = 0
|
||||||
|
self._max_reconnect_attempts = 5
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _parse_message(self, message: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Parse incoming WebSocket message and extract relevant data.
|
||||||
|
|
||||||
|
This method should be implemented by subclasses to handle
|
||||||
|
exchange-specific message formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Raw WebSocket message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed message data as dictionary
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_subscribe_message(self, topics: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Generate subscription message for the exchange.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topics: List of topics to subscribe to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Subscription message as string
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_orderbook_topic(self, symbol: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate orderbook topic name for the exchange.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol (e.g., 'BTCUSDT')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Orderbook topic name
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_trade_topic(self, symbol: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate trade topic name for the exchange.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol (e.g., 'BTCUSDT')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Trade topic name
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_symbol_from_topic(self, topic: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract symbol from topic name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: Topic name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Symbol extracted from topic
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def connect(self) -> bool:
|
||||||
|
"""
|
||||||
|
Establish WebSocket connection to the exchange.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection successful, False otherwise
|
||||||
|
"""
|
||||||
|
if self.is_connected or self.is_connecting:
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.is_connecting = True
|
||||||
|
try:
|
||||||
|
logger.info(f"Connecting to {self.exchange_name} WebSocket: {self.ws_url}")
|
||||||
|
|
||||||
|
# Connect to Redis
|
||||||
|
redis_url = self.redis_url
|
||||||
|
if not self.redis_client:
|
||||||
|
self.redis_client = await aioredis.from_url(redis_url)
|
||||||
|
logger.info(f"Connected to Redis: {redis_url}")
|
||||||
|
|
||||||
|
# Connect to WebSocket
|
||||||
|
self.websocket = await websockets.connect(
|
||||||
|
self.ws_url,
|
||||||
|
ping_interval=self._heartbeat_interval,
|
||||||
|
ping_timeout=self._heartbeat_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
self.is_connected = True
|
||||||
|
self.is_connecting = False
|
||||||
|
self._connection_attempts = 0
|
||||||
|
|
||||||
|
logger.info(f"Connected to {self.exchange_name} WebSocket")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to {self.exchange_name}: {e}")
|
||||||
|
self.is_connecting = False
|
||||||
|
self._connection_attempts += 1
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Close WebSocket connection and Redis connection."""
|
||||||
|
if self.websocket:
|
||||||
|
await self.websocket.close()
|
||||||
|
self.websocket = None
|
||||||
|
if self.redis_client:
|
||||||
|
await self.redis_client.close()
|
||||||
|
self.redis_client = None
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
|
async def _handle_heartbeat(self, message: str) -> bool:
|
||||||
|
"""
|
||||||
|
Process heartbeat/ping messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Incoming message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if message was a heartbeat, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = json.loads(message)
|
||||||
|
|
||||||
|
# Common heartbeat patterns across exchanges
|
||||||
|
if (
|
||||||
|
"ping" in data or
|
||||||
|
"event" in data and data["event"] == "ping" or
|
||||||
|
(isinstance(data.get("op"), str) and data["op"] == "ping")
|
||||||
|
):
|
||||||
|
# Send pong response
|
||||||
|
if "pong" in message:
|
||||||
|
# Already a pong, just update heartbeat time
|
||||||
|
self._last_heartbeat = time.time()
|
||||||
|
else:
|
||||||
|
# Send pong response
|
||||||
|
pong_msg = message.replace("ping", "pong")
|
||||||
|
await self._send_message(pong_msg)
|
||||||
|
self._last_heartbeat = time.time()
|
||||||
|
|
||||||
|
logger.debug(f"Heartbeat received from {self.exchange_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def subscribe(self, topics: List[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Subscribe to topics on the exchange.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topics: List of topics to subscribe to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if subscription successful, False otherwise
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
if not await self.connect():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not self.websocket:
|
||||||
|
logger.error("WebSocket not connected for subscription")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate subscription message
|
||||||
|
subscribe_msg = self._get_subscribe_message(topics)
|
||||||
|
|
||||||
|
# Send subscription message
|
||||||
|
await self._send_message(subscribe_msg)
|
||||||
|
|
||||||
|
# Update subscribed topics
|
||||||
|
self.subscribed_topics.update(topics)
|
||||||
|
|
||||||
|
logger.info(f"Subscribed to topics on {self.exchange_name}: {topics}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Subscription failed on {self.exchange_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _send_message(self, message: str) -> None:
|
||||||
|
"""
|
||||||
|
Send message through WebSocket connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Message to send
|
||||||
|
"""
|
||||||
|
if self.websocket and self.is_connected:
|
||||||
|
await self.websocket.send(message)
|
||||||
|
else:
|
||||||
|
logger.warning("WebSocket not connected, cannot send message")
|
||||||
|
|
||||||
|
async def _process_message(self, message: str) -> None:
|
||||||
|
"""
|
||||||
|
Process incoming WebSocket message and route to Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Raw WebSocket message
|
||||||
|
"""
|
||||||
|
# Check if it's a heartbeat message
|
||||||
|
if await self._handle_heartbeat(message):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse the message
|
||||||
|
data = await self._parse_message(message)
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine the appropriate output channel based on message type
|
||||||
|
channel = self._determine_output_channel(data)
|
||||||
|
if not channel:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add metadata
|
||||||
|
data["exchange"] = self.exchange_name
|
||||||
|
data["timestamp"] = time.time()
|
||||||
|
|
||||||
|
# Send to Redis
|
||||||
|
await self._send_to_redis(channel, data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing message from {self.exchange_name}: {e}")
|
||||||
|
|
||||||
|
def _determine_output_channel(self, data: Dict[str, Any]) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Determine the appropriate Redis channel based on data type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Parsed message data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Redis channel name or None if no channel should be used
|
||||||
|
"""
|
||||||
|
# Default to None if no channel is determined
|
||||||
|
channel = None
|
||||||
|
|
||||||
|
# Try to determine channel from message structure
|
||||||
|
if "e" in data: # Binance-style events
|
||||||
|
event_type = data.get("e")
|
||||||
|
if event_type == "depthUpdate":
|
||||||
|
channel = "orderbook"
|
||||||
|
elif event_type == "trade":
|
||||||
|
channel = "trade"
|
||||||
|
elif "event" in data: # OKX-style events
|
||||||
|
event_type = data.get("event")
|
||||||
|
if event_type == "depth":
|
||||||
|
channel = "orderbook"
|
||||||
|
elif event_type == "trade":
|
||||||
|
channel = "trade"
|
||||||
|
elif "topic" in data: # Generic topic-based
|
||||||
|
topic = data.get("topic")
|
||||||
|
if "depth" in topic:
|
||||||
|
channel = "orderbook"
|
||||||
|
elif "trade" in topic:
|
||||||
|
channel = "trade"
|
||||||
|
|
||||||
|
# If we have a symbol, we might want to add it to the channel
|
||||||
|
symbol = data.get("s") or data.get("symbol")
|
||||||
|
if symbol and channel:
|
||||||
|
return f"{channel}:{symbol}"
|
||||||
|
elif channel:
|
||||||
|
return channel
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _send_to_redis(self, channel: str, data: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Send data to Redis channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: Redis channel to send data to
|
||||||
|
data: Data to send
|
||||||
|
"""
|
||||||
|
if not self.redis_client:
|
||||||
|
logger.warning("Redis client not initialized")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert data to JSON string
|
||||||
|
message = json.dumps(data)
|
||||||
|
|
||||||
|
# Send to Redis
|
||||||
|
await self.redis_client.publish(channel, message)
|
||||||
|
logger.debug(f"Published to Redis channel {channel}: {data}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send to Redis channel {channel}: {e}")
|
||||||
|
|
||||||
|
async def start_listening(self) -> None:
|
||||||
|
"""
|
||||||
|
Start listening for incoming messages from WebSocket.
|
||||||
|
|
||||||
|
This method should be run in an asyncio event loop.
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
if not await self.connect():
|
||||||
|
logger.error(f"Cannot start listening - failed to connect to {self.exchange_name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
while self.is_connected and self.websocket:
|
||||||
|
try:
|
||||||
|
message = await asyncio.wait_for(
|
||||||
|
self.websocket.recv(),
|
||||||
|
timeout=60.0 # 60 second timeout
|
||||||
|
)
|
||||||
|
await self._process_message(message)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Handle timeout - send ping to keep connection alive
|
||||||
|
await self._send_message('{"ping":1}')
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
logger.warning(f"Connection closed by {self.exchange_name}")
|
||||||
|
await self.reconnect()
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in message loop for {self.exchange_name}: {e}")
|
||||||
|
await self.reconnect()
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in listening loop for {self.exchange_name}: {e}")
|
||||||
|
|
||||||
|
async def reconnect(self) -> bool:
|
||||||
|
"""
|
||||||
|
Attempt to reconnect to the exchange.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if reconnection successful, False otherwise
|
||||||
|
"""
|
||||||
|
logger.info(f"Attempting to reconnect to {self.exchange_name}")
|
||||||
|
|
||||||
|
# Disconnect first
|
||||||
|
await self.disconnect()
|
||||||
|
|
||||||
|
# If we've exceeded max attempts, stop
|
||||||
|
if self._connection_attempts >= self._max_reconnect_attempts:
|
||||||
|
logger.error(f"Max reconnection attempts ({self._max_reconnect_attempts}) exceeded for {self.exchange_name}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wait before reconnecting
|
||||||
|
await asyncio.sleep(2 ** self._connection_attempts) # Exponential backoff
|
||||||
|
|
||||||
|
# Try to reconnect
|
||||||
|
success = await self.connect()
|
||||||
|
if success:
|
||||||
|
# Resubscribe to topics
|
||||||
|
if self.subscribed_topics:
|
||||||
|
await self.subscribe(list(self.subscribed_topics))
|
||||||
|
else:
|
||||||
|
logger.error(f"Reconnection failed for {self.exchange_name}")
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception during reconnection to {self.exchange_name}: {e}")
|
||||||
|
self._connection_attempts += 1
|
||||||
|
return False
|
||||||
83
tests/test_connections.py
Normal file
83
tests/test_connections.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""
|
||||||
|
Test module for WebSocket connection classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from modules.binance_connection import BinanceConnection
|
||||||
|
from modules.connections import BaseConnection
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseConnection(unittest.TestCase):
|
||||||
|
"""Test the BaseConnection class functionality."""
|
||||||
|
|
||||||
|
def test_base_connection_initialization(self):
|
||||||
|
"""Test that BaseConnection initializes correctly."""
|
||||||
|
connection = BinanceConnection()
|
||||||
|
self.assertEqual(connection.exchange_name, "binance")
|
||||||
|
self.assertEqual(connection.ws_url, "wss://stream.binance.com:9443")
|
||||||
|
self.assertIsNone(connection.websocket)
|
||||||
|
self.assertIsNone(connection.redis_client)
|
||||||
|
self.assertFalse(connection.is_connected)
|
||||||
|
self.assertFalse(connection.is_connecting)
|
||||||
|
|
||||||
|
@patch('modules.connections.aioredis.from_url')
|
||||||
|
@patch('modules.connections.websockets.connect')
|
||||||
|
async def test_connect_success(self, mock_websocket_connect, mock_redis_connect):
|
||||||
|
"""Test successful connection."""
|
||||||
|
# Mock the websocket connection
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
mock_websocket_connect.return_value = mock_ws
|
||||||
|
|
||||||
|
# Mock the redis connection
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis_connect.return_value = mock_redis
|
||||||
|
|
||||||
|
connection = BinanceConnection()
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
result = await connection.connect()
|
||||||
|
self.assertTrue(result)
|
||||||
|
self.assertTrue(connection.is_connected)
|
||||||
|
|
||||||
|
# Verify connections were made
|
||||||
|
mock_websocket_connect.assert_called_once()
|
||||||
|
mock_redis_connect.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestBinanceConnection(unittest.TestCase):
|
||||||
|
"""Test the BinanceConnection class functionality."""
|
||||||
|
|
||||||
|
def test_binance_connection_initialization(self):
|
||||||
|
"""Test that BinanceConnection initializes correctly."""
|
||||||
|
connection = BinanceConnection()
|
||||||
|
self.assertEqual(connection.exchange_name, "binance")
|
||||||
|
self.assertEqual(connection.ws_url, "wss://stream.binance.com:9443")
|
||||||
|
|
||||||
|
def test_get_subscribe_message(self):
|
||||||
|
"""Test subscription message generation."""
|
||||||
|
connection = BinanceConnection()
|
||||||
|
topics = ["btcusdt@depth20@100ms", "ethusdt@aggTrade"]
|
||||||
|
message = connection._get_subscribe_message(topics)
|
||||||
|
data = eval(message) # Parse the dict string
|
||||||
|
self.assertEqual(data["method"], "SUBSCRIBE")
|
||||||
|
self.assertEqual(data["params"], topics)
|
||||||
|
self.assertEqual(data["id"], 1)
|
||||||
|
|
||||||
|
def test_get_orderbook_topic(self):
|
||||||
|
"""Test orderbook topic generation."""
|
||||||
|
connection = BinanceConnection()
|
||||||
|
topic = connection._get_orderbook_topic("BTCUSDT")
|
||||||
|
self.assertEqual(topic, "btcusdt@depth20@100ms")
|
||||||
|
|
||||||
|
def test_get_trade_topic(self):
|
||||||
|
"""Test trade topic generation."""
|
||||||
|
connection = BinanceConnection()
|
||||||
|
topic = connection._get_trade_topic("BTCUSDT")
|
||||||
|
self.assertEqual(topic, "btcusdt@aggTrade")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user