_trio.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384
  1. from __future__ import annotations
  2. import array
  3. import math
  4. import os
  5. import socket
  6. import sys
  7. import types
  8. import weakref
  9. from collections.abc import (
  10. AsyncGenerator,
  11. AsyncIterator,
  12. Awaitable,
  13. Callable,
  14. Collection,
  15. Coroutine,
  16. Iterable,
  17. Sequence,
  18. )
  19. from concurrent.futures import Future
  20. from contextlib import AbstractContextManager
  21. from dataclasses import dataclass
  22. from functools import partial
  23. from io import IOBase
  24. from os import PathLike
  25. from signal import Signals
  26. from socket import AddressFamily, SocketKind
  27. from types import TracebackType
  28. from typing import (
  29. IO,
  30. TYPE_CHECKING,
  31. Any,
  32. Generic,
  33. NoReturn,
  34. TypeVar,
  35. cast,
  36. overload,
  37. )
  38. import trio.from_thread
  39. import trio.lowlevel
  40. from outcome import Error, Outcome, Value
  41. from trio.lowlevel import (
  42. current_root_task,
  43. current_task,
  44. notify_closing,
  45. wait_readable,
  46. wait_writable,
  47. )
  48. from trio.socket import SocketType as TrioSocketType
  49. from trio.to_thread import run_sync
  50. from .. import (
  51. CapacityLimiterStatistics,
  52. EventStatistics,
  53. LockStatistics,
  54. RunFinishedError,
  55. TaskInfo,
  56. WouldBlock,
  57. abc,
  58. )
  59. from .._core._eventloop import claim_worker_thread
  60. from .._core._exceptions import (
  61. BrokenResourceError,
  62. BusyResourceError,
  63. ClosedResourceError,
  64. EndOfStream,
  65. )
  66. from .._core._sockets import convert_ipv6_sockaddr
  67. from .._core._streams import create_memory_object_stream
  68. from .._core._synchronization import (
  69. CapacityLimiter as BaseCapacityLimiter,
  70. )
  71. from .._core._synchronization import Event as BaseEvent
  72. from .._core._synchronization import Lock as BaseLock
  73. from .._core._synchronization import (
  74. ResourceGuard,
  75. SemaphoreStatistics,
  76. )
  77. from .._core._synchronization import Semaphore as BaseSemaphore
  78. from .._core._tasks import CancelScope as BaseCancelScope
  79. from ..abc import IPSockAddrType, UDPPacketType, UNIXDatagramPacketType
  80. from ..abc._eventloop import AsyncBackend, StrOrBytesPath
  81. from ..streams.memory import MemoryObjectSendStream
  82. if TYPE_CHECKING:
  83. from _typeshed import FileDescriptorLike
  84. if sys.version_info >= (3, 10):
  85. from typing import ParamSpec
  86. else:
  87. from typing_extensions import ParamSpec
  88. if sys.version_info >= (3, 11):
  89. from typing import TypeVarTuple, Unpack
  90. else:
  91. from exceptiongroup import BaseExceptionGroup
  92. from typing_extensions import TypeVarTuple, Unpack
  93. T = TypeVar("T")
  94. T_Retval = TypeVar("T_Retval")
  95. T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType)
  96. PosArgsT = TypeVarTuple("PosArgsT")
  97. P = ParamSpec("P")
  98. #
  99. # Event loop
  100. #
  101. RunVar = trio.lowlevel.RunVar
  102. #
  103. # Timeouts and cancellation
  104. #
  105. class CancelScope(BaseCancelScope):
  106. def __new__(
  107. cls, original: trio.CancelScope | None = None, **kwargs: object
  108. ) -> CancelScope:
  109. return object.__new__(cls)
  110. def __init__(self, original: trio.CancelScope | None = None, **kwargs: Any) -> None:
  111. self.__original = original or trio.CancelScope(**kwargs)
  112. def __enter__(self) -> CancelScope:
  113. self.__original.__enter__()
  114. return self
  115. def __exit__(
  116. self,
  117. exc_type: type[BaseException] | None,
  118. exc_val: BaseException | None,
  119. exc_tb: TracebackType | None,
  120. ) -> bool:
  121. return self.__original.__exit__(exc_type, exc_val, exc_tb)
  122. def cancel(self, reason: str | None = None) -> None:
  123. self.__original.cancel(reason)
  124. @property
  125. def deadline(self) -> float:
  126. return self.__original.deadline
  127. @deadline.setter
  128. def deadline(self, value: float) -> None:
  129. self.__original.deadline = value
  130. @property
  131. def cancel_called(self) -> bool:
  132. return self.__original.cancel_called
  133. @property
  134. def cancelled_caught(self) -> bool:
  135. return self.__original.cancelled_caught
  136. @property
  137. def shield(self) -> bool:
  138. return self.__original.shield
  139. @shield.setter
  140. def shield(self, value: bool) -> None:
  141. self.__original.shield = value
  142. #
  143. # Task groups
  144. #
  145. class TaskGroup(abc.TaskGroup):
  146. def __init__(self) -> None:
  147. self._active = False
  148. self._nursery_manager = trio.open_nursery(strict_exception_groups=True)
  149. self.cancel_scope = None # type: ignore[assignment]
  150. async def __aenter__(self) -> TaskGroup:
  151. self._active = True
  152. self._nursery = await self._nursery_manager.__aenter__()
  153. self.cancel_scope = CancelScope(self._nursery.cancel_scope)
  154. return self
  155. async def __aexit__(
  156. self,
  157. exc_type: type[BaseException] | None,
  158. exc_val: BaseException | None,
  159. exc_tb: TracebackType | None,
  160. ) -> bool:
  161. try:
  162. # trio.Nursery.__exit__ returns bool; .open_nursery has wrong type
  163. return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb) # type: ignore[return-value]
  164. except BaseExceptionGroup as exc:
  165. if not exc.split(trio.Cancelled)[1]:
  166. raise trio.Cancelled._create() from exc
  167. raise
  168. finally:
  169. del exc_val, exc_tb
  170. self._active = False
  171. def start_soon(
  172. self,
  173. func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
  174. *args: Unpack[PosArgsT],
  175. name: object = None,
  176. ) -> None:
  177. if not self._active:
  178. raise RuntimeError(
  179. "This task group is not active; no new tasks can be started."
  180. )
  181. self._nursery.start_soon(func, *args, name=name)
  182. async def start(
  183. self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
  184. ) -> Any:
  185. if not self._active:
  186. raise RuntimeError(
  187. "This task group is not active; no new tasks can be started."
  188. )
  189. return await self._nursery.start(func, *args, name=name)
  190. #
  191. # Threads
  192. #
  193. class BlockingPortal(abc.BlockingPortal):
  194. def __new__(cls) -> BlockingPortal:
  195. return object.__new__(cls)
  196. def __init__(self) -> None:
  197. super().__init__()
  198. self._token = trio.lowlevel.current_trio_token()
  199. def _spawn_task_from_thread(
  200. self,
  201. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  202. args: tuple[Unpack[PosArgsT]],
  203. kwargs: dict[str, Any],
  204. name: object,
  205. future: Future[T_Retval],
  206. ) -> None:
  207. trio.from_thread.run_sync(
  208. partial(self._task_group.start_soon, name=name),
  209. self._call_func,
  210. func,
  211. args,
  212. kwargs,
  213. future,
  214. trio_token=self._token,
  215. )
  216. #
  217. # Subprocesses
  218. #
  219. @dataclass(eq=False)
  220. class ReceiveStreamWrapper(abc.ByteReceiveStream):
  221. _stream: trio.abc.ReceiveStream
  222. async def receive(self, max_bytes: int | None = None) -> bytes:
  223. try:
  224. data = await self._stream.receive_some(max_bytes)
  225. except trio.ClosedResourceError as exc:
  226. raise ClosedResourceError from exc.__cause__
  227. except trio.BrokenResourceError as exc:
  228. raise BrokenResourceError from exc.__cause__
  229. if data:
  230. return bytes(data)
  231. else:
  232. raise EndOfStream
  233. async def aclose(self) -> None:
  234. await self._stream.aclose()
  235. @dataclass(eq=False)
  236. class SendStreamWrapper(abc.ByteSendStream):
  237. _stream: trio.abc.SendStream
  238. async def send(self, item: bytes) -> None:
  239. try:
  240. await self._stream.send_all(item)
  241. except trio.ClosedResourceError as exc:
  242. raise ClosedResourceError from exc.__cause__
  243. except trio.BrokenResourceError as exc:
  244. raise BrokenResourceError from exc.__cause__
  245. async def aclose(self) -> None:
  246. await self._stream.aclose()
  247. @dataclass(eq=False)
  248. class Process(abc.Process):
  249. _process: trio.Process
  250. _stdin: abc.ByteSendStream | None
  251. _stdout: abc.ByteReceiveStream | None
  252. _stderr: abc.ByteReceiveStream | None
  253. async def aclose(self) -> None:
  254. with CancelScope(shield=True):
  255. if self._stdin:
  256. await self._stdin.aclose()
  257. if self._stdout:
  258. await self._stdout.aclose()
  259. if self._stderr:
  260. await self._stderr.aclose()
  261. try:
  262. await self.wait()
  263. except BaseException:
  264. self.kill()
  265. with CancelScope(shield=True):
  266. await self.wait()
  267. raise
  268. async def wait(self) -> int:
  269. return await self._process.wait()
  270. def terminate(self) -> None:
  271. self._process.terminate()
  272. def kill(self) -> None:
  273. self._process.kill()
  274. def send_signal(self, signal: Signals) -> None:
  275. self._process.send_signal(signal)
  276. @property
  277. def pid(self) -> int:
  278. return self._process.pid
  279. @property
  280. def returncode(self) -> int | None:
  281. return self._process.returncode
  282. @property
  283. def stdin(self) -> abc.ByteSendStream | None:
  284. return self._stdin
  285. @property
  286. def stdout(self) -> abc.ByteReceiveStream | None:
  287. return self._stdout
  288. @property
  289. def stderr(self) -> abc.ByteReceiveStream | None:
  290. return self._stderr
  291. class _ProcessPoolShutdownInstrument(trio.abc.Instrument):
  292. def after_run(self) -> None:
  293. super().after_run()
  294. current_default_worker_process_limiter: trio.lowlevel.RunVar = RunVar(
  295. "current_default_worker_process_limiter"
  296. )
  297. async def _shutdown_process_pool(workers: set[abc.Process]) -> None:
  298. try:
  299. await trio.sleep(math.inf)
  300. except trio.Cancelled:
  301. for process in workers:
  302. if process.returncode is None:
  303. process.kill()
  304. with CancelScope(shield=True):
  305. for process in workers:
  306. await process.aclose()
  307. #
  308. # Sockets and networking
  309. #
  310. class _TrioSocketMixin(Generic[T_SockAddr]):
  311. def __init__(self, trio_socket: TrioSocketType) -> None:
  312. self._trio_socket = trio_socket
  313. self._closed = False
  314. def _check_closed(self) -> None:
  315. if self._closed:
  316. raise ClosedResourceError
  317. if self._trio_socket.fileno() < 0:
  318. raise BrokenResourceError
  319. @property
  320. def _raw_socket(self) -> socket.socket:
  321. return self._trio_socket._sock # type: ignore[attr-defined]
  322. async def aclose(self) -> None:
  323. if self._trio_socket.fileno() >= 0:
  324. self._closed = True
  325. self._trio_socket.close()
  326. def _convert_socket_error(self, exc: BaseException) -> NoReturn:
  327. if isinstance(exc, trio.ClosedResourceError):
  328. raise ClosedResourceError from exc
  329. elif self._trio_socket.fileno() < 0 and self._closed:
  330. raise ClosedResourceError from None
  331. elif isinstance(exc, OSError):
  332. raise BrokenResourceError from exc
  333. else:
  334. raise exc
  335. class SocketStream(_TrioSocketMixin, abc.SocketStream):
  336. def __init__(self, trio_socket: TrioSocketType) -> None:
  337. super().__init__(trio_socket)
  338. self._receive_guard = ResourceGuard("reading from")
  339. self._send_guard = ResourceGuard("writing to")
  340. async def receive(self, max_bytes: int = 65536) -> bytes:
  341. with self._receive_guard:
  342. try:
  343. data = await self._trio_socket.recv(max_bytes)
  344. except BaseException as exc:
  345. self._convert_socket_error(exc)
  346. if data:
  347. return data
  348. else:
  349. raise EndOfStream
  350. async def send(self, item: bytes) -> None:
  351. with self._send_guard:
  352. view = memoryview(item)
  353. while view:
  354. try:
  355. bytes_sent = await self._trio_socket.send(view)
  356. except BaseException as exc:
  357. self._convert_socket_error(exc)
  358. view = view[bytes_sent:]
  359. async def send_eof(self) -> None:
  360. self._trio_socket.shutdown(socket.SHUT_WR)
  361. class UNIXSocketStream(SocketStream, abc.UNIXSocketStream):
  362. async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
  363. if not isinstance(msglen, int) or msglen < 0:
  364. raise ValueError("msglen must be a non-negative integer")
  365. if not isinstance(maxfds, int) or maxfds < 1:
  366. raise ValueError("maxfds must be a positive integer")
  367. fds = array.array("i")
  368. await trio.lowlevel.checkpoint()
  369. with self._receive_guard:
  370. while True:
  371. try:
  372. message, ancdata, flags, addr = await self._trio_socket.recvmsg(
  373. msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
  374. )
  375. except BaseException as exc:
  376. self._convert_socket_error(exc)
  377. else:
  378. if not message and not ancdata:
  379. raise EndOfStream
  380. break
  381. for cmsg_level, cmsg_type, cmsg_data in ancdata:
  382. if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
  383. raise RuntimeError(
  384. f"Received unexpected ancillary data; message = {message!r}, "
  385. f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
  386. )
  387. fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
  388. return message, list(fds)
  389. async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
  390. if not message:
  391. raise ValueError("message must not be empty")
  392. if not fds:
  393. raise ValueError("fds must not be empty")
  394. filenos: list[int] = []
  395. for fd in fds:
  396. if isinstance(fd, int):
  397. filenos.append(fd)
  398. elif isinstance(fd, IOBase):
  399. filenos.append(fd.fileno())
  400. fdarray = array.array("i", filenos)
  401. await trio.lowlevel.checkpoint()
  402. with self._send_guard:
  403. while True:
  404. try:
  405. await self._trio_socket.sendmsg(
  406. [message],
  407. [
  408. (
  409. socket.SOL_SOCKET,
  410. socket.SCM_RIGHTS,
  411. fdarray,
  412. )
  413. ],
  414. )
  415. break
  416. except BaseException as exc:
  417. self._convert_socket_error(exc)
  418. class TCPSocketListener(_TrioSocketMixin, abc.SocketListener):
  419. def __init__(self, raw_socket: socket.socket):
  420. super().__init__(trio.socket.from_stdlib_socket(raw_socket))
  421. self._accept_guard = ResourceGuard("accepting connections from")
  422. async def accept(self) -> SocketStream:
  423. with self._accept_guard:
  424. try:
  425. trio_socket, _addr = await self._trio_socket.accept()
  426. except BaseException as exc:
  427. self._convert_socket_error(exc)
  428. trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  429. return SocketStream(trio_socket)
  430. class UNIXSocketListener(_TrioSocketMixin, abc.SocketListener):
  431. def __init__(self, raw_socket: socket.socket):
  432. super().__init__(trio.socket.from_stdlib_socket(raw_socket))
  433. self._accept_guard = ResourceGuard("accepting connections from")
  434. async def accept(self) -> UNIXSocketStream:
  435. with self._accept_guard:
  436. try:
  437. trio_socket, _addr = await self._trio_socket.accept()
  438. except BaseException as exc:
  439. self._convert_socket_error(exc)
  440. return UNIXSocketStream(trio_socket)
  441. class UDPSocket(_TrioSocketMixin[IPSockAddrType], abc.UDPSocket):
  442. def __init__(self, trio_socket: TrioSocketType) -> None:
  443. super().__init__(trio_socket)
  444. self._receive_guard = ResourceGuard("reading from")
  445. self._send_guard = ResourceGuard("writing to")
  446. async def receive(self) -> tuple[bytes, IPSockAddrType]:
  447. with self._receive_guard:
  448. try:
  449. data, addr = await self._trio_socket.recvfrom(65536)
  450. return data, convert_ipv6_sockaddr(addr)
  451. except BaseException as exc:
  452. self._convert_socket_error(exc)
  453. async def send(self, item: UDPPacketType) -> None:
  454. with self._send_guard:
  455. try:
  456. await self._trio_socket.sendto(*item)
  457. except BaseException as exc:
  458. self._convert_socket_error(exc)
  459. class ConnectedUDPSocket(_TrioSocketMixin[IPSockAddrType], abc.ConnectedUDPSocket):
  460. def __init__(self, trio_socket: TrioSocketType) -> None:
  461. super().__init__(trio_socket)
  462. self._receive_guard = ResourceGuard("reading from")
  463. self._send_guard = ResourceGuard("writing to")
  464. async def receive(self) -> bytes:
  465. with self._receive_guard:
  466. try:
  467. return await self._trio_socket.recv(65536)
  468. except BaseException as exc:
  469. self._convert_socket_error(exc)
  470. async def send(self, item: bytes) -> None:
  471. with self._send_guard:
  472. try:
  473. await self._trio_socket.send(item)
  474. except BaseException as exc:
  475. self._convert_socket_error(exc)
  476. class UNIXDatagramSocket(_TrioSocketMixin[str], abc.UNIXDatagramSocket):
  477. def __init__(self, trio_socket: TrioSocketType) -> None:
  478. super().__init__(trio_socket)
  479. self._receive_guard = ResourceGuard("reading from")
  480. self._send_guard = ResourceGuard("writing to")
  481. async def receive(self) -> UNIXDatagramPacketType:
  482. with self._receive_guard:
  483. try:
  484. data, addr = await self._trio_socket.recvfrom(65536)
  485. return data, addr
  486. except BaseException as exc:
  487. self._convert_socket_error(exc)
  488. async def send(self, item: UNIXDatagramPacketType) -> None:
  489. with self._send_guard:
  490. try:
  491. await self._trio_socket.sendto(*item)
  492. except BaseException as exc:
  493. self._convert_socket_error(exc)
  494. class ConnectedUNIXDatagramSocket(
  495. _TrioSocketMixin[str], abc.ConnectedUNIXDatagramSocket
  496. ):
  497. def __init__(self, trio_socket: TrioSocketType) -> None:
  498. super().__init__(trio_socket)
  499. self._receive_guard = ResourceGuard("reading from")
  500. self._send_guard = ResourceGuard("writing to")
  501. async def receive(self) -> bytes:
  502. with self._receive_guard:
  503. try:
  504. return await self._trio_socket.recv(65536)
  505. except BaseException as exc:
  506. self._convert_socket_error(exc)
  507. async def send(self, item: bytes) -> None:
  508. with self._send_guard:
  509. try:
  510. await self._trio_socket.send(item)
  511. except BaseException as exc:
  512. self._convert_socket_error(exc)
  513. #
  514. # Synchronization
  515. #
  516. class Event(BaseEvent):
  517. def __new__(cls) -> Event:
  518. return object.__new__(cls)
  519. def __init__(self) -> None:
  520. self.__original = trio.Event()
  521. def is_set(self) -> bool:
  522. return self.__original.is_set()
  523. async def wait(self) -> None:
  524. return await self.__original.wait()
  525. def statistics(self) -> EventStatistics:
  526. orig_statistics = self.__original.statistics()
  527. return EventStatistics(tasks_waiting=orig_statistics.tasks_waiting)
  528. def set(self) -> None:
  529. self.__original.set()
  530. class Lock(BaseLock):
  531. def __new__(cls, *, fast_acquire: bool = False) -> Lock:
  532. return object.__new__(cls)
  533. def __init__(self, *, fast_acquire: bool = False) -> None:
  534. self._fast_acquire = fast_acquire
  535. self.__original = trio.Lock()
  536. @staticmethod
  537. def _convert_runtime_error_msg(exc: RuntimeError) -> None:
  538. if exc.args == ("attempt to re-acquire an already held Lock",):
  539. exc.args = ("Attempted to acquire an already held Lock",)
  540. async def acquire(self) -> None:
  541. if not self._fast_acquire:
  542. try:
  543. await self.__original.acquire()
  544. except RuntimeError as exc:
  545. self._convert_runtime_error_msg(exc)
  546. raise
  547. return
  548. # This is the "fast path" where we don't let other tasks run
  549. await trio.lowlevel.checkpoint_if_cancelled()
  550. try:
  551. self.__original.acquire_nowait()
  552. except trio.WouldBlock:
  553. await self.__original._lot.park()
  554. except RuntimeError as exc:
  555. self._convert_runtime_error_msg(exc)
  556. raise
  557. def acquire_nowait(self) -> None:
  558. try:
  559. self.__original.acquire_nowait()
  560. except trio.WouldBlock:
  561. raise WouldBlock from None
  562. except RuntimeError as exc:
  563. self._convert_runtime_error_msg(exc)
  564. raise
  565. def locked(self) -> bool:
  566. return self.__original.locked()
  567. def release(self) -> None:
  568. self.__original.release()
  569. def statistics(self) -> LockStatistics:
  570. orig_statistics = self.__original.statistics()
  571. owner = TrioTaskInfo(orig_statistics.owner) if orig_statistics.owner else None
  572. return LockStatistics(
  573. orig_statistics.locked, owner, orig_statistics.tasks_waiting
  574. )
  575. class Semaphore(BaseSemaphore):
  576. def __new__(
  577. cls,
  578. initial_value: int,
  579. *,
  580. max_value: int | None = None,
  581. fast_acquire: bool = False,
  582. ) -> Semaphore:
  583. return object.__new__(cls)
  584. def __init__(
  585. self,
  586. initial_value: int,
  587. *,
  588. max_value: int | None = None,
  589. fast_acquire: bool = False,
  590. ) -> None:
  591. super().__init__(initial_value, max_value=max_value, fast_acquire=fast_acquire)
  592. self.__original = trio.Semaphore(initial_value, max_value=max_value)
  593. async def acquire(self) -> None:
  594. if not self._fast_acquire:
  595. await self.__original.acquire()
  596. return
  597. # This is the "fast path" where we don't let other tasks run
  598. await trio.lowlevel.checkpoint_if_cancelled()
  599. try:
  600. self.__original.acquire_nowait()
  601. except trio.WouldBlock:
  602. await self.__original._lot.park()
  603. def acquire_nowait(self) -> None:
  604. try:
  605. self.__original.acquire_nowait()
  606. except trio.WouldBlock:
  607. raise WouldBlock from None
  608. @property
  609. def max_value(self) -> int | None:
  610. return self.__original.max_value
  611. @property
  612. def value(self) -> int:
  613. return self.__original.value
  614. def release(self) -> None:
  615. self.__original.release()
  616. def statistics(self) -> SemaphoreStatistics:
  617. orig_statistics = self.__original.statistics()
  618. return SemaphoreStatistics(orig_statistics.tasks_waiting)
  619. class CapacityLimiter(BaseCapacityLimiter):
  620. def __new__(
  621. cls,
  622. total_tokens: float | None = None,
  623. *,
  624. original: trio.CapacityLimiter | None = None,
  625. ) -> CapacityLimiter:
  626. return object.__new__(cls)
  627. def __init__(
  628. self,
  629. total_tokens: float | None = None,
  630. *,
  631. original: trio.CapacityLimiter | None = None,
  632. ) -> None:
  633. if original is not None:
  634. self.__original = original
  635. else:
  636. assert total_tokens is not None
  637. self.__original = trio.CapacityLimiter(total_tokens)
  638. async def __aenter__(self) -> None:
  639. return await self.__original.__aenter__()
  640. async def __aexit__(
  641. self,
  642. exc_type: type[BaseException] | None,
  643. exc_val: BaseException | None,
  644. exc_tb: TracebackType | None,
  645. ) -> None:
  646. await self.__original.__aexit__(exc_type, exc_val, exc_tb)
  647. @property
  648. def total_tokens(self) -> float:
  649. return self.__original.total_tokens
  650. @total_tokens.setter
  651. def total_tokens(self, value: float) -> None:
  652. self.__original.total_tokens = value
  653. @property
  654. def borrowed_tokens(self) -> int:
  655. return self.__original.borrowed_tokens
  656. @property
  657. def available_tokens(self) -> float:
  658. return self.__original.available_tokens
  659. def acquire_nowait(self) -> None:
  660. self.__original.acquire_nowait()
  661. def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
  662. self.__original.acquire_on_behalf_of_nowait(borrower)
  663. async def acquire(self) -> None:
  664. await self.__original.acquire()
  665. async def acquire_on_behalf_of(self, borrower: object) -> None:
  666. await self.__original.acquire_on_behalf_of(borrower)
  667. def release(self) -> None:
  668. return self.__original.release()
  669. def release_on_behalf_of(self, borrower: object) -> None:
  670. return self.__original.release_on_behalf_of(borrower)
  671. def statistics(self) -> CapacityLimiterStatistics:
  672. orig = self.__original.statistics()
  673. return CapacityLimiterStatistics(
  674. borrowed_tokens=orig.borrowed_tokens,
  675. total_tokens=orig.total_tokens,
  676. borrowers=tuple(orig.borrowers),
  677. tasks_waiting=orig.tasks_waiting,
  678. )
  679. _capacity_limiter_wrapper: trio.lowlevel.RunVar = RunVar("_capacity_limiter_wrapper")
  680. #
  681. # Signal handling
  682. #
  683. class _SignalReceiver:
  684. _iterator: AsyncIterator[int]
  685. def __init__(self, signals: tuple[Signals, ...]):
  686. self._signals = signals
  687. def __enter__(self) -> _SignalReceiver:
  688. self._cm = trio.open_signal_receiver(*self._signals)
  689. self._iterator = self._cm.__enter__()
  690. return self
  691. def __exit__(
  692. self,
  693. exc_type: type[BaseException] | None,
  694. exc_val: BaseException | None,
  695. exc_tb: TracebackType | None,
  696. ) -> bool | None:
  697. return self._cm.__exit__(exc_type, exc_val, exc_tb)
  698. def __aiter__(self) -> _SignalReceiver:
  699. return self
  700. async def __anext__(self) -> Signals:
  701. signum = await self._iterator.__anext__()
  702. return Signals(signum)
  703. #
  704. # Testing and debugging
  705. #
  706. class TestRunner(abc.TestRunner):
  707. def __init__(self, **options: Any) -> None:
  708. from queue import Queue
  709. self._call_queue: Queue[Callable[[], object]] = Queue()
  710. self._send_stream: MemoryObjectSendStream | None = None
  711. self._options = options
  712. def __exit__(
  713. self,
  714. exc_type: type[BaseException] | None,
  715. exc_val: BaseException | None,
  716. exc_tb: types.TracebackType | None,
  717. ) -> None:
  718. if self._send_stream:
  719. self._send_stream.close()
  720. while self._send_stream is not None:
  721. self._call_queue.get()()
  722. async def _run_tests_and_fixtures(self) -> None:
  723. self._send_stream, receive_stream = create_memory_object_stream(1)
  724. with receive_stream:
  725. async for coro, outcome_holder in receive_stream:
  726. try:
  727. retval = await coro
  728. except BaseException as exc:
  729. outcome_holder.append(Error(exc))
  730. else:
  731. outcome_holder.append(Value(retval))
  732. def _main_task_finished(self, outcome: object) -> None:
  733. self._send_stream = None
  734. def _call_in_runner_task(
  735. self,
  736. func: Callable[P, Awaitable[T_Retval]],
  737. *args: P.args,
  738. **kwargs: P.kwargs,
  739. ) -> T_Retval:
  740. if self._send_stream is None:
  741. trio.lowlevel.start_guest_run(
  742. self._run_tests_and_fixtures,
  743. run_sync_soon_threadsafe=self._call_queue.put,
  744. done_callback=self._main_task_finished,
  745. **self._options,
  746. )
  747. while self._send_stream is None:
  748. self._call_queue.get()()
  749. outcome_holder: list[Outcome] = []
  750. self._send_stream.send_nowait((func(*args, **kwargs), outcome_holder))
  751. while not outcome_holder:
  752. self._call_queue.get()()
  753. return outcome_holder[0].unwrap()
  754. def run_asyncgen_fixture(
  755. self,
  756. fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
  757. kwargs: dict[str, Any],
  758. ) -> Iterable[T_Retval]:
  759. asyncgen = fixture_func(**kwargs)
  760. fixturevalue: T_Retval = self._call_in_runner_task(asyncgen.asend, None)
  761. yield fixturevalue
  762. try:
  763. self._call_in_runner_task(asyncgen.asend, None)
  764. except StopAsyncIteration:
  765. pass
  766. else:
  767. self._call_in_runner_task(asyncgen.aclose)
  768. raise RuntimeError("Async generator fixture did not stop")
  769. def run_fixture(
  770. self,
  771. fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
  772. kwargs: dict[str, Any],
  773. ) -> T_Retval:
  774. return self._call_in_runner_task(fixture_func, **kwargs)
  775. def run_test(
  776. self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
  777. ) -> None:
  778. self._call_in_runner_task(test_func, **kwargs)
  779. class TrioTaskInfo(TaskInfo):
  780. def __init__(self, task: trio.lowlevel.Task):
  781. parent_id = None
  782. if task.parent_nursery and task.parent_nursery.parent_task:
  783. parent_id = id(task.parent_nursery.parent_task)
  784. super().__init__(id(task), parent_id, task.name, task.coro)
  785. self._task = weakref.proxy(task)
  786. def has_pending_cancellation(self) -> bool:
  787. try:
  788. return self._task._cancel_status.effectively_cancelled
  789. except ReferenceError:
  790. # If the task is no longer around, it surely doesn't have a cancellation
  791. # pending
  792. return False
  793. class TrioBackend(AsyncBackend):
  794. @classmethod
  795. def run(
  796. cls,
  797. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  798. args: tuple[Unpack[PosArgsT]],
  799. kwargs: dict[str, Any],
  800. options: dict[str, Any],
  801. ) -> T_Retval:
  802. return trio.run(func, *args)
  803. @classmethod
  804. def current_token(cls) -> object:
  805. return trio.lowlevel.current_trio_token()
  806. @classmethod
  807. def current_time(cls) -> float:
  808. return trio.current_time()
  809. @classmethod
  810. def cancelled_exception_class(cls) -> type[BaseException]:
  811. return trio.Cancelled
  812. @classmethod
  813. async def checkpoint(cls) -> None:
  814. await trio.lowlevel.checkpoint()
  815. @classmethod
  816. async def checkpoint_if_cancelled(cls) -> None:
  817. await trio.lowlevel.checkpoint_if_cancelled()
  818. @classmethod
  819. async def cancel_shielded_checkpoint(cls) -> None:
  820. await trio.lowlevel.cancel_shielded_checkpoint()
  821. @classmethod
  822. async def sleep(cls, delay: float) -> None:
  823. await trio.sleep(delay)
  824. @classmethod
  825. def create_cancel_scope(
  826. cls, *, deadline: float = math.inf, shield: bool = False
  827. ) -> abc.CancelScope:
  828. return CancelScope(deadline=deadline, shield=shield)
  829. @classmethod
  830. def current_effective_deadline(cls) -> float:
  831. return trio.current_effective_deadline()
  832. @classmethod
  833. def create_task_group(cls) -> abc.TaskGroup:
  834. return TaskGroup()
  835. @classmethod
  836. def create_event(cls) -> abc.Event:
  837. return Event()
  838. @classmethod
  839. def create_lock(cls, *, fast_acquire: bool) -> Lock:
  840. return Lock(fast_acquire=fast_acquire)
  841. @classmethod
  842. def create_semaphore(
  843. cls,
  844. initial_value: int,
  845. *,
  846. max_value: int | None = None,
  847. fast_acquire: bool = False,
  848. ) -> abc.Semaphore:
  849. return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
  850. @classmethod
  851. def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter:
  852. return CapacityLimiter(total_tokens)
  853. @classmethod
  854. async def run_sync_in_worker_thread(
  855. cls,
  856. func: Callable[[Unpack[PosArgsT]], T_Retval],
  857. args: tuple[Unpack[PosArgsT]],
  858. abandon_on_cancel: bool = False,
  859. limiter: abc.CapacityLimiter | None = None,
  860. ) -> T_Retval:
  861. def wrapper() -> T_Retval:
  862. with claim_worker_thread(TrioBackend, token):
  863. return func(*args)
  864. token = TrioBackend.current_token()
  865. return await run_sync(
  866. wrapper,
  867. abandon_on_cancel=abandon_on_cancel,
  868. limiter=cast(trio.CapacityLimiter, limiter),
  869. )
  870. @classmethod
  871. def check_cancelled(cls) -> None:
  872. trio.from_thread.check_cancelled()
  873. @classmethod
  874. def run_async_from_thread(
  875. cls,
  876. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  877. args: tuple[Unpack[PosArgsT]],
  878. token: object,
  879. ) -> T_Retval:
  880. trio_token = cast("trio.lowlevel.TrioToken | None", token)
  881. try:
  882. return trio.from_thread.run(func, *args, trio_token=trio_token)
  883. except trio.RunFinishedError:
  884. raise RunFinishedError from None
  885. @classmethod
  886. def run_sync_from_thread(
  887. cls,
  888. func: Callable[[Unpack[PosArgsT]], T_Retval],
  889. args: tuple[Unpack[PosArgsT]],
  890. token: object,
  891. ) -> T_Retval:
  892. trio_token = cast("trio.lowlevel.TrioToken | None", token)
  893. try:
  894. return trio.from_thread.run_sync(func, *args, trio_token=trio_token)
  895. except trio.RunFinishedError:
  896. raise RunFinishedError from None
  897. @classmethod
  898. def create_blocking_portal(cls) -> abc.BlockingPortal:
  899. return BlockingPortal()
  900. @classmethod
  901. async def open_process(
  902. cls,
  903. command: StrOrBytesPath | Sequence[StrOrBytesPath],
  904. *,
  905. stdin: int | IO[Any] | None,
  906. stdout: int | IO[Any] | None,
  907. stderr: int | IO[Any] | None,
  908. **kwargs: Any,
  909. ) -> Process:
  910. def convert_item(item: StrOrBytesPath) -> str:
  911. str_or_bytes = os.fspath(item)
  912. if isinstance(str_or_bytes, str):
  913. return str_or_bytes
  914. else:
  915. return os.fsdecode(str_or_bytes)
  916. if isinstance(command, (str, bytes, PathLike)):
  917. process = await trio.lowlevel.open_process(
  918. convert_item(command),
  919. stdin=stdin,
  920. stdout=stdout,
  921. stderr=stderr,
  922. shell=True,
  923. **kwargs,
  924. )
  925. else:
  926. process = await trio.lowlevel.open_process(
  927. [convert_item(item) for item in command],
  928. stdin=stdin,
  929. stdout=stdout,
  930. stderr=stderr,
  931. shell=False,
  932. **kwargs,
  933. )
  934. stdin_stream = SendStreamWrapper(process.stdin) if process.stdin else None
  935. stdout_stream = ReceiveStreamWrapper(process.stdout) if process.stdout else None
  936. stderr_stream = ReceiveStreamWrapper(process.stderr) if process.stderr else None
  937. return Process(process, stdin_stream, stdout_stream, stderr_stream)
  938. @classmethod
  939. def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
  940. trio.lowlevel.spawn_system_task(_shutdown_process_pool, workers)
  941. @classmethod
  942. async def connect_tcp(
  943. cls, host: str, port: int, local_address: IPSockAddrType | None = None
  944. ) -> SocketStream:
  945. family = socket.AF_INET6 if ":" in host else socket.AF_INET
  946. trio_socket = trio.socket.socket(family)
  947. trio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  948. if local_address:
  949. await trio_socket.bind(local_address)
  950. try:
  951. await trio_socket.connect((host, port))
  952. except BaseException:
  953. trio_socket.close()
  954. raise
  955. return SocketStream(trio_socket)
  956. @classmethod
  957. async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
  958. trio_socket = trio.socket.socket(socket.AF_UNIX)
  959. try:
  960. await trio_socket.connect(path)
  961. except BaseException:
  962. trio_socket.close()
  963. raise
  964. return UNIXSocketStream(trio_socket)
  965. @classmethod
  966. def create_tcp_listener(cls, sock: socket.socket) -> abc.SocketListener:
  967. return TCPSocketListener(sock)
  968. @classmethod
  969. def create_unix_listener(cls, sock: socket.socket) -> abc.SocketListener:
  970. return UNIXSocketListener(sock)
  971. @classmethod
  972. async def create_udp_socket(
  973. cls,
  974. family: socket.AddressFamily,
  975. local_address: IPSockAddrType | None,
  976. remote_address: IPSockAddrType | None,
  977. reuse_port: bool,
  978. ) -> UDPSocket | ConnectedUDPSocket:
  979. trio_socket = trio.socket.socket(family=family, type=socket.SOCK_DGRAM)
  980. if reuse_port:
  981. trio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
  982. if local_address:
  983. await trio_socket.bind(local_address)
  984. if remote_address:
  985. await trio_socket.connect(remote_address)
  986. return ConnectedUDPSocket(trio_socket)
  987. else:
  988. return UDPSocket(trio_socket)
  989. @classmethod
  990. @overload
  991. async def create_unix_datagram_socket(
  992. cls, raw_socket: socket.socket, remote_path: None
  993. ) -> abc.UNIXDatagramSocket: ...
  994. @classmethod
  995. @overload
  996. async def create_unix_datagram_socket(
  997. cls, raw_socket: socket.socket, remote_path: str | bytes
  998. ) -> abc.ConnectedUNIXDatagramSocket: ...
  999. @classmethod
  1000. async def create_unix_datagram_socket(
  1001. cls, raw_socket: socket.socket, remote_path: str | bytes | None
  1002. ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
  1003. trio_socket = trio.socket.from_stdlib_socket(raw_socket)
  1004. if remote_path:
  1005. await trio_socket.connect(remote_path)
  1006. return ConnectedUNIXDatagramSocket(trio_socket)
  1007. else:
  1008. return UNIXDatagramSocket(trio_socket)
  1009. @classmethod
  1010. async def getaddrinfo(
  1011. cls,
  1012. host: bytes | str | None,
  1013. port: str | int | None,
  1014. *,
  1015. family: int | AddressFamily = 0,
  1016. type: int | SocketKind = 0,
  1017. proto: int = 0,
  1018. flags: int = 0,
  1019. ) -> Sequence[
  1020. tuple[
  1021. AddressFamily,
  1022. SocketKind,
  1023. int,
  1024. str,
  1025. tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
  1026. ]
  1027. ]:
  1028. return await trio.socket.getaddrinfo(host, port, family, type, proto, flags)
  1029. @classmethod
  1030. async def getnameinfo(
  1031. cls, sockaddr: IPSockAddrType, flags: int = 0
  1032. ) -> tuple[str, str]:
  1033. return await trio.socket.getnameinfo(sockaddr, flags)
  1034. @classmethod
  1035. async def wait_readable(cls, obj: FileDescriptorLike) -> None:
  1036. try:
  1037. await wait_readable(obj)
  1038. except trio.ClosedResourceError as exc:
  1039. raise ClosedResourceError().with_traceback(exc.__traceback__) from None
  1040. except trio.BusyResourceError:
  1041. raise BusyResourceError("reading from") from None
  1042. @classmethod
  1043. async def wait_writable(cls, obj: FileDescriptorLike) -> None:
  1044. try:
  1045. await wait_writable(obj)
  1046. except trio.ClosedResourceError as exc:
  1047. raise ClosedResourceError().with_traceback(exc.__traceback__) from None
  1048. except trio.BusyResourceError:
  1049. raise BusyResourceError("writing to") from None
  1050. @classmethod
  1051. def notify_closing(cls, obj: FileDescriptorLike) -> None:
  1052. notify_closing(obj)
  1053. @classmethod
  1054. async def wrap_listener_socket(cls, sock: socket.socket) -> abc.SocketListener:
  1055. return TCPSocketListener(sock)
  1056. @classmethod
  1057. async def wrap_stream_socket(cls, sock: socket.socket) -> SocketStream:
  1058. trio_sock = trio.socket.from_stdlib_socket(sock)
  1059. return SocketStream(trio_sock)
  1060. @classmethod
  1061. async def wrap_unix_stream_socket(cls, sock: socket.socket) -> UNIXSocketStream:
  1062. trio_sock = trio.socket.from_stdlib_socket(sock)
  1063. return UNIXSocketStream(trio_sock)
  1064. @classmethod
  1065. async def wrap_udp_socket(cls, sock: socket.socket) -> UDPSocket:
  1066. trio_sock = trio.socket.from_stdlib_socket(sock)
  1067. return UDPSocket(trio_sock)
  1068. @classmethod
  1069. async def wrap_connected_udp_socket(cls, sock: socket.socket) -> ConnectedUDPSocket:
  1070. trio_sock = trio.socket.from_stdlib_socket(sock)
  1071. return ConnectedUDPSocket(trio_sock)
  1072. @classmethod
  1073. async def wrap_unix_datagram_socket(cls, sock: socket.socket) -> UNIXDatagramSocket:
  1074. trio_sock = trio.socket.from_stdlib_socket(sock)
  1075. return UNIXDatagramSocket(trio_sock)
  1076. @classmethod
  1077. async def wrap_connected_unix_datagram_socket(
  1078. cls, sock: socket.socket
  1079. ) -> ConnectedUNIXDatagramSocket:
  1080. trio_sock = trio.socket.from_stdlib_socket(sock)
  1081. return ConnectedUNIXDatagramSocket(trio_sock)
  1082. @classmethod
  1083. def current_default_thread_limiter(cls) -> CapacityLimiter:
  1084. try:
  1085. return _capacity_limiter_wrapper.get()
  1086. except LookupError:
  1087. limiter = CapacityLimiter(
  1088. original=trio.to_thread.current_default_thread_limiter()
  1089. )
  1090. _capacity_limiter_wrapper.set(limiter)
  1091. return limiter
  1092. @classmethod
  1093. def open_signal_receiver(
  1094. cls, *signals: Signals
  1095. ) -> AbstractContextManager[AsyncIterator[Signals]]:
  1096. return _SignalReceiver(signals)
  1097. @classmethod
  1098. def get_current_task(cls) -> TaskInfo:
  1099. task = current_task()
  1100. return TrioTaskInfo(task)
  1101. @classmethod
  1102. def get_running_tasks(cls) -> Sequence[TaskInfo]:
  1103. root_task = current_root_task()
  1104. assert root_task
  1105. task_infos = [TrioTaskInfo(root_task)]
  1106. nurseries = root_task.child_nurseries
  1107. while nurseries:
  1108. new_nurseries: list[trio.Nursery] = []
  1109. for nursery in nurseries:
  1110. for task in nursery.child_tasks:
  1111. task_infos.append(TrioTaskInfo(task))
  1112. new_nurseries.extend(task.child_nurseries)
  1113. nurseries = new_nurseries
  1114. return task_infos
  1115. @classmethod
  1116. async def wait_all_tasks_blocked(cls) -> None:
  1117. from trio.testing import wait_all_tasks_blocked
  1118. await wait_all_tasks_blocked()
  1119. @classmethod
  1120. def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
  1121. return TestRunner(**options)
  1122. backend_class = TrioBackend