| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- from __future__ import annotations
- import asyncio
- import contextlib
- import logging
- import os
- import platform
- import signal
- import socket
- import sys
- import threading
- import time
- from collections.abc import Generator, Sequence
- from email.utils import formatdate
- from types import FrameType
- from typing import TYPE_CHECKING, Union
- import click
- from uvicorn._compat import asyncio_run
- from uvicorn.config import Config
- if TYPE_CHECKING:
- from uvicorn.protocols.http.h11_impl import H11Protocol
- from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
- from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
- from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol
- from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
- Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketsSansIOProtocol]
- HANDLED_SIGNALS = (
- signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
- signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
- )
- if sys.platform == "win32": # pragma: py-not-win32
- HANDLED_SIGNALS += (signal.SIGBREAK,) # Windows signal 21. Sent by Ctrl+Break.
- logger = logging.getLogger("uvicorn.error")
- class ServerState:
- """
- Shared servers state that is available between all protocol instances.
- """
- def __init__(self) -> None:
- self.total_requests = 0
- self.connections: set[Protocols] = set()
- self.tasks: set[asyncio.Task[None]] = set()
- self.default_headers: list[tuple[bytes, bytes]] = []
- class Server:
- def __init__(self, config: Config) -> None:
- self.config = config
- self.server_state = ServerState()
- self.started = False
- self.should_exit = False
- self.force_exit = False
- self.last_notified = 0.0
- self._captured_signals: list[int] = []
- def run(self, sockets: list[socket.socket] | None = None) -> None:
- return asyncio_run(self.serve(sockets=sockets), loop_factory=self.config.get_loop_factory())
- async def serve(self, sockets: list[socket.socket] | None = None) -> None:
- with self.capture_signals():
- await self._serve(sockets)
- async def _serve(self, sockets: list[socket.socket] | None = None) -> None:
- process_id = os.getpid()
- config = self.config
- if not config.loaded:
- config.load()
- self.lifespan = config.lifespan_class(config)
- message = "Started server process [%d]"
- color_message = "Started server process [" + click.style("%d", fg="cyan") + "]"
- logger.info(message, process_id, extra={"color_message": color_message})
- await self.startup(sockets=sockets)
- if self.should_exit:
- return
- await self.main_loop()
- await self.shutdown(sockets=sockets)
- message = "Finished server process [%d]"
- color_message = "Finished server process [" + click.style("%d", fg="cyan") + "]"
- logger.info(message, process_id, extra={"color_message": color_message})
- async def startup(self, sockets: list[socket.socket] | None = None) -> None:
- await self.lifespan.startup()
- if self.lifespan.should_exit:
- self.should_exit = True
- return
- config = self.config
- def create_protocol(
- _loop: asyncio.AbstractEventLoop | None = None,
- ) -> asyncio.Protocol:
- return config.http_protocol_class( # type: ignore[call-arg]
- config=config,
- server_state=self.server_state,
- app_state=self.lifespan.state,
- _loop=_loop,
- )
- loop = asyncio.get_running_loop()
- listeners: Sequence[socket.SocketType]
- if sockets is not None: # pragma: full coverage
- # Explicitly passed a list of open sockets.
- # We use this when the server is run from a Gunicorn worker.
- def _share_socket(
- sock: socket.SocketType,
- ) -> socket.SocketType: # pragma py-not-win32
- # Windows requires the socket be explicitly shared across
- # multiple workers (processes).
- from socket import fromshare # type: ignore[attr-defined]
- sock_data = sock.share(os.getpid()) # type: ignore[attr-defined]
- return fromshare(sock_data)
- self.servers: list[asyncio.base_events.Server] = []
- for sock in sockets:
- is_windows = platform.system() == "Windows"
- if config.workers > 1 and is_windows: # pragma: py-not-win32
- sock = _share_socket(sock) # type: ignore[assignment]
- server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
- self.servers.append(server)
- listeners = sockets
- elif config.fd is not None: # pragma: py-win32
- # Use an existing socket, from a file descriptor.
- sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
- server = await loop.create_server(create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog)
- assert server.sockets is not None # mypy
- listeners = server.sockets
- self.servers = [server]
- elif config.uds is not None: # pragma: py-win32
- # Create a socket using UNIX domain socket.
- uds_perms = 0o666
- if os.path.exists(config.uds):
- uds_perms = os.stat(config.uds).st_mode # pragma: full coverage
- server = await loop.create_unix_server(
- create_protocol, path=config.uds, ssl=config.ssl, backlog=config.backlog
- )
- os.chmod(config.uds, uds_perms)
- assert server.sockets is not None # mypy
- listeners = server.sockets
- self.servers = [server]
- else:
- # Standard case. Create a socket from a host/port pair.
- try:
- server = await loop.create_server(
- create_protocol,
- host=config.host,
- port=config.port,
- ssl=config.ssl,
- backlog=config.backlog,
- )
- except OSError as exc:
- logger.error(exc)
- await self.lifespan.shutdown()
- sys.exit(1)
- assert server.sockets is not None
- listeners = server.sockets
- self.servers = [server]
- if sockets is None:
- self._log_started_message(listeners)
- else:
- # We're most likely running multiple workers, so a message has already been
- # logged by `config.bind_socket()`.
- pass # pragma: full coverage
- self.started = True
- def _log_started_message(self, listeners: Sequence[socket.SocketType]) -> None:
- config = self.config
- if config.fd is not None: # pragma: py-win32
- sock = listeners[0]
- logger.info(
- "Uvicorn running on socket %s (Press CTRL+C to quit)",
- sock.getsockname(),
- )
- elif config.uds is not None: # pragma: py-win32
- logger.info("Uvicorn running on unix socket %s (Press CTRL+C to quit)", config.uds)
- else:
- addr_format = "%s://%s:%d"
- host = "0.0.0.0" if config.host is None else config.host
- if ":" in host:
- # It's an IPv6 address.
- addr_format = "%s://[%s]:%d"
- port = config.port
- if port == 0:
- port = listeners[0].getsockname()[1]
- protocol_name = "https" if config.ssl else "http"
- message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
- color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
- logger.info(
- message,
- protocol_name,
- host,
- port,
- extra={"color_message": color_message},
- )
- async def main_loop(self) -> None:
- counter = 0
- should_exit = await self.on_tick(counter)
- while not should_exit:
- counter += 1
- counter = counter % 864000
- await asyncio.sleep(0.1)
- should_exit = await self.on_tick(counter)
- async def on_tick(self, counter: int) -> bool:
- # Update the default headers, once per second.
- if counter % 10 == 0:
- current_time = time.time()
- current_date = formatdate(current_time, usegmt=True).encode()
- if self.config.date_header:
- date_header = [(b"date", current_date)]
- else:
- date_header = []
- self.server_state.default_headers = date_header + self.config.encoded_headers
- # Callback to `callback_notify` once every `timeout_notify` seconds.
- if self.config.callback_notify is not None:
- if current_time - self.last_notified > self.config.timeout_notify: # pragma: full coverage
- self.last_notified = current_time
- await self.config.callback_notify()
- # Determine if we should exit.
- if self.should_exit:
- return True
- max_requests = self.config.limit_max_requests
- if max_requests is not None and self.server_state.total_requests >= max_requests:
- logger.warning(f"Maximum request limit of {max_requests} exceeded. Terminating process.")
- return True
- return False
- async def shutdown(self, sockets: list[socket.socket] | None = None) -> None:
- logger.info("Shutting down")
- # Stop accepting new connections.
- for server in self.servers:
- server.close()
- for sock in sockets or []:
- sock.close() # pragma: full coverage
- # Request shutdown on all existing connections.
- for connection in list(self.server_state.connections):
- connection.shutdown()
- await asyncio.sleep(0.1)
- # When 3.10 is not supported anymore, use `async with asyncio.timeout(...):`.
- try:
- await asyncio.wait_for(
- self._wait_tasks_to_complete(),
- timeout=self.config.timeout_graceful_shutdown,
- )
- except asyncio.TimeoutError:
- logger.error(
- "Cancel %s running task(s), timeout graceful shutdown exceeded",
- len(self.server_state.tasks),
- )
- for t in self.server_state.tasks:
- t.cancel(msg="Task cancelled, timeout graceful shutdown exceeded")
- # Send the lifespan shutdown event, and wait for application shutdown.
- if not self.force_exit:
- await self.lifespan.shutdown()
- async def _wait_tasks_to_complete(self) -> None:
- # Wait for existing connections to finish sending responses.
- if self.server_state.connections and not self.force_exit:
- msg = "Waiting for connections to close. (CTRL+C to force quit)"
- logger.info(msg)
- while self.server_state.connections and not self.force_exit:
- await asyncio.sleep(0.1)
- # Wait for existing tasks to complete.
- if self.server_state.tasks and not self.force_exit:
- msg = "Waiting for background tasks to complete. (CTRL+C to force quit)"
- logger.info(msg)
- while self.server_state.tasks and not self.force_exit:
- await asyncio.sleep(0.1)
- for server in self.servers:
- await server.wait_closed()
- @contextlib.contextmanager
- def capture_signals(self) -> Generator[None, None, None]:
- # Signals can only be listened to from the main thread.
- if threading.current_thread() is not threading.main_thread():
- yield
- return
- # always use signal.signal, even if loop.add_signal_handler is available
- # this allows to restore previous signal handlers later on
- original_handlers = {sig: signal.signal(sig, self.handle_exit) for sig in HANDLED_SIGNALS}
- try:
- yield
- finally:
- for sig, handler in original_handlers.items():
- signal.signal(sig, handler)
- # If we did gracefully shut down due to a signal, try to
- # trigger the expected behaviour now; multiple signals would be
- # done LIFO, see https://stackoverflow.com/questions/48434964
- for captured_signal in reversed(self._captured_signals):
- signal.raise_signal(captured_signal)
- def handle_exit(self, sig: int, frame: FrameType | None) -> None:
- self._captured_signals.append(sig)
- if self.should_exit and sig == signal.SIGINT:
- self.force_exit = True # pragma: full coverage
- else:
- self.should_exit = True
|