_abnf.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. import array
  2. import os
  3. import struct
  4. import sys
  5. from threading import Lock
  6. from typing import Callable, Optional, Union, Any
  7. from ._exceptions import WebSocketPayloadException, WebSocketProtocolException
  8. from ._utils import validate_utf8
  9. """
  10. _abnf.py
  11. websocket - WebSocket client library for Python
  12. Copyright 2025 engn33r
  13. Licensed under the Apache License, Version 2.0 (the "License");
  14. you may not use this file except in compliance with the License.
  15. You may obtain a copy of the License at
  16. http://www.apache.org/licenses/LICENSE-2.0
  17. Unless required by applicable law or agreed to in writing, software
  18. distributed under the License is distributed on an "AS IS" BASIS,
  19. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. See the License for the specific language governing permissions and
  21. limitations under the License.
  22. """
  23. try:
  24. # If wsaccel is available, use compiled routines to mask data.
  25. # wsaccel only provides around a 10% speed boost compared
  26. # to the websocket-client _mask() implementation.
  27. # Note that wsaccel is unmaintained.
  28. from wsaccel.xormask import XorMaskerSimple
  29. def _mask(mask_value: array.array, data_value: array.array) -> bytes:
  30. mask_result: bytes = XorMaskerSimple(mask_value).process(data_value)
  31. return mask_result
  32. except ImportError:
  33. # wsaccel is not available, use websocket-client _mask()
  34. native_byteorder = sys.byteorder
  35. def _mask(mask_value: array.array, data_value: array.array) -> bytes:
  36. datalen = len(data_value)
  37. int_data_value = int.from_bytes(data_value, native_byteorder)
  38. int_mask_value = int.from_bytes(
  39. mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder
  40. )
  41. return (int_data_value ^ int_mask_value).to_bytes(datalen, native_byteorder)
  42. __all__ = [
  43. "ABNF",
  44. "continuous_frame",
  45. "frame_buffer",
  46. "STATUS_NORMAL",
  47. "STATUS_GOING_AWAY",
  48. "STATUS_PROTOCOL_ERROR",
  49. "STATUS_UNSUPPORTED_DATA_TYPE",
  50. "STATUS_STATUS_NOT_AVAILABLE",
  51. "STATUS_ABNORMAL_CLOSED",
  52. "STATUS_INVALID_PAYLOAD",
  53. "STATUS_POLICY_VIOLATION",
  54. "STATUS_MESSAGE_TOO_BIG",
  55. "STATUS_INVALID_EXTENSION",
  56. "STATUS_UNEXPECTED_CONDITION",
  57. "STATUS_BAD_GATEWAY",
  58. "STATUS_TLS_HANDSHAKE_ERROR",
  59. ]
  60. # closing frame status codes.
  61. STATUS_NORMAL = 1000
  62. STATUS_GOING_AWAY = 1001
  63. STATUS_PROTOCOL_ERROR = 1002
  64. STATUS_UNSUPPORTED_DATA_TYPE = 1003
  65. STATUS_STATUS_NOT_AVAILABLE = 1005
  66. STATUS_ABNORMAL_CLOSED = 1006
  67. STATUS_INVALID_PAYLOAD = 1007
  68. STATUS_POLICY_VIOLATION = 1008
  69. STATUS_MESSAGE_TOO_BIG = 1009
  70. STATUS_INVALID_EXTENSION = 1010
  71. STATUS_UNEXPECTED_CONDITION = 1011
  72. STATUS_SERVICE_RESTART = 1012
  73. STATUS_TRY_AGAIN_LATER = 1013
  74. STATUS_BAD_GATEWAY = 1014
  75. STATUS_TLS_HANDSHAKE_ERROR = 1015
  76. VALID_CLOSE_STATUS = (
  77. STATUS_NORMAL,
  78. STATUS_GOING_AWAY,
  79. STATUS_PROTOCOL_ERROR,
  80. STATUS_UNSUPPORTED_DATA_TYPE,
  81. STATUS_INVALID_PAYLOAD,
  82. STATUS_POLICY_VIOLATION,
  83. STATUS_MESSAGE_TOO_BIG,
  84. STATUS_INVALID_EXTENSION,
  85. STATUS_UNEXPECTED_CONDITION,
  86. STATUS_SERVICE_RESTART,
  87. STATUS_TRY_AGAIN_LATER,
  88. STATUS_BAD_GATEWAY,
  89. )
  90. class ABNF:
  91. """
  92. ABNF frame class.
  93. See http://tools.ietf.org/html/rfc5234
  94. and http://tools.ietf.org/html/rfc6455#section-5.2
  95. """
  96. # operation code values.
  97. OPCODE_CONT = 0x0
  98. OPCODE_TEXT = 0x1
  99. OPCODE_BINARY = 0x2
  100. OPCODE_CLOSE = 0x8
  101. OPCODE_PING = 0x9
  102. OPCODE_PONG = 0xA
  103. # available operation code value tuple
  104. OPCODES = (
  105. OPCODE_CONT,
  106. OPCODE_TEXT,
  107. OPCODE_BINARY,
  108. OPCODE_CLOSE,
  109. OPCODE_PING,
  110. OPCODE_PONG,
  111. )
  112. # opcode human readable string
  113. OPCODE_MAP = {
  114. OPCODE_CONT: "cont",
  115. OPCODE_TEXT: "text",
  116. OPCODE_BINARY: "binary",
  117. OPCODE_CLOSE: "close",
  118. OPCODE_PING: "ping",
  119. OPCODE_PONG: "pong",
  120. }
  121. # data length threshold.
  122. LENGTH_7 = 0x7E
  123. LENGTH_16 = 1 << 16
  124. LENGTH_63 = 1 << 63
  125. def __init__(
  126. self,
  127. fin: int = 0,
  128. rsv1: int = 0,
  129. rsv2: int = 0,
  130. rsv3: int = 0,
  131. opcode: int = OPCODE_TEXT,
  132. mask_value: int = 1,
  133. data: Optional[Union[str, bytes]] = "",
  134. ) -> None:
  135. """
  136. Constructor for ABNF. Please check RFC for arguments.
  137. """
  138. self.fin = fin
  139. self.rsv1 = rsv1
  140. self.rsv2 = rsv2
  141. self.rsv3 = rsv3
  142. self.opcode = opcode
  143. self.mask_value = mask_value
  144. if data is None:
  145. data = ""
  146. self.data = data
  147. self.get_mask_key = os.urandom
  148. def validate(self, skip_utf8_validation: bool = False) -> None:
  149. """
  150. Validate the ABNF frame.
  151. Parameters
  152. ----------
  153. skip_utf8_validation: skip utf8 validation.
  154. """
  155. if self.rsv1 or self.rsv2 or self.rsv3:
  156. raise WebSocketProtocolException("rsv is not implemented, yet")
  157. if self.opcode not in ABNF.OPCODES:
  158. raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
  159. if self.opcode == ABNF.OPCODE_PING and not self.fin:
  160. raise WebSocketProtocolException("Invalid ping frame.")
  161. if self.opcode == ABNF.OPCODE_CLOSE:
  162. data_length = len(self.data)
  163. if not data_length:
  164. return
  165. if data_length == 1 or data_length >= 126:
  166. raise WebSocketProtocolException("Invalid close frame.")
  167. if (
  168. data_length > 2
  169. and not skip_utf8_validation
  170. and not validate_utf8(self.data[2:])
  171. ):
  172. raise WebSocketProtocolException("Invalid close frame.")
  173. data_bytes = (
  174. self.data[:2]
  175. if isinstance(self.data, bytes)
  176. else self.data[:2].encode("utf-8")
  177. )
  178. code = struct.unpack("!H", data_bytes)[0]
  179. if not self._is_valid_close_status(code):
  180. raise WebSocketProtocolException("Invalid close opcode %r", code)
  181. @staticmethod
  182. def _is_valid_close_status(code: int) -> bool:
  183. return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
  184. def __str__(self) -> str:
  185. data_repr = self.data if isinstance(self.data, str) else repr(self.data)
  186. return f"fin={self.fin} opcode={self.opcode} data={data_repr}"
  187. @staticmethod
  188. def create_frame(data: Union[bytes, str], opcode: int, fin: int = 1) -> "ABNF":
  189. """
  190. Create frame to send text, binary and other data.
  191. Parameters
  192. ----------
  193. data: str
  194. data to send. This is string value(byte array).
  195. If opcode is OPCODE_TEXT and this value is unicode,
  196. data value is converted into unicode string, automatically.
  197. opcode: int
  198. operation code. please see OPCODE_MAP.
  199. fin: int
  200. fin flag. if set to 0, create continue fragmentation.
  201. """
  202. if opcode == ABNF.OPCODE_TEXT and isinstance(data, str):
  203. data = data.encode("utf-8")
  204. # mask must be set if send data from client
  205. return ABNF(fin, 0, 0, 0, opcode, 1, data)
  206. def format(self) -> bytes:
  207. """
  208. Format this object to string(byte array) to send data to server.
  209. """
  210. if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
  211. raise ValueError("not 0 or 1")
  212. if self.opcode not in ABNF.OPCODES:
  213. raise ValueError("Invalid OPCODE")
  214. length = len(self.data)
  215. if length >= ABNF.LENGTH_63:
  216. raise ValueError("data is too long")
  217. frame_header = chr(
  218. self.fin << 7
  219. | self.rsv1 << 6
  220. | self.rsv2 << 5
  221. | self.rsv3 << 4
  222. | self.opcode
  223. ).encode("latin-1")
  224. if length < ABNF.LENGTH_7:
  225. frame_header += chr(self.mask_value << 7 | length).encode("latin-1")
  226. elif length < ABNF.LENGTH_16:
  227. frame_header += chr(self.mask_value << 7 | 0x7E).encode("latin-1")
  228. frame_header += struct.pack("!H", length)
  229. else:
  230. frame_header += chr(self.mask_value << 7 | 0x7F).encode("latin-1")
  231. frame_header += struct.pack("!Q", length)
  232. if not self.mask_value:
  233. if isinstance(self.data, str):
  234. self.data = self.data.encode("utf-8")
  235. return frame_header + self.data
  236. mask_key = self.get_mask_key(4)
  237. return frame_header + self._get_masked(mask_key)
  238. def _get_masked(self, mask_key: Union[str, bytes]) -> bytes:
  239. s = ABNF.mask(mask_key, self.data)
  240. if isinstance(mask_key, str):
  241. mask_key = mask_key.encode("utf-8")
  242. return mask_key + s
  243. @staticmethod
  244. def mask(mask_key: Union[str, bytes], data: Union[str, bytes]) -> bytes:
  245. """
  246. Mask or unmask data. Just do xor for each byte
  247. Parameters
  248. ----------
  249. mask_key: bytes or str
  250. 4 byte mask.
  251. data: bytes or str
  252. data to mask/unmask.
  253. """
  254. if data is None:
  255. data = ""
  256. if isinstance(mask_key, str):
  257. mask_key = mask_key.encode("latin-1")
  258. if isinstance(data, str):
  259. data = data.encode("latin-1")
  260. return _mask(array.array("B", mask_key), array.array("B", data))
  261. class frame_buffer:
  262. _HEADER_MASK_INDEX = 5
  263. _HEADER_LENGTH_INDEX = 6
  264. def __init__(
  265. self, recv_fn: Callable[[int], int], skip_utf8_validation: bool
  266. ) -> None:
  267. self.recv = recv_fn
  268. self.skip_utf8_validation = skip_utf8_validation
  269. # Buffers over the packets from the layer beneath until desired amount
  270. # bytes of bytes are received.
  271. self.recv_buffer: list = []
  272. self.clear()
  273. self.lock = Lock()
  274. def clear(self) -> None:
  275. self.header: Optional[tuple] = None
  276. self.length: Optional[int] = None
  277. self.mask_value: Optional[Union[bytes, str]] = None
  278. def needs_header(self) -> bool:
  279. return self.header is None
  280. def recv_header(self) -> None:
  281. header = self.recv_strict(2)
  282. b1 = header[0]
  283. fin = b1 >> 7 & 1
  284. rsv1 = b1 >> 6 & 1
  285. rsv2 = b1 >> 5 & 1
  286. rsv3 = b1 >> 4 & 1
  287. opcode = b1 & 0xF
  288. b2 = header[1]
  289. has_mask = b2 >> 7 & 1
  290. length_bits = b2 & 0x7F
  291. self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
  292. def has_mask(self) -> Union[bool, int]:
  293. if not self.header:
  294. return False
  295. header_val: int = self.header[frame_buffer._HEADER_MASK_INDEX]
  296. return header_val
  297. def needs_length(self) -> bool:
  298. return self.length is None
  299. def recv_length(self) -> None:
  300. if self.header is None:
  301. raise WebSocketProtocolException("Header not received")
  302. bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
  303. length_bits = bits & 0x7F
  304. if length_bits == 0x7E:
  305. v = self.recv_strict(2)
  306. self.length = struct.unpack("!H", v)[0]
  307. elif length_bits == 0x7F:
  308. v = self.recv_strict(8)
  309. self.length = struct.unpack("!Q", v)[0]
  310. else:
  311. self.length = length_bits
  312. def needs_mask(self) -> bool:
  313. return self.mask_value is None
  314. def recv_mask(self) -> None:
  315. self.mask_value = self.recv_strict(4) if self.has_mask() else ""
  316. def recv_frame(self) -> ABNF:
  317. with self.lock:
  318. # Header
  319. if self.needs_header():
  320. self.recv_header()
  321. if self.header is None:
  322. raise WebSocketProtocolException("Header not received")
  323. (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
  324. # Frame length
  325. if self.needs_length():
  326. self.recv_length()
  327. length = self.length
  328. # Mask
  329. if self.needs_mask():
  330. self.recv_mask()
  331. mask_value = self.mask_value
  332. # Payload
  333. if length is None:
  334. raise WebSocketProtocolException("Length not received")
  335. payload = self.recv_strict(length)
  336. if has_mask:
  337. if mask_value is None:
  338. raise WebSocketProtocolException("Mask not received")
  339. payload = ABNF.mask(mask_value, payload)
  340. # Reset for next frame
  341. self.clear()
  342. frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
  343. frame.validate(self.skip_utf8_validation)
  344. return frame
  345. def recv_strict(self, bufsize: int) -> bytes:
  346. if not isinstance(bufsize, int):
  347. raise ValueError("bufsize must be an integer")
  348. shortage = bufsize - sum(len(buf) for buf in self.recv_buffer)
  349. while shortage > 0:
  350. # Limit buffer size that we pass to socket.recv() to avoid
  351. # fragmenting the heap -- the number of bytes recv() actually
  352. # reads is limited by socket buffer and is relatively small,
  353. # yet passing large numbers repeatedly causes lots of large
  354. # buffers allocated and then shrunk, which results in
  355. # fragmentation.
  356. bytes_ = self.recv(min(16384, shortage))
  357. if isinstance(bytes_, bytes):
  358. self.recv_buffer.append(bytes_)
  359. shortage -= len(bytes_)
  360. else:
  361. # Handle case where recv returns int or other type
  362. break
  363. unified = b"".join(self.recv_buffer)
  364. if shortage == 0:
  365. self.recv_buffer = []
  366. return unified
  367. else:
  368. self.recv_buffer = [unified[bufsize:]]
  369. return unified[:bufsize]
  370. class continuous_frame:
  371. def __init__(self, fire_cont_frame: bool, skip_utf8_validation: bool) -> None:
  372. self.fire_cont_frame = fire_cont_frame
  373. self.skip_utf8_validation = skip_utf8_validation
  374. self.cont_data: Optional[list[Any]] = None
  375. self.recving_frames: Optional[int] = None
  376. def validate(self, frame: ABNF) -> None:
  377. if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
  378. raise WebSocketProtocolException("Illegal frame")
  379. if self.recving_frames and frame.opcode in (
  380. ABNF.OPCODE_TEXT,
  381. ABNF.OPCODE_BINARY,
  382. ):
  383. raise WebSocketProtocolException("Illegal frame")
  384. def add(self, frame: ABNF) -> None:
  385. if self.cont_data:
  386. self.cont_data[1] += frame.data
  387. else:
  388. if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
  389. self.recving_frames = frame.opcode
  390. self.cont_data = [frame.opcode, frame.data]
  391. if frame.fin:
  392. self.recving_frames = None
  393. def is_fire(self, frame: ABNF) -> Union[bool, int]:
  394. return frame.fin or self.fire_cont_frame
  395. def extract(self, frame: ABNF) -> tuple:
  396. data = self.cont_data
  397. if data is None:
  398. raise WebSocketProtocolException("No continuation data available")
  399. self.cont_data = None
  400. frame.data = data[1]
  401. if (
  402. not self.fire_cont_frame
  403. and data is not None
  404. and data[0] == ABNF.OPCODE_TEXT
  405. and not self.skip_utf8_validation
  406. and not validate_utf8(frame.data)
  407. ):
  408. raise WebSocketPayloadException(f"cannot decode: {repr(frame.data)}")
  409. if data is None:
  410. raise WebSocketProtocolException("No continuation data available")
  411. return data[0], frame