_utils.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from __future__ import annotations
  2. import functools
  3. import sys
  4. from collections.abc import Awaitable, Generator
  5. from contextlib import AbstractAsyncContextManager, contextmanager
  6. from typing import Any, Callable, Generic, Protocol, TypeVar, overload
  7. from starlette.types import Scope
  8. if sys.version_info >= (3, 13): # pragma: no cover
  9. from inspect import iscoroutinefunction
  10. from typing import TypeIs
  11. else: # pragma: no cover
  12. from asyncio import iscoroutinefunction
  13. from typing_extensions import TypeIs
  14. has_exceptiongroups = True
  15. if sys.version_info < (3, 11): # pragma: no cover
  16. try:
  17. from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
  18. except ImportError:
  19. has_exceptiongroups = False
  20. T = TypeVar("T")
  21. AwaitableCallable = Callable[..., Awaitable[T]]
  22. @overload
  23. def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
  24. @overload
  25. def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
  26. def is_async_callable(obj: Any) -> Any:
  27. while isinstance(obj, functools.partial):
  28. obj = obj.func
  29. return iscoroutinefunction(obj) or (callable(obj) and iscoroutinefunction(obj.__call__))
  30. T_co = TypeVar("T_co", covariant=True)
  31. class AwaitableOrContextManager(Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co]): ...
  32. class SupportsAsyncClose(Protocol):
  33. async def close(self) -> None: ... # pragma: no cover
  34. SupportsAsyncCloseType = TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
  35. class AwaitableOrContextManagerWrapper(Generic[SupportsAsyncCloseType]):
  36. __slots__ = ("aw", "entered")
  37. def __init__(self, aw: Awaitable[SupportsAsyncCloseType]) -> None:
  38. self.aw = aw
  39. def __await__(self) -> Generator[Any, None, SupportsAsyncCloseType]:
  40. return self.aw.__await__()
  41. async def __aenter__(self) -> SupportsAsyncCloseType:
  42. self.entered = await self.aw
  43. return self.entered
  44. async def __aexit__(self, *args: Any) -> None | bool:
  45. await self.entered.close()
  46. return None
  47. @contextmanager
  48. def collapse_excgroups() -> Generator[None, None, None]:
  49. try:
  50. yield
  51. except BaseException as exc:
  52. if has_exceptiongroups: # pragma: no cover
  53. while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
  54. exc = exc.exceptions[0]
  55. raise exc
  56. def get_route_path(scope: Scope) -> str:
  57. path: str = scope["path"]
  58. root_path = scope.get("root_path", "")
  59. if not root_path:
  60. return path
  61. if not path.startswith(root_path):
  62. return path
  63. if path == root_path:
  64. return ""
  65. if path[len(root_path)] == "/":
  66. return path[len(root_path) :]
  67. return path