config.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. from __future__ import annotations
  2. import asyncio
  3. import inspect
  4. import json
  5. import logging
  6. import logging.config
  7. import os
  8. import socket
  9. import ssl
  10. import sys
  11. from collections.abc import Awaitable
  12. from configparser import RawConfigParser
  13. from pathlib import Path
  14. from typing import IO, Any, Callable, Literal
  15. import click
  16. from uvicorn._types import ASGIApplication
  17. from uvicorn.importer import ImportFromStringError, import_from_string
  18. from uvicorn.logging import TRACE_LOG_LEVEL
  19. from uvicorn.middleware.asgi2 import ASGI2Middleware
  20. from uvicorn.middleware.message_logger import MessageLoggerMiddleware
  21. from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
  22. from uvicorn.middleware.wsgi import WSGIMiddleware
  23. HTTPProtocolType = Literal["auto", "h11", "httptools"]
  24. WSProtocolType = Literal["auto", "none", "websockets", "websockets-sansio", "wsproto"]
  25. LifespanType = Literal["auto", "on", "off"]
  26. LoopFactoryType = Literal["none", "auto", "asyncio", "uvloop"]
  27. InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"]
  28. LOG_LEVELS: dict[str, int] = {
  29. "critical": logging.CRITICAL,
  30. "error": logging.ERROR,
  31. "warning": logging.WARNING,
  32. "info": logging.INFO,
  33. "debug": logging.DEBUG,
  34. "trace": TRACE_LOG_LEVEL,
  35. }
  36. HTTP_PROTOCOLS: dict[str, str] = {
  37. "auto": "uvicorn.protocols.http.auto:AutoHTTPProtocol",
  38. "h11": "uvicorn.protocols.http.h11_impl:H11Protocol",
  39. "httptools": "uvicorn.protocols.http.httptools_impl:HttpToolsProtocol",
  40. }
  41. WS_PROTOCOLS: dict[str, str | None] = {
  42. "auto": "uvicorn.protocols.websockets.auto:AutoWebSocketsProtocol",
  43. "none": None,
  44. "websockets": "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
  45. "websockets-sansio": "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol",
  46. "wsproto": "uvicorn.protocols.websockets.wsproto_impl:WSProtocol",
  47. }
  48. LIFESPAN: dict[str, str] = {
  49. "auto": "uvicorn.lifespan.on:LifespanOn",
  50. "on": "uvicorn.lifespan.on:LifespanOn",
  51. "off": "uvicorn.lifespan.off:LifespanOff",
  52. }
  53. LOOP_FACTORIES: dict[str, str | None] = {
  54. "none": None,
  55. "auto": "uvicorn.loops.auto:auto_loop_factory",
  56. "asyncio": "uvicorn.loops.asyncio:asyncio_loop_factory",
  57. "uvloop": "uvicorn.loops.uvloop:uvloop_loop_factory",
  58. }
  59. INTERFACES: list[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"]
  60. SSL_PROTOCOL_VERSION: int = ssl.PROTOCOL_TLS_SERVER
  61. LOGGING_CONFIG: dict[str, Any] = {
  62. "version": 1,
  63. "disable_existing_loggers": False,
  64. "formatters": {
  65. "default": {
  66. "()": "uvicorn.logging.DefaultFormatter",
  67. "fmt": "%(levelprefix)s %(message)s",
  68. "use_colors": None,
  69. },
  70. "access": {
  71. "()": "uvicorn.logging.AccessFormatter",
  72. "fmt": '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s', # noqa: E501
  73. },
  74. },
  75. "handlers": {
  76. "default": {
  77. "formatter": "default",
  78. "class": "logging.StreamHandler",
  79. "stream": "ext://sys.stderr",
  80. },
  81. "access": {
  82. "formatter": "access",
  83. "class": "logging.StreamHandler",
  84. "stream": "ext://sys.stdout",
  85. },
  86. },
  87. "loggers": {
  88. "uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": False},
  89. "uvicorn.error": {"level": "INFO"},
  90. "uvicorn.access": {"handlers": ["access"], "level": "INFO", "propagate": False},
  91. },
  92. }
  93. logger = logging.getLogger("uvicorn.error")
  94. def create_ssl_context(
  95. certfile: str | os.PathLike[str],
  96. keyfile: str | os.PathLike[str] | None,
  97. password: str | None,
  98. ssl_version: int,
  99. cert_reqs: int,
  100. ca_certs: str | os.PathLike[str] | None,
  101. ciphers: str | None,
  102. ) -> ssl.SSLContext:
  103. ctx = ssl.SSLContext(ssl_version)
  104. get_password = (lambda: password) if password else None
  105. ctx.load_cert_chain(certfile, keyfile, get_password)
  106. ctx.verify_mode = ssl.VerifyMode(cert_reqs)
  107. if ca_certs:
  108. ctx.load_verify_locations(ca_certs)
  109. if ciphers:
  110. ctx.set_ciphers(ciphers)
  111. return ctx
  112. def is_dir(path: Path) -> bool:
  113. try:
  114. if not path.is_absolute():
  115. path = path.resolve()
  116. return path.is_dir()
  117. except OSError: # pragma: full coverage
  118. return False
  119. def resolve_reload_patterns(patterns_list: list[str], directories_list: list[str]) -> tuple[list[str], list[Path]]:
  120. directories: list[Path] = list(set(map(Path, directories_list.copy())))
  121. patterns: list[str] = patterns_list.copy()
  122. current_working_directory = Path.cwd()
  123. for pattern in patterns_list:
  124. # Special case for the .* pattern, otherwise this would only match
  125. # hidden directories which is probably undesired
  126. if pattern == ".*":
  127. continue # pragma: py-not-linux
  128. patterns.append(pattern)
  129. if is_dir(Path(pattern)):
  130. directories.append(Path(pattern))
  131. else:
  132. for match in current_working_directory.glob(pattern):
  133. if is_dir(match):
  134. directories.append(match)
  135. directories = list(set(directories))
  136. directories = list(map(Path, directories))
  137. directories = list(map(lambda x: x.resolve(), directories))
  138. directories = list({reload_path for reload_path in directories if is_dir(reload_path)})
  139. children = []
  140. for j in range(len(directories)):
  141. for k in range(j + 1, len(directories)): # pragma: full coverage
  142. if directories[j] in directories[k].parents:
  143. children.append(directories[k])
  144. elif directories[k] in directories[j].parents:
  145. children.append(directories[j])
  146. directories = list(set(directories).difference(set(children)))
  147. return list(set(patterns)), directories
  148. def _normalize_dirs(dirs: list[str] | str | None) -> list[str]:
  149. if dirs is None:
  150. return []
  151. if isinstance(dirs, str):
  152. return [dirs]
  153. return list(set(dirs))
  154. class Config:
  155. def __init__(
  156. self,
  157. app: ASGIApplication | Callable[..., Any] | str,
  158. host: str = "127.0.0.1",
  159. port: int = 8000,
  160. uds: str | None = None,
  161. fd: int | None = None,
  162. loop: LoopFactoryType | str = "auto",
  163. http: type[asyncio.Protocol] | HTTPProtocolType | str = "auto",
  164. ws: type[asyncio.Protocol] | WSProtocolType | str = "auto",
  165. ws_max_size: int = 16 * 1024 * 1024,
  166. ws_max_queue: int = 32,
  167. ws_ping_interval: float | None = 20.0,
  168. ws_ping_timeout: float | None = 20.0,
  169. ws_per_message_deflate: bool = True,
  170. lifespan: LifespanType = "auto",
  171. env_file: str | os.PathLike[str] | None = None,
  172. log_config: dict[str, Any] | str | RawConfigParser | IO[Any] | None = LOGGING_CONFIG,
  173. log_level: str | int | None = None,
  174. access_log: bool = True,
  175. use_colors: bool | None = None,
  176. interface: InterfaceType = "auto",
  177. reload: bool = False,
  178. reload_dirs: list[str] | str | None = None,
  179. reload_delay: float = 0.25,
  180. reload_includes: list[str] | str | None = None,
  181. reload_excludes: list[str] | str | None = None,
  182. workers: int | None = None,
  183. proxy_headers: bool = True,
  184. server_header: bool = True,
  185. date_header: bool = True,
  186. forwarded_allow_ips: list[str] | str | None = None,
  187. root_path: str = "",
  188. limit_concurrency: int | None = None,
  189. limit_max_requests: int | None = None,
  190. backlog: int = 2048,
  191. timeout_keep_alive: int = 5,
  192. timeout_notify: int = 30,
  193. timeout_graceful_shutdown: int | None = None,
  194. timeout_worker_healthcheck: int = 5,
  195. callback_notify: Callable[..., Awaitable[None]] | None = None,
  196. ssl_keyfile: str | os.PathLike[str] | None = None,
  197. ssl_certfile: str | os.PathLike[str] | None = None,
  198. ssl_keyfile_password: str | None = None,
  199. ssl_version: int = SSL_PROTOCOL_VERSION,
  200. ssl_cert_reqs: int = ssl.CERT_NONE,
  201. ssl_ca_certs: str | os.PathLike[str] | None = None,
  202. ssl_ciphers: str = "TLSv1",
  203. headers: list[tuple[str, str]] | None = None,
  204. factory: bool = False,
  205. h11_max_incomplete_event_size: int | None = None,
  206. ):
  207. self.app = app
  208. self.host = host
  209. self.port = port
  210. self.uds = uds
  211. self.fd = fd
  212. self.loop = loop
  213. self.http = http
  214. self.ws = ws
  215. self.ws_max_size = ws_max_size
  216. self.ws_max_queue = ws_max_queue
  217. self.ws_ping_interval = ws_ping_interval
  218. self.ws_ping_timeout = ws_ping_timeout
  219. self.ws_per_message_deflate = ws_per_message_deflate
  220. self.lifespan = lifespan
  221. self.log_config = log_config
  222. self.log_level = log_level
  223. self.access_log = access_log
  224. self.use_colors = use_colors
  225. self.interface = interface
  226. self.reload = reload
  227. self.reload_delay = reload_delay
  228. self.workers = workers or 1
  229. self.proxy_headers = proxy_headers
  230. self.server_header = server_header
  231. self.date_header = date_header
  232. self.root_path = root_path
  233. self.limit_concurrency = limit_concurrency
  234. self.limit_max_requests = limit_max_requests
  235. self.backlog = backlog
  236. self.timeout_keep_alive = timeout_keep_alive
  237. self.timeout_notify = timeout_notify
  238. self.timeout_graceful_shutdown = timeout_graceful_shutdown
  239. self.timeout_worker_healthcheck = timeout_worker_healthcheck
  240. self.callback_notify = callback_notify
  241. self.ssl_keyfile = ssl_keyfile
  242. self.ssl_certfile = ssl_certfile
  243. self.ssl_keyfile_password = ssl_keyfile_password
  244. self.ssl_version = ssl_version
  245. self.ssl_cert_reqs = ssl_cert_reqs
  246. self.ssl_ca_certs = ssl_ca_certs
  247. self.ssl_ciphers = ssl_ciphers
  248. self.headers: list[tuple[str, str]] = headers or []
  249. self.encoded_headers: list[tuple[bytes, bytes]] = []
  250. self.factory = factory
  251. self.h11_max_incomplete_event_size = h11_max_incomplete_event_size
  252. self.loaded = False
  253. self.configure_logging()
  254. self.reload_dirs: list[Path] = []
  255. self.reload_dirs_excludes: list[Path] = []
  256. self.reload_includes: list[str] = []
  257. self.reload_excludes: list[str] = []
  258. if (reload_dirs or reload_includes or reload_excludes) and not self.should_reload:
  259. logger.warning(
  260. "Current configuration will not reload as not all conditions are met, please refer to documentation."
  261. )
  262. if self.should_reload:
  263. reload_dirs = _normalize_dirs(reload_dirs)
  264. reload_includes = _normalize_dirs(reload_includes)
  265. reload_excludes = _normalize_dirs(reload_excludes)
  266. self.reload_includes, self.reload_dirs = resolve_reload_patterns(reload_includes, reload_dirs)
  267. self.reload_excludes, self.reload_dirs_excludes = resolve_reload_patterns(reload_excludes, [])
  268. reload_dirs_tmp = self.reload_dirs.copy()
  269. for directory in self.reload_dirs_excludes:
  270. for reload_directory in reload_dirs_tmp:
  271. if directory == reload_directory or directory in reload_directory.parents:
  272. try:
  273. self.reload_dirs.remove(reload_directory)
  274. except ValueError: # pragma: full coverage
  275. pass
  276. for pattern in self.reload_excludes:
  277. if pattern in self.reload_includes:
  278. self.reload_includes.remove(pattern) # pragma: full coverage
  279. if not self.reload_dirs:
  280. if reload_dirs:
  281. logger.warning(
  282. "Provided reload directories %s did not contain valid "
  283. + "directories, watching current working directory.",
  284. reload_dirs,
  285. )
  286. self.reload_dirs = [Path.cwd()]
  287. logger.info(
  288. "Will watch for changes in these directories: %s",
  289. sorted(list(map(str, self.reload_dirs))),
  290. )
  291. if env_file is not None:
  292. from dotenv import load_dotenv
  293. logger.info("Loading environment from '%s'", env_file)
  294. load_dotenv(dotenv_path=env_file)
  295. if workers is None and "WEB_CONCURRENCY" in os.environ:
  296. self.workers = int(os.environ["WEB_CONCURRENCY"])
  297. self.forwarded_allow_ips: list[str] | str
  298. if forwarded_allow_ips is None:
  299. self.forwarded_allow_ips = os.environ.get("FORWARDED_ALLOW_IPS", "127.0.0.1")
  300. else:
  301. self.forwarded_allow_ips = forwarded_allow_ips # pragma: full coverage
  302. if self.reload and self.workers > 1:
  303. logger.warning('"workers" flag is ignored when reloading is enabled.')
  304. @property
  305. def asgi_version(self) -> Literal["2.0", "3.0"]:
  306. mapping: dict[str, Literal["2.0", "3.0"]] = {
  307. "asgi2": "2.0",
  308. "asgi3": "3.0",
  309. "wsgi": "3.0",
  310. }
  311. return mapping[self.interface]
  312. @property
  313. def is_ssl(self) -> bool:
  314. return bool(self.ssl_keyfile or self.ssl_certfile)
  315. @property
  316. def use_subprocess(self) -> bool:
  317. return bool(self.reload or self.workers > 1)
  318. def configure_logging(self) -> None:
  319. logging.addLevelName(TRACE_LOG_LEVEL, "TRACE")
  320. if self.log_config is not None:
  321. if isinstance(self.log_config, dict):
  322. if self.use_colors in (True, False):
  323. self.log_config["formatters"]["default"]["use_colors"] = self.use_colors
  324. self.log_config["formatters"]["access"]["use_colors"] = self.use_colors
  325. logging.config.dictConfig(self.log_config)
  326. elif isinstance(self.log_config, str) and self.log_config.endswith(".json"):
  327. with open(self.log_config) as file:
  328. loaded_config = json.load(file)
  329. logging.config.dictConfig(loaded_config)
  330. elif isinstance(self.log_config, str) and self.log_config.endswith((".yaml", ".yml")):
  331. # Install the PyYAML package or the uvicorn[standard] optional
  332. # dependencies to enable this functionality.
  333. import yaml
  334. with open(self.log_config) as file:
  335. loaded_config = yaml.safe_load(file)
  336. logging.config.dictConfig(loaded_config)
  337. else:
  338. # See the note about fileConfig() here:
  339. # https://docs.python.org/3/library/logging.config.html#configuration-file-format
  340. logging.config.fileConfig(self.log_config, disable_existing_loggers=False)
  341. if self.log_level is not None:
  342. if isinstance(self.log_level, str):
  343. log_level = LOG_LEVELS[self.log_level]
  344. else:
  345. log_level = self.log_level
  346. logging.getLogger("uvicorn.error").setLevel(log_level)
  347. logging.getLogger("uvicorn.access").setLevel(log_level)
  348. logging.getLogger("uvicorn.asgi").setLevel(log_level)
  349. if self.access_log is False:
  350. logging.getLogger("uvicorn.access").handlers = []
  351. logging.getLogger("uvicorn.access").propagate = False
  352. def load(self) -> None:
  353. assert not self.loaded
  354. if self.is_ssl:
  355. assert self.ssl_certfile
  356. self.ssl: ssl.SSLContext | None = create_ssl_context(
  357. keyfile=self.ssl_keyfile,
  358. certfile=self.ssl_certfile,
  359. password=self.ssl_keyfile_password,
  360. ssl_version=self.ssl_version,
  361. cert_reqs=self.ssl_cert_reqs,
  362. ca_certs=self.ssl_ca_certs,
  363. ciphers=self.ssl_ciphers,
  364. )
  365. else:
  366. self.ssl = None
  367. encoded_headers = [(key.lower().encode("latin1"), value.encode("latin1")) for key, value in self.headers]
  368. self.encoded_headers = (
  369. [(b"server", b"uvicorn")] + encoded_headers
  370. if b"server" not in dict(encoded_headers) and self.server_header
  371. else encoded_headers
  372. )
  373. if isinstance(self.http, str):
  374. http_protocol_class = import_from_string(HTTP_PROTOCOLS.get(self.http, self.http))
  375. self.http_protocol_class: type[asyncio.Protocol] = http_protocol_class
  376. else:
  377. self.http_protocol_class = self.http
  378. if isinstance(self.ws, str):
  379. ws_protocol_class = import_from_string(WS_PROTOCOLS.get(self.ws, self.ws))
  380. self.ws_protocol_class: type[asyncio.Protocol] | None = ws_protocol_class
  381. else:
  382. self.ws_protocol_class = self.ws
  383. self.lifespan_class = import_from_string(LIFESPAN[self.lifespan])
  384. try:
  385. self.loaded_app = import_from_string(self.app)
  386. except ImportFromStringError as exc:
  387. logger.error("Error loading ASGI app. %s" % exc)
  388. sys.exit(1)
  389. try:
  390. self.loaded_app = self.loaded_app()
  391. except TypeError as exc:
  392. if self.factory:
  393. logger.error("Error loading ASGI app factory: %s", exc)
  394. sys.exit(1)
  395. else:
  396. if not self.factory:
  397. logger.warning(
  398. "ASGI app factory detected. Using it, but please consider setting the --factory flag explicitly."
  399. )
  400. if self.interface == "auto":
  401. if inspect.isclass(self.loaded_app):
  402. use_asgi_3 = hasattr(self.loaded_app, "__await__")
  403. elif inspect.isfunction(self.loaded_app):
  404. use_asgi_3 = inspect.iscoroutinefunction(self.loaded_app)
  405. else:
  406. call = getattr(self.loaded_app, "__call__", None)
  407. use_asgi_3 = inspect.iscoroutinefunction(call)
  408. self.interface = "asgi3" if use_asgi_3 else "asgi2"
  409. if self.interface == "wsgi":
  410. self.loaded_app = WSGIMiddleware(self.loaded_app)
  411. self.ws_protocol_class = None
  412. elif self.interface == "asgi2":
  413. self.loaded_app = ASGI2Middleware(self.loaded_app)
  414. if logger.getEffectiveLevel() <= TRACE_LOG_LEVEL:
  415. self.loaded_app = MessageLoggerMiddleware(self.loaded_app)
  416. if self.proxy_headers:
  417. self.loaded_app = ProxyHeadersMiddleware(self.loaded_app, trusted_hosts=self.forwarded_allow_ips)
  418. self.loaded = True
  419. def setup_event_loop(self) -> None:
  420. raise AttributeError(
  421. "The `setup_event_loop` method was replaced by `get_loop_factory` in uvicorn 0.36.0.\n"
  422. "None of those methods are supposed to be used directly. If you are doing it, please let me know here: "
  423. "https://github.com/Kludex/uvicorn/discussions/2706. Thank you, and sorry for the inconvenience."
  424. )
  425. def get_loop_factory(self) -> Callable[[], asyncio.AbstractEventLoop] | None:
  426. if self.loop in LOOP_FACTORIES:
  427. loop_factory: Callable[..., Any] | None = import_from_string(LOOP_FACTORIES[self.loop])
  428. else:
  429. try:
  430. return import_from_string(self.loop)
  431. except ImportFromStringError as exc:
  432. logger.error("Error loading custom loop setup function. %s" % exc)
  433. sys.exit(1)
  434. if loop_factory is None:
  435. return None
  436. return loop_factory(use_subprocess=self.use_subprocess)
  437. def bind_socket(self) -> socket.socket:
  438. logger_args: list[str | int]
  439. if self.uds: # pragma: py-win32
  440. path = self.uds
  441. sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
  442. try:
  443. sock.bind(path)
  444. uds_perms = 0o666
  445. os.chmod(self.uds, uds_perms)
  446. except OSError as exc: # pragma: full coverage
  447. logger.error(exc)
  448. sys.exit(1)
  449. message = "Uvicorn running on unix socket %s (Press CTRL+C to quit)"
  450. sock_name_format = "%s"
  451. color_message = "Uvicorn running on " + click.style(sock_name_format, bold=True) + " (Press CTRL+C to quit)"
  452. logger_args = [self.uds]
  453. elif self.fd: # pragma: py-win32
  454. sock = socket.fromfd(self.fd, socket.AF_UNIX, socket.SOCK_STREAM)
  455. message = "Uvicorn running on socket %s (Press CTRL+C to quit)"
  456. fd_name_format = "%s"
  457. color_message = "Uvicorn running on " + click.style(fd_name_format, bold=True) + " (Press CTRL+C to quit)"
  458. logger_args = [sock.getsockname()]
  459. else:
  460. family = socket.AF_INET
  461. addr_format = "%s://%s:%d"
  462. if self.host and ":" in self.host: # pragma: full coverage
  463. # It's an IPv6 address.
  464. family = socket.AF_INET6
  465. addr_format = "%s://[%s]:%d"
  466. sock = socket.socket(family=family)
  467. sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  468. try:
  469. sock.bind((self.host, self.port))
  470. except OSError as exc: # pragma: full coverage
  471. logger.error(exc)
  472. sys.exit(1)
  473. message = f"Uvicorn running on {addr_format} (Press CTRL+C to quit)"
  474. color_message = "Uvicorn running on " + click.style(addr_format, bold=True) + " (Press CTRL+C to quit)"
  475. protocol_name = "https" if self.is_ssl else "http"
  476. logger_args = [protocol_name, self.host, sock.getsockname()[1]]
  477. logger.info(message, *logger_args, extra={"color_message": color_message})
  478. sock.set_inheritable(True)
  479. return sock
  480. @property
  481. def should_reload(self) -> bool:
  482. return isinstance(self.app, str) and self.reload