tls.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. from __future__ import annotations
  2. import logging
  3. import re
  4. import ssl
  5. import sys
  6. from collections.abc import Callable, Mapping
  7. from dataclasses import dataclass
  8. from functools import wraps
  9. from ssl import SSLContext
  10. from typing import Any, TypeVar
  11. from .. import (
  12. BrokenResourceError,
  13. EndOfStream,
  14. aclose_forcefully,
  15. get_cancelled_exc_class,
  16. to_thread,
  17. )
  18. from .._core._typedattr import TypedAttributeSet, typed_attribute
  19. from ..abc import (
  20. AnyByteStream,
  21. AnyByteStreamConnectable,
  22. ByteStream,
  23. ByteStreamConnectable,
  24. Listener,
  25. TaskGroup,
  26. )
  27. if sys.version_info >= (3, 10):
  28. from typing import TypeAlias
  29. else:
  30. from typing_extensions import TypeAlias
  31. if sys.version_info >= (3, 11):
  32. from typing import TypeVarTuple, Unpack
  33. else:
  34. from typing_extensions import TypeVarTuple, Unpack
  35. if sys.version_info >= (3, 12):
  36. from typing import override
  37. else:
  38. from typing_extensions import override
  39. T_Retval = TypeVar("T_Retval")
  40. PosArgsT = TypeVarTuple("PosArgsT")
  41. _PCTRTT: TypeAlias = tuple[tuple[str, str], ...]
  42. _PCTRTTT: TypeAlias = tuple[_PCTRTT, ...]
  43. class TLSAttribute(TypedAttributeSet):
  44. """Contains Transport Layer Security related attributes."""
  45. #: the selected ALPN protocol
  46. alpn_protocol: str | None = typed_attribute()
  47. #: the channel binding for type ``tls-unique``
  48. channel_binding_tls_unique: bytes = typed_attribute()
  49. #: the selected cipher
  50. cipher: tuple[str, str, int] = typed_attribute()
  51. #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
  52. # for more information)
  53. peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
  54. #: the peer certificate in binary form
  55. peer_certificate_binary: bytes | None = typed_attribute()
  56. #: ``True`` if this is the server side of the connection
  57. server_side: bool = typed_attribute()
  58. #: ciphers shared by the client during the TLS handshake (``None`` if this is the
  59. #: client side)
  60. shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
  61. #: the :class:`~ssl.SSLObject` used for encryption
  62. ssl_object: ssl.SSLObject = typed_attribute()
  63. #: ``True`` if this stream does (and expects) a closing TLS handshake when the
  64. #: stream is being closed
  65. standard_compatible: bool = typed_attribute()
  66. #: the TLS protocol version (e.g. ``TLSv1.2``)
  67. tls_version: str = typed_attribute()
  68. @dataclass(eq=False)
  69. class TLSStream(ByteStream):
  70. """
  71. A stream wrapper that encrypts all sent data and decrypts received data.
  72. This class has no public initializer; use :meth:`wrap` instead.
  73. All extra attributes from :class:`~TLSAttribute` are supported.
  74. :var AnyByteStream transport_stream: the wrapped stream
  75. """
  76. transport_stream: AnyByteStream
  77. standard_compatible: bool
  78. _ssl_object: ssl.SSLObject
  79. _read_bio: ssl.MemoryBIO
  80. _write_bio: ssl.MemoryBIO
  81. @classmethod
  82. async def wrap(
  83. cls,
  84. transport_stream: AnyByteStream,
  85. *,
  86. server_side: bool | None = None,
  87. hostname: str | None = None,
  88. ssl_context: ssl.SSLContext | None = None,
  89. standard_compatible: bool = True,
  90. ) -> TLSStream:
  91. """
  92. Wrap an existing stream with Transport Layer Security.
  93. This performs a TLS handshake with the peer.
  94. :param transport_stream: a bytes-transporting stream to wrap
  95. :param server_side: ``True`` if this is the server side of the connection,
  96. ``False`` if this is the client side (if omitted, will be set to ``False``
  97. if ``hostname`` has been provided, ``False`` otherwise). Used only to create
  98. a default context when an explicit context has not been provided.
  99. :param hostname: host name of the peer (if host name checking is desired)
  100. :param ssl_context: the SSLContext object to use (if not provided, a secure
  101. default will be created)
  102. :param standard_compatible: if ``False``, skip the closing handshake when
  103. closing the connection, and don't raise an exception if the peer does the
  104. same
  105. :raises ~ssl.SSLError: if the TLS handshake fails
  106. """
  107. if server_side is None:
  108. server_side = not hostname
  109. if not ssl_context:
  110. purpose = (
  111. ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
  112. )
  113. ssl_context = ssl.create_default_context(purpose)
  114. # Re-enable detection of unexpected EOFs if it was disabled by Python
  115. if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
  116. ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
  117. bio_in = ssl.MemoryBIO()
  118. bio_out = ssl.MemoryBIO()
  119. # External SSLContext implementations may do blocking I/O in wrap_bio(),
  120. # but the standard library implementation won't
  121. if type(ssl_context) is ssl.SSLContext:
  122. ssl_object = ssl_context.wrap_bio(
  123. bio_in, bio_out, server_side=server_side, server_hostname=hostname
  124. )
  125. else:
  126. ssl_object = await to_thread.run_sync(
  127. ssl_context.wrap_bio,
  128. bio_in,
  129. bio_out,
  130. server_side,
  131. hostname,
  132. None,
  133. )
  134. wrapper = cls(
  135. transport_stream=transport_stream,
  136. standard_compatible=standard_compatible,
  137. _ssl_object=ssl_object,
  138. _read_bio=bio_in,
  139. _write_bio=bio_out,
  140. )
  141. await wrapper._call_sslobject_method(ssl_object.do_handshake)
  142. return wrapper
  143. async def _call_sslobject_method(
  144. self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
  145. ) -> T_Retval:
  146. while True:
  147. try:
  148. result = func(*args)
  149. except ssl.SSLWantReadError:
  150. try:
  151. # Flush any pending writes first
  152. if self._write_bio.pending:
  153. await self.transport_stream.send(self._write_bio.read())
  154. data = await self.transport_stream.receive()
  155. except EndOfStream:
  156. self._read_bio.write_eof()
  157. except OSError as exc:
  158. self._read_bio.write_eof()
  159. self._write_bio.write_eof()
  160. raise BrokenResourceError from exc
  161. else:
  162. self._read_bio.write(data)
  163. except ssl.SSLWantWriteError:
  164. await self.transport_stream.send(self._write_bio.read())
  165. except ssl.SSLSyscallError as exc:
  166. self._read_bio.write_eof()
  167. self._write_bio.write_eof()
  168. raise BrokenResourceError from exc
  169. except ssl.SSLError as exc:
  170. self._read_bio.write_eof()
  171. self._write_bio.write_eof()
  172. if isinstance(exc, ssl.SSLEOFError) or (
  173. exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
  174. ):
  175. if self.standard_compatible:
  176. raise BrokenResourceError from exc
  177. else:
  178. raise EndOfStream from None
  179. raise
  180. else:
  181. # Flush any pending writes first
  182. if self._write_bio.pending:
  183. await self.transport_stream.send(self._write_bio.read())
  184. return result
  185. async def unwrap(self) -> tuple[AnyByteStream, bytes]:
  186. """
  187. Does the TLS closing handshake.
  188. :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
  189. """
  190. await self._call_sslobject_method(self._ssl_object.unwrap)
  191. self._read_bio.write_eof()
  192. self._write_bio.write_eof()
  193. return self.transport_stream, self._read_bio.read()
  194. async def aclose(self) -> None:
  195. if self.standard_compatible:
  196. try:
  197. await self.unwrap()
  198. except BaseException:
  199. await aclose_forcefully(self.transport_stream)
  200. raise
  201. await self.transport_stream.aclose()
  202. async def receive(self, max_bytes: int = 65536) -> bytes:
  203. data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
  204. if not data:
  205. raise EndOfStream
  206. return data
  207. async def send(self, item: bytes) -> None:
  208. await self._call_sslobject_method(self._ssl_object.write, item)
  209. async def send_eof(self) -> None:
  210. tls_version = self.extra(TLSAttribute.tls_version)
  211. match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
  212. if match:
  213. major, minor = int(match.group(1)), int(match.group(2) or 0)
  214. if (major, minor) < (1, 3):
  215. raise NotImplementedError(
  216. f"send_eof() requires at least TLSv1.3; current "
  217. f"session uses {tls_version}"
  218. )
  219. raise NotImplementedError(
  220. "send_eof() has not yet been implemented for TLS streams"
  221. )
  222. @property
  223. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  224. return {
  225. **self.transport_stream.extra_attributes,
  226. TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
  227. TLSAttribute.channel_binding_tls_unique: (
  228. self._ssl_object.get_channel_binding
  229. ),
  230. TLSAttribute.cipher: self._ssl_object.cipher,
  231. TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
  232. TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
  233. True
  234. ),
  235. TLSAttribute.server_side: lambda: self._ssl_object.server_side,
  236. TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
  237. if self._ssl_object.server_side
  238. else None,
  239. TLSAttribute.standard_compatible: lambda: self.standard_compatible,
  240. TLSAttribute.ssl_object: lambda: self._ssl_object,
  241. TLSAttribute.tls_version: self._ssl_object.version,
  242. }
  243. @dataclass(eq=False)
  244. class TLSListener(Listener[TLSStream]):
  245. """
  246. A convenience listener that wraps another listener and auto-negotiates a TLS session
  247. on every accepted connection.
  248. If the TLS handshake times out or raises an exception,
  249. :meth:`handle_handshake_error` is called to do whatever post-mortem processing is
  250. deemed necessary.
  251. Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
  252. :param Listener listener: the listener to wrap
  253. :param ssl_context: the SSL context object
  254. :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
  255. :param handshake_timeout: time limit for the TLS handshake
  256. (passed to :func:`~anyio.fail_after`)
  257. """
  258. listener: Listener[Any]
  259. ssl_context: ssl.SSLContext
  260. standard_compatible: bool = True
  261. handshake_timeout: float = 30
  262. @staticmethod
  263. async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
  264. """
  265. Handle an exception raised during the TLS handshake.
  266. This method does 3 things:
  267. #. Forcefully closes the original stream
  268. #. Logs the exception (unless it was a cancellation exception) using the
  269. ``anyio.streams.tls`` logger
  270. #. Reraises the exception if it was a base exception or a cancellation exception
  271. :param exc: the exception
  272. :param stream: the original stream
  273. """
  274. await aclose_forcefully(stream)
  275. # Log all except cancellation exceptions
  276. if not isinstance(exc, get_cancelled_exc_class()):
  277. # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
  278. # any asyncio implementation, so we explicitly pass the exception to log
  279. # (https://github.com/python/cpython/issues/108668). Trio does not have this
  280. # issue because it works around the CPython bug.
  281. logging.getLogger(__name__).exception(
  282. "Error during TLS handshake", exc_info=exc
  283. )
  284. # Only reraise base exceptions and cancellation exceptions
  285. if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
  286. raise
  287. async def serve(
  288. self,
  289. handler: Callable[[TLSStream], Any],
  290. task_group: TaskGroup | None = None,
  291. ) -> None:
  292. @wraps(handler)
  293. async def handler_wrapper(stream: AnyByteStream) -> None:
  294. from .. import fail_after
  295. try:
  296. with fail_after(self.handshake_timeout):
  297. wrapped_stream = await TLSStream.wrap(
  298. stream,
  299. ssl_context=self.ssl_context,
  300. standard_compatible=self.standard_compatible,
  301. )
  302. except BaseException as exc:
  303. await self.handle_handshake_error(exc, stream)
  304. else:
  305. await handler(wrapped_stream)
  306. await self.listener.serve(handler_wrapper, task_group)
  307. async def aclose(self) -> None:
  308. await self.listener.aclose()
  309. @property
  310. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  311. return {
  312. TLSAttribute.standard_compatible: lambda: self.standard_compatible,
  313. }
  314. class TLSConnectable(ByteStreamConnectable):
  315. """
  316. Wraps another connectable and does TLS negotiation after a successful connection.
  317. :param connectable: the connectable to wrap
  318. :param hostname: host name of the server (if host name checking is desired)
  319. :param ssl_context: the SSLContext object to use (if not provided, a secure default
  320. will be created)
  321. :param standard_compatible: if ``False``, skip the closing handshake when closing
  322. the connection, and don't raise an exception if the server does the same
  323. """
  324. def __init__(
  325. self,
  326. connectable: AnyByteStreamConnectable,
  327. *,
  328. hostname: str | None = None,
  329. ssl_context: ssl.SSLContext | None = None,
  330. standard_compatible: bool = True,
  331. ) -> None:
  332. self.connectable = connectable
  333. self.ssl_context: SSLContext = ssl_context or ssl.create_default_context(
  334. ssl.Purpose.SERVER_AUTH
  335. )
  336. if not isinstance(self.ssl_context, ssl.SSLContext):
  337. raise TypeError(
  338. "ssl_context must be an instance of ssl.SSLContext, not "
  339. f"{type(self.ssl_context).__name__}"
  340. )
  341. self.hostname = hostname
  342. self.standard_compatible = standard_compatible
  343. @override
  344. async def connect(self) -> TLSStream:
  345. stream = await self.connectable.connect()
  346. try:
  347. return await TLSStream.wrap(
  348. stream,
  349. hostname=self.hostname,
  350. ssl_context=self.ssl_context,
  351. standard_compatible=self.standard_compatible,
  352. )
  353. except BaseException:
  354. await aclose_forcefully(stream)
  355. raise