test_websocket.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. # -*- coding: utf-8 -*-
  2. #
  3. import os
  4. import os.path
  5. import socket
  6. import unittest
  7. from base64 import decodebytes as base64decode
  8. import websocket as ws
  9. from websocket._exceptions import (
  10. WebSocketBadStatusException,
  11. WebSocketAddressException,
  12. WebSocketException,
  13. )
  14. from websocket._handshake import _create_sec_websocket_key
  15. from websocket._handshake import _validate as _validate_header
  16. from websocket._http import read_headers
  17. from websocket._utils import validate_utf8
  18. """
  19. test_websocket.py
  20. websocket - WebSocket client library for Python
  21. Copyright 2025 engn33r
  22. Licensed under the Apache License, Version 2.0 (the "License");
  23. you may not use this file except in compliance with the License.
  24. You may obtain a copy of the License at
  25. http://www.apache.org/licenses/LICENSE-2.0
  26. Unless required by applicable law or agreed to in writing, software
  27. distributed under the License is distributed on an "AS IS" BASIS,
  28. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  29. See the License for the specific language governing permissions and
  30. limitations under the License.
  31. """
  32. try:
  33. import ssl
  34. except ImportError:
  35. # dummy class of SSLError for ssl none-support environment.
  36. class SSLError(Exception):
  37. pass
  38. # Skip test to access the internet unless TEST_WITH_INTERNET == 1
  39. TEST_WITH_INTERNET = os.environ.get("TEST_WITH_INTERNET", "0") == "1"
  40. # Skip tests relying on local websockets server unless LOCAL_WS_SERVER_PORT != -1
  41. LOCAL_WS_SERVER_PORT = os.environ.get("LOCAL_WS_SERVER_PORT", "-1")
  42. TEST_WITH_LOCAL_SERVER = LOCAL_WS_SERVER_PORT != "-1"
  43. TRACEABLE = True
  44. def create_mask_key(_):
  45. return "abcd"
  46. class SockMock:
  47. def __init__(self):
  48. self.data = []
  49. self.sent = []
  50. def add_packet(self, data):
  51. self.data.append(data)
  52. def gettimeout(self):
  53. return None
  54. def recv(self, bufsize):
  55. if self.data:
  56. e = self.data.pop(0)
  57. if isinstance(e, Exception):
  58. raise e
  59. if len(e) > bufsize:
  60. self.data.insert(0, e[bufsize:])
  61. return e[:bufsize]
  62. def send(self, data):
  63. self.sent.append(data)
  64. return len(data)
  65. def close(self):
  66. pass
  67. class HeaderSockMock(SockMock):
  68. def __init__(self, fname):
  69. SockMock.__init__(self)
  70. path = os.path.join(os.path.dirname(__file__), fname)
  71. with open(path, "rb") as f:
  72. self.add_packet(f.read())
  73. class WebSocketTest(unittest.TestCase):
  74. def setUp(self):
  75. ws.enableTrace(TRACEABLE)
  76. def tearDown(self):
  77. pass
  78. def test_default_timeout(self):
  79. self.assertEqual(ws.getdefaulttimeout(), None)
  80. ws.setdefaulttimeout(10)
  81. self.assertEqual(ws.getdefaulttimeout(), 10)
  82. ws.setdefaulttimeout(None)
  83. def test_ws_key(self):
  84. key = _create_sec_websocket_key()
  85. self.assertTrue(key != 24)
  86. self.assertTrue("¥n" not in key)
  87. def test_nonce(self):
  88. """WebSocket key should be a random 16-byte nonce."""
  89. key = _create_sec_websocket_key()
  90. nonce = base64decode(key.encode("utf-8"))
  91. self.assertEqual(16, len(nonce))
  92. def test_ws_utils(self):
  93. key = "c6b8hTg4EeGb2gQMztV1/g=="
  94. required_header = {
  95. "upgrade": "websocket",
  96. "connection": "upgrade",
  97. "sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=",
  98. }
  99. self.assertEqual(_validate_header(required_header, key, None), (True, None))
  100. header = required_header.copy()
  101. header["upgrade"] = "http"
  102. self.assertEqual(_validate_header(header, key, None), (False, None))
  103. del header["upgrade"]
  104. self.assertEqual(_validate_header(header, key, None), (False, None))
  105. header = required_header.copy()
  106. header["connection"] = "something"
  107. self.assertEqual(_validate_header(header, key, None), (False, None))
  108. del header["connection"]
  109. self.assertEqual(_validate_header(header, key, None), (False, None))
  110. header = required_header.copy()
  111. header["sec-websocket-accept"] = "something"
  112. self.assertEqual(_validate_header(header, key, None), (False, None))
  113. del header["sec-websocket-accept"]
  114. self.assertEqual(_validate_header(header, key, None), (False, None))
  115. header = required_header.copy()
  116. header["sec-websocket-protocol"] = "sub1"
  117. self.assertEqual(
  118. _validate_header(header, key, ["sub1", "sub2"]), (True, "sub1")
  119. )
  120. # This case will print out a logging error using the error() function, but that is expected
  121. self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None))
  122. header = required_header.copy()
  123. header["sec-websocket-protocol"] = "sUb1"
  124. self.assertEqual(
  125. _validate_header(header, key, ["Sub1", "suB2"]), (True, "sub1")
  126. )
  127. header = required_header.copy()
  128. # This case will print out a logging error using the error() function, but that is expected
  129. self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None))
  130. def test_read_header(self):
  131. status, header, _ = read_headers(HeaderSockMock("data/header01.txt"))
  132. self.assertEqual(status, 101)
  133. self.assertEqual(header["connection"], "Upgrade")
  134. status, header, _ = read_headers(HeaderSockMock("data/header03.txt"))
  135. self.assertEqual(status, 101)
  136. self.assertEqual(header["connection"], "Upgrade, Keep-Alive")
  137. HeaderSockMock("data/header02.txt")
  138. self.assertRaises(
  139. ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt")
  140. )
  141. def test_send(self):
  142. # TODO: add longer frame data
  143. sock = ws.WebSocket()
  144. sock.set_mask_key(create_mask_key)
  145. s = sock.sock = HeaderSockMock("data/header01.txt")
  146. sock.send("Hello")
  147. self.assertEqual(s.sent[0], b"\x81\x85abcd)\x07\x0f\x08\x0e")
  148. sock.send("こんにちは")
  149. self.assertEqual(
  150. s.sent[1],
  151. b"\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc",
  152. )
  153. # sock.send("x" * 5000)
  154. # self.assertEqual(s.sent[1], b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
  155. self.assertEqual(sock.send_binary(b"1111111111101"), 19)
  156. def test_recv(self):
  157. # TODO: add longer frame data
  158. sock = ws.WebSocket()
  159. s = sock.sock = SockMock()
  160. something = (
  161. b"\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc"
  162. )
  163. s.add_packet(something)
  164. data = sock.recv()
  165. self.assertEqual(data, "こんにちは")
  166. s.add_packet(b"\x81\x85abcd)\x07\x0f\x08\x0e")
  167. data = sock.recv()
  168. self.assertEqual(data, "Hello")
  169. @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
  170. def test_iter(self):
  171. count = 2
  172. s = ws.create_connection("wss://api.bitfinex.com/ws/2")
  173. s.send('{"event": "subscribe", "channel": "ticker"}')
  174. for _ in s:
  175. count -= 1
  176. if count == 0:
  177. break
  178. @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
  179. def test_next(self):
  180. sock = ws.create_connection("wss://api.bitfinex.com/ws/2")
  181. self.assertEqual(str, type(next(sock)))
  182. def test_internal_recv_strict(self):
  183. sock = ws.WebSocket()
  184. s = sock.sock = SockMock()
  185. s.add_packet(b"foo")
  186. s.add_packet(socket.timeout())
  187. s.add_packet(b"bar")
  188. # s.add_packet(SSLError("The read operation timed out"))
  189. s.add_packet(b"baz")
  190. with self.assertRaises(ws.WebSocketTimeoutException):
  191. sock.frame_buffer.recv_strict(9)
  192. # with self.assertRaises(SSLError):
  193. # data = sock._recv_strict(9)
  194. data = sock.frame_buffer.recv_strict(9)
  195. self.assertEqual(data, b"foobarbaz")
  196. with self.assertRaises(ws.WebSocketConnectionClosedException):
  197. sock.frame_buffer.recv_strict(1)
  198. def test_recv_timeout(self):
  199. sock = ws.WebSocket()
  200. s = sock.sock = SockMock()
  201. s.add_packet(b"\x81")
  202. s.add_packet(socket.timeout())
  203. s.add_packet(b"\x8dabcd\x29\x07\x0f\x08\x0e")
  204. s.add_packet(socket.timeout())
  205. s.add_packet(b"\x4e\x43\x33\x0e\x10\x0f\x00\x40")
  206. with self.assertRaises(ws.WebSocketTimeoutException):
  207. sock.recv()
  208. with self.assertRaises(ws.WebSocketTimeoutException):
  209. sock.recv()
  210. data = sock.recv()
  211. self.assertEqual(data, "Hello, World!")
  212. with self.assertRaises(ws.WebSocketConnectionClosedException):
  213. sock.recv()
  214. def test_recv_with_simple_fragmentation(self):
  215. sock = ws.WebSocket()
  216. s = sock.sock = SockMock()
  217. # OPCODE=TEXT, FIN=0, MSG="Brevity is "
  218. s.add_packet(b"\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
  219. # OPCODE=CONT, FIN=1, MSG="the soul of wit"
  220. s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
  221. data = sock.recv()
  222. self.assertEqual(data, "Brevity is the soul of wit")
  223. with self.assertRaises(ws.WebSocketConnectionClosedException):
  224. sock.recv()
  225. def test_recv_with_fire_event_of_fragmentation(self):
  226. sock = ws.WebSocket(fire_cont_frame=True)
  227. s = sock.sock = SockMock()
  228. # OPCODE=TEXT, FIN=0, MSG="Brevity is "
  229. s.add_packet(b"\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
  230. # OPCODE=CONT, FIN=0, MSG="Brevity is "
  231. s.add_packet(b"\x00\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
  232. # OPCODE=CONT, FIN=1, MSG="the soul of wit"
  233. s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
  234. _, data = sock.recv_data()
  235. self.assertEqual(data, b"Brevity is ")
  236. _, data = sock.recv_data()
  237. self.assertEqual(data, b"Brevity is ")
  238. _, data = sock.recv_data()
  239. self.assertEqual(data, b"the soul of wit")
  240. # OPCODE=CONT, FIN=0, MSG="Brevity is "
  241. s.add_packet(b"\x80\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")
  242. with self.assertRaises(ws.WebSocketException):
  243. sock.recv_data()
  244. with self.assertRaises(ws.WebSocketConnectionClosedException):
  245. sock.recv()
  246. def test_close(self):
  247. sock = ws.WebSocket()
  248. sock.connected = True
  249. sock.close()
  250. sock = ws.WebSocket()
  251. s = sock.sock = SockMock()
  252. sock.connected = True
  253. s.add_packet(b"\x88\x80\x17\x98p\x84")
  254. sock.recv()
  255. self.assertEqual(sock.connected, False)
  256. def test_recv_cont_fragmentation(self):
  257. sock = ws.WebSocket()
  258. s = sock.sock = SockMock()
  259. # OPCODE=CONT, FIN=1, MSG="the soul of wit"
  260. s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
  261. self.assertRaises(ws.WebSocketException, sock.recv)
  262. def test_recv_with_prolonged_fragmentation(self):
  263. sock = ws.WebSocket()
  264. s = sock.sock = SockMock()
  265. # OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
  266. s.add_packet(
  267. b"\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC"
  268. )
  269. # OPCODE=CONT, FIN=0, MSG="dear friends, "
  270. s.add_packet(b"\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07\x17MB")
  271. # OPCODE=CONT, FIN=1, MSG="once more"
  272. s.add_packet(b"\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04")
  273. data = sock.recv()
  274. self.assertEqual(data, "Once more unto the breach, dear friends, once more")
  275. with self.assertRaises(ws.WebSocketConnectionClosedException):
  276. sock.recv()
  277. def test_recv_with_fragmentation_and_control_frame(self):
  278. sock = ws.WebSocket()
  279. sock.set_mask_key(create_mask_key)
  280. s = sock.sock = SockMock()
  281. # OPCODE=TEXT, FIN=0, MSG="Too much "
  282. s.add_packet(b"\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA")
  283. # OPCODE=PING, FIN=1, MSG="Please PONG this"
  284. s.add_packet(b"\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")
  285. # OPCODE=CONT, FIN=1, MSG="of a good thing"
  286. s.add_packet(b"\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c\x08\x0c\x04")
  287. data = sock.recv()
  288. self.assertEqual(data, "Too much of a good thing")
  289. with self.assertRaises(ws.WebSocketConnectionClosedException):
  290. sock.recv()
  291. self.assertEqual(
  292. s.sent[0], b"\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17"
  293. )
  294. @unittest.skipUnless(
  295. TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
  296. )
  297. def test_websocket(self):
  298. s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
  299. self.assertNotEqual(s, None)
  300. s.send("Hello, World")
  301. result = s.next()
  302. s.fileno()
  303. self.assertEqual(result, "Hello, World")
  304. s.send("こにゃにゃちは、世界")
  305. result = s.recv()
  306. self.assertEqual(result, "こにゃにゃちは、世界")
  307. self.assertRaises(ValueError, s.send_close, -1, "")
  308. s.close()
  309. @unittest.skipUnless(
  310. TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
  311. )
  312. def test_ping_pong(self):
  313. s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
  314. self.assertNotEqual(s, None)
  315. s.ping("Hello")
  316. s.pong("Hi")
  317. s.close()
  318. @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
  319. def test_support_redirect(self):
  320. s = ws.WebSocket()
  321. self.assertRaises(WebSocketBadStatusException, s.connect, "ws://google.com/")
  322. # Need to find a URL that has a redirect code leading to a websocket
  323. @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
  324. def test_secure_websocket(self):
  325. s = ws.create_connection("wss://api.bitfinex.com/ws/2")
  326. self.assertNotEqual(s, None)
  327. self.assertTrue(isinstance(s.sock, ssl.SSLSocket))
  328. self.assertEqual(s.getstatus(), 101)
  329. self.assertNotEqual(s.getheaders(), None)
  330. s.settimeout(10)
  331. self.assertEqual(s.gettimeout(), 10)
  332. self.assertEqual(s.getsubprotocol(), None)
  333. s.abort()
  334. @unittest.skipUnless(
  335. TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
  336. )
  337. def test_websocket_with_custom_header(self):
  338. s = ws.create_connection(
  339. f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
  340. headers={"User-Agent": "PythonWebsocketClient"},
  341. )
  342. self.assertNotEqual(s, None)
  343. self.assertEqual(s.getsubprotocol(), None)
  344. s.send("Hello, World")
  345. result = s.recv()
  346. self.assertEqual(result, "Hello, World")
  347. self.assertRaises(ValueError, s.close, -1, "")
  348. s.close()
  349. @unittest.skipUnless(
  350. TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
  351. )
  352. def test_after_close(self):
  353. s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
  354. self.assertNotEqual(s, None)
  355. s.close()
  356. self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello")
  357. self.assertRaises(ws.WebSocketConnectionClosedException, s.recv)
  358. class SockOptTest(unittest.TestCase):
  359. @unittest.skipUnless(
  360. TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
  361. )
  362. def test_sockopt(self):
  363. sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),)
  364. s = ws.create_connection(
  365. f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", sockopt=sockopt
  366. )
  367. self.assertNotEqual(
  368. s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0
  369. )
  370. s.close()
  371. class UtilsTest(unittest.TestCase):
  372. def test_utf8_validator(self):
  373. state = validate_utf8(b"\xf0\x90\x80\x80")
  374. self.assertEqual(state, True)
  375. state = validate_utf8(
  376. b"\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5\xed\xa0\x80edited"
  377. )
  378. self.assertEqual(state, False)
  379. state = validate_utf8(b"")
  380. self.assertEqual(state, True)
  381. class HandshakeTest(unittest.TestCase):
  382. @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
  383. def test_http_ssl(self):
  384. websock1 = ws.WebSocket(
  385. sslopt={"cert_chain": ssl.get_default_verify_paths().capath},
  386. enable_multithread=False,
  387. )
  388. self.assertRaises(ValueError, websock1.connect, "wss://api.bitfinex.com/ws/2")
  389. websock2 = ws.WebSocket(sslopt={"certfile": "myNonexistentCertFile"})
  390. self.assertRaises(
  391. WebSocketException, websock2.connect, "wss://api.bitfinex.com/ws/2"
  392. )
  393. @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
  394. def test_manual_headers(self):
  395. websock3 = ws.WebSocket(
  396. sslopt={
  397. "ca_certs": ssl.get_default_verify_paths().cafile,
  398. "ca_cert_path": ssl.get_default_verify_paths().capath,
  399. }
  400. )
  401. self.assertRaises(
  402. WebSocketBadStatusException,
  403. websock3.connect,
  404. "wss://api.bitfinex.com/ws/2",
  405. cookie="chocolate",
  406. origin="testing_websockets.com",
  407. host="echo.websocket.events/websocket-client-test",
  408. subprotocols=["testproto"],
  409. connection="Upgrade",
  410. header={
  411. "CustomHeader1": "123",
  412. "Cookie": "TestValue",
  413. "Sec-WebSocket-Key": "k9kFAUWNAMmf5OEMfTlOEA==",
  414. "Sec-WebSocket-Protocol": "newprotocol",
  415. },
  416. )
  417. def test_ipv6(self):
  418. websock2 = ws.WebSocket()
  419. self.assertRaises(ValueError, websock2.connect, "2001:4860:4860::8888")
  420. def test_bad_urls(self):
  421. websock3 = ws.WebSocket()
  422. self.assertRaises(ValueError, websock3.connect, "ws//example.com")
  423. self.assertRaises(WebSocketAddressException, websock3.connect, "ws://example")
  424. self.assertRaises(ValueError, websock3.connect, "example.com")
  425. if __name__ == "__main__":
  426. unittest.main()