utils.py 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012
  1. import inspect
  2. import sys
  3. from contextlib import AsyncExitStack, contextmanager
  4. from copy import copy, deepcopy
  5. from dataclasses import dataclass
  6. from typing import (
  7. Any,
  8. Callable,
  9. Coroutine,
  10. Dict,
  11. ForwardRef,
  12. List,
  13. Mapping,
  14. Optional,
  15. Sequence,
  16. Tuple,
  17. Type,
  18. Union,
  19. cast,
  20. )
  21. import anyio
  22. from fastapi import params
  23. from fastapi._compat import (
  24. PYDANTIC_V2,
  25. ErrorWrapper,
  26. ModelField,
  27. RequiredParam,
  28. Undefined,
  29. _regenerate_error_with_loc,
  30. copy_field_info,
  31. create_body_model,
  32. evaluate_forwardref,
  33. field_annotation_is_scalar,
  34. get_annotation_from_field_info,
  35. get_cached_model_fields,
  36. get_missing_field_error,
  37. is_bytes_field,
  38. is_bytes_sequence_field,
  39. is_scalar_field,
  40. is_scalar_sequence_field,
  41. is_sequence_field,
  42. is_uploadfile_or_nonable_uploadfile_annotation,
  43. is_uploadfile_sequence_annotation,
  44. lenient_issubclass,
  45. sequence_types,
  46. serialize_sequence_value,
  47. value_is_sequence,
  48. )
  49. from fastapi.background import BackgroundTasks
  50. from fastapi.concurrency import (
  51. asynccontextmanager,
  52. contextmanager_in_threadpool,
  53. )
  54. from fastapi.dependencies.models import Dependant, SecurityRequirement
  55. from fastapi.logger import logger
  56. from fastapi.security.base import SecurityBase
  57. from fastapi.security.oauth2 import OAuth2, SecurityScopes
  58. from fastapi.security.open_id_connect_url import OpenIdConnect
  59. from fastapi.utils import create_model_field, get_path_param_names
  60. from pydantic import BaseModel
  61. from pydantic.fields import FieldInfo
  62. from starlette.background import BackgroundTasks as StarletteBackgroundTasks
  63. from starlette.concurrency import run_in_threadpool
  64. from starlette.datastructures import (
  65. FormData,
  66. Headers,
  67. ImmutableMultiDict,
  68. QueryParams,
  69. UploadFile,
  70. )
  71. from starlette.requests import HTTPConnection, Request
  72. from starlette.responses import Response
  73. from starlette.websockets import WebSocket
  74. from typing_extensions import Annotated, get_args, get_origin
  75. if sys.version_info >= (3, 13): # pragma: no cover
  76. from inspect import iscoroutinefunction
  77. else: # pragma: no cover
  78. from asyncio import iscoroutinefunction
  79. multipart_not_installed_error = (
  80. 'Form data requires "python-multipart" to be installed. \n'
  81. 'You can install "python-multipart" with: \n\n'
  82. "pip install python-multipart\n"
  83. )
  84. multipart_incorrect_install_error = (
  85. 'Form data requires "python-multipart" to be installed. '
  86. 'It seems you installed "multipart" instead. \n'
  87. 'You can remove "multipart" with: \n\n'
  88. "pip uninstall multipart\n\n"
  89. 'And then install "python-multipart" with: \n\n'
  90. "pip install python-multipart\n"
  91. )
  92. def ensure_multipart_is_installed() -> None:
  93. try:
  94. from python_multipart import __version__
  95. # Import an attribute that can be mocked/deleted in testing
  96. assert __version__ > "0.0.12"
  97. except (ImportError, AssertionError):
  98. try:
  99. # __version__ is available in both multiparts, and can be mocked
  100. from multipart import __version__ # type: ignore[no-redef,import-untyped]
  101. assert __version__
  102. try:
  103. # parse_options_header is only available in the right multipart
  104. from multipart.multipart import ( # type: ignore[import-untyped]
  105. parse_options_header,
  106. )
  107. assert parse_options_header
  108. except ImportError:
  109. logger.error(multipart_incorrect_install_error)
  110. raise RuntimeError(multipart_incorrect_install_error) from None
  111. except ImportError:
  112. logger.error(multipart_not_installed_error)
  113. raise RuntimeError(multipart_not_installed_error) from None
  114. def get_param_sub_dependant(
  115. *,
  116. param_name: str,
  117. depends: params.Depends,
  118. path: str,
  119. security_scopes: Optional[List[str]] = None,
  120. ) -> Dependant:
  121. assert depends.dependency
  122. return get_sub_dependant(
  123. depends=depends,
  124. dependency=depends.dependency,
  125. path=path,
  126. name=param_name,
  127. security_scopes=security_scopes,
  128. )
  129. def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
  130. assert callable(depends.dependency), (
  131. "A parameter-less dependency must have a callable dependency"
  132. )
  133. return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
  134. def get_sub_dependant(
  135. *,
  136. depends: params.Depends,
  137. dependency: Callable[..., Any],
  138. path: str,
  139. name: Optional[str] = None,
  140. security_scopes: Optional[List[str]] = None,
  141. ) -> Dependant:
  142. security_requirement = None
  143. security_scopes = security_scopes or []
  144. if isinstance(depends, params.Security):
  145. dependency_scopes = depends.scopes
  146. security_scopes.extend(dependency_scopes)
  147. if isinstance(dependency, SecurityBase):
  148. use_scopes: List[str] = []
  149. if isinstance(dependency, (OAuth2, OpenIdConnect)):
  150. use_scopes = security_scopes
  151. security_requirement = SecurityRequirement(
  152. security_scheme=dependency, scopes=use_scopes
  153. )
  154. sub_dependant = get_dependant(
  155. path=path,
  156. call=dependency,
  157. name=name,
  158. security_scopes=security_scopes,
  159. use_cache=depends.use_cache,
  160. )
  161. if security_requirement:
  162. sub_dependant.security_requirements.append(security_requirement)
  163. return sub_dependant
  164. CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
  165. def get_flat_dependant(
  166. dependant: Dependant,
  167. *,
  168. skip_repeats: bool = False,
  169. visited: Optional[List[CacheKey]] = None,
  170. ) -> Dependant:
  171. if visited is None:
  172. visited = []
  173. visited.append(dependant.cache_key)
  174. flat_dependant = Dependant(
  175. path_params=dependant.path_params.copy(),
  176. query_params=dependant.query_params.copy(),
  177. header_params=dependant.header_params.copy(),
  178. cookie_params=dependant.cookie_params.copy(),
  179. body_params=dependant.body_params.copy(),
  180. security_requirements=dependant.security_requirements.copy(),
  181. use_cache=dependant.use_cache,
  182. path=dependant.path,
  183. )
  184. for sub_dependant in dependant.dependencies:
  185. if skip_repeats and sub_dependant.cache_key in visited:
  186. continue
  187. flat_sub = get_flat_dependant(
  188. sub_dependant, skip_repeats=skip_repeats, visited=visited
  189. )
  190. flat_dependant.path_params.extend(flat_sub.path_params)
  191. flat_dependant.query_params.extend(flat_sub.query_params)
  192. flat_dependant.header_params.extend(flat_sub.header_params)
  193. flat_dependant.cookie_params.extend(flat_sub.cookie_params)
  194. flat_dependant.body_params.extend(flat_sub.body_params)
  195. flat_dependant.security_requirements.extend(flat_sub.security_requirements)
  196. return flat_dependant
  197. def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
  198. if not fields:
  199. return fields
  200. first_field = fields[0]
  201. if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
  202. fields_to_extract = get_cached_model_fields(first_field.type_)
  203. return fields_to_extract
  204. return fields
  205. def get_flat_params(dependant: Dependant) -> List[ModelField]:
  206. flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
  207. path_params = _get_flat_fields_from_params(flat_dependant.path_params)
  208. query_params = _get_flat_fields_from_params(flat_dependant.query_params)
  209. header_params = _get_flat_fields_from_params(flat_dependant.header_params)
  210. cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
  211. return path_params + query_params + header_params + cookie_params
  212. def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
  213. signature = inspect.signature(call)
  214. globalns = getattr(call, "__globals__", {})
  215. typed_params = [
  216. inspect.Parameter(
  217. name=param.name,
  218. kind=param.kind,
  219. default=param.default,
  220. annotation=get_typed_annotation(param.annotation, globalns),
  221. )
  222. for param in signature.parameters.values()
  223. ]
  224. typed_signature = inspect.Signature(typed_params)
  225. return typed_signature
  226. def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
  227. if isinstance(annotation, str):
  228. annotation = ForwardRef(annotation)
  229. annotation = evaluate_forwardref(annotation, globalns, globalns)
  230. if annotation is type(None):
  231. return None
  232. return annotation
  233. def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
  234. signature = inspect.signature(call)
  235. annotation = signature.return_annotation
  236. if annotation is inspect.Signature.empty:
  237. return None
  238. globalns = getattr(call, "__globals__", {})
  239. return get_typed_annotation(annotation, globalns)
  240. def get_dependant(
  241. *,
  242. path: str,
  243. call: Callable[..., Any],
  244. name: Optional[str] = None,
  245. security_scopes: Optional[List[str]] = None,
  246. use_cache: bool = True,
  247. ) -> Dependant:
  248. path_param_names = get_path_param_names(path)
  249. endpoint_signature = get_typed_signature(call)
  250. signature_params = endpoint_signature.parameters
  251. dependant = Dependant(
  252. call=call,
  253. name=name,
  254. path=path,
  255. security_scopes=security_scopes,
  256. use_cache=use_cache,
  257. )
  258. for param_name, param in signature_params.items():
  259. is_path_param = param_name in path_param_names
  260. param_details = analyze_param(
  261. param_name=param_name,
  262. annotation=param.annotation,
  263. value=param.default,
  264. is_path_param=is_path_param,
  265. )
  266. if param_details.depends is not None:
  267. sub_dependant = get_param_sub_dependant(
  268. param_name=param_name,
  269. depends=param_details.depends,
  270. path=path,
  271. security_scopes=security_scopes,
  272. )
  273. dependant.dependencies.append(sub_dependant)
  274. continue
  275. if add_non_field_param_to_dependency(
  276. param_name=param_name,
  277. type_annotation=param_details.type_annotation,
  278. dependant=dependant,
  279. ):
  280. assert param_details.field is None, (
  281. f"Cannot specify multiple FastAPI annotations for {param_name!r}"
  282. )
  283. continue
  284. assert param_details.field is not None
  285. if isinstance(param_details.field.field_info, params.Body):
  286. dependant.body_params.append(param_details.field)
  287. else:
  288. add_param_to_fields(field=param_details.field, dependant=dependant)
  289. return dependant
  290. def add_non_field_param_to_dependency(
  291. *, param_name: str, type_annotation: Any, dependant: Dependant
  292. ) -> Optional[bool]:
  293. if lenient_issubclass(type_annotation, Request):
  294. dependant.request_param_name = param_name
  295. return True
  296. elif lenient_issubclass(type_annotation, WebSocket):
  297. dependant.websocket_param_name = param_name
  298. return True
  299. elif lenient_issubclass(type_annotation, HTTPConnection):
  300. dependant.http_connection_param_name = param_name
  301. return True
  302. elif lenient_issubclass(type_annotation, Response):
  303. dependant.response_param_name = param_name
  304. return True
  305. elif lenient_issubclass(type_annotation, StarletteBackgroundTasks):
  306. dependant.background_tasks_param_name = param_name
  307. return True
  308. elif lenient_issubclass(type_annotation, SecurityScopes):
  309. dependant.security_scopes_param_name = param_name
  310. return True
  311. return None
  312. @dataclass
  313. class ParamDetails:
  314. type_annotation: Any
  315. depends: Optional[params.Depends]
  316. field: Optional[ModelField]
  317. def analyze_param(
  318. *,
  319. param_name: str,
  320. annotation: Any,
  321. value: Any,
  322. is_path_param: bool,
  323. ) -> ParamDetails:
  324. field_info = None
  325. depends = None
  326. type_annotation: Any = Any
  327. use_annotation: Any = Any
  328. if annotation is not inspect.Signature.empty:
  329. use_annotation = annotation
  330. type_annotation = annotation
  331. # Extract Annotated info
  332. if get_origin(use_annotation) is Annotated:
  333. annotated_args = get_args(annotation)
  334. type_annotation = annotated_args[0]
  335. fastapi_annotations = [
  336. arg
  337. for arg in annotated_args[1:]
  338. if isinstance(arg, (FieldInfo, params.Depends))
  339. ]
  340. fastapi_specific_annotations = [
  341. arg
  342. for arg in fastapi_annotations
  343. if isinstance(arg, (params.Param, params.Body, params.Depends))
  344. ]
  345. if fastapi_specific_annotations:
  346. fastapi_annotation: Union[FieldInfo, params.Depends, None] = (
  347. fastapi_specific_annotations[-1]
  348. )
  349. else:
  350. fastapi_annotation = None
  351. # Set default for Annotated FieldInfo
  352. if isinstance(fastapi_annotation, FieldInfo):
  353. # Copy `field_info` because we mutate `field_info.default` below.
  354. field_info = copy_field_info(
  355. field_info=fastapi_annotation, annotation=use_annotation
  356. )
  357. assert (
  358. field_info.default is Undefined or field_info.default is RequiredParam
  359. ), (
  360. f"`{field_info.__class__.__name__}` default value cannot be set in"
  361. f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
  362. )
  363. if value is not inspect.Signature.empty:
  364. assert not is_path_param, "Path parameters cannot have default values"
  365. field_info.default = value
  366. else:
  367. field_info.default = RequiredParam
  368. # Get Annotated Depends
  369. elif isinstance(fastapi_annotation, params.Depends):
  370. depends = fastapi_annotation
  371. # Get Depends from default value
  372. if isinstance(value, params.Depends):
  373. assert depends is None, (
  374. "Cannot specify `Depends` in `Annotated` and default value"
  375. f" together for {param_name!r}"
  376. )
  377. assert field_info is None, (
  378. "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
  379. f" default value together for {param_name!r}"
  380. )
  381. depends = value
  382. # Get FieldInfo from default value
  383. elif isinstance(value, FieldInfo):
  384. assert field_info is None, (
  385. "Cannot specify FastAPI annotations in `Annotated` and default value"
  386. f" together for {param_name!r}"
  387. )
  388. field_info = value
  389. if PYDANTIC_V2:
  390. field_info.annotation = type_annotation
  391. # Get Depends from type annotation
  392. if depends is not None and depends.dependency is None:
  393. # Copy `depends` before mutating it
  394. depends = copy(depends)
  395. depends.dependency = type_annotation
  396. # Handle non-param type annotations like Request
  397. if lenient_issubclass(
  398. type_annotation,
  399. (
  400. Request,
  401. WebSocket,
  402. HTTPConnection,
  403. Response,
  404. StarletteBackgroundTasks,
  405. SecurityScopes,
  406. ),
  407. ):
  408. assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
  409. assert field_info is None, (
  410. f"Cannot specify FastAPI annotation for type {type_annotation!r}"
  411. )
  412. # Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
  413. elif field_info is None and depends is None:
  414. default_value = value if value is not inspect.Signature.empty else RequiredParam
  415. if is_path_param:
  416. # We might check here that `default_value is RequiredParam`, but the fact is that the same
  417. # parameter might sometimes be a path parameter and sometimes not. See
  418. # `tests/test_infer_param_optionality.py` for an example.
  419. field_info = params.Path(annotation=use_annotation)
  420. elif is_uploadfile_or_nonable_uploadfile_annotation(
  421. type_annotation
  422. ) or is_uploadfile_sequence_annotation(type_annotation):
  423. field_info = params.File(annotation=use_annotation, default=default_value)
  424. elif not field_annotation_is_scalar(annotation=type_annotation):
  425. field_info = params.Body(annotation=use_annotation, default=default_value)
  426. else:
  427. field_info = params.Query(annotation=use_annotation, default=default_value)
  428. field = None
  429. # It's a field_info, not a dependency
  430. if field_info is not None:
  431. # Handle field_info.in_
  432. if is_path_param:
  433. assert isinstance(field_info, params.Path), (
  434. f"Cannot use `{field_info.__class__.__name__}` for path param"
  435. f" {param_name!r}"
  436. )
  437. elif (
  438. isinstance(field_info, params.Param)
  439. and getattr(field_info, "in_", None) is None
  440. ):
  441. field_info.in_ = params.ParamTypes.query
  442. use_annotation_from_field_info = get_annotation_from_field_info(
  443. use_annotation,
  444. field_info,
  445. param_name,
  446. )
  447. if isinstance(field_info, params.Form):
  448. ensure_multipart_is_installed()
  449. if not field_info.alias and getattr(field_info, "convert_underscores", None):
  450. alias = param_name.replace("_", "-")
  451. else:
  452. alias = field_info.alias or param_name
  453. field_info.alias = alias
  454. field = create_model_field(
  455. name=param_name,
  456. type_=use_annotation_from_field_info,
  457. default=field_info.default,
  458. alias=alias,
  459. required=field_info.default in (RequiredParam, Undefined),
  460. field_info=field_info,
  461. )
  462. if is_path_param:
  463. assert is_scalar_field(field=field), (
  464. "Path params must be of one of the supported types"
  465. )
  466. elif isinstance(field_info, params.Query):
  467. assert (
  468. is_scalar_field(field)
  469. or is_scalar_sequence_field(field)
  470. or (
  471. lenient_issubclass(field.type_, BaseModel)
  472. # For Pydantic v1
  473. and getattr(field, "shape", 1) == 1
  474. )
  475. )
  476. return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
  477. def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
  478. field_info = field.field_info
  479. field_info_in = getattr(field_info, "in_", None)
  480. if field_info_in == params.ParamTypes.path:
  481. dependant.path_params.append(field)
  482. elif field_info_in == params.ParamTypes.query:
  483. dependant.query_params.append(field)
  484. elif field_info_in == params.ParamTypes.header:
  485. dependant.header_params.append(field)
  486. else:
  487. assert field_info_in == params.ParamTypes.cookie, (
  488. f"non-body parameters must be in path, query, header or cookie: {field.name}"
  489. )
  490. dependant.cookie_params.append(field)
  491. def is_coroutine_callable(call: Callable[..., Any]) -> bool:
  492. if inspect.isroutine(call):
  493. return iscoroutinefunction(call)
  494. if inspect.isclass(call):
  495. return False
  496. dunder_call = getattr(call, "__call__", None) # noqa: B004
  497. return iscoroutinefunction(dunder_call)
  498. def is_async_gen_callable(call: Callable[..., Any]) -> bool:
  499. if inspect.isasyncgenfunction(call):
  500. return True
  501. dunder_call = getattr(call, "__call__", None) # noqa: B004
  502. return inspect.isasyncgenfunction(dunder_call)
  503. def is_gen_callable(call: Callable[..., Any]) -> bool:
  504. if inspect.isgeneratorfunction(call):
  505. return True
  506. dunder_call = getattr(call, "__call__", None) # noqa: B004
  507. return inspect.isgeneratorfunction(dunder_call)
  508. async def solve_generator(
  509. *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
  510. ) -> Any:
  511. if is_gen_callable(call):
  512. cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
  513. elif is_async_gen_callable(call):
  514. cm = asynccontextmanager(call)(**sub_values)
  515. return await stack.enter_async_context(cm)
  516. @dataclass
  517. class SolvedDependency:
  518. values: Dict[str, Any]
  519. errors: List[Any]
  520. background_tasks: Optional[StarletteBackgroundTasks]
  521. response: Response
  522. dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
  523. async def solve_dependencies(
  524. *,
  525. request: Union[Request, WebSocket],
  526. dependant: Dependant,
  527. body: Optional[Union[Dict[str, Any], FormData]] = None,
  528. background_tasks: Optional[StarletteBackgroundTasks] = None,
  529. response: Optional[Response] = None,
  530. dependency_overrides_provider: Optional[Any] = None,
  531. dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
  532. async_exit_stack: AsyncExitStack,
  533. embed_body_fields: bool,
  534. ) -> SolvedDependency:
  535. values: Dict[str, Any] = {}
  536. errors: List[Any] = []
  537. if response is None:
  538. response = Response()
  539. del response.headers["content-length"]
  540. response.status_code = None # type: ignore
  541. if dependency_cache is None:
  542. dependency_cache = {}
  543. sub_dependant: Dependant
  544. for sub_dependant in dependant.dependencies:
  545. sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
  546. sub_dependant.cache_key = cast(
  547. Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
  548. )
  549. call = sub_dependant.call
  550. use_sub_dependant = sub_dependant
  551. if (
  552. dependency_overrides_provider
  553. and dependency_overrides_provider.dependency_overrides
  554. ):
  555. original_call = sub_dependant.call
  556. call = getattr(
  557. dependency_overrides_provider, "dependency_overrides", {}
  558. ).get(original_call, original_call)
  559. use_path: str = sub_dependant.path # type: ignore
  560. use_sub_dependant = get_dependant(
  561. path=use_path,
  562. call=call,
  563. name=sub_dependant.name,
  564. security_scopes=sub_dependant.security_scopes,
  565. )
  566. solved_result = await solve_dependencies(
  567. request=request,
  568. dependant=use_sub_dependant,
  569. body=body,
  570. background_tasks=background_tasks,
  571. response=response,
  572. dependency_overrides_provider=dependency_overrides_provider,
  573. dependency_cache=dependency_cache,
  574. async_exit_stack=async_exit_stack,
  575. embed_body_fields=embed_body_fields,
  576. )
  577. background_tasks = solved_result.background_tasks
  578. if solved_result.errors:
  579. errors.extend(solved_result.errors)
  580. continue
  581. if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
  582. solved = dependency_cache[sub_dependant.cache_key]
  583. elif is_gen_callable(call) or is_async_gen_callable(call):
  584. solved = await solve_generator(
  585. call=call, stack=async_exit_stack, sub_values=solved_result.values
  586. )
  587. elif is_coroutine_callable(call):
  588. solved = await call(**solved_result.values)
  589. else:
  590. solved = await run_in_threadpool(call, **solved_result.values)
  591. if sub_dependant.name is not None:
  592. values[sub_dependant.name] = solved
  593. if sub_dependant.cache_key not in dependency_cache:
  594. dependency_cache[sub_dependant.cache_key] = solved
  595. path_values, path_errors = request_params_to_args(
  596. dependant.path_params, request.path_params
  597. )
  598. query_values, query_errors = request_params_to_args(
  599. dependant.query_params, request.query_params
  600. )
  601. header_values, header_errors = request_params_to_args(
  602. dependant.header_params, request.headers
  603. )
  604. cookie_values, cookie_errors = request_params_to_args(
  605. dependant.cookie_params, request.cookies
  606. )
  607. values.update(path_values)
  608. values.update(query_values)
  609. values.update(header_values)
  610. values.update(cookie_values)
  611. errors += path_errors + query_errors + header_errors + cookie_errors
  612. if dependant.body_params:
  613. (
  614. body_values,
  615. body_errors,
  616. ) = await request_body_to_args( # body_params checked above
  617. body_fields=dependant.body_params,
  618. received_body=body,
  619. embed_body_fields=embed_body_fields,
  620. )
  621. values.update(body_values)
  622. errors.extend(body_errors)
  623. if dependant.http_connection_param_name:
  624. values[dependant.http_connection_param_name] = request
  625. if dependant.request_param_name and isinstance(request, Request):
  626. values[dependant.request_param_name] = request
  627. elif dependant.websocket_param_name and isinstance(request, WebSocket):
  628. values[dependant.websocket_param_name] = request
  629. if dependant.background_tasks_param_name:
  630. if background_tasks is None:
  631. background_tasks = BackgroundTasks()
  632. values[dependant.background_tasks_param_name] = background_tasks
  633. if dependant.response_param_name:
  634. values[dependant.response_param_name] = response
  635. if dependant.security_scopes_param_name:
  636. values[dependant.security_scopes_param_name] = SecurityScopes(
  637. scopes=dependant.security_scopes
  638. )
  639. return SolvedDependency(
  640. values=values,
  641. errors=errors,
  642. background_tasks=background_tasks,
  643. response=response,
  644. dependency_cache=dependency_cache,
  645. )
  646. def _validate_value_with_model_field(
  647. *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
  648. ) -> Tuple[Any, List[Any]]:
  649. if value is None:
  650. if field.required:
  651. return None, [get_missing_field_error(loc=loc)]
  652. else:
  653. return deepcopy(field.default), []
  654. v_, errors_ = field.validate(value, values, loc=loc)
  655. if isinstance(errors_, ErrorWrapper):
  656. return None, [errors_]
  657. elif isinstance(errors_, list):
  658. new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
  659. return None, new_errors
  660. else:
  661. return v_, []
  662. def _get_multidict_value(
  663. field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
  664. ) -> Any:
  665. alias = alias or field.alias
  666. if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
  667. value = values.getlist(alias)
  668. else:
  669. value = values.get(alias, None)
  670. if (
  671. value is None
  672. or (
  673. isinstance(field.field_info, params.Form)
  674. and isinstance(value, str) # For type checks
  675. and value == ""
  676. )
  677. or (is_sequence_field(field) and len(value) == 0)
  678. ):
  679. if field.required:
  680. return
  681. else:
  682. return deepcopy(field.default)
  683. return value
  684. def request_params_to_args(
  685. fields: Sequence[ModelField],
  686. received_params: Union[Mapping[str, Any], QueryParams, Headers],
  687. ) -> Tuple[Dict[str, Any], List[Any]]:
  688. values: Dict[str, Any] = {}
  689. errors: List[Dict[str, Any]] = []
  690. if not fields:
  691. return values, errors
  692. first_field = fields[0]
  693. fields_to_extract = fields
  694. single_not_embedded_field = False
  695. default_convert_underscores = True
  696. if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
  697. fields_to_extract = get_cached_model_fields(first_field.type_)
  698. single_not_embedded_field = True
  699. # If headers are in a Pydantic model, the way to disable convert_underscores
  700. # would be with Header(convert_underscores=False) at the Pydantic model level
  701. default_convert_underscores = getattr(
  702. first_field.field_info, "convert_underscores", True
  703. )
  704. params_to_process: Dict[str, Any] = {}
  705. processed_keys = set()
  706. for field in fields_to_extract:
  707. alias = None
  708. if isinstance(received_params, Headers):
  709. # Handle fields extracted from a Pydantic Model for a header, each field
  710. # doesn't have a FieldInfo of type Header with the default convert_underscores=True
  711. convert_underscores = getattr(
  712. field.field_info, "convert_underscores", default_convert_underscores
  713. )
  714. if convert_underscores:
  715. alias = (
  716. field.alias
  717. if field.alias != field.name
  718. else field.name.replace("_", "-")
  719. )
  720. value = _get_multidict_value(field, received_params, alias=alias)
  721. if value is not None:
  722. params_to_process[field.name] = value
  723. processed_keys.add(alias or field.alias)
  724. processed_keys.add(field.name)
  725. for key, value in received_params.items():
  726. if key not in processed_keys:
  727. params_to_process[key] = value
  728. if single_not_embedded_field:
  729. field_info = first_field.field_info
  730. assert isinstance(field_info, params.Param), (
  731. "Params must be subclasses of Param"
  732. )
  733. loc: Tuple[str, ...] = (field_info.in_.value,)
  734. v_, errors_ = _validate_value_with_model_field(
  735. field=first_field, value=params_to_process, values=values, loc=loc
  736. )
  737. return {first_field.name: v_}, errors_
  738. for field in fields:
  739. value = _get_multidict_value(field, received_params)
  740. field_info = field.field_info
  741. assert isinstance(field_info, params.Param), (
  742. "Params must be subclasses of Param"
  743. )
  744. loc = (field_info.in_.value, field.alias)
  745. v_, errors_ = _validate_value_with_model_field(
  746. field=field, value=value, values=values, loc=loc
  747. )
  748. if errors_:
  749. errors.extend(errors_)
  750. else:
  751. values[field.name] = v_
  752. return values, errors
  753. def is_union_of_base_models(field_type: Any) -> bool:
  754. """Check if field type is a Union where all members are BaseModel subclasses."""
  755. from fastapi.types import UnionType
  756. origin = get_origin(field_type)
  757. # Check if it's a Union type (covers both typing.Union and types.UnionType in Python 3.10+)
  758. if origin is not Union and origin is not UnionType:
  759. return False
  760. union_args = get_args(field_type)
  761. for arg in union_args:
  762. if not lenient_issubclass(arg, BaseModel):
  763. return False
  764. return True
  765. def _should_embed_body_fields(fields: List[ModelField]) -> bool:
  766. if not fields:
  767. return False
  768. # More than one dependency could have the same field, it would show up as multiple
  769. # fields but it's the same one, so count them by name
  770. body_param_names_set = {field.name for field in fields}
  771. # A top level field has to be a single field, not multiple
  772. if len(body_param_names_set) > 1:
  773. return True
  774. first_field = fields[0]
  775. # If it explicitly specifies it is embedded, it has to be embedded
  776. if getattr(first_field.field_info, "embed", None):
  777. return True
  778. # If it's a Form (or File) field, it has to be a BaseModel (or a union of BaseModels) to be top level
  779. # otherwise it has to be embedded, so that the key value pair can be extracted
  780. if (
  781. isinstance(first_field.field_info, params.Form)
  782. and not lenient_issubclass(first_field.type_, BaseModel)
  783. and not is_union_of_base_models(first_field.type_)
  784. ):
  785. return True
  786. return False
  787. async def _extract_form_body(
  788. body_fields: List[ModelField],
  789. received_body: FormData,
  790. ) -> Dict[str, Any]:
  791. values = {}
  792. for field in body_fields:
  793. value = _get_multidict_value(field, received_body)
  794. field_info = field.field_info
  795. if (
  796. isinstance(field_info, params.File)
  797. and is_bytes_field(field)
  798. and isinstance(value, UploadFile)
  799. ):
  800. value = await value.read()
  801. elif (
  802. is_bytes_sequence_field(field)
  803. and isinstance(field_info, params.File)
  804. and value_is_sequence(value)
  805. ):
  806. # For types
  807. assert isinstance(value, sequence_types) # type: ignore[arg-type]
  808. results: List[Union[bytes, str]] = []
  809. async def process_fn(
  810. fn: Callable[[], Coroutine[Any, Any, Any]],
  811. ) -> None:
  812. result = await fn()
  813. results.append(result) # noqa: B023
  814. async with anyio.create_task_group() as tg:
  815. for sub_value in value:
  816. tg.start_soon(process_fn, sub_value.read)
  817. value = serialize_sequence_value(field=field, value=results)
  818. if value is not None:
  819. values[field.alias] = value
  820. for key, value in received_body.items():
  821. if key not in values:
  822. values[key] = value
  823. return values
  824. async def request_body_to_args(
  825. body_fields: List[ModelField],
  826. received_body: Optional[Union[Dict[str, Any], FormData]],
  827. embed_body_fields: bool,
  828. ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
  829. values: Dict[str, Any] = {}
  830. errors: List[Dict[str, Any]] = []
  831. assert body_fields, "request_body_to_args() should be called with fields"
  832. single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
  833. first_field = body_fields[0]
  834. body_to_process = received_body
  835. fields_to_extract: List[ModelField] = body_fields
  836. if (
  837. single_not_embedded_field
  838. and lenient_issubclass(first_field.type_, BaseModel)
  839. and isinstance(received_body, FormData)
  840. ):
  841. fields_to_extract = get_cached_model_fields(first_field.type_)
  842. if isinstance(received_body, FormData):
  843. body_to_process = await _extract_form_body(fields_to_extract, received_body)
  844. if single_not_embedded_field:
  845. loc: Tuple[str, ...] = ("body",)
  846. v_, errors_ = _validate_value_with_model_field(
  847. field=first_field, value=body_to_process, values=values, loc=loc
  848. )
  849. return {first_field.name: v_}, errors_
  850. for field in body_fields:
  851. loc = ("body", field.alias)
  852. value: Optional[Any] = None
  853. if body_to_process is not None:
  854. try:
  855. value = body_to_process.get(field.alias)
  856. # If the received body is a list, not a dict
  857. except AttributeError:
  858. errors.append(get_missing_field_error(loc))
  859. continue
  860. v_, errors_ = _validate_value_with_model_field(
  861. field=field, value=value, values=values, loc=loc
  862. )
  863. if errors_:
  864. errors.extend(errors_)
  865. else:
  866. values[field.name] = v_
  867. return values, errors
  868. def get_body_field(
  869. *, flat_dependant: Dependant, name: str, embed_body_fields: bool
  870. ) -> Optional[ModelField]:
  871. """
  872. Get a ModelField representing the request body for a path operation, combining
  873. all body parameters into a single field if necessary.
  874. Used to check if it's form data (with `isinstance(body_field, params.Form)`)
  875. or JSON and to generate the JSON Schema for a request body.
  876. This is **not** used to validate/parse the request body, that's done with each
  877. individual body parameter.
  878. """
  879. if not flat_dependant.body_params:
  880. return None
  881. first_param = flat_dependant.body_params[0]
  882. if not embed_body_fields:
  883. return first_param
  884. model_name = "Body_" + name
  885. BodyModel = create_body_model(
  886. fields=flat_dependant.body_params, model_name=model_name
  887. )
  888. required = any(True for f in flat_dependant.body_params if f.required)
  889. BodyFieldInfo_kwargs: Dict[str, Any] = {
  890. "annotation": BodyModel,
  891. "alias": "body",
  892. }
  893. if not required:
  894. BodyFieldInfo_kwargs["default"] = None
  895. if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
  896. BodyFieldInfo: Type[params.Body] = params.File
  897. elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
  898. BodyFieldInfo = params.Form
  899. else:
  900. BodyFieldInfo = params.Body
  901. body_param_media_types = [
  902. f.field_info.media_type
  903. for f in flat_dependant.body_params
  904. if isinstance(f.field_info, params.Body)
  905. ]
  906. if len(set(body_param_media_types)) == 1:
  907. BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
  908. final_field = create_model_field(
  909. name="body",
  910. type_=BodyModel,
  911. required=required,
  912. alias="body",
  913. field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
  914. )
  915. return final_field