_asyncio.py 96 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967
  1. from __future__ import annotations
  2. import array
  3. import asyncio
  4. import concurrent.futures
  5. import contextvars
  6. import math
  7. import os
  8. import socket
  9. import sys
  10. import threading
  11. import weakref
  12. from asyncio import (
  13. AbstractEventLoop,
  14. CancelledError,
  15. all_tasks,
  16. create_task,
  17. current_task,
  18. get_running_loop,
  19. sleep,
  20. )
  21. from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined]
  22. from collections import OrderedDict, deque
  23. from collections.abc import (
  24. AsyncGenerator,
  25. AsyncIterator,
  26. Awaitable,
  27. Callable,
  28. Collection,
  29. Coroutine,
  30. Iterable,
  31. Sequence,
  32. )
  33. from concurrent.futures import Future
  34. from contextlib import AbstractContextManager, suppress
  35. from contextvars import Context, copy_context
  36. from dataclasses import dataclass
  37. from functools import partial, wraps
  38. from inspect import (
  39. CORO_RUNNING,
  40. CORO_SUSPENDED,
  41. getcoroutinestate,
  42. iscoroutine,
  43. )
  44. from io import IOBase
  45. from os import PathLike
  46. from queue import Queue
  47. from signal import Signals
  48. from socket import AddressFamily, SocketKind
  49. from threading import Thread
  50. from types import CodeType, TracebackType
  51. from typing import (
  52. IO,
  53. TYPE_CHECKING,
  54. Any,
  55. Optional,
  56. TypeVar,
  57. cast,
  58. )
  59. from weakref import WeakKeyDictionary
  60. import sniffio
  61. from .. import (
  62. CapacityLimiterStatistics,
  63. EventStatistics,
  64. LockStatistics,
  65. TaskInfo,
  66. abc,
  67. )
  68. from .._core._eventloop import claim_worker_thread, threadlocals
  69. from .._core._exceptions import (
  70. BrokenResourceError,
  71. BusyResourceError,
  72. ClosedResourceError,
  73. EndOfStream,
  74. RunFinishedError,
  75. WouldBlock,
  76. iterate_exceptions,
  77. )
  78. from .._core._sockets import convert_ipv6_sockaddr
  79. from .._core._streams import create_memory_object_stream
  80. from .._core._synchronization import (
  81. CapacityLimiter as BaseCapacityLimiter,
  82. )
  83. from .._core._synchronization import Event as BaseEvent
  84. from .._core._synchronization import Lock as BaseLock
  85. from .._core._synchronization import (
  86. ResourceGuard,
  87. SemaphoreStatistics,
  88. )
  89. from .._core._synchronization import Semaphore as BaseSemaphore
  90. from .._core._tasks import CancelScope as BaseCancelScope
  91. from ..abc import (
  92. AsyncBackend,
  93. IPSockAddrType,
  94. SocketListener,
  95. UDPPacketType,
  96. UNIXDatagramPacketType,
  97. )
  98. from ..abc._eventloop import StrOrBytesPath
  99. from ..lowlevel import RunVar
  100. from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
  101. if TYPE_CHECKING:
  102. from _typeshed import FileDescriptorLike
  103. else:
  104. FileDescriptorLike = object
  105. if sys.version_info >= (3, 10):
  106. from typing import ParamSpec
  107. else:
  108. from typing_extensions import ParamSpec
  109. if sys.version_info >= (3, 11):
  110. from asyncio import Runner
  111. from typing import TypeVarTuple, Unpack
  112. else:
  113. import contextvars
  114. import enum
  115. import signal
  116. from asyncio import coroutines, events, exceptions, tasks
  117. from exceptiongroup import BaseExceptionGroup
  118. from typing_extensions import TypeVarTuple, Unpack
  119. class _State(enum.Enum):
  120. CREATED = "created"
  121. INITIALIZED = "initialized"
  122. CLOSED = "closed"
  123. class Runner:
  124. # Copied from CPython 3.11
  125. def __init__(
  126. self,
  127. *,
  128. debug: bool | None = None,
  129. loop_factory: Callable[[], AbstractEventLoop] | None = None,
  130. ):
  131. self._state = _State.CREATED
  132. self._debug = debug
  133. self._loop_factory = loop_factory
  134. self._loop: AbstractEventLoop | None = None
  135. self._context = None
  136. self._interrupt_count = 0
  137. self._set_event_loop = False
  138. def __enter__(self) -> Runner:
  139. self._lazy_init()
  140. return self
  141. def __exit__(
  142. self,
  143. exc_type: type[BaseException],
  144. exc_val: BaseException,
  145. exc_tb: TracebackType,
  146. ) -> None:
  147. self.close()
  148. def close(self) -> None:
  149. """Shutdown and close event loop."""
  150. if self._state is not _State.INITIALIZED:
  151. return
  152. try:
  153. loop = self._loop
  154. _cancel_all_tasks(loop)
  155. loop.run_until_complete(loop.shutdown_asyncgens())
  156. if hasattr(loop, "shutdown_default_executor"):
  157. loop.run_until_complete(loop.shutdown_default_executor())
  158. else:
  159. loop.run_until_complete(_shutdown_default_executor(loop))
  160. finally:
  161. if self._set_event_loop:
  162. events.set_event_loop(None)
  163. loop.close()
  164. self._loop = None
  165. self._state = _State.CLOSED
  166. def get_loop(self) -> AbstractEventLoop:
  167. """Return embedded event loop."""
  168. self._lazy_init()
  169. return self._loop
  170. def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
  171. """Run a coroutine inside the embedded event loop."""
  172. if not coroutines.iscoroutine(coro):
  173. raise ValueError(f"a coroutine was expected, got {coro!r}")
  174. if events._get_running_loop() is not None:
  175. # fail fast with short traceback
  176. raise RuntimeError(
  177. "Runner.run() cannot be called from a running event loop"
  178. )
  179. self._lazy_init()
  180. if context is None:
  181. context = self._context
  182. task = context.run(self._loop.create_task, coro)
  183. if (
  184. threading.current_thread() is threading.main_thread()
  185. and signal.getsignal(signal.SIGINT) is signal.default_int_handler
  186. ):
  187. sigint_handler = partial(self._on_sigint, main_task=task)
  188. try:
  189. signal.signal(signal.SIGINT, sigint_handler)
  190. except ValueError:
  191. # `signal.signal` may throw if `threading.main_thread` does
  192. # not support signals (e.g. embedded interpreter with signals
  193. # not registered - see gh-91880)
  194. sigint_handler = None
  195. else:
  196. sigint_handler = None
  197. self._interrupt_count = 0
  198. try:
  199. return self._loop.run_until_complete(task)
  200. except exceptions.CancelledError:
  201. if self._interrupt_count > 0:
  202. uncancel = getattr(task, "uncancel", None)
  203. if uncancel is not None and uncancel() == 0:
  204. raise KeyboardInterrupt # noqa: B904
  205. raise # CancelledError
  206. finally:
  207. if (
  208. sigint_handler is not None
  209. and signal.getsignal(signal.SIGINT) is sigint_handler
  210. ):
  211. signal.signal(signal.SIGINT, signal.default_int_handler)
  212. def _lazy_init(self) -> None:
  213. if self._state is _State.CLOSED:
  214. raise RuntimeError("Runner is closed")
  215. if self._state is _State.INITIALIZED:
  216. return
  217. if self._loop_factory is None:
  218. self._loop = events.new_event_loop()
  219. if not self._set_event_loop:
  220. # Call set_event_loop only once to avoid calling
  221. # attach_loop multiple times on child watchers
  222. events.set_event_loop(self._loop)
  223. self._set_event_loop = True
  224. else:
  225. self._loop = self._loop_factory()
  226. if self._debug is not None:
  227. self._loop.set_debug(self._debug)
  228. self._context = contextvars.copy_context()
  229. self._state = _State.INITIALIZED
  230. def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
  231. self._interrupt_count += 1
  232. if self._interrupt_count == 1 and not main_task.done():
  233. main_task.cancel()
  234. # wakeup loop if it is blocked by select() with long timeout
  235. self._loop.call_soon_threadsafe(lambda: None)
  236. return
  237. raise KeyboardInterrupt()
  238. def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
  239. to_cancel = tasks.all_tasks(loop)
  240. if not to_cancel:
  241. return
  242. for task in to_cancel:
  243. task.cancel()
  244. loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
  245. for task in to_cancel:
  246. if task.cancelled():
  247. continue
  248. if task.exception() is not None:
  249. loop.call_exception_handler(
  250. {
  251. "message": "unhandled exception during asyncio.run() shutdown",
  252. "exception": task.exception(),
  253. "task": task,
  254. }
  255. )
  256. async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
  257. """Schedule the shutdown of the default executor."""
  258. def _do_shutdown(future: asyncio.futures.Future) -> None:
  259. try:
  260. loop._default_executor.shutdown(wait=True) # type: ignore[attr-defined]
  261. loop.call_soon_threadsafe(future.set_result, None)
  262. except Exception as ex:
  263. loop.call_soon_threadsafe(future.set_exception, ex)
  264. loop._executor_shutdown_called = True
  265. if loop._default_executor is None:
  266. return
  267. future = loop.create_future()
  268. thread = threading.Thread(target=_do_shutdown, args=(future,))
  269. thread.start()
  270. try:
  271. await future
  272. finally:
  273. thread.join()
  274. T_Retval = TypeVar("T_Retval")
  275. T_contra = TypeVar("T_contra", contravariant=True)
  276. PosArgsT = TypeVarTuple("PosArgsT")
  277. P = ParamSpec("P")
  278. _root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
  279. def find_root_task() -> asyncio.Task:
  280. root_task = _root_task.get(None)
  281. if root_task is not None and not root_task.done():
  282. return root_task
  283. # Look for a task that has been started via run_until_complete()
  284. for task in all_tasks():
  285. if task._callbacks and not task.done():
  286. callbacks = [cb for cb, context in task._callbacks]
  287. for cb in callbacks:
  288. if (
  289. cb is _run_until_complete_cb
  290. or getattr(cb, "__module__", None) == "uvloop.loop"
  291. ):
  292. _root_task.set(task)
  293. return task
  294. # Look up the topmost task in the AnyIO task tree, if possible
  295. task = cast(asyncio.Task, current_task())
  296. state = _task_states.get(task)
  297. if state:
  298. cancel_scope = state.cancel_scope
  299. while cancel_scope and cancel_scope._parent_scope is not None:
  300. cancel_scope = cancel_scope._parent_scope
  301. if cancel_scope is not None:
  302. return cast(asyncio.Task, cancel_scope._host_task)
  303. return task
  304. def get_callable_name(func: Callable) -> str:
  305. module = getattr(func, "__module__", None)
  306. qualname = getattr(func, "__qualname__", None)
  307. return ".".join([x for x in (module, qualname) if x])
  308. #
  309. # Event loop
  310. #
  311. _run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary()
  312. def _task_started(task: asyncio.Task) -> bool:
  313. """Return ``True`` if the task has been started and has not finished."""
  314. # The task coro should never be None here, as we never add finished tasks to the
  315. # task list
  316. coro = task.get_coro()
  317. assert coro is not None
  318. try:
  319. return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
  320. except AttributeError:
  321. # task coro is async_genenerator_asend https://bugs.python.org/issue37771
  322. raise Exception(f"Cannot determine if task {task} has started or not") from None
  323. #
  324. # Timeouts and cancellation
  325. #
  326. def is_anyio_cancellation(exc: CancelledError) -> bool:
  327. # Sometimes third party frameworks catch a CancelledError and raise a new one, so as
  328. # a workaround we have to look at the previous ones in __context__ too for a
  329. # matching cancel message
  330. while True:
  331. if (
  332. exc.args
  333. and isinstance(exc.args[0], str)
  334. and exc.args[0].startswith("Cancelled via cancel scope ")
  335. ):
  336. return True
  337. if isinstance(exc.__context__, CancelledError):
  338. exc = exc.__context__
  339. continue
  340. return False
  341. class CancelScope(BaseCancelScope):
  342. def __new__(
  343. cls, *, deadline: float = math.inf, shield: bool = False
  344. ) -> CancelScope:
  345. return object.__new__(cls)
  346. def __init__(self, deadline: float = math.inf, shield: bool = False):
  347. self._deadline = deadline
  348. self._shield = shield
  349. self._parent_scope: CancelScope | None = None
  350. self._child_scopes: set[CancelScope] = set()
  351. self._cancel_called = False
  352. self._cancel_reason: str | None = None
  353. self._cancelled_caught = False
  354. self._active = False
  355. self._timeout_handle: asyncio.TimerHandle | None = None
  356. self._cancel_handle: asyncio.Handle | None = None
  357. self._tasks: set[asyncio.Task] = set()
  358. self._host_task: asyncio.Task | None = None
  359. if sys.version_info >= (3, 11):
  360. self._pending_uncancellations: int | None = 0
  361. else:
  362. self._pending_uncancellations = None
  363. def __enter__(self) -> CancelScope:
  364. if self._active:
  365. raise RuntimeError(
  366. "Each CancelScope may only be used for a single 'with' block"
  367. )
  368. self._host_task = host_task = cast(asyncio.Task, current_task())
  369. self._tasks.add(host_task)
  370. try:
  371. task_state = _task_states[host_task]
  372. except KeyError:
  373. task_state = TaskState(None, self)
  374. _task_states[host_task] = task_state
  375. else:
  376. self._parent_scope = task_state.cancel_scope
  377. task_state.cancel_scope = self
  378. if self._parent_scope is not None:
  379. # If using an eager task factory, the parent scope may not even contain
  380. # the host task
  381. self._parent_scope._child_scopes.add(self)
  382. self._parent_scope._tasks.discard(host_task)
  383. self._timeout()
  384. self._active = True
  385. # Start cancelling the host task if the scope was cancelled before entering
  386. if self._cancel_called:
  387. self._deliver_cancellation(self)
  388. return self
  389. def __exit__(
  390. self,
  391. exc_type: type[BaseException] | None,
  392. exc_val: BaseException | None,
  393. exc_tb: TracebackType | None,
  394. ) -> bool:
  395. del exc_tb
  396. if not self._active:
  397. raise RuntimeError("This cancel scope is not active")
  398. if current_task() is not self._host_task:
  399. raise RuntimeError(
  400. "Attempted to exit cancel scope in a different task than it was "
  401. "entered in"
  402. )
  403. assert self._host_task is not None
  404. host_task_state = _task_states.get(self._host_task)
  405. if host_task_state is None or host_task_state.cancel_scope is not self:
  406. raise RuntimeError(
  407. "Attempted to exit a cancel scope that isn't the current tasks's "
  408. "current cancel scope"
  409. )
  410. try:
  411. self._active = False
  412. if self._timeout_handle:
  413. self._timeout_handle.cancel()
  414. self._timeout_handle = None
  415. self._tasks.remove(self._host_task)
  416. if self._parent_scope is not None:
  417. self._parent_scope._child_scopes.remove(self)
  418. self._parent_scope._tasks.add(self._host_task)
  419. host_task_state.cancel_scope = self._parent_scope
  420. # Restart the cancellation effort in the closest visible, cancelled parent
  421. # scope if necessary
  422. self._restart_cancellation_in_parent()
  423. # We only swallow the exception iff it was an AnyIO CancelledError, either
  424. # directly as exc_val or inside an exception group and there are no cancelled
  425. # parent cancel scopes visible to us here
  426. if self._cancel_called and not self._parent_cancellation_is_visible_to_us:
  427. # For each level-cancel() call made on the host task, call uncancel()
  428. while self._pending_uncancellations:
  429. self._host_task.uncancel()
  430. self._pending_uncancellations -= 1
  431. # Update cancelled_caught and check for exceptions we must not swallow
  432. cannot_swallow_exc_val = False
  433. if exc_val is not None:
  434. for exc in iterate_exceptions(exc_val):
  435. if isinstance(exc, CancelledError) and is_anyio_cancellation(
  436. exc
  437. ):
  438. self._cancelled_caught = True
  439. else:
  440. cannot_swallow_exc_val = True
  441. return self._cancelled_caught and not cannot_swallow_exc_val
  442. else:
  443. if self._pending_uncancellations:
  444. assert self._parent_scope is not None
  445. assert self._parent_scope._pending_uncancellations is not None
  446. self._parent_scope._pending_uncancellations += (
  447. self._pending_uncancellations
  448. )
  449. self._pending_uncancellations = 0
  450. return False
  451. finally:
  452. self._host_task = None
  453. del exc_val
  454. @property
  455. def _effectively_cancelled(self) -> bool:
  456. cancel_scope: CancelScope | None = self
  457. while cancel_scope is not None:
  458. if cancel_scope._cancel_called:
  459. return True
  460. if cancel_scope.shield:
  461. return False
  462. cancel_scope = cancel_scope._parent_scope
  463. return False
  464. @property
  465. def _parent_cancellation_is_visible_to_us(self) -> bool:
  466. return (
  467. self._parent_scope is not None
  468. and not self.shield
  469. and self._parent_scope._effectively_cancelled
  470. )
  471. def _timeout(self) -> None:
  472. if self._deadline != math.inf:
  473. loop = get_running_loop()
  474. if loop.time() >= self._deadline:
  475. self.cancel("deadline exceeded")
  476. else:
  477. self._timeout_handle = loop.call_at(self._deadline, self._timeout)
  478. def _deliver_cancellation(self, origin: CancelScope) -> bool:
  479. """
  480. Deliver cancellation to directly contained tasks and nested cancel scopes.
  481. Schedule another run at the end if we still have tasks eligible for
  482. cancellation.
  483. :param origin: the cancel scope that originated the cancellation
  484. :return: ``True`` if the delivery needs to be retried on the next cycle
  485. """
  486. should_retry = False
  487. current = current_task()
  488. for task in self._tasks:
  489. should_retry = True
  490. if task._must_cancel: # type: ignore[attr-defined]
  491. continue
  492. # The task is eligible for cancellation if it has started
  493. if task is not current and (task is self._host_task or _task_started(task)):
  494. waiter = task._fut_waiter # type: ignore[attr-defined]
  495. if not isinstance(waiter, asyncio.Future) or not waiter.done():
  496. task.cancel(origin._cancel_reason)
  497. if (
  498. task is origin._host_task
  499. and origin._pending_uncancellations is not None
  500. ):
  501. origin._pending_uncancellations += 1
  502. # Deliver cancellation to child scopes that aren't shielded or running their own
  503. # cancellation callbacks
  504. for scope in self._child_scopes:
  505. if not scope._shield and not scope.cancel_called:
  506. should_retry = scope._deliver_cancellation(origin) or should_retry
  507. # Schedule another callback if there are still tasks left
  508. if origin is self:
  509. if should_retry:
  510. self._cancel_handle = get_running_loop().call_soon(
  511. self._deliver_cancellation, origin
  512. )
  513. else:
  514. self._cancel_handle = None
  515. return should_retry
  516. def _restart_cancellation_in_parent(self) -> None:
  517. """
  518. Restart the cancellation effort in the closest directly cancelled parent scope.
  519. """
  520. scope = self._parent_scope
  521. while scope is not None:
  522. if scope._cancel_called:
  523. if scope._cancel_handle is None:
  524. scope._deliver_cancellation(scope)
  525. break
  526. # No point in looking beyond any shielded scope
  527. if scope._shield:
  528. break
  529. scope = scope._parent_scope
  530. def cancel(self, reason: str | None = None) -> None:
  531. if not self._cancel_called:
  532. if self._timeout_handle:
  533. self._timeout_handle.cancel()
  534. self._timeout_handle = None
  535. self._cancel_called = True
  536. self._cancel_reason = f"Cancelled via cancel scope {id(self):x}"
  537. if task := current_task():
  538. self._cancel_reason += f" by {task}"
  539. if reason:
  540. self._cancel_reason += f"; reason: {reason}"
  541. if self._host_task is not None:
  542. self._deliver_cancellation(self)
  543. @property
  544. def deadline(self) -> float:
  545. return self._deadline
  546. @deadline.setter
  547. def deadline(self, value: float) -> None:
  548. self._deadline = float(value)
  549. if self._timeout_handle is not None:
  550. self._timeout_handle.cancel()
  551. self._timeout_handle = None
  552. if self._active and not self._cancel_called:
  553. self._timeout()
  554. @property
  555. def cancel_called(self) -> bool:
  556. return self._cancel_called
  557. @property
  558. def cancelled_caught(self) -> bool:
  559. return self._cancelled_caught
  560. @property
  561. def shield(self) -> bool:
  562. return self._shield
  563. @shield.setter
  564. def shield(self, value: bool) -> None:
  565. if self._shield != value:
  566. self._shield = value
  567. if not value:
  568. self._restart_cancellation_in_parent()
  569. #
  570. # Task states
  571. #
  572. class TaskState:
  573. """
  574. Encapsulates auxiliary task information that cannot be added to the Task instance
  575. itself because there are no guarantees about its implementation.
  576. """
  577. __slots__ = "parent_id", "cancel_scope", "__weakref__"
  578. def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
  579. self.parent_id = parent_id
  580. self.cancel_scope = cancel_scope
  581. _task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
  582. #
  583. # Task groups
  584. #
  585. class _AsyncioTaskStatus(abc.TaskStatus):
  586. def __init__(self, future: asyncio.Future, parent_id: int):
  587. self._future = future
  588. self._parent_id = parent_id
  589. def started(self, value: T_contra | None = None) -> None:
  590. try:
  591. self._future.set_result(value)
  592. except asyncio.InvalidStateError:
  593. if not self._future.cancelled():
  594. raise RuntimeError(
  595. "called 'started' twice on the same task status"
  596. ) from None
  597. task = cast(asyncio.Task, current_task())
  598. _task_states[task].parent_id = self._parent_id
  599. if sys.version_info >= (3, 12):
  600. _eager_task_factory_code: CodeType | None = asyncio.eager_task_factory.__code__
  601. else:
  602. _eager_task_factory_code = None
  603. class TaskGroup(abc.TaskGroup):
  604. def __init__(self) -> None:
  605. self.cancel_scope: CancelScope = CancelScope()
  606. self._active = False
  607. self._exceptions: list[BaseException] = []
  608. self._tasks: set[asyncio.Task] = set()
  609. self._on_completed_fut: asyncio.Future[None] | None = None
  610. async def __aenter__(self) -> TaskGroup:
  611. self.cancel_scope.__enter__()
  612. self._active = True
  613. return self
  614. async def __aexit__(
  615. self,
  616. exc_type: type[BaseException] | None,
  617. exc_val: BaseException | None,
  618. exc_tb: TracebackType | None,
  619. ) -> bool:
  620. try:
  621. if exc_val is not None:
  622. self.cancel_scope.cancel()
  623. if not isinstance(exc_val, CancelledError):
  624. self._exceptions.append(exc_val)
  625. loop = get_running_loop()
  626. try:
  627. if self._tasks:
  628. with CancelScope() as wait_scope:
  629. while self._tasks:
  630. self._on_completed_fut = loop.create_future()
  631. try:
  632. await self._on_completed_fut
  633. except CancelledError as exc:
  634. # Shield the scope against further cancellation attempts,
  635. # as they're not productive (#695)
  636. wait_scope.shield = True
  637. self.cancel_scope.cancel()
  638. # Set exc_val from the cancellation exception if it was
  639. # previously unset. However, we should not replace a native
  640. # cancellation exception with one raise by a cancel scope.
  641. if exc_val is None or (
  642. isinstance(exc_val, CancelledError)
  643. and not is_anyio_cancellation(exc)
  644. ):
  645. exc_val = exc
  646. self._on_completed_fut = None
  647. else:
  648. # If there are no child tasks to wait on, run at least one checkpoint
  649. # anyway
  650. await AsyncIOBackend.cancel_shielded_checkpoint()
  651. self._active = False
  652. if self._exceptions:
  653. # The exception that got us here should already have been
  654. # added to self._exceptions so it's ok to break exception
  655. # chaining and avoid adding a "During handling of above..."
  656. # for each nesting level.
  657. raise BaseExceptionGroup(
  658. "unhandled errors in a TaskGroup", self._exceptions
  659. ) from None
  660. elif exc_val:
  661. raise exc_val
  662. except BaseException as exc:
  663. if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
  664. return True
  665. raise
  666. return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
  667. finally:
  668. del exc_val, exc_tb, self._exceptions
  669. def _spawn(
  670. self,
  671. func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
  672. args: tuple[Unpack[PosArgsT]],
  673. name: object,
  674. task_status_future: asyncio.Future | None = None,
  675. ) -> asyncio.Task:
  676. def task_done(_task: asyncio.Task) -> None:
  677. task_state = _task_states[_task]
  678. assert task_state.cancel_scope is not None
  679. assert _task in task_state.cancel_scope._tasks
  680. task_state.cancel_scope._tasks.remove(_task)
  681. self._tasks.remove(task)
  682. del _task_states[_task]
  683. if self._on_completed_fut is not None and not self._tasks:
  684. try:
  685. self._on_completed_fut.set_result(None)
  686. except asyncio.InvalidStateError:
  687. pass
  688. try:
  689. exc = _task.exception()
  690. except CancelledError as e:
  691. while isinstance(e.__context__, CancelledError):
  692. e = e.__context__
  693. exc = e
  694. if exc is not None:
  695. # The future can only be in the cancelled state if the host task was
  696. # cancelled, so return immediately instead of adding one more
  697. # CancelledError to the exceptions list
  698. if task_status_future is not None and task_status_future.cancelled():
  699. return
  700. if task_status_future is None or task_status_future.done():
  701. if not isinstance(exc, CancelledError):
  702. self._exceptions.append(exc)
  703. if not self.cancel_scope._effectively_cancelled:
  704. self.cancel_scope.cancel()
  705. else:
  706. task_status_future.set_exception(exc)
  707. elif task_status_future is not None and not task_status_future.done():
  708. task_status_future.set_exception(
  709. RuntimeError("Child exited without calling task_status.started()")
  710. )
  711. if not self._active:
  712. raise RuntimeError(
  713. "This task group is not active; no new tasks can be started."
  714. )
  715. kwargs = {}
  716. if task_status_future:
  717. parent_id = id(current_task())
  718. kwargs["task_status"] = _AsyncioTaskStatus(
  719. task_status_future, id(self.cancel_scope._host_task)
  720. )
  721. else:
  722. parent_id = id(self.cancel_scope._host_task)
  723. coro = func(*args, **kwargs)
  724. if not iscoroutine(coro):
  725. prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
  726. raise TypeError(
  727. f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
  728. f"the return value ({coro!r}) is not a coroutine object"
  729. )
  730. name = get_callable_name(func) if name is None else str(name)
  731. loop = asyncio.get_running_loop()
  732. if (
  733. (factory := loop.get_task_factory())
  734. and getattr(factory, "__code__", None) is _eager_task_factory_code
  735. and (closure := getattr(factory, "__closure__", None))
  736. ):
  737. custom_task_constructor = closure[0].cell_contents
  738. task = custom_task_constructor(coro, loop=loop, name=name)
  739. else:
  740. task = create_task(coro, name=name)
  741. # Make the spawned task inherit the task group's cancel scope
  742. _task_states[task] = TaskState(
  743. parent_id=parent_id, cancel_scope=self.cancel_scope
  744. )
  745. self.cancel_scope._tasks.add(task)
  746. self._tasks.add(task)
  747. task.add_done_callback(task_done)
  748. return task
  749. def start_soon(
  750. self,
  751. func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
  752. *args: Unpack[PosArgsT],
  753. name: object = None,
  754. ) -> None:
  755. self._spawn(func, args, name)
  756. async def start(
  757. self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
  758. ) -> Any:
  759. future: asyncio.Future = asyncio.Future()
  760. task = self._spawn(func, args, name, future)
  761. # If the task raises an exception after sending a start value without a switch
  762. # point between, the task group is cancelled and this method never proceeds to
  763. # process the completed future. That's why we have to have a shielded cancel
  764. # scope here.
  765. try:
  766. return await future
  767. except CancelledError:
  768. # Cancel the task and wait for it to exit before returning
  769. task.cancel()
  770. with CancelScope(shield=True), suppress(CancelledError):
  771. await task
  772. raise
  773. #
  774. # Threads
  775. #
  776. _Retval_Queue_Type = tuple[Optional[T_Retval], Optional[BaseException]]
  777. class WorkerThread(Thread):
  778. MAX_IDLE_TIME = 10 # seconds
  779. def __init__(
  780. self,
  781. root_task: asyncio.Task,
  782. workers: set[WorkerThread],
  783. idle_workers: deque[WorkerThread],
  784. ):
  785. super().__init__(name="AnyIO worker thread")
  786. self.root_task = root_task
  787. self.workers = workers
  788. self.idle_workers = idle_workers
  789. self.loop = root_task._loop
  790. self.queue: Queue[
  791. tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
  792. ] = Queue(2)
  793. self.idle_since = AsyncIOBackend.current_time()
  794. self.stopping = False
  795. def _report_result(
  796. self, future: asyncio.Future, result: Any, exc: BaseException | None
  797. ) -> None:
  798. self.idle_since = AsyncIOBackend.current_time()
  799. if not self.stopping:
  800. self.idle_workers.append(self)
  801. if not future.cancelled():
  802. if exc is not None:
  803. if isinstance(exc, StopIteration):
  804. new_exc = RuntimeError("coroutine raised StopIteration")
  805. new_exc.__cause__ = exc
  806. exc = new_exc
  807. future.set_exception(exc)
  808. else:
  809. future.set_result(result)
  810. def run(self) -> None:
  811. with claim_worker_thread(AsyncIOBackend, self.loop):
  812. while True:
  813. item = self.queue.get()
  814. if item is None:
  815. # Shutdown command received
  816. return
  817. context, func, args, future, cancel_scope = item
  818. if not future.cancelled():
  819. result = None
  820. exception: BaseException | None = None
  821. threadlocals.current_cancel_scope = cancel_scope
  822. try:
  823. result = context.run(func, *args)
  824. except BaseException as exc:
  825. exception = exc
  826. finally:
  827. del threadlocals.current_cancel_scope
  828. if not self.loop.is_closed():
  829. self.loop.call_soon_threadsafe(
  830. self._report_result, future, result, exception
  831. )
  832. del result, exception
  833. self.queue.task_done()
  834. del item, context, func, args, future, cancel_scope
  835. def stop(self, f: asyncio.Task | None = None) -> None:
  836. self.stopping = True
  837. self.queue.put_nowait(None)
  838. self.workers.discard(self)
  839. try:
  840. self.idle_workers.remove(self)
  841. except ValueError:
  842. pass
  843. _threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
  844. "_threadpool_idle_workers"
  845. )
  846. _threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
  847. class BlockingPortal(abc.BlockingPortal):
  848. def __new__(cls) -> BlockingPortal:
  849. return object.__new__(cls)
  850. def __init__(self) -> None:
  851. super().__init__()
  852. self._loop = get_running_loop()
  853. def _spawn_task_from_thread(
  854. self,
  855. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  856. args: tuple[Unpack[PosArgsT]],
  857. kwargs: dict[str, Any],
  858. name: object,
  859. future: Future[T_Retval],
  860. ) -> None:
  861. AsyncIOBackend.run_sync_from_thread(
  862. partial(self._task_group.start_soon, name=name),
  863. (self._call_func, func, args, kwargs, future),
  864. self._loop,
  865. )
  866. #
  867. # Subprocesses
  868. #
  869. @dataclass(eq=False)
  870. class StreamReaderWrapper(abc.ByteReceiveStream):
  871. _stream: asyncio.StreamReader
  872. async def receive(self, max_bytes: int = 65536) -> bytes:
  873. data = await self._stream.read(max_bytes)
  874. if data:
  875. return data
  876. else:
  877. raise EndOfStream
  878. async def aclose(self) -> None:
  879. self._stream.set_exception(ClosedResourceError())
  880. await AsyncIOBackend.checkpoint()
  881. @dataclass(eq=False)
  882. class StreamWriterWrapper(abc.ByteSendStream):
  883. _stream: asyncio.StreamWriter
  884. async def send(self, item: bytes) -> None:
  885. self._stream.write(item)
  886. await self._stream.drain()
  887. async def aclose(self) -> None:
  888. self._stream.close()
  889. await AsyncIOBackend.checkpoint()
  890. @dataclass(eq=False)
  891. class Process(abc.Process):
  892. _process: asyncio.subprocess.Process
  893. _stdin: StreamWriterWrapper | None
  894. _stdout: StreamReaderWrapper | None
  895. _stderr: StreamReaderWrapper | None
  896. async def aclose(self) -> None:
  897. with CancelScope(shield=True) as scope:
  898. if self._stdin:
  899. await self._stdin.aclose()
  900. if self._stdout:
  901. await self._stdout.aclose()
  902. if self._stderr:
  903. await self._stderr.aclose()
  904. scope.shield = False
  905. try:
  906. await self.wait()
  907. except BaseException:
  908. scope.shield = True
  909. self.kill()
  910. await self.wait()
  911. raise
  912. async def wait(self) -> int:
  913. return await self._process.wait()
  914. def terminate(self) -> None:
  915. self._process.terminate()
  916. def kill(self) -> None:
  917. self._process.kill()
  918. def send_signal(self, signal: int) -> None:
  919. self._process.send_signal(signal)
  920. @property
  921. def pid(self) -> int:
  922. return self._process.pid
  923. @property
  924. def returncode(self) -> int | None:
  925. return self._process.returncode
  926. @property
  927. def stdin(self) -> abc.ByteSendStream | None:
  928. return self._stdin
  929. @property
  930. def stdout(self) -> abc.ByteReceiveStream | None:
  931. return self._stdout
  932. @property
  933. def stderr(self) -> abc.ByteReceiveStream | None:
  934. return self._stderr
  935. def _forcibly_shutdown_process_pool_on_exit(
  936. workers: set[Process], _task: object
  937. ) -> None:
  938. """
  939. Forcibly shuts down worker processes belonging to this event loop."""
  940. child_watcher: asyncio.AbstractChildWatcher | None = None
  941. if sys.version_info < (3, 12):
  942. try:
  943. child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
  944. except NotImplementedError:
  945. pass
  946. # Close as much as possible (w/o async/await) to avoid warnings
  947. for process in workers:
  948. if process.returncode is None:
  949. continue
  950. process._stdin._stream._transport.close() # type: ignore[union-attr]
  951. process._stdout._stream._transport.close() # type: ignore[union-attr]
  952. process._stderr._stream._transport.close() # type: ignore[union-attr]
  953. process.kill()
  954. if child_watcher:
  955. child_watcher.remove_child_handler(process.pid)
  956. async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
  957. """
  958. Shuts down worker processes belonging to this event loop.
  959. NOTE: this only works when the event loop was started using asyncio.run() or
  960. anyio.run().
  961. """
  962. process: abc.Process
  963. try:
  964. await sleep(math.inf)
  965. except asyncio.CancelledError:
  966. for process in workers:
  967. if process.returncode is None:
  968. process.kill()
  969. for process in workers:
  970. await process.aclose()
  971. #
  972. # Sockets and networking
  973. #
  974. class StreamProtocol(asyncio.Protocol):
  975. read_queue: deque[bytes]
  976. read_event: asyncio.Event
  977. write_event: asyncio.Event
  978. exception: Exception | None = None
  979. is_at_eof: bool = False
  980. def connection_made(self, transport: asyncio.BaseTransport) -> None:
  981. self.read_queue = deque()
  982. self.read_event = asyncio.Event()
  983. self.write_event = asyncio.Event()
  984. self.write_event.set()
  985. cast(asyncio.Transport, transport).set_write_buffer_limits(0)
  986. def connection_lost(self, exc: Exception | None) -> None:
  987. if exc:
  988. self.exception = BrokenResourceError()
  989. self.exception.__cause__ = exc
  990. self.read_event.set()
  991. self.write_event.set()
  992. def data_received(self, data: bytes) -> None:
  993. # ProactorEventloop sometimes sends bytearray instead of bytes
  994. self.read_queue.append(bytes(data))
  995. self.read_event.set()
  996. def eof_received(self) -> bool | None:
  997. self.is_at_eof = True
  998. self.read_event.set()
  999. return True
  1000. def pause_writing(self) -> None:
  1001. self.write_event = asyncio.Event()
  1002. def resume_writing(self) -> None:
  1003. self.write_event.set()
  1004. class DatagramProtocol(asyncio.DatagramProtocol):
  1005. read_queue: deque[tuple[bytes, IPSockAddrType]]
  1006. read_event: asyncio.Event
  1007. write_event: asyncio.Event
  1008. exception: Exception | None = None
  1009. def connection_made(self, transport: asyncio.BaseTransport) -> None:
  1010. self.read_queue = deque(maxlen=100) # arbitrary value
  1011. self.read_event = asyncio.Event()
  1012. self.write_event = asyncio.Event()
  1013. self.write_event.set()
  1014. def connection_lost(self, exc: Exception | None) -> None:
  1015. self.read_event.set()
  1016. self.write_event.set()
  1017. def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
  1018. addr = convert_ipv6_sockaddr(addr)
  1019. self.read_queue.append((data, addr))
  1020. self.read_event.set()
  1021. def error_received(self, exc: Exception) -> None:
  1022. self.exception = exc
  1023. def pause_writing(self) -> None:
  1024. self.write_event.clear()
  1025. def resume_writing(self) -> None:
  1026. self.write_event.set()
  1027. class SocketStream(abc.SocketStream):
  1028. def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
  1029. self._transport = transport
  1030. self._protocol = protocol
  1031. self._receive_guard = ResourceGuard("reading from")
  1032. self._send_guard = ResourceGuard("writing to")
  1033. self._closed = False
  1034. @property
  1035. def _raw_socket(self) -> socket.socket:
  1036. return self._transport.get_extra_info("socket")
  1037. async def receive(self, max_bytes: int = 65536) -> bytes:
  1038. with self._receive_guard:
  1039. if (
  1040. not self._protocol.read_event.is_set()
  1041. and not self._transport.is_closing()
  1042. and not self._protocol.is_at_eof
  1043. ):
  1044. self._transport.resume_reading()
  1045. await self._protocol.read_event.wait()
  1046. self._transport.pause_reading()
  1047. else:
  1048. await AsyncIOBackend.checkpoint()
  1049. try:
  1050. chunk = self._protocol.read_queue.popleft()
  1051. except IndexError:
  1052. if self._closed:
  1053. raise ClosedResourceError from None
  1054. elif self._protocol.exception:
  1055. raise self._protocol.exception from None
  1056. else:
  1057. raise EndOfStream from None
  1058. if len(chunk) > max_bytes:
  1059. # Split the oversized chunk
  1060. chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
  1061. self._protocol.read_queue.appendleft(leftover)
  1062. # If the read queue is empty, clear the flag so that the next call will
  1063. # block until data is available
  1064. if not self._protocol.read_queue:
  1065. self._protocol.read_event.clear()
  1066. return chunk
  1067. async def send(self, item: bytes) -> None:
  1068. with self._send_guard:
  1069. await AsyncIOBackend.checkpoint()
  1070. if self._closed:
  1071. raise ClosedResourceError
  1072. elif self._protocol.exception is not None:
  1073. raise self._protocol.exception
  1074. try:
  1075. self._transport.write(item)
  1076. except RuntimeError as exc:
  1077. if self._transport.is_closing():
  1078. raise BrokenResourceError from exc
  1079. else:
  1080. raise
  1081. await self._protocol.write_event.wait()
  1082. async def send_eof(self) -> None:
  1083. try:
  1084. self._transport.write_eof()
  1085. except OSError:
  1086. pass
  1087. async def aclose(self) -> None:
  1088. self._closed = True
  1089. if not self._transport.is_closing():
  1090. try:
  1091. self._transport.write_eof()
  1092. except OSError:
  1093. pass
  1094. self._transport.close()
  1095. await sleep(0)
  1096. self._transport.abort()
  1097. class _RawSocketMixin:
  1098. _receive_future: asyncio.Future | None = None
  1099. _send_future: asyncio.Future | None = None
  1100. _closing = False
  1101. def __init__(self, raw_socket: socket.socket):
  1102. self.__raw_socket = raw_socket
  1103. self._receive_guard = ResourceGuard("reading from")
  1104. self._send_guard = ResourceGuard("writing to")
  1105. @property
  1106. def _raw_socket(self) -> socket.socket:
  1107. return self.__raw_socket
  1108. def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
  1109. def callback(f: object) -> None:
  1110. del self._receive_future
  1111. loop.remove_reader(self.__raw_socket)
  1112. f = self._receive_future = asyncio.Future()
  1113. loop.add_reader(self.__raw_socket, f.set_result, None)
  1114. f.add_done_callback(callback)
  1115. return f
  1116. def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
  1117. def callback(f: object) -> None:
  1118. del self._send_future
  1119. loop.remove_writer(self.__raw_socket)
  1120. f = self._send_future = asyncio.Future()
  1121. loop.add_writer(self.__raw_socket, f.set_result, None)
  1122. f.add_done_callback(callback)
  1123. return f
  1124. async def aclose(self) -> None:
  1125. if not self._closing:
  1126. self._closing = True
  1127. if self.__raw_socket.fileno() != -1:
  1128. self.__raw_socket.close()
  1129. if self._receive_future:
  1130. self._receive_future.set_result(None)
  1131. if self._send_future:
  1132. self._send_future.set_result(None)
  1133. class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
  1134. async def send_eof(self) -> None:
  1135. with self._send_guard:
  1136. self._raw_socket.shutdown(socket.SHUT_WR)
  1137. async def receive(self, max_bytes: int = 65536) -> bytes:
  1138. loop = get_running_loop()
  1139. await AsyncIOBackend.checkpoint()
  1140. with self._receive_guard:
  1141. while True:
  1142. try:
  1143. data = self._raw_socket.recv(max_bytes)
  1144. except BlockingIOError:
  1145. await self._wait_until_readable(loop)
  1146. except OSError as exc:
  1147. if self._closing:
  1148. raise ClosedResourceError from None
  1149. else:
  1150. raise BrokenResourceError from exc
  1151. else:
  1152. if not data:
  1153. raise EndOfStream
  1154. return data
  1155. async def send(self, item: bytes) -> None:
  1156. loop = get_running_loop()
  1157. await AsyncIOBackend.checkpoint()
  1158. with self._send_guard:
  1159. view = memoryview(item)
  1160. while view:
  1161. try:
  1162. bytes_sent = self._raw_socket.send(view)
  1163. except BlockingIOError:
  1164. await self._wait_until_writable(loop)
  1165. except OSError as exc:
  1166. if self._closing:
  1167. raise ClosedResourceError from None
  1168. else:
  1169. raise BrokenResourceError from exc
  1170. else:
  1171. view = view[bytes_sent:]
  1172. async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
  1173. if not isinstance(msglen, int) or msglen < 0:
  1174. raise ValueError("msglen must be a non-negative integer")
  1175. if not isinstance(maxfds, int) or maxfds < 1:
  1176. raise ValueError("maxfds must be a positive integer")
  1177. loop = get_running_loop()
  1178. fds = array.array("i")
  1179. await AsyncIOBackend.checkpoint()
  1180. with self._receive_guard:
  1181. while True:
  1182. try:
  1183. message, ancdata, flags, addr = self._raw_socket.recvmsg(
  1184. msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
  1185. )
  1186. except BlockingIOError:
  1187. await self._wait_until_readable(loop)
  1188. except OSError as exc:
  1189. if self._closing:
  1190. raise ClosedResourceError from None
  1191. else:
  1192. raise BrokenResourceError from exc
  1193. else:
  1194. if not message and not ancdata:
  1195. raise EndOfStream
  1196. break
  1197. for cmsg_level, cmsg_type, cmsg_data in ancdata:
  1198. if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
  1199. raise RuntimeError(
  1200. f"Received unexpected ancillary data; message = {message!r}, "
  1201. f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
  1202. )
  1203. fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
  1204. return message, list(fds)
  1205. async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
  1206. if not message:
  1207. raise ValueError("message must not be empty")
  1208. if not fds:
  1209. raise ValueError("fds must not be empty")
  1210. loop = get_running_loop()
  1211. filenos: list[int] = []
  1212. for fd in fds:
  1213. if isinstance(fd, int):
  1214. filenos.append(fd)
  1215. elif isinstance(fd, IOBase):
  1216. filenos.append(fd.fileno())
  1217. fdarray = array.array("i", filenos)
  1218. await AsyncIOBackend.checkpoint()
  1219. with self._send_guard:
  1220. while True:
  1221. try:
  1222. # The ignore can be removed after mypy picks up
  1223. # https://github.com/python/typeshed/pull/5545
  1224. self._raw_socket.sendmsg(
  1225. [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
  1226. )
  1227. break
  1228. except BlockingIOError:
  1229. await self._wait_until_writable(loop)
  1230. except OSError as exc:
  1231. if self._closing:
  1232. raise ClosedResourceError from None
  1233. else:
  1234. raise BrokenResourceError from exc
  1235. class TCPSocketListener(abc.SocketListener):
  1236. _accept_scope: CancelScope | None = None
  1237. _closed = False
  1238. def __init__(self, raw_socket: socket.socket):
  1239. self.__raw_socket = raw_socket
  1240. self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
  1241. self._accept_guard = ResourceGuard("accepting connections from")
  1242. @property
  1243. def _raw_socket(self) -> socket.socket:
  1244. return self.__raw_socket
  1245. async def accept(self) -> abc.SocketStream:
  1246. if self._closed:
  1247. raise ClosedResourceError
  1248. with self._accept_guard:
  1249. await AsyncIOBackend.checkpoint()
  1250. with CancelScope() as self._accept_scope:
  1251. try:
  1252. client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
  1253. except asyncio.CancelledError:
  1254. # Workaround for https://bugs.python.org/issue41317
  1255. try:
  1256. self._loop.remove_reader(self._raw_socket)
  1257. except (ValueError, NotImplementedError):
  1258. pass
  1259. if self._closed:
  1260. raise ClosedResourceError from None
  1261. raise
  1262. finally:
  1263. self._accept_scope = None
  1264. client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
  1265. transport, protocol = await self._loop.connect_accepted_socket(
  1266. StreamProtocol, client_sock
  1267. )
  1268. return SocketStream(transport, protocol)
  1269. async def aclose(self) -> None:
  1270. if self._closed:
  1271. return
  1272. self._closed = True
  1273. if self._accept_scope:
  1274. # Workaround for https://bugs.python.org/issue41317
  1275. try:
  1276. self._loop.remove_reader(self._raw_socket)
  1277. except (ValueError, NotImplementedError):
  1278. pass
  1279. self._accept_scope.cancel()
  1280. await sleep(0)
  1281. self._raw_socket.close()
  1282. class UNIXSocketListener(abc.SocketListener):
  1283. def __init__(self, raw_socket: socket.socket):
  1284. self.__raw_socket = raw_socket
  1285. self._loop = get_running_loop()
  1286. self._accept_guard = ResourceGuard("accepting connections from")
  1287. self._closed = False
  1288. async def accept(self) -> abc.SocketStream:
  1289. await AsyncIOBackend.checkpoint()
  1290. with self._accept_guard:
  1291. while True:
  1292. try:
  1293. client_sock, _ = self.__raw_socket.accept()
  1294. client_sock.setblocking(False)
  1295. return UNIXSocketStream(client_sock)
  1296. except BlockingIOError:
  1297. f: asyncio.Future = asyncio.Future()
  1298. self._loop.add_reader(self.__raw_socket, f.set_result, None)
  1299. f.add_done_callback(
  1300. lambda _: self._loop.remove_reader(self.__raw_socket)
  1301. )
  1302. await f
  1303. except OSError as exc:
  1304. if self._closed:
  1305. raise ClosedResourceError from None
  1306. else:
  1307. raise BrokenResourceError from exc
  1308. async def aclose(self) -> None:
  1309. self._closed = True
  1310. self.__raw_socket.close()
  1311. @property
  1312. def _raw_socket(self) -> socket.socket:
  1313. return self.__raw_socket
  1314. class UDPSocket(abc.UDPSocket):
  1315. def __init__(
  1316. self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
  1317. ):
  1318. self._transport = transport
  1319. self._protocol = protocol
  1320. self._receive_guard = ResourceGuard("reading from")
  1321. self._send_guard = ResourceGuard("writing to")
  1322. self._closed = False
  1323. @property
  1324. def _raw_socket(self) -> socket.socket:
  1325. return self._transport.get_extra_info("socket")
  1326. async def aclose(self) -> None:
  1327. if not self._transport.is_closing():
  1328. self._closed = True
  1329. self._transport.close()
  1330. async def receive(self) -> tuple[bytes, IPSockAddrType]:
  1331. with self._receive_guard:
  1332. await AsyncIOBackend.checkpoint()
  1333. # If the buffer is empty, ask for more data
  1334. if not self._protocol.read_queue and not self._transport.is_closing():
  1335. self._protocol.read_event.clear()
  1336. await self._protocol.read_event.wait()
  1337. try:
  1338. return self._protocol.read_queue.popleft()
  1339. except IndexError:
  1340. if self._closed:
  1341. raise ClosedResourceError from None
  1342. else:
  1343. raise BrokenResourceError from None
  1344. async def send(self, item: UDPPacketType) -> None:
  1345. with self._send_guard:
  1346. await AsyncIOBackend.checkpoint()
  1347. await self._protocol.write_event.wait()
  1348. if self._closed:
  1349. raise ClosedResourceError
  1350. elif self._transport.is_closing():
  1351. raise BrokenResourceError
  1352. else:
  1353. self._transport.sendto(*item)
  1354. class ConnectedUDPSocket(abc.ConnectedUDPSocket):
  1355. def __init__(
  1356. self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
  1357. ):
  1358. self._transport = transport
  1359. self._protocol = protocol
  1360. self._receive_guard = ResourceGuard("reading from")
  1361. self._send_guard = ResourceGuard("writing to")
  1362. self._closed = False
  1363. @property
  1364. def _raw_socket(self) -> socket.socket:
  1365. return self._transport.get_extra_info("socket")
  1366. async def aclose(self) -> None:
  1367. if not self._transport.is_closing():
  1368. self._closed = True
  1369. self._transport.close()
  1370. async def receive(self) -> bytes:
  1371. with self._receive_guard:
  1372. await AsyncIOBackend.checkpoint()
  1373. # If the buffer is empty, ask for more data
  1374. if not self._protocol.read_queue and not self._transport.is_closing():
  1375. self._protocol.read_event.clear()
  1376. await self._protocol.read_event.wait()
  1377. try:
  1378. packet = self._protocol.read_queue.popleft()
  1379. except IndexError:
  1380. if self._closed:
  1381. raise ClosedResourceError from None
  1382. else:
  1383. raise BrokenResourceError from None
  1384. return packet[0]
  1385. async def send(self, item: bytes) -> None:
  1386. with self._send_guard:
  1387. await AsyncIOBackend.checkpoint()
  1388. await self._protocol.write_event.wait()
  1389. if self._closed:
  1390. raise ClosedResourceError
  1391. elif self._transport.is_closing():
  1392. raise BrokenResourceError
  1393. else:
  1394. self._transport.sendto(item)
  1395. class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
  1396. async def receive(self) -> UNIXDatagramPacketType:
  1397. loop = get_running_loop()
  1398. await AsyncIOBackend.checkpoint()
  1399. with self._receive_guard:
  1400. while True:
  1401. try:
  1402. data = self._raw_socket.recvfrom(65536)
  1403. except BlockingIOError:
  1404. await self._wait_until_readable(loop)
  1405. except OSError as exc:
  1406. if self._closing:
  1407. raise ClosedResourceError from None
  1408. else:
  1409. raise BrokenResourceError from exc
  1410. else:
  1411. return data
  1412. async def send(self, item: UNIXDatagramPacketType) -> None:
  1413. loop = get_running_loop()
  1414. await AsyncIOBackend.checkpoint()
  1415. with self._send_guard:
  1416. while True:
  1417. try:
  1418. self._raw_socket.sendto(*item)
  1419. except BlockingIOError:
  1420. await self._wait_until_writable(loop)
  1421. except OSError as exc:
  1422. if self._closing:
  1423. raise ClosedResourceError from None
  1424. else:
  1425. raise BrokenResourceError from exc
  1426. else:
  1427. return
  1428. class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
  1429. async def receive(self) -> bytes:
  1430. loop = get_running_loop()
  1431. await AsyncIOBackend.checkpoint()
  1432. with self._receive_guard:
  1433. while True:
  1434. try:
  1435. data = self._raw_socket.recv(65536)
  1436. except BlockingIOError:
  1437. await self._wait_until_readable(loop)
  1438. except OSError as exc:
  1439. if self._closing:
  1440. raise ClosedResourceError from None
  1441. else:
  1442. raise BrokenResourceError from exc
  1443. else:
  1444. return data
  1445. async def send(self, item: bytes) -> None:
  1446. loop = get_running_loop()
  1447. await AsyncIOBackend.checkpoint()
  1448. with self._send_guard:
  1449. while True:
  1450. try:
  1451. self._raw_socket.send(item)
  1452. except BlockingIOError:
  1453. await self._wait_until_writable(loop)
  1454. except OSError as exc:
  1455. if self._closing:
  1456. raise ClosedResourceError from None
  1457. else:
  1458. raise BrokenResourceError from exc
  1459. else:
  1460. return
  1461. _read_events: RunVar[dict[int, asyncio.Future[bool]]] = RunVar("read_events")
  1462. _write_events: RunVar[dict[int, asyncio.Future[bool]]] = RunVar("write_events")
  1463. #
  1464. # Synchronization
  1465. #
  1466. class Event(BaseEvent):
  1467. def __new__(cls) -> Event:
  1468. return object.__new__(cls)
  1469. def __init__(self) -> None:
  1470. self._event = asyncio.Event()
  1471. def set(self) -> None:
  1472. self._event.set()
  1473. def is_set(self) -> bool:
  1474. return self._event.is_set()
  1475. async def wait(self) -> None:
  1476. if self.is_set():
  1477. await AsyncIOBackend.checkpoint()
  1478. else:
  1479. await self._event.wait()
  1480. def statistics(self) -> EventStatistics:
  1481. return EventStatistics(len(self._event._waiters))
  1482. class Lock(BaseLock):
  1483. def __new__(cls, *, fast_acquire: bool = False) -> Lock:
  1484. return object.__new__(cls)
  1485. def __init__(self, *, fast_acquire: bool = False) -> None:
  1486. self._fast_acquire = fast_acquire
  1487. self._owner_task: asyncio.Task | None = None
  1488. self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque()
  1489. async def acquire(self) -> None:
  1490. task = cast(asyncio.Task, current_task())
  1491. if self._owner_task is None and not self._waiters:
  1492. await AsyncIOBackend.checkpoint_if_cancelled()
  1493. self._owner_task = task
  1494. # Unless on the "fast path", yield control of the event loop so that other
  1495. # tasks can run too
  1496. if not self._fast_acquire:
  1497. try:
  1498. await AsyncIOBackend.cancel_shielded_checkpoint()
  1499. except CancelledError:
  1500. self.release()
  1501. raise
  1502. return
  1503. if self._owner_task == task:
  1504. raise RuntimeError("Attempted to acquire an already held Lock")
  1505. fut: asyncio.Future[None] = asyncio.Future()
  1506. item = task, fut
  1507. self._waiters.append(item)
  1508. try:
  1509. await fut
  1510. except CancelledError:
  1511. self._waiters.remove(item)
  1512. if self._owner_task is task:
  1513. self.release()
  1514. raise
  1515. self._waiters.remove(item)
  1516. def acquire_nowait(self) -> None:
  1517. task = cast(asyncio.Task, current_task())
  1518. if self._owner_task is None and not self._waiters:
  1519. self._owner_task = task
  1520. return
  1521. if self._owner_task is task:
  1522. raise RuntimeError("Attempted to acquire an already held Lock")
  1523. raise WouldBlock
  1524. def locked(self) -> bool:
  1525. return self._owner_task is not None
  1526. def release(self) -> None:
  1527. if self._owner_task != current_task():
  1528. raise RuntimeError("The current task is not holding this lock")
  1529. for task, fut in self._waiters:
  1530. if not fut.cancelled():
  1531. self._owner_task = task
  1532. fut.set_result(None)
  1533. return
  1534. self._owner_task = None
  1535. def statistics(self) -> LockStatistics:
  1536. task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None
  1537. return LockStatistics(self.locked(), task_info, len(self._waiters))
  1538. class Semaphore(BaseSemaphore):
  1539. def __new__(
  1540. cls,
  1541. initial_value: int,
  1542. *,
  1543. max_value: int | None = None,
  1544. fast_acquire: bool = False,
  1545. ) -> Semaphore:
  1546. return object.__new__(cls)
  1547. def __init__(
  1548. self,
  1549. initial_value: int,
  1550. *,
  1551. max_value: int | None = None,
  1552. fast_acquire: bool = False,
  1553. ):
  1554. super().__init__(initial_value, max_value=max_value)
  1555. self._value = initial_value
  1556. self._max_value = max_value
  1557. self._fast_acquire = fast_acquire
  1558. self._waiters: deque[asyncio.Future[None]] = deque()
  1559. async def acquire(self) -> None:
  1560. if self._value > 0 and not self._waiters:
  1561. await AsyncIOBackend.checkpoint_if_cancelled()
  1562. self._value -= 1
  1563. # Unless on the "fast path", yield control of the event loop so that other
  1564. # tasks can run too
  1565. if not self._fast_acquire:
  1566. try:
  1567. await AsyncIOBackend.cancel_shielded_checkpoint()
  1568. except CancelledError:
  1569. self.release()
  1570. raise
  1571. return
  1572. fut: asyncio.Future[None] = asyncio.Future()
  1573. self._waiters.append(fut)
  1574. try:
  1575. await fut
  1576. except CancelledError:
  1577. try:
  1578. self._waiters.remove(fut)
  1579. except ValueError:
  1580. self.release()
  1581. raise
  1582. def acquire_nowait(self) -> None:
  1583. if self._value == 0:
  1584. raise WouldBlock
  1585. self._value -= 1
  1586. def release(self) -> None:
  1587. if self._max_value is not None and self._value == self._max_value:
  1588. raise ValueError("semaphore released too many times")
  1589. for fut in self._waiters:
  1590. if not fut.cancelled():
  1591. fut.set_result(None)
  1592. self._waiters.remove(fut)
  1593. return
  1594. self._value += 1
  1595. @property
  1596. def value(self) -> int:
  1597. return self._value
  1598. @property
  1599. def max_value(self) -> int | None:
  1600. return self._max_value
  1601. def statistics(self) -> SemaphoreStatistics:
  1602. return SemaphoreStatistics(len(self._waiters))
  1603. class CapacityLimiter(BaseCapacityLimiter):
  1604. _total_tokens: float = 0
  1605. def __new__(cls, total_tokens: float) -> CapacityLimiter:
  1606. return object.__new__(cls)
  1607. def __init__(self, total_tokens: float):
  1608. self._borrowers: set[Any] = set()
  1609. self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
  1610. self.total_tokens = total_tokens
  1611. async def __aenter__(self) -> None:
  1612. await self.acquire()
  1613. async def __aexit__(
  1614. self,
  1615. exc_type: type[BaseException] | None,
  1616. exc_val: BaseException | None,
  1617. exc_tb: TracebackType | None,
  1618. ) -> None:
  1619. self.release()
  1620. @property
  1621. def total_tokens(self) -> float:
  1622. return self._total_tokens
  1623. @total_tokens.setter
  1624. def total_tokens(self, value: float) -> None:
  1625. if not isinstance(value, int) and not math.isinf(value):
  1626. raise TypeError("total_tokens must be an int or math.inf")
  1627. if value < 1:
  1628. raise ValueError("total_tokens must be >= 1")
  1629. waiters_to_notify = max(value - self._total_tokens, 0)
  1630. self._total_tokens = value
  1631. # Notify waiting tasks that they have acquired the limiter
  1632. while self._wait_queue and waiters_to_notify:
  1633. event = self._wait_queue.popitem(last=False)[1]
  1634. event.set()
  1635. waiters_to_notify -= 1
  1636. @property
  1637. def borrowed_tokens(self) -> int:
  1638. return len(self._borrowers)
  1639. @property
  1640. def available_tokens(self) -> float:
  1641. return self._total_tokens - len(self._borrowers)
  1642. def _notify_next_waiter(self) -> None:
  1643. """Notify the next task in line if this limiter has free capacity now."""
  1644. if self._wait_queue and len(self._borrowers) < self._total_tokens:
  1645. event = self._wait_queue.popitem(last=False)[1]
  1646. event.set()
  1647. def acquire_nowait(self) -> None:
  1648. self.acquire_on_behalf_of_nowait(current_task())
  1649. def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
  1650. if borrower in self._borrowers:
  1651. raise RuntimeError(
  1652. "this borrower is already holding one of this CapacityLimiter's tokens"
  1653. )
  1654. if self._wait_queue or len(self._borrowers) >= self._total_tokens:
  1655. raise WouldBlock
  1656. self._borrowers.add(borrower)
  1657. async def acquire(self) -> None:
  1658. return await self.acquire_on_behalf_of(current_task())
  1659. async def acquire_on_behalf_of(self, borrower: object) -> None:
  1660. await AsyncIOBackend.checkpoint_if_cancelled()
  1661. try:
  1662. self.acquire_on_behalf_of_nowait(borrower)
  1663. except WouldBlock:
  1664. event = asyncio.Event()
  1665. self._wait_queue[borrower] = event
  1666. try:
  1667. await event.wait()
  1668. except BaseException:
  1669. self._wait_queue.pop(borrower, None)
  1670. if event.is_set():
  1671. self._notify_next_waiter()
  1672. raise
  1673. self._borrowers.add(borrower)
  1674. else:
  1675. try:
  1676. await AsyncIOBackend.cancel_shielded_checkpoint()
  1677. except BaseException:
  1678. self.release()
  1679. raise
  1680. def release(self) -> None:
  1681. self.release_on_behalf_of(current_task())
  1682. def release_on_behalf_of(self, borrower: object) -> None:
  1683. try:
  1684. self._borrowers.remove(borrower)
  1685. except KeyError:
  1686. raise RuntimeError(
  1687. "this borrower isn't holding any of this CapacityLimiter's tokens"
  1688. ) from None
  1689. self._notify_next_waiter()
  1690. def statistics(self) -> CapacityLimiterStatistics:
  1691. return CapacityLimiterStatistics(
  1692. self.borrowed_tokens,
  1693. self.total_tokens,
  1694. tuple(self._borrowers),
  1695. len(self._wait_queue),
  1696. )
  1697. _default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
  1698. #
  1699. # Operating system signals
  1700. #
  1701. class _SignalReceiver:
  1702. def __init__(self, signals: tuple[Signals, ...]):
  1703. self._signals = signals
  1704. self._loop = get_running_loop()
  1705. self._signal_queue: deque[Signals] = deque()
  1706. self._future: asyncio.Future = asyncio.Future()
  1707. self._handled_signals: set[Signals] = set()
  1708. def _deliver(self, signum: Signals) -> None:
  1709. self._signal_queue.append(signum)
  1710. if not self._future.done():
  1711. self._future.set_result(None)
  1712. def __enter__(self) -> _SignalReceiver:
  1713. for sig in set(self._signals):
  1714. self._loop.add_signal_handler(sig, self._deliver, sig)
  1715. self._handled_signals.add(sig)
  1716. return self
  1717. def __exit__(
  1718. self,
  1719. exc_type: type[BaseException] | None,
  1720. exc_val: BaseException | None,
  1721. exc_tb: TracebackType | None,
  1722. ) -> None:
  1723. for sig in self._handled_signals:
  1724. self._loop.remove_signal_handler(sig)
  1725. def __aiter__(self) -> _SignalReceiver:
  1726. return self
  1727. async def __anext__(self) -> Signals:
  1728. await AsyncIOBackend.checkpoint()
  1729. if not self._signal_queue:
  1730. self._future = asyncio.Future()
  1731. await self._future
  1732. return self._signal_queue.popleft()
  1733. #
  1734. # Testing and debugging
  1735. #
  1736. class AsyncIOTaskInfo(TaskInfo):
  1737. def __init__(self, task: asyncio.Task):
  1738. task_state = _task_states.get(task)
  1739. if task_state is None:
  1740. parent_id = None
  1741. else:
  1742. parent_id = task_state.parent_id
  1743. coro = task.get_coro()
  1744. assert coro is not None, "created TaskInfo from a completed Task"
  1745. super().__init__(id(task), parent_id, task.get_name(), coro)
  1746. self._task = weakref.ref(task)
  1747. def has_pending_cancellation(self) -> bool:
  1748. if not (task := self._task()):
  1749. # If the task isn't around anymore, it won't have a pending cancellation
  1750. return False
  1751. if task._must_cancel: # type: ignore[attr-defined]
  1752. return True
  1753. elif (
  1754. isinstance(task._fut_waiter, asyncio.Future) # type: ignore[attr-defined]
  1755. and task._fut_waiter.cancelled() # type: ignore[attr-defined]
  1756. ):
  1757. return True
  1758. if task_state := _task_states.get(task):
  1759. if cancel_scope := task_state.cancel_scope:
  1760. return cancel_scope._effectively_cancelled
  1761. return False
  1762. class TestRunner(abc.TestRunner):
  1763. _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
  1764. def __init__(
  1765. self,
  1766. *,
  1767. debug: bool | None = None,
  1768. use_uvloop: bool = False,
  1769. loop_factory: Callable[[], AbstractEventLoop] | None = None,
  1770. ) -> None:
  1771. if use_uvloop and loop_factory is None:
  1772. import uvloop
  1773. loop_factory = uvloop.new_event_loop
  1774. self._runner = Runner(debug=debug, loop_factory=loop_factory)
  1775. self._exceptions: list[BaseException] = []
  1776. self._runner_task: asyncio.Task | None = None
  1777. def __enter__(self) -> TestRunner:
  1778. self._runner.__enter__()
  1779. self.get_loop().set_exception_handler(self._exception_handler)
  1780. return self
  1781. def __exit__(
  1782. self,
  1783. exc_type: type[BaseException] | None,
  1784. exc_val: BaseException | None,
  1785. exc_tb: TracebackType | None,
  1786. ) -> None:
  1787. self._runner.__exit__(exc_type, exc_val, exc_tb)
  1788. def get_loop(self) -> AbstractEventLoop:
  1789. return self._runner.get_loop()
  1790. def _exception_handler(
  1791. self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
  1792. ) -> None:
  1793. if isinstance(context.get("exception"), Exception):
  1794. self._exceptions.append(context["exception"])
  1795. else:
  1796. loop.default_exception_handler(context)
  1797. def _raise_async_exceptions(self) -> None:
  1798. # Re-raise any exceptions raised in asynchronous callbacks
  1799. if self._exceptions:
  1800. exceptions, self._exceptions = self._exceptions, []
  1801. if len(exceptions) == 1:
  1802. raise exceptions[0]
  1803. elif exceptions:
  1804. raise BaseExceptionGroup(
  1805. "Multiple exceptions occurred in asynchronous callbacks", exceptions
  1806. )
  1807. async def _run_tests_and_fixtures(
  1808. self,
  1809. receive_stream: MemoryObjectReceiveStream[
  1810. tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
  1811. ],
  1812. ) -> None:
  1813. from _pytest.outcomes import OutcomeException
  1814. with receive_stream, self._send_stream:
  1815. async for coro, future in receive_stream:
  1816. try:
  1817. retval = await coro
  1818. except CancelledError as exc:
  1819. if not future.cancelled():
  1820. future.cancel(*exc.args)
  1821. raise
  1822. except BaseException as exc:
  1823. if not future.cancelled():
  1824. future.set_exception(exc)
  1825. if not isinstance(exc, (Exception, OutcomeException)):
  1826. raise
  1827. else:
  1828. if not future.cancelled():
  1829. future.set_result(retval)
  1830. async def _call_in_runner_task(
  1831. self,
  1832. func: Callable[P, Awaitable[T_Retval]],
  1833. *args: P.args,
  1834. **kwargs: P.kwargs,
  1835. ) -> T_Retval:
  1836. if not self._runner_task:
  1837. self._send_stream, receive_stream = create_memory_object_stream[
  1838. tuple[Awaitable[Any], asyncio.Future]
  1839. ](1)
  1840. self._runner_task = self.get_loop().create_task(
  1841. self._run_tests_and_fixtures(receive_stream)
  1842. )
  1843. coro = func(*args, **kwargs)
  1844. future: asyncio.Future[T_Retval] = self.get_loop().create_future()
  1845. self._send_stream.send_nowait((coro, future))
  1846. return await future
  1847. def run_asyncgen_fixture(
  1848. self,
  1849. fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
  1850. kwargs: dict[str, Any],
  1851. ) -> Iterable[T_Retval]:
  1852. asyncgen = fixture_func(**kwargs)
  1853. fixturevalue: T_Retval = self.get_loop().run_until_complete(
  1854. self._call_in_runner_task(asyncgen.asend, None)
  1855. )
  1856. self._raise_async_exceptions()
  1857. yield fixturevalue
  1858. try:
  1859. self.get_loop().run_until_complete(
  1860. self._call_in_runner_task(asyncgen.asend, None)
  1861. )
  1862. except StopAsyncIteration:
  1863. self._raise_async_exceptions()
  1864. else:
  1865. self.get_loop().run_until_complete(asyncgen.aclose())
  1866. raise RuntimeError("Async generator fixture did not stop")
  1867. def run_fixture(
  1868. self,
  1869. fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
  1870. kwargs: dict[str, Any],
  1871. ) -> T_Retval:
  1872. retval = self.get_loop().run_until_complete(
  1873. self._call_in_runner_task(fixture_func, **kwargs)
  1874. )
  1875. self._raise_async_exceptions()
  1876. return retval
  1877. def run_test(
  1878. self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
  1879. ) -> None:
  1880. try:
  1881. self.get_loop().run_until_complete(
  1882. self._call_in_runner_task(test_func, **kwargs)
  1883. )
  1884. except Exception as exc:
  1885. self._exceptions.append(exc)
  1886. self._raise_async_exceptions()
  1887. class AsyncIOBackend(AsyncBackend):
  1888. @classmethod
  1889. def run(
  1890. cls,
  1891. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  1892. args: tuple[Unpack[PosArgsT]],
  1893. kwargs: dict[str, Any],
  1894. options: dict[str, Any],
  1895. ) -> T_Retval:
  1896. @wraps(func)
  1897. async def wrapper() -> T_Retval:
  1898. task = cast(asyncio.Task, current_task())
  1899. task.set_name(get_callable_name(func))
  1900. _task_states[task] = TaskState(None, None)
  1901. try:
  1902. return await func(*args)
  1903. finally:
  1904. del _task_states[task]
  1905. debug = options.get("debug", None)
  1906. loop_factory = options.get("loop_factory", None)
  1907. if loop_factory is None and options.get("use_uvloop", False):
  1908. import uvloop
  1909. loop_factory = uvloop.new_event_loop
  1910. with Runner(debug=debug, loop_factory=loop_factory) as runner:
  1911. return runner.run(wrapper())
  1912. @classmethod
  1913. def current_token(cls) -> object:
  1914. return get_running_loop()
  1915. @classmethod
  1916. def current_time(cls) -> float:
  1917. return get_running_loop().time()
  1918. @classmethod
  1919. def cancelled_exception_class(cls) -> type[BaseException]:
  1920. return CancelledError
  1921. @classmethod
  1922. async def checkpoint(cls) -> None:
  1923. await sleep(0)
  1924. @classmethod
  1925. async def checkpoint_if_cancelled(cls) -> None:
  1926. task = current_task()
  1927. if task is None:
  1928. return
  1929. try:
  1930. cancel_scope = _task_states[task].cancel_scope
  1931. except KeyError:
  1932. return
  1933. while cancel_scope:
  1934. if cancel_scope.cancel_called:
  1935. await sleep(0)
  1936. elif cancel_scope.shield:
  1937. break
  1938. else:
  1939. cancel_scope = cancel_scope._parent_scope
  1940. @classmethod
  1941. async def cancel_shielded_checkpoint(cls) -> None:
  1942. with CancelScope(shield=True):
  1943. await sleep(0)
  1944. @classmethod
  1945. async def sleep(cls, delay: float) -> None:
  1946. await sleep(delay)
  1947. @classmethod
  1948. def create_cancel_scope(
  1949. cls, *, deadline: float = math.inf, shield: bool = False
  1950. ) -> CancelScope:
  1951. return CancelScope(deadline=deadline, shield=shield)
  1952. @classmethod
  1953. def current_effective_deadline(cls) -> float:
  1954. if (task := current_task()) is None:
  1955. return math.inf
  1956. try:
  1957. cancel_scope = _task_states[task].cancel_scope
  1958. except KeyError:
  1959. return math.inf
  1960. deadline = math.inf
  1961. while cancel_scope:
  1962. deadline = min(deadline, cancel_scope.deadline)
  1963. if cancel_scope._cancel_called:
  1964. deadline = -math.inf
  1965. break
  1966. elif cancel_scope.shield:
  1967. break
  1968. else:
  1969. cancel_scope = cancel_scope._parent_scope
  1970. return deadline
  1971. @classmethod
  1972. def create_task_group(cls) -> abc.TaskGroup:
  1973. return TaskGroup()
  1974. @classmethod
  1975. def create_event(cls) -> abc.Event:
  1976. return Event()
  1977. @classmethod
  1978. def create_lock(cls, *, fast_acquire: bool) -> abc.Lock:
  1979. return Lock(fast_acquire=fast_acquire)
  1980. @classmethod
  1981. def create_semaphore(
  1982. cls,
  1983. initial_value: int,
  1984. *,
  1985. max_value: int | None = None,
  1986. fast_acquire: bool = False,
  1987. ) -> abc.Semaphore:
  1988. return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
  1989. @classmethod
  1990. def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
  1991. return CapacityLimiter(total_tokens)
  1992. @classmethod
  1993. async def run_sync_in_worker_thread( # type: ignore[return]
  1994. cls,
  1995. func: Callable[[Unpack[PosArgsT]], T_Retval],
  1996. args: tuple[Unpack[PosArgsT]],
  1997. abandon_on_cancel: bool = False,
  1998. limiter: abc.CapacityLimiter | None = None,
  1999. ) -> T_Retval:
  2000. await cls.checkpoint()
  2001. # If this is the first run in this event loop thread, set up the necessary
  2002. # variables
  2003. try:
  2004. idle_workers = _threadpool_idle_workers.get()
  2005. workers = _threadpool_workers.get()
  2006. except LookupError:
  2007. idle_workers = deque()
  2008. workers = set()
  2009. _threadpool_idle_workers.set(idle_workers)
  2010. _threadpool_workers.set(workers)
  2011. async with limiter or cls.current_default_thread_limiter():
  2012. with CancelScope(shield=not abandon_on_cancel) as scope:
  2013. future = asyncio.Future[T_Retval]()
  2014. root_task = find_root_task()
  2015. if not idle_workers:
  2016. worker = WorkerThread(root_task, workers, idle_workers)
  2017. worker.start()
  2018. workers.add(worker)
  2019. root_task.add_done_callback(
  2020. worker.stop, context=contextvars.Context()
  2021. )
  2022. else:
  2023. worker = idle_workers.pop()
  2024. # Prune any other workers that have been idle for MAX_IDLE_TIME
  2025. # seconds or longer
  2026. now = cls.current_time()
  2027. while idle_workers:
  2028. if (
  2029. now - idle_workers[0].idle_since
  2030. < WorkerThread.MAX_IDLE_TIME
  2031. ):
  2032. break
  2033. expired_worker = idle_workers.popleft()
  2034. expired_worker.root_task.remove_done_callback(
  2035. expired_worker.stop
  2036. )
  2037. expired_worker.stop()
  2038. context = copy_context()
  2039. context.run(sniffio.current_async_library_cvar.set, None)
  2040. if abandon_on_cancel or scope._parent_scope is None:
  2041. worker_scope = scope
  2042. else:
  2043. worker_scope = scope._parent_scope
  2044. worker.queue.put_nowait((context, func, args, future, worker_scope))
  2045. return await future
  2046. @classmethod
  2047. def check_cancelled(cls) -> None:
  2048. scope: CancelScope | None = threadlocals.current_cancel_scope
  2049. while scope is not None:
  2050. if scope.cancel_called:
  2051. raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")
  2052. if scope.shield:
  2053. return
  2054. scope = scope._parent_scope
  2055. @classmethod
  2056. def run_async_from_thread(
  2057. cls,
  2058. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  2059. args: tuple[Unpack[PosArgsT]],
  2060. token: object,
  2061. ) -> T_Retval:
  2062. async def task_wrapper() -> T_Retval:
  2063. __tracebackhide__ = True
  2064. if scope is not None:
  2065. task = cast(asyncio.Task, current_task())
  2066. _task_states[task] = TaskState(None, scope)
  2067. scope._tasks.add(task)
  2068. try:
  2069. return await func(*args)
  2070. except CancelledError as exc:
  2071. raise concurrent.futures.CancelledError(str(exc)) from None
  2072. finally:
  2073. if scope is not None:
  2074. scope._tasks.discard(task)
  2075. loop = cast(
  2076. "AbstractEventLoop", token or threadlocals.current_token.native_token
  2077. )
  2078. if loop.is_closed():
  2079. raise RunFinishedError
  2080. context = copy_context()
  2081. context.run(sniffio.current_async_library_cvar.set, "asyncio")
  2082. scope = getattr(threadlocals, "current_cancel_scope", None)
  2083. f: concurrent.futures.Future[T_Retval] = context.run(
  2084. asyncio.run_coroutine_threadsafe, task_wrapper(), loop=loop
  2085. )
  2086. return f.result()
  2087. @classmethod
  2088. def run_sync_from_thread(
  2089. cls,
  2090. func: Callable[[Unpack[PosArgsT]], T_Retval],
  2091. args: tuple[Unpack[PosArgsT]],
  2092. token: object,
  2093. ) -> T_Retval:
  2094. @wraps(func)
  2095. def wrapper() -> None:
  2096. try:
  2097. sniffio.current_async_library_cvar.set("asyncio")
  2098. f.set_result(func(*args))
  2099. except BaseException as exc:
  2100. f.set_exception(exc)
  2101. if not isinstance(exc, Exception):
  2102. raise
  2103. loop = cast(
  2104. "AbstractEventLoop", token or threadlocals.current_token.native_token
  2105. )
  2106. if loop.is_closed():
  2107. raise RunFinishedError
  2108. f: concurrent.futures.Future[T_Retval] = Future()
  2109. loop.call_soon_threadsafe(wrapper)
  2110. return f.result()
  2111. @classmethod
  2112. def create_blocking_portal(cls) -> abc.BlockingPortal:
  2113. return BlockingPortal()
  2114. @classmethod
  2115. async def open_process(
  2116. cls,
  2117. command: StrOrBytesPath | Sequence[StrOrBytesPath],
  2118. *,
  2119. stdin: int | IO[Any] | None,
  2120. stdout: int | IO[Any] | None,
  2121. stderr: int | IO[Any] | None,
  2122. **kwargs: Any,
  2123. ) -> Process:
  2124. await cls.checkpoint()
  2125. if isinstance(command, PathLike):
  2126. command = os.fspath(command)
  2127. if isinstance(command, (str, bytes)):
  2128. process = await asyncio.create_subprocess_shell(
  2129. command,
  2130. stdin=stdin,
  2131. stdout=stdout,
  2132. stderr=stderr,
  2133. **kwargs,
  2134. )
  2135. else:
  2136. process = await asyncio.create_subprocess_exec(
  2137. *command,
  2138. stdin=stdin,
  2139. stdout=stdout,
  2140. stderr=stderr,
  2141. **kwargs,
  2142. )
  2143. stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
  2144. stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
  2145. stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
  2146. return Process(process, stdin_stream, stdout_stream, stderr_stream)
  2147. @classmethod
  2148. def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
  2149. create_task(
  2150. _shutdown_process_pool_on_exit(workers),
  2151. name="AnyIO process pool shutdown task",
  2152. )
  2153. find_root_task().add_done_callback(
  2154. partial(_forcibly_shutdown_process_pool_on_exit, workers) # type:ignore[arg-type]
  2155. )
  2156. @classmethod
  2157. async def connect_tcp(
  2158. cls, host: str, port: int, local_address: IPSockAddrType | None = None
  2159. ) -> abc.SocketStream:
  2160. transport, protocol = cast(
  2161. tuple[asyncio.Transport, StreamProtocol],
  2162. await get_running_loop().create_connection(
  2163. StreamProtocol, host, port, local_addr=local_address
  2164. ),
  2165. )
  2166. transport.pause_reading()
  2167. return SocketStream(transport, protocol)
  2168. @classmethod
  2169. async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
  2170. await cls.checkpoint()
  2171. loop = get_running_loop()
  2172. raw_socket = socket.socket(socket.AF_UNIX)
  2173. raw_socket.setblocking(False)
  2174. while True:
  2175. try:
  2176. raw_socket.connect(path)
  2177. except BlockingIOError:
  2178. f: asyncio.Future = asyncio.Future()
  2179. loop.add_writer(raw_socket, f.set_result, None)
  2180. f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
  2181. await f
  2182. except BaseException:
  2183. raw_socket.close()
  2184. raise
  2185. else:
  2186. return UNIXSocketStream(raw_socket)
  2187. @classmethod
  2188. def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
  2189. return TCPSocketListener(sock)
  2190. @classmethod
  2191. def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
  2192. return UNIXSocketListener(sock)
  2193. @classmethod
  2194. async def create_udp_socket(
  2195. cls,
  2196. family: AddressFamily,
  2197. local_address: IPSockAddrType | None,
  2198. remote_address: IPSockAddrType | None,
  2199. reuse_port: bool,
  2200. ) -> UDPSocket | ConnectedUDPSocket:
  2201. transport, protocol = await get_running_loop().create_datagram_endpoint(
  2202. DatagramProtocol,
  2203. local_addr=local_address,
  2204. remote_addr=remote_address,
  2205. family=family,
  2206. reuse_port=reuse_port,
  2207. )
  2208. if protocol.exception:
  2209. transport.close()
  2210. raise protocol.exception
  2211. if not remote_address:
  2212. return UDPSocket(transport, protocol)
  2213. else:
  2214. return ConnectedUDPSocket(transport, protocol)
  2215. @classmethod
  2216. async def create_unix_datagram_socket( # type: ignore[override]
  2217. cls, raw_socket: socket.socket, remote_path: str | bytes | None
  2218. ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
  2219. await cls.checkpoint()
  2220. loop = get_running_loop()
  2221. if remote_path:
  2222. while True:
  2223. try:
  2224. raw_socket.connect(remote_path)
  2225. except BlockingIOError:
  2226. f: asyncio.Future = asyncio.Future()
  2227. loop.add_writer(raw_socket, f.set_result, None)
  2228. f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
  2229. await f
  2230. except BaseException:
  2231. raw_socket.close()
  2232. raise
  2233. else:
  2234. return ConnectedUNIXDatagramSocket(raw_socket)
  2235. else:
  2236. return UNIXDatagramSocket(raw_socket)
  2237. @classmethod
  2238. async def getaddrinfo(
  2239. cls,
  2240. host: bytes | str | None,
  2241. port: str | int | None,
  2242. *,
  2243. family: int | AddressFamily = 0,
  2244. type: int | SocketKind = 0,
  2245. proto: int = 0,
  2246. flags: int = 0,
  2247. ) -> Sequence[
  2248. tuple[
  2249. AddressFamily,
  2250. SocketKind,
  2251. int,
  2252. str,
  2253. tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
  2254. ]
  2255. ]:
  2256. return await get_running_loop().getaddrinfo(
  2257. host, port, family=family, type=type, proto=proto, flags=flags
  2258. )
  2259. @classmethod
  2260. async def getnameinfo(
  2261. cls, sockaddr: IPSockAddrType, flags: int = 0
  2262. ) -> tuple[str, str]:
  2263. return await get_running_loop().getnameinfo(sockaddr, flags)
  2264. @classmethod
  2265. async def wait_readable(cls, obj: FileDescriptorLike) -> None:
  2266. try:
  2267. read_events = _read_events.get()
  2268. except LookupError:
  2269. read_events = {}
  2270. _read_events.set(read_events)
  2271. fd = obj if isinstance(obj, int) else obj.fileno()
  2272. if read_events.get(fd):
  2273. raise BusyResourceError("reading from")
  2274. loop = get_running_loop()
  2275. fut: asyncio.Future[bool] = loop.create_future()
  2276. def cb() -> None:
  2277. try:
  2278. del read_events[fd]
  2279. except KeyError:
  2280. pass
  2281. else:
  2282. remove_reader(fd)
  2283. try:
  2284. fut.set_result(True)
  2285. except asyncio.InvalidStateError:
  2286. pass
  2287. try:
  2288. loop.add_reader(fd, cb)
  2289. except NotImplementedError:
  2290. from anyio._core._asyncio_selector_thread import get_selector
  2291. selector = get_selector()
  2292. selector.add_reader(fd, cb)
  2293. remove_reader = selector.remove_reader
  2294. else:
  2295. remove_reader = loop.remove_reader
  2296. read_events[fd] = fut
  2297. try:
  2298. success = await fut
  2299. finally:
  2300. try:
  2301. del read_events[fd]
  2302. except KeyError:
  2303. pass
  2304. else:
  2305. remove_reader(fd)
  2306. if not success:
  2307. raise ClosedResourceError
  2308. @classmethod
  2309. async def wait_writable(cls, obj: FileDescriptorLike) -> None:
  2310. try:
  2311. write_events = _write_events.get()
  2312. except LookupError:
  2313. write_events = {}
  2314. _write_events.set(write_events)
  2315. fd = obj if isinstance(obj, int) else obj.fileno()
  2316. if write_events.get(fd):
  2317. raise BusyResourceError("writing to")
  2318. loop = get_running_loop()
  2319. fut: asyncio.Future[bool] = loop.create_future()
  2320. def cb() -> None:
  2321. try:
  2322. del write_events[fd]
  2323. except KeyError:
  2324. pass
  2325. else:
  2326. remove_writer(fd)
  2327. try:
  2328. fut.set_result(True)
  2329. except asyncio.InvalidStateError:
  2330. pass
  2331. try:
  2332. loop.add_writer(fd, cb)
  2333. except NotImplementedError:
  2334. from anyio._core._asyncio_selector_thread import get_selector
  2335. selector = get_selector()
  2336. selector.add_writer(fd, cb)
  2337. remove_writer = selector.remove_writer
  2338. else:
  2339. remove_writer = loop.remove_writer
  2340. write_events[fd] = fut
  2341. try:
  2342. success = await fut
  2343. finally:
  2344. try:
  2345. del write_events[fd]
  2346. except KeyError:
  2347. pass
  2348. else:
  2349. remove_writer(fd)
  2350. if not success:
  2351. raise ClosedResourceError
  2352. @classmethod
  2353. def notify_closing(cls, obj: FileDescriptorLike) -> None:
  2354. fd = obj if isinstance(obj, int) else obj.fileno()
  2355. loop = get_running_loop()
  2356. try:
  2357. write_events = _write_events.get()
  2358. except LookupError:
  2359. pass
  2360. else:
  2361. try:
  2362. fut = write_events.pop(fd)
  2363. except KeyError:
  2364. pass
  2365. else:
  2366. try:
  2367. fut.set_result(False)
  2368. except asyncio.InvalidStateError:
  2369. pass
  2370. try:
  2371. loop.remove_writer(fd)
  2372. except NotImplementedError:
  2373. from anyio._core._asyncio_selector_thread import get_selector
  2374. get_selector().remove_writer(fd)
  2375. try:
  2376. read_events = _read_events.get()
  2377. except LookupError:
  2378. pass
  2379. else:
  2380. try:
  2381. fut = read_events.pop(fd)
  2382. except KeyError:
  2383. pass
  2384. else:
  2385. try:
  2386. fut.set_result(False)
  2387. except asyncio.InvalidStateError:
  2388. pass
  2389. try:
  2390. loop.remove_reader(fd)
  2391. except NotImplementedError:
  2392. from anyio._core._asyncio_selector_thread import get_selector
  2393. get_selector().remove_reader(fd)
  2394. @classmethod
  2395. async def wrap_listener_socket(cls, sock: socket.socket) -> SocketListener:
  2396. return TCPSocketListener(sock)
  2397. @classmethod
  2398. async def wrap_stream_socket(cls, sock: socket.socket) -> SocketStream:
  2399. transport, protocol = await get_running_loop().create_connection(
  2400. StreamProtocol, sock=sock
  2401. )
  2402. return SocketStream(transport, protocol)
  2403. @classmethod
  2404. async def wrap_unix_stream_socket(cls, sock: socket.socket) -> UNIXSocketStream:
  2405. return UNIXSocketStream(sock)
  2406. @classmethod
  2407. async def wrap_udp_socket(cls, sock: socket.socket) -> UDPSocket:
  2408. transport, protocol = await get_running_loop().create_datagram_endpoint(
  2409. DatagramProtocol, sock=sock
  2410. )
  2411. return UDPSocket(transport, protocol)
  2412. @classmethod
  2413. async def wrap_connected_udp_socket(cls, sock: socket.socket) -> ConnectedUDPSocket:
  2414. transport, protocol = await get_running_loop().create_datagram_endpoint(
  2415. DatagramProtocol, sock=sock
  2416. )
  2417. return ConnectedUDPSocket(transport, protocol)
  2418. @classmethod
  2419. async def wrap_unix_datagram_socket(cls, sock: socket.socket) -> UNIXDatagramSocket:
  2420. return UNIXDatagramSocket(sock)
  2421. @classmethod
  2422. async def wrap_connected_unix_datagram_socket(
  2423. cls, sock: socket.socket
  2424. ) -> ConnectedUNIXDatagramSocket:
  2425. return ConnectedUNIXDatagramSocket(sock)
  2426. @classmethod
  2427. def current_default_thread_limiter(cls) -> CapacityLimiter:
  2428. try:
  2429. return _default_thread_limiter.get()
  2430. except LookupError:
  2431. limiter = CapacityLimiter(40)
  2432. _default_thread_limiter.set(limiter)
  2433. return limiter
  2434. @classmethod
  2435. def open_signal_receiver(
  2436. cls, *signals: Signals
  2437. ) -> AbstractContextManager[AsyncIterator[Signals]]:
  2438. return _SignalReceiver(signals)
  2439. @classmethod
  2440. def get_current_task(cls) -> TaskInfo:
  2441. return AsyncIOTaskInfo(current_task()) # type: ignore[arg-type]
  2442. @classmethod
  2443. def get_running_tasks(cls) -> Sequence[TaskInfo]:
  2444. return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()]
  2445. @classmethod
  2446. async def wait_all_tasks_blocked(cls) -> None:
  2447. await cls.checkpoint()
  2448. this_task = current_task()
  2449. while True:
  2450. for task in all_tasks():
  2451. if task is this_task:
  2452. continue
  2453. waiter = task._fut_waiter # type: ignore[attr-defined]
  2454. if waiter is None or waiter.done():
  2455. await sleep(0.1)
  2456. break
  2457. else:
  2458. return
  2459. @classmethod
  2460. def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
  2461. return TestRunner(**options)
  2462. backend_class = AsyncIOBackend