from_thread.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. from __future__ import annotations
  2. import sys
  3. from collections.abc import Awaitable, Callable, Generator
  4. from concurrent.futures import Future
  5. from contextlib import (
  6. AbstractAsyncContextManager,
  7. AbstractContextManager,
  8. contextmanager,
  9. )
  10. from dataclasses import dataclass, field
  11. from inspect import isawaitable
  12. from threading import Lock, Thread, current_thread, get_ident
  13. from types import TracebackType
  14. from typing import (
  15. Any,
  16. Generic,
  17. TypeVar,
  18. cast,
  19. overload,
  20. )
  21. from ._core._eventloop import (
  22. get_async_backend,
  23. get_cancelled_exc_class,
  24. threadlocals,
  25. )
  26. from ._core._eventloop import run as run_eventloop
  27. from ._core._exceptions import NoEventLoopError
  28. from ._core._synchronization import Event
  29. from ._core._tasks import CancelScope, create_task_group
  30. from .abc._tasks import TaskStatus
  31. from .lowlevel import EventLoopToken
  32. if sys.version_info >= (3, 11):
  33. from typing import TypeVarTuple, Unpack
  34. else:
  35. from typing_extensions import TypeVarTuple, Unpack
  36. T_Retval = TypeVar("T_Retval")
  37. T_co = TypeVar("T_co", covariant=True)
  38. PosArgsT = TypeVarTuple("PosArgsT")
  39. def _token_or_error(token: EventLoopToken | None) -> EventLoopToken:
  40. if token is not None:
  41. return token
  42. try:
  43. return threadlocals.current_token
  44. except AttributeError:
  45. raise NoEventLoopError(
  46. "Not running inside an AnyIO worker thread, and no event loop token was "
  47. "provided"
  48. ) from None
  49. def run(
  50. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  51. *args: Unpack[PosArgsT],
  52. token: EventLoopToken | None = None,
  53. ) -> T_Retval:
  54. """
  55. Call a coroutine function from a worker thread.
  56. :param func: a coroutine function
  57. :param args: positional arguments for the callable
  58. :param token: an event loop token to use to get back to the event loop thread
  59. (required if calling this function from outside an AnyIO worker thread)
  60. :return: the return value of the coroutine function
  61. :raises MissingTokenError: if no token was provided and called from outside an
  62. AnyIO worker thread
  63. :raises RunFinishedError: if the event loop tied to ``token`` is no longer running
  64. .. versionchanged:: 4.11.0
  65. Added the ``token`` parameter.
  66. """
  67. explicit_token = token is not None
  68. token = _token_or_error(token)
  69. return token.backend_class.run_async_from_thread(
  70. func, args, token=token.native_token if explicit_token else None
  71. )
  72. def run_sync(
  73. func: Callable[[Unpack[PosArgsT]], T_Retval],
  74. *args: Unpack[PosArgsT],
  75. token: EventLoopToken | None = None,
  76. ) -> T_Retval:
  77. """
  78. Call a function in the event loop thread from a worker thread.
  79. :param func: a callable
  80. :param args: positional arguments for the callable
  81. :param token: an event loop token to use to get back to the event loop thread
  82. (required if calling this function from outside an AnyIO worker thread)
  83. :return: the return value of the callable
  84. :raises MissingTokenError: if no token was provided and called from outside an
  85. AnyIO worker thread
  86. :raises RunFinishedError: if the event loop tied to ``token`` is no longer running
  87. .. versionchanged:: 4.11.0
  88. Added the ``token`` parameter.
  89. """
  90. explicit_token = token is not None
  91. token = _token_or_error(token)
  92. return token.backend_class.run_sync_from_thread(
  93. func, args, token=token.native_token if explicit_token else None
  94. )
  95. class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
  96. _enter_future: Future[T_co]
  97. _exit_future: Future[bool | None]
  98. _exit_event: Event
  99. _exit_exc_info: tuple[
  100. type[BaseException] | None, BaseException | None, TracebackType | None
  101. ] = (None, None, None)
  102. def __init__(
  103. self, async_cm: AbstractAsyncContextManager[T_co], portal: BlockingPortal
  104. ):
  105. self._async_cm = async_cm
  106. self._portal = portal
  107. async def run_async_cm(self) -> bool | None:
  108. try:
  109. self._exit_event = Event()
  110. value = await self._async_cm.__aenter__()
  111. except BaseException as exc:
  112. self._enter_future.set_exception(exc)
  113. raise
  114. else:
  115. self._enter_future.set_result(value)
  116. try:
  117. # Wait for the sync context manager to exit.
  118. # This next statement can raise `get_cancelled_exc_class()` if
  119. # something went wrong in a task group in this async context
  120. # manager.
  121. await self._exit_event.wait()
  122. finally:
  123. # In case of cancellation, it could be that we end up here before
  124. # `_BlockingAsyncContextManager.__exit__` is called, and an
  125. # `_exit_exc_info` has been set.
  126. result = await self._async_cm.__aexit__(*self._exit_exc_info)
  127. return result
  128. def __enter__(self) -> T_co:
  129. self._enter_future = Future()
  130. self._exit_future = self._portal.start_task_soon(self.run_async_cm)
  131. return self._enter_future.result()
  132. def __exit__(
  133. self,
  134. __exc_type: type[BaseException] | None,
  135. __exc_value: BaseException | None,
  136. __traceback: TracebackType | None,
  137. ) -> bool | None:
  138. self._exit_exc_info = __exc_type, __exc_value, __traceback
  139. self._portal.call(self._exit_event.set)
  140. return self._exit_future.result()
  141. class _BlockingPortalTaskStatus(TaskStatus):
  142. def __init__(self, future: Future):
  143. self._future = future
  144. def started(self, value: object = None) -> None:
  145. self._future.set_result(value)
  146. class BlockingPortal:
  147. """An object that lets external threads run code in an asynchronous event loop."""
  148. def __new__(cls) -> BlockingPortal:
  149. return get_async_backend().create_blocking_portal()
  150. def __init__(self) -> None:
  151. self._event_loop_thread_id: int | None = get_ident()
  152. self._stop_event = Event()
  153. self._task_group = create_task_group()
  154. self._cancelled_exc_class = get_cancelled_exc_class()
  155. async def __aenter__(self) -> BlockingPortal:
  156. await self._task_group.__aenter__()
  157. return self
  158. async def __aexit__(
  159. self,
  160. exc_type: type[BaseException] | None,
  161. exc_val: BaseException | None,
  162. exc_tb: TracebackType | None,
  163. ) -> bool:
  164. await self.stop()
  165. return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
  166. def _check_running(self) -> None:
  167. if self._event_loop_thread_id is None:
  168. raise RuntimeError("This portal is not running")
  169. if self._event_loop_thread_id == get_ident():
  170. raise RuntimeError(
  171. "This method cannot be called from the event loop thread"
  172. )
  173. async def sleep_until_stopped(self) -> None:
  174. """Sleep until :meth:`stop` is called."""
  175. await self._stop_event.wait()
  176. async def stop(self, cancel_remaining: bool = False) -> None:
  177. """
  178. Signal the portal to shut down.
  179. This marks the portal as no longer accepting new calls and exits from
  180. :meth:`sleep_until_stopped`.
  181. :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False``
  182. to let them finish before returning
  183. """
  184. self._event_loop_thread_id = None
  185. self._stop_event.set()
  186. if cancel_remaining:
  187. self._task_group.cancel_scope.cancel("the blocking portal is shutting down")
  188. async def _call_func(
  189. self,
  190. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  191. args: tuple[Unpack[PosArgsT]],
  192. kwargs: dict[str, Any],
  193. future: Future[T_Retval],
  194. ) -> None:
  195. def callback(f: Future[T_Retval]) -> None:
  196. if f.cancelled() and self._event_loop_thread_id not in (
  197. None,
  198. get_ident(),
  199. ):
  200. self.call(scope.cancel, "the future was cancelled")
  201. try:
  202. retval_or_awaitable = func(*args, **kwargs)
  203. if isawaitable(retval_or_awaitable):
  204. with CancelScope() as scope:
  205. if future.cancelled():
  206. scope.cancel("the future was cancelled")
  207. else:
  208. future.add_done_callback(callback)
  209. retval = await retval_or_awaitable
  210. else:
  211. retval = retval_or_awaitable
  212. except self._cancelled_exc_class:
  213. future.cancel()
  214. future.set_running_or_notify_cancel()
  215. except BaseException as exc:
  216. if not future.cancelled():
  217. future.set_exception(exc)
  218. # Let base exceptions fall through
  219. if not isinstance(exc, Exception):
  220. raise
  221. else:
  222. if not future.cancelled():
  223. future.set_result(retval)
  224. finally:
  225. scope = None # type: ignore[assignment]
  226. def _spawn_task_from_thread(
  227. self,
  228. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  229. args: tuple[Unpack[PosArgsT]],
  230. kwargs: dict[str, Any],
  231. name: object,
  232. future: Future[T_Retval],
  233. ) -> None:
  234. """
  235. Spawn a new task using the given callable.
  236. Implementers must ensure that the future is resolved when the task finishes.
  237. :param func: a callable
  238. :param args: positional arguments to be passed to the callable
  239. :param kwargs: keyword arguments to be passed to the callable
  240. :param name: name of the task (will be coerced to a string if not ``None``)
  241. :param future: a future that will resolve to the return value of the callable,
  242. or the exception raised during its execution
  243. """
  244. raise NotImplementedError
  245. @overload
  246. def call(
  247. self,
  248. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  249. *args: Unpack[PosArgsT],
  250. ) -> T_Retval: ...
  251. @overload
  252. def call(
  253. self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
  254. ) -> T_Retval: ...
  255. def call(
  256. self,
  257. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  258. *args: Unpack[PosArgsT],
  259. ) -> T_Retval:
  260. """
  261. Call the given function in the event loop thread.
  262. If the callable returns a coroutine object, it is awaited on.
  263. :param func: any callable
  264. :raises RuntimeError: if the portal is not running or if this method is called
  265. from within the event loop thread
  266. """
  267. return cast(T_Retval, self.start_task_soon(func, *args).result())
  268. @overload
  269. def start_task_soon(
  270. self,
  271. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
  272. *args: Unpack[PosArgsT],
  273. name: object = None,
  274. ) -> Future[T_Retval]: ...
  275. @overload
  276. def start_task_soon(
  277. self,
  278. func: Callable[[Unpack[PosArgsT]], T_Retval],
  279. *args: Unpack[PosArgsT],
  280. name: object = None,
  281. ) -> Future[T_Retval]: ...
  282. def start_task_soon(
  283. self,
  284. func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
  285. *args: Unpack[PosArgsT],
  286. name: object = None,
  287. ) -> Future[T_Retval]:
  288. """
  289. Start a task in the portal's task group.
  290. The task will be run inside a cancel scope which can be cancelled by cancelling
  291. the returned future.
  292. :param func: the target function
  293. :param args: positional arguments passed to ``func``
  294. :param name: name of the task (will be coerced to a string if not ``None``)
  295. :return: a future that resolves with the return value of the callable if the
  296. task completes successfully, or with the exception raised in the task
  297. :raises RuntimeError: if the portal is not running or if this method is called
  298. from within the event loop thread
  299. :rtype: concurrent.futures.Future[T_Retval]
  300. .. versionadded:: 3.0
  301. """
  302. self._check_running()
  303. f: Future[T_Retval] = Future()
  304. self._spawn_task_from_thread(func, args, {}, name, f)
  305. return f
  306. def start_task(
  307. self,
  308. func: Callable[..., Awaitable[T_Retval]],
  309. *args: object,
  310. name: object = None,
  311. ) -> tuple[Future[T_Retval], Any]:
  312. """
  313. Start a task in the portal's task group and wait until it signals for readiness.
  314. This method works the same way as :meth:`.abc.TaskGroup.start`.
  315. :param func: the target function
  316. :param args: positional arguments passed to ``func``
  317. :param name: name of the task (will be coerced to a string if not ``None``)
  318. :return: a tuple of (future, task_status_value) where the ``task_status_value``
  319. is the value passed to ``task_status.started()`` from within the target
  320. function
  321. :rtype: tuple[concurrent.futures.Future[T_Retval], Any]
  322. .. versionadded:: 3.0
  323. """
  324. def task_done(future: Future[T_Retval]) -> None:
  325. if not task_status_future.done():
  326. if future.cancelled():
  327. task_status_future.cancel()
  328. elif future.exception():
  329. task_status_future.set_exception(future.exception())
  330. else:
  331. exc = RuntimeError(
  332. "Task exited without calling task_status.started()"
  333. )
  334. task_status_future.set_exception(exc)
  335. self._check_running()
  336. task_status_future: Future = Future()
  337. task_status = _BlockingPortalTaskStatus(task_status_future)
  338. f: Future = Future()
  339. f.add_done_callback(task_done)
  340. self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f)
  341. return f, task_status_future.result()
  342. def wrap_async_context_manager(
  343. self, cm: AbstractAsyncContextManager[T_co]
  344. ) -> AbstractContextManager[T_co]:
  345. """
  346. Wrap an async context manager as a synchronous context manager via this portal.
  347. Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping
  348. in the middle until the synchronous context manager exits.
  349. :param cm: an asynchronous context manager
  350. :return: a synchronous context manager
  351. .. versionadded:: 2.1
  352. """
  353. return _BlockingAsyncContextManager(cm, self)
  354. @dataclass
  355. class BlockingPortalProvider:
  356. """
  357. A manager for a blocking portal. Used as a context manager. The first thread to
  358. enter this context manager causes a blocking portal to be started with the specific
  359. parameters, and the last thread to exit causes the portal to be shut down. Thus,
  360. there will be exactly one blocking portal running in this context as long as at
  361. least one thread has entered this context manager.
  362. The parameters are the same as for :func:`~anyio.run`.
  363. :param backend: name of the backend
  364. :param backend_options: backend options
  365. .. versionadded:: 4.4
  366. """
  367. backend: str = "asyncio"
  368. backend_options: dict[str, Any] | None = None
  369. _lock: Lock = field(init=False, default_factory=Lock)
  370. _leases: int = field(init=False, default=0)
  371. _portal: BlockingPortal = field(init=False)
  372. _portal_cm: AbstractContextManager[BlockingPortal] | None = field(
  373. init=False, default=None
  374. )
  375. def __enter__(self) -> BlockingPortal:
  376. with self._lock:
  377. if self._portal_cm is None:
  378. self._portal_cm = start_blocking_portal(
  379. self.backend, self.backend_options
  380. )
  381. self._portal = self._portal_cm.__enter__()
  382. self._leases += 1
  383. return self._portal
  384. def __exit__(
  385. self,
  386. exc_type: type[BaseException] | None,
  387. exc_val: BaseException | None,
  388. exc_tb: TracebackType | None,
  389. ) -> None:
  390. portal_cm: AbstractContextManager[BlockingPortal] | None = None
  391. with self._lock:
  392. assert self._portal_cm
  393. assert self._leases > 0
  394. self._leases -= 1
  395. if not self._leases:
  396. portal_cm = self._portal_cm
  397. self._portal_cm = None
  398. del self._portal
  399. if portal_cm:
  400. portal_cm.__exit__(None, None, None)
  401. @contextmanager
  402. def start_blocking_portal(
  403. backend: str = "asyncio",
  404. backend_options: dict[str, Any] | None = None,
  405. *,
  406. name: str | None = None,
  407. ) -> Generator[BlockingPortal, Any, None]:
  408. """
  409. Start a new event loop in a new thread and run a blocking portal in its main task.
  410. The parameters are the same as for :func:`~anyio.run`.
  411. :param backend: name of the backend
  412. :param backend_options: backend options
  413. :param name: name of the thread
  414. :return: a context manager that yields a blocking portal
  415. .. versionchanged:: 3.0
  416. Usage as a context manager is now required.
  417. """
  418. async def run_portal() -> None:
  419. async with BlockingPortal() as portal_:
  420. if name is None:
  421. current_thread().name = f"{backend}-portal-{id(portal_):x}"
  422. future.set_result(portal_)
  423. await portal_.sleep_until_stopped()
  424. def run_blocking_portal() -> None:
  425. if future.set_running_or_notify_cancel():
  426. try:
  427. run_eventloop(
  428. run_portal, backend=backend, backend_options=backend_options
  429. )
  430. except BaseException as exc:
  431. if not future.done():
  432. future.set_exception(exc)
  433. future: Future[BlockingPortal] = Future()
  434. thread = Thread(target=run_blocking_portal, daemon=True, name=name)
  435. thread.start()
  436. try:
  437. cancel_remaining_tasks = False
  438. portal = future.result()
  439. try:
  440. yield portal
  441. except BaseException:
  442. cancel_remaining_tasks = True
  443. raise
  444. finally:
  445. try:
  446. portal.call(portal.stop, cancel_remaining_tasks)
  447. except RuntimeError:
  448. pass
  449. finally:
  450. thread.join()
  451. def check_cancelled() -> None:
  452. """
  453. Check if the cancel scope of the host task's running the current worker thread has
  454. been cancelled.
  455. If the host task's current cancel scope has indeed been cancelled, the
  456. backend-specific cancellation exception will be raised.
  457. :raises RuntimeError: if the current thread was not spawned by
  458. :func:`.to_thread.run_sync`
  459. """
  460. try:
  461. token: EventLoopToken = threadlocals.current_token
  462. except AttributeError:
  463. raise NoEventLoopError(
  464. "This function can only be called inside an AnyIO worker thread"
  465. ) from None
  466. token.backend_class.check_cancelled()