_socket.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import errno
  2. import selectors
  3. import socket
  4. from typing import Optional, Union, Any
  5. from ._exceptions import (
  6. WebSocketConnectionClosedException,
  7. WebSocketTimeoutException,
  8. )
  9. from ._ssl_compat import SSLError, SSLEOFError, SSLWantReadError, SSLWantWriteError
  10. from ._utils import extract_error_code, extract_err_message
  11. """
  12. _socket.py
  13. websocket - WebSocket client library for Python
  14. Copyright 2025 engn33r
  15. Licensed under the Apache License, Version 2.0 (the "License");
  16. you may not use this file except in compliance with the License.
  17. You may obtain a copy of the License at
  18. http://www.apache.org/licenses/LICENSE-2.0
  19. Unless required by applicable law or agreed to in writing, software
  20. distributed under the License is distributed on an "AS IS" BASIS,
  21. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. See the License for the specific language governing permissions and
  23. limitations under the License.
  24. """
  25. DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)]
  26. if hasattr(socket, "SO_KEEPALIVE"):
  27. DEFAULT_SOCKET_OPTION.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1))
  28. if hasattr(socket, "TCP_KEEPIDLE"):
  29. DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPIDLE, 30))
  30. if hasattr(socket, "TCP_KEEPINTVL"):
  31. DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPINTVL, 10))
  32. if hasattr(socket, "TCP_KEEPCNT"):
  33. DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPCNT, 3))
  34. _default_timeout = None
  35. __all__ = [
  36. "DEFAULT_SOCKET_OPTION",
  37. "sock_opt",
  38. "setdefaulttimeout",
  39. "getdefaulttimeout",
  40. "recv",
  41. "recv_line",
  42. "send",
  43. ]
  44. class sock_opt:
  45. def __init__(
  46. self, sockopt: Optional[list[tuple]], sslopt: Optional[dict[str, Any]]
  47. ) -> None:
  48. if sockopt is None:
  49. sockopt = []
  50. if sslopt is None:
  51. sslopt = {}
  52. self.sockopt = sockopt
  53. self.sslopt = sslopt
  54. self.timeout: Optional[Union[int, float]] = None
  55. def setdefaulttimeout(timeout: Optional[Union[int, float]]) -> None:
  56. """
  57. Set the global timeout setting to connect.
  58. Parameters
  59. ----------
  60. timeout: int or float
  61. default socket timeout time (in seconds)
  62. """
  63. global _default_timeout
  64. _default_timeout = timeout
  65. def getdefaulttimeout() -> Optional[Union[int, float]]:
  66. """
  67. Get default timeout
  68. Returns
  69. ----------
  70. _default_timeout: int or float
  71. Return the global timeout setting (in seconds) to connect.
  72. """
  73. return _default_timeout
  74. def recv(sock: socket.socket, bufsize: int) -> bytes:
  75. if not sock:
  76. raise WebSocketConnectionClosedException("socket is already closed.")
  77. def _recv():
  78. try:
  79. return sock.recv(bufsize)
  80. except SSLWantReadError:
  81. # Don't return None implicitly - fall through to retry logic
  82. pass
  83. except socket.error as exc:
  84. error_code = extract_error_code(exc)
  85. if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]:
  86. raise
  87. # Don't return None implicitly - fall through to retry logic
  88. # Retry logic using selector for both SSLWantReadError and EAGAIN/EWOULDBLOCK
  89. sel = selectors.DefaultSelector()
  90. sel.register(sock, selectors.EVENT_READ)
  91. r = sel.select(sock.gettimeout())
  92. sel.close()
  93. if r:
  94. return sock.recv(bufsize)
  95. else:
  96. # Selector timeout should raise WebSocketTimeoutException
  97. # not return None which gets misclassified as connection closed
  98. raise WebSocketTimeoutException("Connection timed out waiting for data")
  99. try:
  100. if sock.gettimeout() == 0:
  101. bytes_ = sock.recv(bufsize)
  102. else:
  103. bytes_ = _recv()
  104. except TimeoutError:
  105. raise WebSocketTimeoutException("Connection timed out")
  106. except socket.timeout as e:
  107. message = extract_err_message(e)
  108. raise WebSocketTimeoutException(message)
  109. except SSLError as e:
  110. message = extract_err_message(e)
  111. if isinstance(message, str) and "timed out" in message:
  112. raise WebSocketTimeoutException(message)
  113. else:
  114. raise
  115. if bytes_ is None:
  116. raise WebSocketConnectionClosedException("Connection to remote host was lost.")
  117. if not bytes_:
  118. raise WebSocketConnectionClosedException("Connection to remote host was lost.")
  119. return bytes_
  120. def recv_line(sock: socket.socket) -> bytes:
  121. line = []
  122. while True:
  123. c = recv(sock, 1)
  124. line.append(c)
  125. if c == b"\n":
  126. break
  127. return b"".join(line)
  128. def send(sock: socket.socket, data: Union[bytes, str]) -> int:
  129. if isinstance(data, str):
  130. data = data.encode("utf-8")
  131. if not sock:
  132. raise WebSocketConnectionClosedException("socket is already closed.")
  133. def _send() -> int:
  134. try:
  135. return sock.send(data)
  136. except SSLEOFError:
  137. raise WebSocketConnectionClosedException("socket is already closed.")
  138. except SSLWantWriteError:
  139. pass
  140. except socket.error as exc:
  141. error_code = extract_error_code(exc)
  142. if error_code is None:
  143. raise
  144. if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]:
  145. raise
  146. sel = selectors.DefaultSelector()
  147. sel.register(sock, selectors.EVENT_WRITE)
  148. w = sel.select(sock.gettimeout())
  149. sel.close()
  150. if w:
  151. return sock.send(data)
  152. return 0
  153. try:
  154. if sock.gettimeout() == 0:
  155. return sock.send(data)
  156. else:
  157. return _send()
  158. except socket.timeout as e:
  159. message = extract_err_message(e)
  160. raise WebSocketTimeoutException(message)
  161. except (OSError, SSLError) as e:
  162. message = extract_err_message(e)
  163. if isinstance(message, str) and "timed out" in message:
  164. raise WebSocketTimeoutException(message)
  165. else:
  166. raise