asgi.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from __future__ import annotations
  2. import typing
  3. from .._models import Request, Response
  4. from .._types import AsyncByteStream
  5. from .base import AsyncBaseTransport
  6. if typing.TYPE_CHECKING: # pragma: no cover
  7. import asyncio
  8. import trio
  9. Event = typing.Union[asyncio.Event, trio.Event]
  10. _Message = typing.MutableMapping[str, typing.Any]
  11. _Receive = typing.Callable[[], typing.Awaitable[_Message]]
  12. _Send = typing.Callable[
  13. [typing.MutableMapping[str, typing.Any]], typing.Awaitable[None]
  14. ]
  15. _ASGIApp = typing.Callable[
  16. [typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None]
  17. ]
  18. __all__ = ["ASGITransport"]
  19. def is_running_trio() -> bool:
  20. try:
  21. # sniffio is a dependency of trio.
  22. # See https://github.com/python-trio/trio/issues/2802
  23. import sniffio
  24. if sniffio.current_async_library() == "trio":
  25. return True
  26. except ImportError: # pragma: nocover
  27. pass
  28. return False
  29. def create_event() -> Event:
  30. if is_running_trio():
  31. import trio
  32. return trio.Event()
  33. import asyncio
  34. return asyncio.Event()
  35. class ASGIResponseStream(AsyncByteStream):
  36. def __init__(self, body: list[bytes]) -> None:
  37. self._body = body
  38. async def __aiter__(self) -> typing.AsyncIterator[bytes]:
  39. yield b"".join(self._body)
  40. class ASGITransport(AsyncBaseTransport):
  41. """
  42. A custom AsyncTransport that handles sending requests directly to an ASGI app.
  43. ```python
  44. transport = httpx.ASGITransport(
  45. app=app,
  46. root_path="/submount",
  47. client=("1.2.3.4", 123)
  48. )
  49. client = httpx.AsyncClient(transport=transport)
  50. ```
  51. Arguments:
  52. * `app` - The ASGI application.
  53. * `raise_app_exceptions` - Boolean indicating if exceptions in the application
  54. should be raised. Default to `True`. Can be set to `False` for use cases
  55. such as testing the content of a client 500 response.
  56. * `root_path` - The root path on which the ASGI application should be mounted.
  57. * `client` - A two-tuple indicating the client IP and port of incoming requests.
  58. ```
  59. """
  60. def __init__(
  61. self,
  62. app: _ASGIApp,
  63. raise_app_exceptions: bool = True,
  64. root_path: str = "",
  65. client: tuple[str, int] = ("127.0.0.1", 123),
  66. ) -> None:
  67. self.app = app
  68. self.raise_app_exceptions = raise_app_exceptions
  69. self.root_path = root_path
  70. self.client = client
  71. async def handle_async_request(
  72. self,
  73. request: Request,
  74. ) -> Response:
  75. assert isinstance(request.stream, AsyncByteStream)
  76. # ASGI scope.
  77. scope = {
  78. "type": "http",
  79. "asgi": {"version": "3.0"},
  80. "http_version": "1.1",
  81. "method": request.method,
  82. "headers": [(k.lower(), v) for (k, v) in request.headers.raw],
  83. "scheme": request.url.scheme,
  84. "path": request.url.path,
  85. "raw_path": request.url.raw_path.split(b"?")[0],
  86. "query_string": request.url.query,
  87. "server": (request.url.host, request.url.port),
  88. "client": self.client,
  89. "root_path": self.root_path,
  90. }
  91. # Request.
  92. request_body_chunks = request.stream.__aiter__()
  93. request_complete = False
  94. # Response.
  95. status_code = None
  96. response_headers = None
  97. body_parts = []
  98. response_started = False
  99. response_complete = create_event()
  100. # ASGI callables.
  101. async def receive() -> dict[str, typing.Any]:
  102. nonlocal request_complete
  103. if request_complete:
  104. await response_complete.wait()
  105. return {"type": "http.disconnect"}
  106. try:
  107. body = await request_body_chunks.__anext__()
  108. except StopAsyncIteration:
  109. request_complete = True
  110. return {"type": "http.request", "body": b"", "more_body": False}
  111. return {"type": "http.request", "body": body, "more_body": True}
  112. async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
  113. nonlocal status_code, response_headers, response_started
  114. if message["type"] == "http.response.start":
  115. assert not response_started
  116. status_code = message["status"]
  117. response_headers = message.get("headers", [])
  118. response_started = True
  119. elif message["type"] == "http.response.body":
  120. assert not response_complete.is_set()
  121. body = message.get("body", b"")
  122. more_body = message.get("more_body", False)
  123. if body and request.method != "HEAD":
  124. body_parts.append(body)
  125. if not more_body:
  126. response_complete.set()
  127. try:
  128. await self.app(scope, receive, send)
  129. except Exception: # noqa: PIE-786
  130. if self.raise_app_exceptions:
  131. raise
  132. response_complete.set()
  133. if status_code is None:
  134. status_code = 500
  135. if response_headers is None:
  136. response_headers = {}
  137. assert response_complete.is_set()
  138. assert status_code is not None
  139. assert response_headers is not None
  140. stream = ASGIResponseStream(body_parts)
  141. return Response(status_code, headers=response_headers, stream=stream)