utils.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. import http.client
  2. import inspect
  3. import warnings
  4. from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
  5. from fastapi import routing
  6. from fastapi._compat import (
  7. GenerateJsonSchema,
  8. JsonSchemaValue,
  9. ModelField,
  10. Undefined,
  11. get_compat_model_name_map,
  12. get_definitions,
  13. get_schema_from_model_field,
  14. lenient_issubclass,
  15. )
  16. from fastapi.datastructures import DefaultPlaceholder
  17. from fastapi.dependencies.models import Dependant
  18. from fastapi.dependencies.utils import (
  19. _get_flat_fields_from_params,
  20. get_flat_dependant,
  21. get_flat_params,
  22. )
  23. from fastapi.encoders import jsonable_encoder
  24. from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE
  25. from fastapi.openapi.models import OpenAPI
  26. from fastapi.params import Body, ParamTypes
  27. from fastapi.responses import Response
  28. from fastapi.types import ModelNameMap
  29. from fastapi.utils import (
  30. deep_dict_update,
  31. generate_operation_id_for_path,
  32. is_body_allowed_for_status_code,
  33. )
  34. from pydantic import BaseModel
  35. from starlette.responses import JSONResponse
  36. from starlette.routing import BaseRoute
  37. from typing_extensions import Literal
  38. validation_error_definition = {
  39. "title": "ValidationError",
  40. "type": "object",
  41. "properties": {
  42. "loc": {
  43. "title": "Location",
  44. "type": "array",
  45. "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
  46. },
  47. "msg": {"title": "Message", "type": "string"},
  48. "type": {"title": "Error Type", "type": "string"},
  49. },
  50. "required": ["loc", "msg", "type"],
  51. }
  52. validation_error_response_definition = {
  53. "title": "HTTPValidationError",
  54. "type": "object",
  55. "properties": {
  56. "detail": {
  57. "title": "Detail",
  58. "type": "array",
  59. "items": {"$ref": REF_PREFIX + "ValidationError"},
  60. }
  61. },
  62. }
  63. status_code_ranges: Dict[str, str] = {
  64. "1XX": "Information",
  65. "2XX": "Success",
  66. "3XX": "Redirection",
  67. "4XX": "Client Error",
  68. "5XX": "Server Error",
  69. "DEFAULT": "Default Response",
  70. }
  71. def get_openapi_security_definitions(
  72. flat_dependant: Dependant,
  73. ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
  74. security_definitions = {}
  75. operation_security = []
  76. for security_requirement in flat_dependant.security_requirements:
  77. security_definition = jsonable_encoder(
  78. security_requirement.security_scheme.model,
  79. by_alias=True,
  80. exclude_none=True,
  81. )
  82. security_name = security_requirement.security_scheme.scheme_name
  83. security_definitions[security_name] = security_definition
  84. operation_security.append({security_name: security_requirement.scopes})
  85. return security_definitions, operation_security
  86. def _get_openapi_operation_parameters(
  87. *,
  88. dependant: Dependant,
  89. schema_generator: GenerateJsonSchema,
  90. model_name_map: ModelNameMap,
  91. field_mapping: Dict[
  92. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  93. ],
  94. separate_input_output_schemas: bool = True,
  95. ) -> List[Dict[str, Any]]:
  96. parameters = []
  97. flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
  98. path_params = _get_flat_fields_from_params(flat_dependant.path_params)
  99. query_params = _get_flat_fields_from_params(flat_dependant.query_params)
  100. header_params = _get_flat_fields_from_params(flat_dependant.header_params)
  101. cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
  102. parameter_groups = [
  103. (ParamTypes.path, path_params),
  104. (ParamTypes.query, query_params),
  105. (ParamTypes.header, header_params),
  106. (ParamTypes.cookie, cookie_params),
  107. ]
  108. default_convert_underscores = True
  109. if len(flat_dependant.header_params) == 1:
  110. first_field = flat_dependant.header_params[0]
  111. if lenient_issubclass(first_field.type_, BaseModel):
  112. default_convert_underscores = getattr(
  113. first_field.field_info, "convert_underscores", True
  114. )
  115. for param_type, param_group in parameter_groups:
  116. for param in param_group:
  117. field_info = param.field_info
  118. # field_info = cast(Param, field_info)
  119. if not getattr(field_info, "include_in_schema", True):
  120. continue
  121. param_schema = get_schema_from_model_field(
  122. field=param,
  123. schema_generator=schema_generator,
  124. model_name_map=model_name_map,
  125. field_mapping=field_mapping,
  126. separate_input_output_schemas=separate_input_output_schemas,
  127. )
  128. name = param.alias
  129. convert_underscores = getattr(
  130. param.field_info,
  131. "convert_underscores",
  132. default_convert_underscores,
  133. )
  134. if (
  135. param_type == ParamTypes.header
  136. and param.alias == param.name
  137. and convert_underscores
  138. ):
  139. name = param.name.replace("_", "-")
  140. parameter = {
  141. "name": name,
  142. "in": param_type.value,
  143. "required": param.required,
  144. "schema": param_schema,
  145. }
  146. if field_info.description:
  147. parameter["description"] = field_info.description
  148. openapi_examples = getattr(field_info, "openapi_examples", None)
  149. example = getattr(field_info, "example", None)
  150. if openapi_examples:
  151. parameter["examples"] = jsonable_encoder(openapi_examples)
  152. elif example != Undefined:
  153. parameter["example"] = jsonable_encoder(example)
  154. if getattr(field_info, "deprecated", None):
  155. parameter["deprecated"] = True
  156. parameters.append(parameter)
  157. return parameters
  158. def get_openapi_operation_request_body(
  159. *,
  160. body_field: Optional[ModelField],
  161. schema_generator: GenerateJsonSchema,
  162. model_name_map: ModelNameMap,
  163. field_mapping: Dict[
  164. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  165. ],
  166. separate_input_output_schemas: bool = True,
  167. ) -> Optional[Dict[str, Any]]:
  168. if not body_field:
  169. return None
  170. assert isinstance(body_field, ModelField)
  171. body_schema = get_schema_from_model_field(
  172. field=body_field,
  173. schema_generator=schema_generator,
  174. model_name_map=model_name_map,
  175. field_mapping=field_mapping,
  176. separate_input_output_schemas=separate_input_output_schemas,
  177. )
  178. field_info = cast(Body, body_field.field_info)
  179. request_media_type = field_info.media_type
  180. required = body_field.required
  181. request_body_oai: Dict[str, Any] = {}
  182. if required:
  183. request_body_oai["required"] = required
  184. request_media_content: Dict[str, Any] = {"schema": body_schema}
  185. if field_info.openapi_examples:
  186. request_media_content["examples"] = jsonable_encoder(
  187. field_info.openapi_examples
  188. )
  189. elif field_info.example != Undefined:
  190. request_media_content["example"] = jsonable_encoder(field_info.example)
  191. request_body_oai["content"] = {request_media_type: request_media_content}
  192. return request_body_oai
  193. def generate_operation_id(
  194. *, route: routing.APIRoute, method: str
  195. ) -> str: # pragma: nocover
  196. warnings.warn(
  197. "fastapi.openapi.utils.generate_operation_id() was deprecated, "
  198. "it is not used internally, and will be removed soon",
  199. DeprecationWarning,
  200. stacklevel=2,
  201. )
  202. if route.operation_id:
  203. return route.operation_id
  204. path: str = route.path_format
  205. return generate_operation_id_for_path(name=route.name, path=path, method=method)
  206. def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
  207. if route.summary:
  208. return route.summary
  209. return route.name.replace("_", " ").title()
  210. def get_openapi_operation_metadata(
  211. *, route: routing.APIRoute, method: str, operation_ids: Set[str]
  212. ) -> Dict[str, Any]:
  213. operation: Dict[str, Any] = {}
  214. if route.tags:
  215. operation["tags"] = route.tags
  216. operation["summary"] = generate_operation_summary(route=route, method=method)
  217. if route.description:
  218. operation["description"] = route.description
  219. operation_id = route.operation_id or route.unique_id
  220. if operation_id in operation_ids:
  221. message = (
  222. f"Duplicate Operation ID {operation_id} for function "
  223. + f"{route.endpoint.__name__}"
  224. )
  225. file_name = getattr(route.endpoint, "__globals__", {}).get("__file__")
  226. if file_name:
  227. message += f" at {file_name}"
  228. warnings.warn(message, stacklevel=1)
  229. operation_ids.add(operation_id)
  230. operation["operationId"] = operation_id
  231. if route.deprecated:
  232. operation["deprecated"] = route.deprecated
  233. return operation
  234. def get_openapi_path(
  235. *,
  236. route: routing.APIRoute,
  237. operation_ids: Set[str],
  238. schema_generator: GenerateJsonSchema,
  239. model_name_map: ModelNameMap,
  240. field_mapping: Dict[
  241. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  242. ],
  243. separate_input_output_schemas: bool = True,
  244. ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
  245. path = {}
  246. security_schemes: Dict[str, Any] = {}
  247. definitions: Dict[str, Any] = {}
  248. assert route.methods is not None, "Methods must be a list"
  249. if isinstance(route.response_class, DefaultPlaceholder):
  250. current_response_class: Type[Response] = route.response_class.value
  251. else:
  252. current_response_class = route.response_class
  253. assert current_response_class, "A response class is needed to generate OpenAPI"
  254. route_response_media_type: Optional[str] = current_response_class.media_type
  255. if route.include_in_schema:
  256. for method in route.methods:
  257. operation = get_openapi_operation_metadata(
  258. route=route, method=method, operation_ids=operation_ids
  259. )
  260. parameters: List[Dict[str, Any]] = []
  261. flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
  262. security_definitions, operation_security = get_openapi_security_definitions(
  263. flat_dependant=flat_dependant
  264. )
  265. if operation_security:
  266. operation.setdefault("security", []).extend(operation_security)
  267. if security_definitions:
  268. security_schemes.update(security_definitions)
  269. operation_parameters = _get_openapi_operation_parameters(
  270. dependant=route.dependant,
  271. schema_generator=schema_generator,
  272. model_name_map=model_name_map,
  273. field_mapping=field_mapping,
  274. separate_input_output_schemas=separate_input_output_schemas,
  275. )
  276. parameters.extend(operation_parameters)
  277. if parameters:
  278. all_parameters = {
  279. (param["in"], param["name"]): param for param in parameters
  280. }
  281. required_parameters = {
  282. (param["in"], param["name"]): param
  283. for param in parameters
  284. if param.get("required")
  285. }
  286. # Make sure required definitions of the same parameter take precedence
  287. # over non-required definitions
  288. all_parameters.update(required_parameters)
  289. operation["parameters"] = list(all_parameters.values())
  290. if method in METHODS_WITH_BODY:
  291. request_body_oai = get_openapi_operation_request_body(
  292. body_field=route.body_field,
  293. schema_generator=schema_generator,
  294. model_name_map=model_name_map,
  295. field_mapping=field_mapping,
  296. separate_input_output_schemas=separate_input_output_schemas,
  297. )
  298. if request_body_oai:
  299. operation["requestBody"] = request_body_oai
  300. if route.callbacks:
  301. callbacks = {}
  302. for callback in route.callbacks:
  303. if isinstance(callback, routing.APIRoute):
  304. (
  305. cb_path,
  306. cb_security_schemes,
  307. cb_definitions,
  308. ) = get_openapi_path(
  309. route=callback,
  310. operation_ids=operation_ids,
  311. schema_generator=schema_generator,
  312. model_name_map=model_name_map,
  313. field_mapping=field_mapping,
  314. separate_input_output_schemas=separate_input_output_schemas,
  315. )
  316. callbacks[callback.name] = {callback.path: cb_path}
  317. operation["callbacks"] = callbacks
  318. if route.status_code is not None:
  319. status_code = str(route.status_code)
  320. else:
  321. # It would probably make more sense for all response classes to have an
  322. # explicit default status_code, and to extract it from them, instead of
  323. # doing this inspection tricks, that would probably be in the future
  324. # TODO: probably make status_code a default class attribute for all
  325. # responses in Starlette
  326. response_signature = inspect.signature(current_response_class.__init__)
  327. status_code_param = response_signature.parameters.get("status_code")
  328. if status_code_param is not None:
  329. if isinstance(status_code_param.default, int):
  330. status_code = str(status_code_param.default)
  331. operation.setdefault("responses", {}).setdefault(status_code, {})[
  332. "description"
  333. ] = route.response_description
  334. if route_response_media_type and is_body_allowed_for_status_code(
  335. route.status_code
  336. ):
  337. response_schema = {"type": "string"}
  338. if lenient_issubclass(current_response_class, JSONResponse):
  339. if route.response_field:
  340. response_schema = get_schema_from_model_field(
  341. field=route.response_field,
  342. schema_generator=schema_generator,
  343. model_name_map=model_name_map,
  344. field_mapping=field_mapping,
  345. separate_input_output_schemas=separate_input_output_schemas,
  346. )
  347. else:
  348. response_schema = {}
  349. operation.setdefault("responses", {}).setdefault(
  350. status_code, {}
  351. ).setdefault("content", {}).setdefault(route_response_media_type, {})[
  352. "schema"
  353. ] = response_schema
  354. if route.responses:
  355. operation_responses = operation.setdefault("responses", {})
  356. for (
  357. additional_status_code,
  358. additional_response,
  359. ) in route.responses.items():
  360. process_response = additional_response.copy()
  361. process_response.pop("model", None)
  362. status_code_key = str(additional_status_code).upper()
  363. if status_code_key == "DEFAULT":
  364. status_code_key = "default"
  365. openapi_response = operation_responses.setdefault(
  366. status_code_key, {}
  367. )
  368. assert isinstance(process_response, dict), (
  369. "An additional response must be a dict"
  370. )
  371. field = route.response_fields.get(additional_status_code)
  372. additional_field_schema: Optional[Dict[str, Any]] = None
  373. if field:
  374. additional_field_schema = get_schema_from_model_field(
  375. field=field,
  376. schema_generator=schema_generator,
  377. model_name_map=model_name_map,
  378. field_mapping=field_mapping,
  379. separate_input_output_schemas=separate_input_output_schemas,
  380. )
  381. media_type = route_response_media_type or "application/json"
  382. additional_schema = (
  383. process_response.setdefault("content", {})
  384. .setdefault(media_type, {})
  385. .setdefault("schema", {})
  386. )
  387. deep_dict_update(additional_schema, additional_field_schema)
  388. status_text: Optional[str] = status_code_ranges.get(
  389. str(additional_status_code).upper()
  390. ) or http.client.responses.get(int(additional_status_code))
  391. description = (
  392. process_response.get("description")
  393. or openapi_response.get("description")
  394. or status_text
  395. or "Additional Response"
  396. )
  397. deep_dict_update(openapi_response, process_response)
  398. openapi_response["description"] = description
  399. http422 = "422"
  400. all_route_params = get_flat_params(route.dependant)
  401. if (all_route_params or route.body_field) and not any(
  402. status in operation["responses"]
  403. for status in [http422, "4XX", "default"]
  404. ):
  405. operation["responses"][http422] = {
  406. "description": "Validation Error",
  407. "content": {
  408. "application/json": {
  409. "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
  410. }
  411. },
  412. }
  413. if "ValidationError" not in definitions:
  414. definitions.update(
  415. {
  416. "ValidationError": validation_error_definition,
  417. "HTTPValidationError": validation_error_response_definition,
  418. }
  419. )
  420. if route.openapi_extra:
  421. deep_dict_update(operation, route.openapi_extra)
  422. path[method.lower()] = operation
  423. return path, security_schemes, definitions
  424. def get_fields_from_routes(
  425. routes: Sequence[BaseRoute],
  426. ) -> List[ModelField]:
  427. body_fields_from_routes: List[ModelField] = []
  428. responses_from_routes: List[ModelField] = []
  429. request_fields_from_routes: List[ModelField] = []
  430. callback_flat_models: List[ModelField] = []
  431. for route in routes:
  432. if getattr(route, "include_in_schema", None) and isinstance(
  433. route, routing.APIRoute
  434. ):
  435. if route.body_field:
  436. assert isinstance(route.body_field, ModelField), (
  437. "A request body must be a Pydantic Field"
  438. )
  439. body_fields_from_routes.append(route.body_field)
  440. if route.response_field:
  441. responses_from_routes.append(route.response_field)
  442. if route.response_fields:
  443. responses_from_routes.extend(route.response_fields.values())
  444. if route.callbacks:
  445. callback_flat_models.extend(get_fields_from_routes(route.callbacks))
  446. params = get_flat_params(route.dependant)
  447. request_fields_from_routes.extend(params)
  448. flat_models = callback_flat_models + list(
  449. body_fields_from_routes + responses_from_routes + request_fields_from_routes
  450. )
  451. return flat_models
  452. def get_openapi(
  453. *,
  454. title: str,
  455. version: str,
  456. openapi_version: str = "3.1.0",
  457. summary: Optional[str] = None,
  458. description: Optional[str] = None,
  459. routes: Sequence[BaseRoute],
  460. webhooks: Optional[Sequence[BaseRoute]] = None,
  461. tags: Optional[List[Dict[str, Any]]] = None,
  462. servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
  463. terms_of_service: Optional[str] = None,
  464. contact: Optional[Dict[str, Union[str, Any]]] = None,
  465. license_info: Optional[Dict[str, Union[str, Any]]] = None,
  466. separate_input_output_schemas: bool = True,
  467. external_docs: Optional[Dict[str, Any]] = None,
  468. ) -> Dict[str, Any]:
  469. info: Dict[str, Any] = {"title": title, "version": version}
  470. if summary:
  471. info["summary"] = summary
  472. if description:
  473. info["description"] = description
  474. if terms_of_service:
  475. info["termsOfService"] = terms_of_service
  476. if contact:
  477. info["contact"] = contact
  478. if license_info:
  479. info["license"] = license_info
  480. output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
  481. if servers:
  482. output["servers"] = servers
  483. components: Dict[str, Dict[str, Any]] = {}
  484. paths: Dict[str, Dict[str, Any]] = {}
  485. webhook_paths: Dict[str, Dict[str, Any]] = {}
  486. operation_ids: Set[str] = set()
  487. all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
  488. model_name_map = get_compat_model_name_map(all_fields)
  489. schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
  490. field_mapping, definitions = get_definitions(
  491. fields=all_fields,
  492. schema_generator=schema_generator,
  493. model_name_map=model_name_map,
  494. separate_input_output_schemas=separate_input_output_schemas,
  495. )
  496. for route in routes or []:
  497. if isinstance(route, routing.APIRoute):
  498. result = get_openapi_path(
  499. route=route,
  500. operation_ids=operation_ids,
  501. schema_generator=schema_generator,
  502. model_name_map=model_name_map,
  503. field_mapping=field_mapping,
  504. separate_input_output_schemas=separate_input_output_schemas,
  505. )
  506. if result:
  507. path, security_schemes, path_definitions = result
  508. if path:
  509. paths.setdefault(route.path_format, {}).update(path)
  510. if security_schemes:
  511. components.setdefault("securitySchemes", {}).update(
  512. security_schemes
  513. )
  514. if path_definitions:
  515. definitions.update(path_definitions)
  516. for webhook in webhooks or []:
  517. if isinstance(webhook, routing.APIRoute):
  518. result = get_openapi_path(
  519. route=webhook,
  520. operation_ids=operation_ids,
  521. schema_generator=schema_generator,
  522. model_name_map=model_name_map,
  523. field_mapping=field_mapping,
  524. separate_input_output_schemas=separate_input_output_schemas,
  525. )
  526. if result:
  527. path, security_schemes, path_definitions = result
  528. if path:
  529. webhook_paths.setdefault(webhook.path_format, {}).update(path)
  530. if security_schemes:
  531. components.setdefault("securitySchemes", {}).update(
  532. security_schemes
  533. )
  534. if path_definitions:
  535. definitions.update(path_definitions)
  536. if definitions:
  537. components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
  538. if components:
  539. output["components"] = components
  540. output["paths"] = paths
  541. if webhook_paths:
  542. output["webhooks"] = webhook_paths
  543. if tags:
  544. output["tags"] = tags
  545. if external_docs:
  546. output["externalDocs"] = external_docs
  547. return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore