_compat.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680
  1. import warnings
  2. from collections import deque
  3. from copy import copy
  4. from dataclasses import dataclass, is_dataclass
  5. from enum import Enum
  6. from functools import lru_cache
  7. from typing import (
  8. Any,
  9. Callable,
  10. Deque,
  11. Dict,
  12. FrozenSet,
  13. List,
  14. Mapping,
  15. Sequence,
  16. Set,
  17. Tuple,
  18. Type,
  19. Union,
  20. cast,
  21. )
  22. from fastapi.exceptions import RequestErrorModel
  23. from fastapi.types import IncEx, ModelNameMap, UnionType
  24. from pydantic import BaseModel, create_model
  25. from pydantic.version import VERSION as PYDANTIC_VERSION
  26. from starlette.datastructures import UploadFile
  27. from typing_extensions import Annotated, Literal, get_args, get_origin
  28. PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
  29. PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
  30. sequence_annotation_to_type = {
  31. Sequence: list,
  32. List: list,
  33. list: list,
  34. Tuple: tuple,
  35. tuple: tuple,
  36. Set: set,
  37. set: set,
  38. FrozenSet: frozenset,
  39. frozenset: frozenset,
  40. Deque: deque,
  41. deque: deque,
  42. }
  43. sequence_types = tuple(sequence_annotation_to_type.keys())
  44. Url: Type[Any]
  45. if PYDANTIC_V2:
  46. from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
  47. from pydantic import TypeAdapter
  48. from pydantic import ValidationError as ValidationError
  49. from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
  50. GetJsonSchemaHandler as GetJsonSchemaHandler,
  51. )
  52. from pydantic._internal._typing_extra import eval_type_lenient
  53. from pydantic._internal._utils import lenient_issubclass as lenient_issubclass
  54. from pydantic.fields import FieldInfo
  55. from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
  56. from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
  57. from pydantic_core import CoreSchema as CoreSchema
  58. from pydantic_core import PydanticUndefined, PydanticUndefinedType
  59. from pydantic_core import Url as Url
  60. try:
  61. from pydantic_core.core_schema import (
  62. with_info_plain_validator_function as with_info_plain_validator_function,
  63. )
  64. except ImportError: # pragma: no cover
  65. from pydantic_core.core_schema import (
  66. general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
  67. )
  68. RequiredParam = PydanticUndefined
  69. Undefined = PydanticUndefined
  70. UndefinedType = PydanticUndefinedType
  71. evaluate_forwardref = eval_type_lenient
  72. Validator = Any
  73. class BaseConfig:
  74. pass
  75. class ErrorWrapper(Exception):
  76. pass
  77. @dataclass
  78. class ModelField:
  79. field_info: FieldInfo
  80. name: str
  81. mode: Literal["validation", "serialization"] = "validation"
  82. @property
  83. def alias(self) -> str:
  84. a = self.field_info.alias
  85. return a if a is not None else self.name
  86. @property
  87. def required(self) -> bool:
  88. return self.field_info.is_required()
  89. @property
  90. def default(self) -> Any:
  91. return self.get_default()
  92. @property
  93. def type_(self) -> Any:
  94. return self.field_info.annotation
  95. def __post_init__(self) -> None:
  96. with warnings.catch_warnings():
  97. # Pydantic >= 2.12.0 warns about field specific metadata that is unused
  98. # (e.g. `TypeAdapter(Annotated[int, Field(alias='b')])`). In some cases, we
  99. # end up building the type adapter from a model field annotation so we
  100. # need to ignore the warning:
  101. if PYDANTIC_VERSION_MINOR_TUPLE >= (2, 12):
  102. from pydantic.warnings import UnsupportedFieldAttributeWarning
  103. warnings.simplefilter(
  104. "ignore", category=UnsupportedFieldAttributeWarning
  105. )
  106. self._type_adapter: TypeAdapter[Any] = TypeAdapter(
  107. Annotated[self.field_info.annotation, self.field_info]
  108. )
  109. def get_default(self) -> Any:
  110. if self.field_info.is_required():
  111. return Undefined
  112. return self.field_info.get_default(call_default_factory=True)
  113. def validate(
  114. self,
  115. value: Any,
  116. values: Dict[str, Any] = {}, # noqa: B006
  117. *,
  118. loc: Tuple[Union[int, str], ...] = (),
  119. ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
  120. try:
  121. return (
  122. self._type_adapter.validate_python(value, from_attributes=True),
  123. None,
  124. )
  125. except ValidationError as exc:
  126. return None, _regenerate_error_with_loc(
  127. errors=exc.errors(include_url=False), loc_prefix=loc
  128. )
  129. def serialize(
  130. self,
  131. value: Any,
  132. *,
  133. mode: Literal["json", "python"] = "json",
  134. include: Union[IncEx, None] = None,
  135. exclude: Union[IncEx, None] = None,
  136. by_alias: bool = True,
  137. exclude_unset: bool = False,
  138. exclude_defaults: bool = False,
  139. exclude_none: bool = False,
  140. ) -> Any:
  141. # What calls this code passes a value that already called
  142. # self._type_adapter.validate_python(value)
  143. return self._type_adapter.dump_python(
  144. value,
  145. mode=mode,
  146. include=include,
  147. exclude=exclude,
  148. by_alias=by_alias,
  149. exclude_unset=exclude_unset,
  150. exclude_defaults=exclude_defaults,
  151. exclude_none=exclude_none,
  152. )
  153. def __hash__(self) -> int:
  154. # Each ModelField is unique for our purposes, to allow making a dict from
  155. # ModelField to its JSON Schema.
  156. return id(self)
  157. def get_annotation_from_field_info(
  158. annotation: Any, field_info: FieldInfo, field_name: str
  159. ) -> Any:
  160. return annotation
  161. def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
  162. return errors # type: ignore[return-value]
  163. def _model_rebuild(model: Type[BaseModel]) -> None:
  164. model.model_rebuild()
  165. def _model_dump(
  166. model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
  167. ) -> Any:
  168. return model.model_dump(mode=mode, **kwargs)
  169. def _get_model_config(model: BaseModel) -> Any:
  170. return model.model_config
  171. def get_schema_from_model_field(
  172. *,
  173. field: ModelField,
  174. schema_generator: GenerateJsonSchema,
  175. model_name_map: ModelNameMap,
  176. field_mapping: Dict[
  177. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  178. ],
  179. separate_input_output_schemas: bool = True,
  180. ) -> Dict[str, Any]:
  181. override_mode: Union[Literal["validation"], None] = (
  182. None if separate_input_output_schemas else "validation"
  183. )
  184. # This expects that GenerateJsonSchema was already used to generate the definitions
  185. json_schema = field_mapping[(field, override_mode or field.mode)]
  186. if "$ref" not in json_schema:
  187. # TODO remove when deprecating Pydantic v1
  188. # Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
  189. json_schema["title"] = (
  190. field.field_info.title or field.alias.title().replace("_", " ")
  191. )
  192. return json_schema
  193. def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
  194. return {}
  195. def get_definitions(
  196. *,
  197. fields: List[ModelField],
  198. schema_generator: GenerateJsonSchema,
  199. model_name_map: ModelNameMap,
  200. separate_input_output_schemas: bool = True,
  201. ) -> Tuple[
  202. Dict[
  203. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  204. ],
  205. Dict[str, Dict[str, Any]],
  206. ]:
  207. override_mode: Union[Literal["validation"], None] = (
  208. None if separate_input_output_schemas else "validation"
  209. )
  210. inputs = [
  211. (field, override_mode or field.mode, field._type_adapter.core_schema)
  212. for field in fields
  213. ]
  214. field_mapping, definitions = schema_generator.generate_definitions(
  215. inputs=inputs
  216. )
  217. for item_def in cast(Dict[str, Dict[str, Any]], definitions).values():
  218. if "description" in item_def:
  219. item_description = cast(str, item_def["description"]).split("\f")[0]
  220. item_def["description"] = item_description
  221. return field_mapping, definitions # type: ignore[return-value]
  222. def is_scalar_field(field: ModelField) -> bool:
  223. from fastapi import params
  224. return field_annotation_is_scalar(
  225. field.field_info.annotation
  226. ) and not isinstance(field.field_info, params.Body)
  227. def is_sequence_field(field: ModelField) -> bool:
  228. return field_annotation_is_sequence(field.field_info.annotation)
  229. def is_scalar_sequence_field(field: ModelField) -> bool:
  230. return field_annotation_is_scalar_sequence(field.field_info.annotation)
  231. def is_bytes_field(field: ModelField) -> bool:
  232. return is_bytes_or_nonable_bytes_annotation(field.type_)
  233. def is_bytes_sequence_field(field: ModelField) -> bool:
  234. return is_bytes_sequence_annotation(field.type_)
  235. def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
  236. cls = type(field_info)
  237. merged_field_info = cls.from_annotation(annotation)
  238. new_field_info = copy(field_info)
  239. new_field_info.metadata = merged_field_info.metadata
  240. new_field_info.annotation = merged_field_info.annotation
  241. return new_field_info
  242. def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
  243. origin_type = (
  244. get_origin(field.field_info.annotation) or field.field_info.annotation
  245. )
  246. assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
  247. return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
  248. def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
  249. error = ValidationError.from_exception_data(
  250. "Field required", [{"type": "missing", "loc": loc, "input": {}}]
  251. ).errors(include_url=False)[0]
  252. error["input"] = None
  253. return error # type: ignore[return-value]
  254. def create_body_model(
  255. *, fields: Sequence[ModelField], model_name: str
  256. ) -> Type[BaseModel]:
  257. field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
  258. BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
  259. return BodyModel
  260. def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
  261. return [
  262. ModelField(field_info=field_info, name=name)
  263. for name, field_info in model.model_fields.items()
  264. ]
  265. else:
  266. from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
  267. from pydantic import AnyUrl as Url # noqa: F401
  268. from pydantic import ( # type: ignore[assignment]
  269. BaseConfig as BaseConfig, # noqa: F401
  270. )
  271. from pydantic import ValidationError as ValidationError # noqa: F401
  272. from pydantic.class_validators import ( # type: ignore[no-redef]
  273. Validator as Validator, # noqa: F401
  274. )
  275. from pydantic.error_wrappers import ( # type: ignore[no-redef]
  276. ErrorWrapper as ErrorWrapper, # noqa: F401
  277. )
  278. from pydantic.errors import MissingError
  279. from pydantic.fields import ( # type: ignore[attr-defined]
  280. SHAPE_FROZENSET,
  281. SHAPE_LIST,
  282. SHAPE_SEQUENCE,
  283. SHAPE_SET,
  284. SHAPE_SINGLETON,
  285. SHAPE_TUPLE,
  286. SHAPE_TUPLE_ELLIPSIS,
  287. )
  288. from pydantic.fields import FieldInfo as FieldInfo
  289. from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
  290. ModelField as ModelField, # noqa: F401
  291. )
  292. # Keeping old "Required" functionality from Pydantic V1, without
  293. # shadowing typing.Required.
  294. RequiredParam: Any = Ellipsis # type: ignore[no-redef]
  295. from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
  296. Undefined as Undefined,
  297. )
  298. from pydantic.fields import ( # type: ignore[no-redef, attr-defined]
  299. UndefinedType as UndefinedType, # noqa: F401
  300. )
  301. from pydantic.schema import (
  302. field_schema,
  303. get_flat_models_from_fields,
  304. get_model_name_map,
  305. model_process_schema,
  306. )
  307. from pydantic.schema import ( # type: ignore[no-redef] # noqa: F401
  308. get_annotation_from_field_info as get_annotation_from_field_info,
  309. )
  310. from pydantic.typing import ( # type: ignore[no-redef]
  311. evaluate_forwardref as evaluate_forwardref, # noqa: F401
  312. )
  313. from pydantic.utils import ( # type: ignore[no-redef]
  314. lenient_issubclass as lenient_issubclass, # noqa: F401
  315. )
  316. GetJsonSchemaHandler = Any # type: ignore[assignment,misc]
  317. JsonSchemaValue = Dict[str, Any] # type: ignore[misc]
  318. CoreSchema = Any # type: ignore[assignment,misc]
  319. sequence_shapes = {
  320. SHAPE_LIST,
  321. SHAPE_SET,
  322. SHAPE_FROZENSET,
  323. SHAPE_TUPLE,
  324. SHAPE_SEQUENCE,
  325. SHAPE_TUPLE_ELLIPSIS,
  326. }
  327. sequence_shape_to_type = {
  328. SHAPE_LIST: list,
  329. SHAPE_SET: set,
  330. SHAPE_TUPLE: tuple,
  331. SHAPE_SEQUENCE: list,
  332. SHAPE_TUPLE_ELLIPSIS: list,
  333. }
  334. @dataclass
  335. class GenerateJsonSchema: # type: ignore[no-redef]
  336. ref_template: str
  337. class PydanticSchemaGenerationError(Exception): # type: ignore[no-redef]
  338. pass
  339. def with_info_plain_validator_function( # type: ignore[misc]
  340. function: Callable[..., Any],
  341. *,
  342. ref: Union[str, None] = None,
  343. metadata: Any = None,
  344. serialization: Any = None,
  345. ) -> Any:
  346. return {}
  347. def get_model_definitions(
  348. *,
  349. flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
  350. model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
  351. ) -> Dict[str, Any]:
  352. definitions: Dict[str, Dict[str, Any]] = {}
  353. for model in flat_models:
  354. m_schema, m_definitions, m_nested_models = model_process_schema(
  355. model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
  356. )
  357. definitions.update(m_definitions)
  358. model_name = model_name_map[model]
  359. definitions[model_name] = m_schema
  360. for m_schema in definitions.values():
  361. if "description" in m_schema:
  362. m_schema["description"] = m_schema["description"].split("\f")[0]
  363. return definitions
  364. def is_pv1_scalar_field(field: ModelField) -> bool:
  365. from fastapi import params
  366. field_info = field.field_info
  367. if not (
  368. field.shape == SHAPE_SINGLETON # type: ignore[attr-defined]
  369. and not lenient_issubclass(field.type_, BaseModel)
  370. and not lenient_issubclass(field.type_, dict)
  371. and not field_annotation_is_sequence(field.type_)
  372. and not is_dataclass(field.type_)
  373. and not isinstance(field_info, params.Body)
  374. ):
  375. return False
  376. if field.sub_fields: # type: ignore[attr-defined]
  377. if not all(
  378. is_pv1_scalar_field(f)
  379. for f in field.sub_fields # type: ignore[attr-defined]
  380. ):
  381. return False
  382. return True
  383. def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
  384. if (field.shape in sequence_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
  385. field.type_, BaseModel
  386. ):
  387. if field.sub_fields is not None: # type: ignore[attr-defined]
  388. for sub_field in field.sub_fields: # type: ignore[attr-defined]
  389. if not is_pv1_scalar_field(sub_field):
  390. return False
  391. return True
  392. if _annotation_is_sequence(field.type_):
  393. return True
  394. return False
  395. def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
  396. use_errors: List[Any] = []
  397. for error in errors:
  398. if isinstance(error, ErrorWrapper):
  399. new_errors = ValidationError( # type: ignore[call-arg]
  400. errors=[error], model=RequestErrorModel
  401. ).errors()
  402. use_errors.extend(new_errors)
  403. elif isinstance(error, list):
  404. use_errors.extend(_normalize_errors(error))
  405. else:
  406. use_errors.append(error)
  407. return use_errors
  408. def _model_rebuild(model: Type[BaseModel]) -> None:
  409. model.update_forward_refs()
  410. def _model_dump(
  411. model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
  412. ) -> Any:
  413. return model.dict(**kwargs)
  414. def _get_model_config(model: BaseModel) -> Any:
  415. return model.__config__ # type: ignore[attr-defined]
  416. def get_schema_from_model_field(
  417. *,
  418. field: ModelField,
  419. schema_generator: GenerateJsonSchema,
  420. model_name_map: ModelNameMap,
  421. field_mapping: Dict[
  422. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  423. ],
  424. separate_input_output_schemas: bool = True,
  425. ) -> Dict[str, Any]:
  426. # This expects that GenerateJsonSchema was already used to generate the definitions
  427. return field_schema( # type: ignore[no-any-return]
  428. field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
  429. )[0]
  430. def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
  431. models = get_flat_models_from_fields(fields, known_models=set())
  432. return get_model_name_map(models) # type: ignore[no-any-return]
  433. def get_definitions(
  434. *,
  435. fields: List[ModelField],
  436. schema_generator: GenerateJsonSchema,
  437. model_name_map: ModelNameMap,
  438. separate_input_output_schemas: bool = True,
  439. ) -> Tuple[
  440. Dict[
  441. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  442. ],
  443. Dict[str, Dict[str, Any]],
  444. ]:
  445. models = get_flat_models_from_fields(fields, known_models=set())
  446. return {}, get_model_definitions(
  447. flat_models=models, model_name_map=model_name_map
  448. )
  449. def is_scalar_field(field: ModelField) -> bool:
  450. return is_pv1_scalar_field(field)
  451. def is_sequence_field(field: ModelField) -> bool:
  452. return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
  453. def is_scalar_sequence_field(field: ModelField) -> bool:
  454. return is_pv1_scalar_sequence_field(field)
  455. def is_bytes_field(field: ModelField) -> bool:
  456. return lenient_issubclass(field.type_, bytes)
  457. def is_bytes_sequence_field(field: ModelField) -> bool:
  458. return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined]
  459. def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
  460. return copy(field_info)
  461. def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
  462. return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
  463. def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
  464. missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
  465. new_error = ValidationError([missing_field_error], RequestErrorModel)
  466. return new_error.errors()[0] # type: ignore[return-value]
  467. def create_body_model(
  468. *, fields: Sequence[ModelField], model_name: str
  469. ) -> Type[BaseModel]:
  470. BodyModel = create_model(model_name)
  471. for f in fields:
  472. BodyModel.__fields__[f.name] = f # type: ignore[index]
  473. return BodyModel
  474. def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
  475. return list(model.__fields__.values()) # type: ignore[attr-defined]
  476. def _regenerate_error_with_loc(
  477. *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
  478. ) -> List[Dict[str, Any]]:
  479. updated_loc_errors: List[Any] = [
  480. {**err, "loc": loc_prefix + err.get("loc", ())}
  481. for err in _normalize_errors(errors)
  482. ]
  483. return updated_loc_errors
  484. def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
  485. if lenient_issubclass(annotation, (str, bytes)):
  486. return False
  487. return lenient_issubclass(annotation, sequence_types)
  488. def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
  489. origin = get_origin(annotation)
  490. if origin is Union or origin is UnionType:
  491. for arg in get_args(annotation):
  492. if field_annotation_is_sequence(arg):
  493. return True
  494. return False
  495. return _annotation_is_sequence(annotation) or _annotation_is_sequence(
  496. get_origin(annotation)
  497. )
  498. def value_is_sequence(value: Any) -> bool:
  499. return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
  500. def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
  501. return (
  502. lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
  503. or _annotation_is_sequence(annotation)
  504. or is_dataclass(annotation)
  505. )
  506. def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
  507. origin = get_origin(annotation)
  508. if origin is Union or origin is UnionType:
  509. return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
  510. if origin is Annotated:
  511. return field_annotation_is_complex(get_args(annotation)[0])
  512. return (
  513. _annotation_is_complex(annotation)
  514. or _annotation_is_complex(origin)
  515. or hasattr(origin, "__pydantic_core_schema__")
  516. or hasattr(origin, "__get_pydantic_core_schema__")
  517. )
  518. def field_annotation_is_scalar(annotation: Any) -> bool:
  519. # handle Ellipsis here to make tuple[int, ...] work nicely
  520. return annotation is Ellipsis or not field_annotation_is_complex(annotation)
  521. def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
  522. origin = get_origin(annotation)
  523. if origin is Union or origin is UnionType:
  524. at_least_one_scalar_sequence = False
  525. for arg in get_args(annotation):
  526. if field_annotation_is_scalar_sequence(arg):
  527. at_least_one_scalar_sequence = True
  528. continue
  529. elif not field_annotation_is_scalar(arg):
  530. return False
  531. return at_least_one_scalar_sequence
  532. return field_annotation_is_sequence(annotation) and all(
  533. field_annotation_is_scalar(sub_annotation)
  534. for sub_annotation in get_args(annotation)
  535. )
  536. def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
  537. if lenient_issubclass(annotation, bytes):
  538. return True
  539. origin = get_origin(annotation)
  540. if origin is Union or origin is UnionType:
  541. for arg in get_args(annotation):
  542. if lenient_issubclass(arg, bytes):
  543. return True
  544. return False
  545. def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
  546. if lenient_issubclass(annotation, UploadFile):
  547. return True
  548. origin = get_origin(annotation)
  549. if origin is Union or origin is UnionType:
  550. for arg in get_args(annotation):
  551. if lenient_issubclass(arg, UploadFile):
  552. return True
  553. return False
  554. def is_bytes_sequence_annotation(annotation: Any) -> bool:
  555. origin = get_origin(annotation)
  556. if origin is Union or origin is UnionType:
  557. at_least_one = False
  558. for arg in get_args(annotation):
  559. if is_bytes_sequence_annotation(arg):
  560. at_least_one = True
  561. continue
  562. return at_least_one
  563. return field_annotation_is_sequence(annotation) and all(
  564. is_bytes_or_nonable_bytes_annotation(sub_annotation)
  565. for sub_annotation in get_args(annotation)
  566. )
  567. def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
  568. origin = get_origin(annotation)
  569. if origin is Union or origin is UnionType:
  570. at_least_one = False
  571. for arg in get_args(annotation):
  572. if is_uploadfile_sequence_annotation(arg):
  573. at_least_one = True
  574. continue
  575. return at_least_one
  576. return field_annotation_is_sequence(annotation) and all(
  577. is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
  578. for sub_annotation in get_args(annotation)
  579. )
  580. @lru_cache
  581. def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
  582. return get_model_fields(model)