| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357 |
- # -*- coding: utf-8 -*-
- import errno
- import socket
- import unittest
- from unittest.mock import Mock, patch, MagicMock
- import time
- from websocket._socket import recv, recv_line, send, DEFAULT_SOCKET_OPTION
- from websocket._ssl_compat import (
- SSLError,
- SSLEOFError,
- SSLWantWriteError,
- SSLWantReadError,
- )
- from websocket._exceptions import (
- WebSocketTimeoutException,
- WebSocketConnectionClosedException,
- )
- """
- test_socket.py
- websocket - WebSocket client library for Python
- Copyright 2025 engn33r
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- """
- class SocketTest(unittest.TestCase):
- def test_default_socket_option(self):
- """Test DEFAULT_SOCKET_OPTION contains expected options"""
- self.assertIsInstance(DEFAULT_SOCKET_OPTION, list)
- self.assertGreater(len(DEFAULT_SOCKET_OPTION), 0)
- # Should contain TCP_NODELAY option
- tcp_nodelay_found = any(
- opt[1] == socket.TCP_NODELAY for opt in DEFAULT_SOCKET_OPTION
- )
- self.assertTrue(tcp_nodelay_found)
- def test_recv_normal(self):
- """Test normal recv operation"""
- mock_sock = Mock()
- mock_sock.recv.return_value = b"test data"
- result = recv(mock_sock, 9)
- self.assertEqual(result, b"test data")
- mock_sock.recv.assert_called_once_with(9)
- def test_recv_timeout_error(self):
- """Test recv with TimeoutError"""
- mock_sock = Mock()
- mock_sock.recv.side_effect = TimeoutError("Connection timed out")
- with self.assertRaises(WebSocketTimeoutException) as cm:
- recv(mock_sock, 9)
- self.assertEqual(str(cm.exception), "Connection timed out")
- def test_recv_socket_timeout(self):
- """Test recv with socket.timeout"""
- mock_sock = Mock()
- timeout_exc = socket.timeout("Socket timed out")
- timeout_exc.args = ("Socket timed out",)
- mock_sock.recv.side_effect = timeout_exc
- mock_sock.gettimeout.return_value = 30.0
- with self.assertRaises(WebSocketTimeoutException) as cm:
- recv(mock_sock, 9)
- # In Python 3.10+, socket.timeout is a subclass of TimeoutError
- # so it's caught by the TimeoutError handler with hardcoded message
- # In Python 3.9, socket.timeout is caught by socket.timeout handler
- # which preserves the original message
- import sys
- if sys.version_info >= (3, 10):
- self.assertEqual(str(cm.exception), "Connection timed out")
- else:
- self.assertEqual(str(cm.exception), "Socket timed out")
- def test_recv_ssl_timeout(self):
- """Test recv with SSL timeout error"""
- mock_sock = Mock()
- ssl_exc = SSLError("The operation timed out")
- ssl_exc.args = ("The operation timed out",)
- mock_sock.recv.side_effect = ssl_exc
- with self.assertRaises(WebSocketTimeoutException) as cm:
- recv(mock_sock, 9)
- self.assertEqual(str(cm.exception), "The operation timed out")
- def test_recv_ssl_non_timeout_error(self):
- """Test recv with SSL non-timeout error"""
- mock_sock = Mock()
- ssl_exc = SSLError("SSL certificate error")
- ssl_exc.args = ("SSL certificate error",)
- mock_sock.recv.side_effect = ssl_exc
- # Should re-raise the original SSL error
- with self.assertRaises(SSLError):
- recv(mock_sock, 9)
- def test_recv_empty_response(self):
- """Test recv with empty response (connection closed)"""
- mock_sock = Mock()
- mock_sock.recv.return_value = b""
- with self.assertRaises(WebSocketConnectionClosedException) as cm:
- recv(mock_sock, 9)
- self.assertEqual(str(cm.exception), "Connection to remote host was lost.")
- def test_recv_ssl_want_read_error(self):
- """Test recv with SSLWantReadError (should retry)"""
- mock_sock = Mock()
- # First call raises SSLWantReadError, second call succeeds
- mock_sock.recv.side_effect = [SSLWantReadError(), b"data after retry"]
- with patch("selectors.DefaultSelector") as mock_selector_class:
- mock_selector = Mock()
- mock_selector_class.return_value = mock_selector
- mock_selector.select.return_value = [True] # Ready to read
- result = recv(mock_sock, 100)
- self.assertEqual(result, b"data after retry")
- mock_selector.register.assert_called()
- mock_selector.close.assert_called()
- def test_recv_ssl_want_read_timeout(self):
- """Test recv with SSLWantReadError that times out"""
- mock_sock = Mock()
- mock_sock.recv.side_effect = SSLWantReadError()
- mock_sock.gettimeout.return_value = 1.0
- with patch("selectors.DefaultSelector") as mock_selector_class:
- mock_selector = Mock()
- mock_selector_class.return_value = mock_selector
- mock_selector.select.return_value = [] # Timeout
- with self.assertRaises(WebSocketTimeoutException):
- recv(mock_sock, 100)
- def test_recv_line(self):
- """Test recv_line functionality"""
- mock_sock = Mock()
- # Mock recv to return one character at a time
- recv_calls = [b"H", b"e", b"l", b"l", b"o", b"\n"]
- with patch("websocket._socket.recv", side_effect=recv_calls) as mock_recv:
- result = recv_line(mock_sock)
- self.assertEqual(result, b"Hello\n")
- self.assertEqual(mock_recv.call_count, 6)
- def test_send_normal(self):
- """Test normal send operation"""
- mock_sock = Mock()
- mock_sock.send.return_value = 9
- mock_sock.gettimeout.return_value = 30.0
- result = send(mock_sock, b"test data")
- self.assertEqual(result, 9)
- mock_sock.send.assert_called_with(b"test data")
- def test_send_zero_timeout(self):
- """Test send with zero timeout (non-blocking)"""
- mock_sock = Mock()
- mock_sock.send.return_value = 9
- mock_sock.gettimeout.return_value = 0
- result = send(mock_sock, b"test data")
- self.assertEqual(result, 9)
- mock_sock.send.assert_called_once_with(b"test data")
- def test_send_ssl_eof_error(self):
- """Test send with SSLEOFError"""
- mock_sock = Mock()
- mock_sock.gettimeout.return_value = 30.0
- mock_sock.send.side_effect = SSLEOFError("Connection closed")
- with self.assertRaises(WebSocketConnectionClosedException) as cm:
- send(mock_sock, b"test data")
- self.assertEqual(str(cm.exception), "socket is already closed.")
- def test_send_ssl_want_write_error(self):
- """Test send with SSLWantWriteError (should retry)"""
- mock_sock = Mock()
- mock_sock.gettimeout.return_value = 30.0
- # First call raises SSLWantWriteError, second call succeeds
- mock_sock.send.side_effect = [SSLWantWriteError(), 9]
- with patch("selectors.DefaultSelector") as mock_selector_class:
- mock_selector = Mock()
- mock_selector_class.return_value = mock_selector
- mock_selector.select.return_value = [True] # Ready to write
- result = send(mock_sock, b"test data")
- self.assertEqual(result, 9)
- mock_selector.register.assert_called()
- mock_selector.close.assert_called()
- def test_send_socket_eagain_error(self):
- """Test send with EAGAIN error (should retry)"""
- mock_sock = Mock()
- mock_sock.gettimeout.return_value = 30.0
- # Create socket error with EAGAIN
- eagain_error = socket.error("Resource temporarily unavailable")
- eagain_error.errno = errno.EAGAIN
- eagain_error.args = (errno.EAGAIN, "Resource temporarily unavailable")
- # First call raises EAGAIN, second call succeeds
- mock_sock.send.side_effect = [eagain_error, 9]
- with patch("selectors.DefaultSelector") as mock_selector_class:
- mock_selector = Mock()
- mock_selector_class.return_value = mock_selector
- mock_selector.select.return_value = [True] # Ready to write
- result = send(mock_sock, b"test data")
- self.assertEqual(result, 9)
- def test_send_socket_ewouldblock_error(self):
- """Test send with EWOULDBLOCK error (should retry)"""
- mock_sock = Mock()
- mock_sock.gettimeout.return_value = 30.0
- # Create socket error with EWOULDBLOCK
- ewouldblock_error = socket.error("Operation would block")
- ewouldblock_error.errno = errno.EWOULDBLOCK
- ewouldblock_error.args = (errno.EWOULDBLOCK, "Operation would block")
- # First call raises EWOULDBLOCK, second call succeeds
- mock_sock.send.side_effect = [ewouldblock_error, 9]
- with patch("selectors.DefaultSelector") as mock_selector_class:
- mock_selector = Mock()
- mock_selector_class.return_value = mock_selector
- mock_selector.select.return_value = [True] # Ready to write
- result = send(mock_sock, b"test data")
- self.assertEqual(result, 9)
- def test_send_socket_other_error(self):
- """Test send with other socket error (should raise)"""
- mock_sock = Mock()
- mock_sock.gettimeout.return_value = 30.0
- # Create socket error with different errno
- other_error = socket.error("Connection reset by peer")
- other_error.errno = errno.ECONNRESET
- other_error.args = (errno.ECONNRESET, "Connection reset by peer")
- mock_sock.send.side_effect = other_error
- with self.assertRaises(socket.error):
- send(mock_sock, b"test data")
- def test_send_socket_error_no_errno(self):
- """Test send with socket error that has no errno"""
- mock_sock = Mock()
- mock_sock.gettimeout.return_value = 30.0
- # Create socket error without errno attribute
- no_errno_error = socket.error("Generic socket error")
- no_errno_error.args = ("Generic socket error",)
- mock_sock.send.side_effect = no_errno_error
- with self.assertRaises(socket.error):
- send(mock_sock, b"test data")
- def test_send_write_timeout(self):
- """Test send write operation timeout"""
- mock_sock = Mock()
- mock_sock.gettimeout.return_value = 30.0
- # First call raises EAGAIN
- eagain_error = socket.error("Resource temporarily unavailable")
- eagain_error.errno = errno.EAGAIN
- eagain_error.args = (errno.EAGAIN, "Resource temporarily unavailable")
- mock_sock.send.side_effect = eagain_error
- with patch("selectors.DefaultSelector") as mock_selector_class:
- mock_selector = Mock()
- mock_selector_class.return_value = mock_selector
- mock_selector.select.return_value = [] # Timeout - nothing ready
- result = send(mock_sock, b"test data")
- # Should return 0 when write times out
- self.assertEqual(result, 0)
- def test_send_string_data(self):
- """Test send with string data (should be encoded)"""
- mock_sock = Mock()
- mock_sock.send.return_value = 9
- mock_sock.gettimeout.return_value = 30.0
- result = send(mock_sock, "test data")
- self.assertEqual(result, 9)
- mock_sock.send.assert_called_with(b"test data")
- def test_send_partial_send_retry(self):
- """Test send retry mechanism"""
- mock_sock = Mock()
- mock_sock.gettimeout.return_value = 30.0
- # Create a scenario where send succeeds after selector retry
- eagain_error = socket.error("Resource temporarily unavailable")
- eagain_error.errno = errno.EAGAIN
- eagain_error.args = (errno.EAGAIN, "Resource temporarily unavailable")
- # Mock the internal _send function behavior
- mock_sock.send.side_effect = [eagain_error, 9]
- with patch("selectors.DefaultSelector") as mock_selector_class:
- mock_selector = Mock()
- mock_selector_class.return_value = mock_selector
- mock_selector.select.return_value = [True] # Socket ready for writing
- result = send(mock_sock, b"test data")
- self.assertEqual(result, 9)
- # Verify selector was used for retry mechanism
- mock_selector.register.assert_called()
- mock_selector.select.assert_called()
- mock_selector.close.assert_called()
- if __name__ == "__main__":
- unittest.main()
|