lowlevel.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from __future__ import annotations
  2. import enum
  3. from dataclasses import dataclass
  4. from typing import Any, Generic, Literal, TypeVar, final, overload
  5. from weakref import WeakKeyDictionary
  6. from ._core._eventloop import get_async_backend
  7. from .abc import AsyncBackend
  8. T = TypeVar("T")
  9. D = TypeVar("D")
  10. async def checkpoint() -> None:
  11. """
  12. Check for cancellation and allow the scheduler to switch to another task.
  13. Equivalent to (but more efficient than)::
  14. await checkpoint_if_cancelled()
  15. await cancel_shielded_checkpoint()
  16. .. versionadded:: 3.0
  17. """
  18. await get_async_backend().checkpoint()
  19. async def checkpoint_if_cancelled() -> None:
  20. """
  21. Enter a checkpoint if the enclosing cancel scope has been cancelled.
  22. This does not allow the scheduler to switch to a different task.
  23. .. versionadded:: 3.0
  24. """
  25. await get_async_backend().checkpoint_if_cancelled()
  26. async def cancel_shielded_checkpoint() -> None:
  27. """
  28. Allow the scheduler to switch to another task but without checking for cancellation.
  29. Equivalent to (but potentially more efficient than)::
  30. with CancelScope(shield=True):
  31. await checkpoint()
  32. .. versionadded:: 3.0
  33. """
  34. await get_async_backend().cancel_shielded_checkpoint()
  35. @final
  36. @dataclass(frozen=True, repr=False)
  37. class EventLoopToken:
  38. """
  39. An opaque object that holds a reference to an event loop.
  40. .. versionadded:: 4.11.0
  41. """
  42. backend_class: type[AsyncBackend]
  43. native_token: object
  44. def current_token() -> EventLoopToken:
  45. """
  46. Return a token object that can be used to call code in the current event loop from
  47. another thread.
  48. .. versionadded:: 4.11.0
  49. """
  50. backend_class = get_async_backend()
  51. raw_token = backend_class.current_token()
  52. return EventLoopToken(backend_class, raw_token)
  53. _run_vars: WeakKeyDictionary[object, dict[RunVar[Any], Any]] = WeakKeyDictionary()
  54. class _NoValueSet(enum.Enum):
  55. NO_VALUE_SET = enum.auto()
  56. class RunvarToken(Generic[T]):
  57. __slots__ = "_var", "_value", "_redeemed"
  58. def __init__(self, var: RunVar[T], value: T | Literal[_NoValueSet.NO_VALUE_SET]):
  59. self._var = var
  60. self._value: T | Literal[_NoValueSet.NO_VALUE_SET] = value
  61. self._redeemed = False
  62. class RunVar(Generic[T]):
  63. """
  64. Like a :class:`~contextvars.ContextVar`, except scoped to the running event loop.
  65. """
  66. __slots__ = "_name", "_default"
  67. NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET
  68. def __init__(
  69. self, name: str, default: T | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET
  70. ):
  71. self._name = name
  72. self._default = default
  73. @property
  74. def _current_vars(self) -> dict[RunVar[T], T]:
  75. native_token = current_token().native_token
  76. try:
  77. return _run_vars[native_token]
  78. except KeyError:
  79. run_vars = _run_vars[native_token] = {}
  80. return run_vars
  81. @overload
  82. def get(self, default: D) -> T | D: ...
  83. @overload
  84. def get(self) -> T: ...
  85. def get(
  86. self, default: D | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET
  87. ) -> T | D:
  88. try:
  89. return self._current_vars[self]
  90. except KeyError:
  91. if default is not RunVar.NO_VALUE_SET:
  92. return default
  93. elif self._default is not RunVar.NO_VALUE_SET:
  94. return self._default
  95. raise LookupError(
  96. f'Run variable "{self._name}" has no value and no default set'
  97. )
  98. def set(self, value: T) -> RunvarToken[T]:
  99. current_vars = self._current_vars
  100. token = RunvarToken(self, current_vars.get(self, RunVar.NO_VALUE_SET))
  101. current_vars[self] = value
  102. return token
  103. def reset(self, token: RunvarToken[T]) -> None:
  104. if token._var is not self:
  105. raise ValueError("This token does not belong to this RunVar")
  106. if token._redeemed:
  107. raise ValueError("This token has already been used")
  108. if token._value is _NoValueSet.NO_VALUE_SET:
  109. try:
  110. del self._current_vars[self]
  111. except KeyError:
  112. pass
  113. else:
  114. self._current_vars[self] = token._value
  115. token._redeemed = True
  116. def __repr__(self) -> str:
  117. return f"<RunVar name={self._name!r}>"