wsgi.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from __future__ import annotations
  2. import io
  3. import itertools
  4. import sys
  5. import typing
  6. from .._models import Request, Response
  7. from .._types import SyncByteStream
  8. from .base import BaseTransport
  9. if typing.TYPE_CHECKING:
  10. from _typeshed import OptExcInfo # pragma: no cover
  11. from _typeshed.wsgi import WSGIApplication # pragma: no cover
  12. _T = typing.TypeVar("_T")
  13. __all__ = ["WSGITransport"]
  14. def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]:
  15. body = iter(body)
  16. for chunk in body:
  17. if chunk:
  18. return itertools.chain([chunk], body)
  19. return []
  20. class WSGIByteStream(SyncByteStream):
  21. def __init__(self, result: typing.Iterable[bytes]) -> None:
  22. self._close = getattr(result, "close", None)
  23. self._result = _skip_leading_empty_chunks(result)
  24. def __iter__(self) -> typing.Iterator[bytes]:
  25. for part in self._result:
  26. yield part
  27. def close(self) -> None:
  28. if self._close is not None:
  29. self._close()
  30. class WSGITransport(BaseTransport):
  31. """
  32. A custom transport that handles sending requests directly to an WSGI app.
  33. The simplest way to use this functionality is to use the `app` argument.
  34. ```
  35. client = httpx.Client(app=app)
  36. ```
  37. Alternatively, you can setup the transport instance explicitly.
  38. This allows you to include any additional configuration arguments specific
  39. to the WSGITransport class:
  40. ```
  41. transport = httpx.WSGITransport(
  42. app=app,
  43. script_name="/submount",
  44. remote_addr="1.2.3.4"
  45. )
  46. client = httpx.Client(transport=transport)
  47. ```
  48. Arguments:
  49. * `app` - The WSGI application.
  50. * `raise_app_exceptions` - Boolean indicating if exceptions in the application
  51. should be raised. Default to `True`. Can be set to `False` for use cases
  52. such as testing the content of a client 500 response.
  53. * `script_name` - The root path on which the WSGI application should be mounted.
  54. * `remote_addr` - A string indicating the client IP of incoming requests.
  55. ```
  56. """
  57. def __init__(
  58. self,
  59. app: WSGIApplication,
  60. raise_app_exceptions: bool = True,
  61. script_name: str = "",
  62. remote_addr: str = "127.0.0.1",
  63. wsgi_errors: typing.TextIO | None = None,
  64. ) -> None:
  65. self.app = app
  66. self.raise_app_exceptions = raise_app_exceptions
  67. self.script_name = script_name
  68. self.remote_addr = remote_addr
  69. self.wsgi_errors = wsgi_errors
  70. def handle_request(self, request: Request) -> Response:
  71. request.read()
  72. wsgi_input = io.BytesIO(request.content)
  73. port = request.url.port or {"http": 80, "https": 443}[request.url.scheme]
  74. environ = {
  75. "wsgi.version": (1, 0),
  76. "wsgi.url_scheme": request.url.scheme,
  77. "wsgi.input": wsgi_input,
  78. "wsgi.errors": self.wsgi_errors or sys.stderr,
  79. "wsgi.multithread": True,
  80. "wsgi.multiprocess": False,
  81. "wsgi.run_once": False,
  82. "REQUEST_METHOD": request.method,
  83. "SCRIPT_NAME": self.script_name,
  84. "PATH_INFO": request.url.path,
  85. "QUERY_STRING": request.url.query.decode("ascii"),
  86. "SERVER_NAME": request.url.host,
  87. "SERVER_PORT": str(port),
  88. "SERVER_PROTOCOL": "HTTP/1.1",
  89. "REMOTE_ADDR": self.remote_addr,
  90. }
  91. for header_key, header_value in request.headers.raw:
  92. key = header_key.decode("ascii").upper().replace("-", "_")
  93. if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"):
  94. key = "HTTP_" + key
  95. environ[key] = header_value.decode("ascii")
  96. seen_status = None
  97. seen_response_headers = None
  98. seen_exc_info = None
  99. def start_response(
  100. status: str,
  101. response_headers: list[tuple[str, str]],
  102. exc_info: OptExcInfo | None = None,
  103. ) -> typing.Callable[[bytes], typing.Any]:
  104. nonlocal seen_status, seen_response_headers, seen_exc_info
  105. seen_status = status
  106. seen_response_headers = response_headers
  107. seen_exc_info = exc_info
  108. return lambda _: None
  109. result = self.app(environ, start_response)
  110. stream = WSGIByteStream(result)
  111. assert seen_status is not None
  112. assert seen_response_headers is not None
  113. if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions:
  114. raise seen_exc_info[1]
  115. status_code = int(seen_status.split()[0])
  116. headers = [
  117. (key.encode("ascii"), value.encode("ascii"))
  118. for key, value in seen_response_headers
  119. ]
  120. return Response(status_code, headers=headers, stream=stream)