to_interpreter.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. from __future__ import annotations
  2. import atexit
  3. import os
  4. import sys
  5. from collections import deque
  6. from collections.abc import Callable
  7. from typing import Any, Final, TypeVar
  8. from . import current_time, to_thread
  9. from ._core._exceptions import BrokenWorkerInterpreter
  10. from ._core._synchronization import CapacityLimiter
  11. from .lowlevel import RunVar
  12. if sys.version_info >= (3, 11):
  13. from typing import TypeVarTuple, Unpack
  14. else:
  15. from typing_extensions import TypeVarTuple, Unpack
  16. if sys.version_info >= (3, 14):
  17. from concurrent.interpreters import ExecutionFailed, create
  18. def _interp_call(func: Callable[..., Any], args: tuple[Any, ...]):
  19. try:
  20. retval = func(*args)
  21. except BaseException as exc:
  22. return exc, True
  23. else:
  24. return retval, False
  25. class Worker:
  26. last_used: float = 0
  27. def __init__(self) -> None:
  28. self._interpreter = create()
  29. def destroy(self) -> None:
  30. self._interpreter.close()
  31. def call(
  32. self,
  33. func: Callable[..., T_Retval],
  34. args: tuple[Any, ...],
  35. ) -> T_Retval:
  36. try:
  37. res, is_exception = self._interpreter.call(_interp_call, func, args)
  38. except ExecutionFailed as exc:
  39. raise BrokenWorkerInterpreter(exc.excinfo) from exc
  40. if is_exception:
  41. raise res
  42. return res
  43. elif sys.version_info >= (3, 13):
  44. import _interpqueues
  45. import _interpreters
  46. UNBOUND: Final = 2 # I have no clue how this works, but it was used in the stdlib
  47. FMT_UNPICKLED: Final = 0
  48. FMT_PICKLED: Final = 1
  49. QUEUE_PICKLE_ARGS: Final = (FMT_PICKLED, UNBOUND)
  50. QUEUE_UNPICKLE_ARGS: Final = (FMT_UNPICKLED, UNBOUND)
  51. _run_func = compile(
  52. """
  53. import _interpqueues
  54. from _interpreters import NotShareableError
  55. from pickle import loads, dumps, HIGHEST_PROTOCOL
  56. QUEUE_PICKLE_ARGS = (1, 2)
  57. QUEUE_UNPICKLE_ARGS = (0, 2)
  58. item = _interpqueues.get(queue_id)[0]
  59. try:
  60. func, args = loads(item)
  61. retval = func(*args)
  62. except BaseException as exc:
  63. is_exception = True
  64. retval = exc
  65. else:
  66. is_exception = False
  67. try:
  68. _interpqueues.put(queue_id, (retval, is_exception), *QUEUE_UNPICKLE_ARGS)
  69. except NotShareableError:
  70. retval = dumps(retval, HIGHEST_PROTOCOL)
  71. _interpqueues.put(queue_id, (retval, is_exception), *QUEUE_PICKLE_ARGS)
  72. """,
  73. "<string>",
  74. "exec",
  75. )
  76. class Worker:
  77. last_used: float = 0
  78. def __init__(self) -> None:
  79. self._interpreter_id = _interpreters.create()
  80. self._queue_id = _interpqueues.create(1, *QUEUE_UNPICKLE_ARGS)
  81. _interpreters.set___main___attrs(
  82. self._interpreter_id, {"queue_id": self._queue_id}
  83. )
  84. def destroy(self) -> None:
  85. _interpqueues.destroy(self._queue_id)
  86. _interpreters.destroy(self._interpreter_id)
  87. def call(
  88. self,
  89. func: Callable[..., T_Retval],
  90. args: tuple[Any, ...],
  91. ) -> T_Retval:
  92. import pickle
  93. item = pickle.dumps((func, args), pickle.HIGHEST_PROTOCOL)
  94. _interpqueues.put(self._queue_id, item, *QUEUE_PICKLE_ARGS)
  95. exc_info = _interpreters.exec(self._interpreter_id, _run_func)
  96. if exc_info:
  97. raise BrokenWorkerInterpreter(exc_info)
  98. res = _interpqueues.get(self._queue_id)
  99. (res, is_exception), fmt = res[:2]
  100. if fmt == FMT_PICKLED:
  101. res = pickle.loads(res)
  102. if is_exception:
  103. raise res
  104. return res
  105. else:
  106. class Worker:
  107. last_used: float = 0
  108. def __init__(self) -> None:
  109. raise RuntimeError("subinterpreters require at least Python 3.13")
  110. def call(
  111. self,
  112. func: Callable[..., T_Retval],
  113. args: tuple[Any, ...],
  114. ) -> T_Retval:
  115. raise NotImplementedError
  116. def destroy(self) -> None:
  117. pass
  118. DEFAULT_CPU_COUNT: Final = 8 # this is just an arbitrarily selected value
  119. MAX_WORKER_IDLE_TIME = (
  120. 30 # seconds a subinterpreter can be idle before becoming eligible for pruning
  121. )
  122. T_Retval = TypeVar("T_Retval")
  123. PosArgsT = TypeVarTuple("PosArgsT")
  124. _idle_workers = RunVar[deque[Worker]]("_available_workers")
  125. _default_interpreter_limiter = RunVar[CapacityLimiter]("_default_interpreter_limiter")
  126. def _stop_workers(workers: deque[Worker]) -> None:
  127. for worker in workers:
  128. worker.destroy()
  129. workers.clear()
  130. async def run_sync(
  131. func: Callable[[Unpack[PosArgsT]], T_Retval],
  132. *args: Unpack[PosArgsT],
  133. limiter: CapacityLimiter | None = None,
  134. ) -> T_Retval:
  135. """
  136. Call the given function with the given arguments in a subinterpreter.
  137. .. warning:: On Python 3.13, the :mod:`concurrent.interpreters` module was not yet
  138. available, so the code path for that Python version relies on an undocumented,
  139. private API. As such, it is recommended to not rely on this function for anything
  140. mission-critical on Python 3.13.
  141. :param func: a callable
  142. :param args: the positional arguments for the callable
  143. :param limiter: capacity limiter to use to limit the total number of subinterpreters
  144. running (if omitted, the default limiter is used)
  145. :return: the result of the call
  146. :raises BrokenWorkerInterpreter: if there's an internal error in a subinterpreter
  147. """
  148. if limiter is None:
  149. limiter = current_default_interpreter_limiter()
  150. try:
  151. idle_workers = _idle_workers.get()
  152. except LookupError:
  153. idle_workers = deque()
  154. _idle_workers.set(idle_workers)
  155. atexit.register(_stop_workers, idle_workers)
  156. async with limiter:
  157. try:
  158. worker = idle_workers.pop()
  159. except IndexError:
  160. worker = Worker()
  161. try:
  162. return await to_thread.run_sync(
  163. worker.call,
  164. func,
  165. args,
  166. limiter=limiter,
  167. )
  168. finally:
  169. # Prune workers that have been idle for too long
  170. now = current_time()
  171. while idle_workers:
  172. if now - idle_workers[0].last_used <= MAX_WORKER_IDLE_TIME:
  173. break
  174. await to_thread.run_sync(idle_workers.popleft().destroy, limiter=limiter)
  175. worker.last_used = current_time()
  176. idle_workers.append(worker)
  177. def current_default_interpreter_limiter() -> CapacityLimiter:
  178. """
  179. Return the capacity limiter used by default to limit the number of concurrently
  180. running subinterpreters.
  181. Defaults to the number of CPU cores.
  182. :return: a capacity limiter object
  183. """
  184. try:
  185. return _default_interpreter_limiter.get()
  186. except LookupError:
  187. limiter = CapacityLimiter(os.cpu_count() or DEFAULT_CPU_COUNT)
  188. _default_interpreter_limiter.set(limiter)
  189. return limiter