""" Test module for WebSocket connection classes. """ import asyncio import unittest from unittest.mock import AsyncMock, MagicMock, patch from modules.binance_connection import BinanceConnection from modules.connections import BaseConnection class TestBaseConnection(unittest.TestCase): """Test the BaseConnection class functionality.""" def test_base_connection_initialization(self): """Test that BaseConnection initializes correctly.""" connection = BinanceConnection() self.assertEqual(connection.exchange_name, "binance") self.assertEqual(connection.ws_url, "wss://stream.binance.com:9443") self.assertIsNone(connection.websocket) self.assertIsNone(connection.redis_client) self.assertFalse(connection.is_connected) self.assertFalse(connection.is_connecting) @patch('modules.connections.aioredis.from_url') @patch('modules.connections.websockets.connect') async def test_connect_success(self, mock_websocket_connect, mock_redis_connect): """Test successful connection.""" # Mock the websocket connection mock_ws = AsyncMock() mock_websocket_connect.return_value = mock_ws # Mock the redis connection mock_redis = AsyncMock() mock_redis_connect.return_value = mock_redis connection = BinanceConnection() # Test connection result = await connection.connect() self.assertTrue(result) self.assertTrue(connection.is_connected) # Verify connections were made mock_websocket_connect.assert_called_once() mock_redis_connect.assert_called_once() class TestBinanceConnection(unittest.TestCase): """Test the BinanceConnection class functionality.""" def test_binance_connection_initialization(self): """Test that BinanceConnection initializes correctly.""" connection = BinanceConnection() self.assertEqual(connection.exchange_name, "binance") self.assertEqual(connection.ws_url, "wss://stream.binance.com:9443") def test_get_subscribe_message(self): """Test subscription message generation.""" connection = BinanceConnection() topics = ["btcusdt@depth20@100ms", "ethusdt@aggTrade"] message = connection._get_subscribe_message(topics) data = eval(message) # Parse the dict string self.assertEqual(data["method"], "SUBSCRIBE") self.assertEqual(data["params"], topics) self.assertEqual(data["id"], 1) def test_get_orderbook_topic(self): """Test orderbook topic generation.""" connection = BinanceConnection() topic = connection._get_orderbook_topic("BTCUSDT") self.assertEqual(topic, "btcusdt@depth20@100ms") def test_get_trade_topic(self): """Test trade topic generation.""" connection = BinanceConnection() topic = connection._get_trade_topic("BTCUSDT") self.assertEqual(topic, "btcusdt@aggTrade") if __name__ == '__main__': unittest.main()