server.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. from __future__ import annotations
  2. import asyncio
  3. import contextlib
  4. import logging
  5. import os
  6. import platform
  7. import signal
  8. import socket
  9. import sys
  10. import threading
  11. import time
  12. from collections.abc import Generator, Sequence
  13. from email.utils import formatdate
  14. from types import FrameType
  15. from typing import TYPE_CHECKING, Union
  16. import click
  17. from uvicorn._compat import asyncio_run
  18. from uvicorn.config import Config
  19. if TYPE_CHECKING:
  20. from uvicorn.protocols.http.h11_impl import H11Protocol
  21. from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
  22. from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
  23. from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol
  24. from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
  25. Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketsSansIOProtocol]
  26. HANDLED_SIGNALS = (
  27. signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
  28. signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
  29. )
  30. if sys.platform == "win32": # pragma: py-not-win32
  31. HANDLED_SIGNALS += (signal.SIGBREAK,) # Windows signal 21. Sent by Ctrl+Break.
  32. logger = logging.getLogger("uvicorn.error")
  33. class ServerState:
  34. """
  35. Shared servers state that is available between all protocol instances.
  36. """
  37. def __init__(self) -> None:
  38. self.total_requests = 0
  39. self.connections: set[Protocols] = set()
  40. self.tasks: set[asyncio.Task[None]] = set()
  41. self.default_headers: list[tuple[bytes, bytes]] = []
  42. class Server:
  43. def __init__(self, config: Config) -> None:
  44. self.config = config
  45. self.server_state = ServerState()
  46. self.started = False
  47. self.should_exit = False
  48. self.force_exit = False
  49. self.last_notified = 0.0
  50. self._captured_signals: list[int] = []
  51. def run(self, sockets: list[socket.socket] | None = None) -> None:
  52. return asyncio_run(self.serve(sockets=sockets), loop_factory=self.config.get_loop_factory())
  53. async def serve(self, sockets: list[socket.socket] | None = None) -> None:
  54. with self.capture_signals():
  55. await self._serve(sockets)
  56. async def _serve(self, sockets: list[socket.socket] | None = None) -> None:
  57. process_id = os.getpid()
  58. config = self.config
  59. if not config.loaded:
  60. config.load()
  61. self.lifespan = config.lifespan_class(config)
  62. message = "Started server process [%d]"
  63. color_message = "Started server process [" + click.style("%d", fg="cyan") + "]"
  64. logger.info(message, process_id, extra={"color_message": color_message})
  65. await self.startup(sockets=sockets)
  66. if self.should_exit:
  67. return
  68. await self.main_loop()
  69. await self.shutdown(sockets=sockets)
  70. message = "Finished server process [%d]"
  71. color_message = "Finished server process [" + click.style("%d", fg="cyan") + "]"
  72. logger.info(message, process_id, extra={"color_message": color_message})
  73. async def startup(self, sockets: list[socket.socket] | None = None) -> None:
  74. await self.lifespan.startup()
  75. if self.lifespan.should_exit:
  76. self.should_exit = True
  77. return
  78. config = self.config
  79. def create_protocol(
  80. _loop: asyncio.AbstractEventLoop | None = None,
  81. ) -> asyncio.Protocol:
  82. return config.http_protocol_class( # type: ignore[call-arg]
  83. config=config,
  84. server_state=self.server_state,
  85. app_state=self.lifespan.state,
  86. _loop=_loop,
  87. )
  88. loop = asyncio.get_running_loop()
  89. listeners: Sequence[socket.SocketType]
  90. if sockets is not None: # pragma: full coverage
  91. # Explicitly passed a list of open sockets.
  92. # We use this when the server is run from a Gunicorn worker.
  93. def _share_socket(
  94. sock: socket.SocketType,
  95. ) -> socket.SocketType: # pragma py-not-win32
  96. # Windows requires the socket be explicitly shared across
  97. # multiple workers (processes).
  98. from socket import fromshare # type: ignore[attr-defined]
  99. sock_data = sock.share(os.getpid()) # type: ignore[attr-defined]
  100. return fromshare(sock_data)
  101. self.servers: list[asyncio.base_events.Server] = []
  102. for sock in sockets:
  103. is_windows = platform.system() == "Windows"
  104. if config.workers > 1 and is_windows: # pragma: py-not-win32
  105. sock = _share_socket(sock) # type: ignore[assignment]
  106. server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
  107. self.servers.append(server)
  108. listeners = sockets
  109. elif config.fd is not None: # pragma: py-win32
  110. # Use an existing socket, from a file descriptor.
  111. sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
  112. server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
  113. assert server.sockets is not None # mypy
  114. listeners = server.sockets
  115. self.servers = [server]
  116. elif config.uds is not None: # pragma: py-win32
  117. # Create a socket using UNIX domain socket.
  118. uds_perms = 0o666
  119. if os.path.exists(config.uds):
  120. uds_perms = os.stat(config.uds).st_mode # pragma: full coverage
  121. server = await loop.create_unix_server(
  122. create_protocol, path=config.uds, ssl=config.ssl, backlog=config.backlog
  123. )
  124. os.chmod(config.uds, uds_perms)
  125. assert server.sockets is not None # mypy
  126. listeners = server.sockets
  127. self.servers = [server]
  128. else:
  129. # Standard case. Create a socket from a host/port pair.
  130. try:
  131. server = await loop.create_server(
  132. create_protocol,
  133. host=config.host,
  134. port=config.port,
  135. ssl=config.ssl,
  136. backlog=config.backlog,
  137. )
  138. except OSError as exc:
  139. logger.error(exc)
  140. await self.lifespan.shutdown()
  141. sys.exit(1)
  142. assert server.sockets is not None
  143. listeners = server.sockets
  144. self.servers = [server]
  145. if sockets is None:
  146. self._log_started_message(listeners)
  147. else:
  148. # We're most likely running multiple workers, so a message has already been
  149. # logged by `config.bind_socket()`.
  150. pass # pragma: full coverage
  151. self.started = True
  152. def _log_started_message(self, listeners: Sequence[socket.SocketType]) -> None:
  153. config = self.config
  154. if config.fd is not None: # pragma: py-win32
  155. sock = listeners[0]
  156. logger.info(
  157. "Uvicorn running on socket %s (Press CTRL+C to quit)",
  158. sock.getsockname(),
  159. )
  160. elif config.uds is not None: # pragma: py-win32
  161. logger.info("Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds)
  162. else:
  163. addr_format = "%s://%s:%d"
  164. host = "0.0.0.0" if config.host is None else config.host
  165. if ":" in host:
  166. # It's an IPv6 address.
  167. addr_format = "%s://[%s]:%d"
  168. port = config.port
  169. if port == 0:
  170. port = listeners[0].getsockname()[1]
  171. protocol_name = "https" if config.ssl else "http"
  172. message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
  173. color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
  174. logger.info(
  175. message,
  176. protocol_name,
  177. host,
  178. port,
  179. extra={"color_message": color_message},
  180. )
  181. async def main_loop(self) -> None:
  182. counter = 0
  183. should_exit = await self.on_tick(counter)
  184. while not should_exit:
  185. counter += 1
  186. counter = counter % 864000
  187. await asyncio.sleep(0.1)
  188. should_exit = await self.on_tick(counter)
  189. async def on_tick(self, counter: int) -> bool:
  190. # Update the default headers, once per second.
  191. if counter % 10 == 0:
  192. current_time = time.time()
  193. current_date = formatdate(current_time, usegmt=True).encode()
  194. if self.config.date_header:
  195. date_header = [(b"date", current_date)]
  196. else:
  197. date_header = []
  198. self.server_state.default_headers = date_header + self.config.encoded_headers
  199. # Callback to `callback_notify` once every `timeout_notify` seconds.
  200. if self.config.callback_notify is not None:
  201. if current_time - self.last_notified > self.config.timeout_notify: # pragma: full coverage
  202. self.last_notified = current_time
  203. await self.config.callback_notify()
  204. # Determine if we should exit.
  205. if self.should_exit:
  206. return True
  207. max_requests = self.config.limit_max_requests
  208. if max_requests is not None and self.server_state.total_requests >= max_requests:
  209. logger.warning(f"Maximum request limit of {max_requests} exceeded. Terminating process.")
  210. return True
  211. return False
  212. async def shutdown(self, sockets: list[socket.socket] | None = None) -> None:
  213. logger.info("Shutting down")
  214. # Stop accepting new connections.
  215. for server in self.servers:
  216. server.close()
  217. for sock in sockets or []:
  218. sock.close() # pragma: full coverage
  219. # Request shutdown on all existing connections.
  220. for connection in list(self.server_state.connections):
  221. connection.shutdown()
  222. await asyncio.sleep(0.1)
  223. # When 3.10 is not supported anymore, use `async with asyncio.timeout(...):`.
  224. try:
  225. await asyncio.wait_for(
  226. self._wait_tasks_to_complete(),
  227. timeout=self.config.timeout_graceful_shutdown,
  228. )
  229. except asyncio.TimeoutError:
  230. logger.error(
  231. "Cancel %s running task(s), timeout graceful shutdown exceeded",
  232. len(self.server_state.tasks),
  233. )
  234. for t in self.server_state.tasks:
  235. t.cancel(msg="Task cancelled, timeout graceful shutdown exceeded")
  236. # Send the lifespan shutdown event, and wait for application shutdown.
  237. if not self.force_exit:
  238. await self.lifespan.shutdown()
  239. async def _wait_tasks_to_complete(self) -> None:
  240. # Wait for existing connections to finish sending responses.
  241. if self.server_state.connections and not self.force_exit:
  242. msg = "Waiting for connections to close. (CTRL+C to force quit)"
  243. logger.info(msg)
  244. while self.server_state.connections and not self.force_exit:
  245. await asyncio.sleep(0.1)
  246. # Wait for existing tasks to complete.
  247. if self.server_state.tasks and not self.force_exit:
  248. msg = "Waiting for background tasks to complete. (CTRL+C to force quit)"
  249. logger.info(msg)
  250. while self.server_state.tasks and not self.force_exit:
  251. await asyncio.sleep(0.1)
  252. for server in self.servers:
  253. await server.wait_closed()
  254. @contextlib.contextmanager
  255. def capture_signals(self) -> Generator[None, None, None]:
  256. # Signals can only be listened to from the main thread.
  257. if threading.current_thread() is not threading.main_thread():
  258. yield
  259. return
  260. # always use signal.signal, even if loop.add_signal_handler is available
  261. # this allows to restore previous signal handlers later on
  262. original_handlers = {sig: signal.signal(sig, self.handle_exit) for sig in HANDLED_SIGNALS}
  263. try:
  264. yield
  265. finally:
  266. for sig, handler in original_handlers.items():
  267. signal.signal(sig, handler)
  268. # If we did gracefully shut down due to a signal, try to
  269. # trigger the expected behaviour now; multiple signals would be
  270. # done LIFO, see https://stackoverflow.com/questions/48434964
  271. for captured_signal in reversed(self._captured_signals):
  272. signal.raise_signal(captured_signal)
  273. def handle_exit(self, sig: int, frame: FrameType | None) -> None:
  274. self._captured_signals.append(sig)
  275. if self.should_exit and sig == signal.SIGINT:
  276. self.force_exit = True # pragma: full coverage
  277. else:
  278. self.should_exit = True