_generics.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. from __future__ import annotations
  2. import operator
  3. import sys
  4. import types
  5. import typing
  6. from collections import ChainMap
  7. from collections.abc import Iterator, Mapping
  8. from contextlib import contextmanager
  9. from contextvars import ContextVar
  10. from functools import reduce
  11. from itertools import zip_longest
  12. from types import prepare_class
  13. from typing import TYPE_CHECKING, Annotated, Any, TypedDict, TypeVar, cast
  14. from weakref import WeakValueDictionary
  15. import typing_extensions
  16. from typing_inspection import typing_objects
  17. from typing_inspection.introspection import is_union_origin
  18. from . import _typing_extra
  19. from ._core_utils import get_type_ref
  20. from ._forward_ref import PydanticRecursiveRef
  21. from ._utils import all_identical, is_model_class
  22. if TYPE_CHECKING:
  23. from ..main import BaseModel
  24. GenericTypesCacheKey = tuple[Any, Any, tuple[Any, ...]]
  25. # Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching.
  26. # Right now, to handle recursive generics, we some types must remain cached for brief periods without references.
  27. # By chaining the WeakValuesDict with a LimitedDict, we have a way to retain caching for all types with references,
  28. # while also retaining a limited number of types even without references. This is generally enough to build
  29. # specific recursive generic models without losing required items out of the cache.
  30. KT = TypeVar('KT')
  31. VT = TypeVar('VT')
  32. _LIMITED_DICT_SIZE = 100
  33. class LimitedDict(dict[KT, VT]):
  34. def __init__(self, size_limit: int = _LIMITED_DICT_SIZE) -> None:
  35. self.size_limit = size_limit
  36. super().__init__()
  37. def __setitem__(self, key: KT, value: VT, /) -> None:
  38. super().__setitem__(key, value)
  39. if len(self) > self.size_limit:
  40. excess = len(self) - self.size_limit + self.size_limit // 10
  41. to_remove = list(self.keys())[:excess]
  42. for k in to_remove:
  43. del self[k]
  44. # weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
  45. # once they are no longer referenced by the caller.
  46. GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
  47. if TYPE_CHECKING:
  48. class DeepChainMap(ChainMap[KT, VT]): # type: ignore
  49. ...
  50. else:
  51. class DeepChainMap(ChainMap):
  52. """Variant of ChainMap that allows direct updates to inner scopes.
  53. Taken from https://docs.python.org/3/library/collections.html#collections.ChainMap,
  54. with some light modifications for this use case.
  55. """
  56. def clear(self) -> None:
  57. for mapping in self.maps:
  58. mapping.clear()
  59. def __setitem__(self, key: KT, value: VT) -> None:
  60. for mapping in self.maps:
  61. mapping[key] = value
  62. def __delitem__(self, key: KT) -> None:
  63. hit = False
  64. for mapping in self.maps:
  65. if key in mapping:
  66. del mapping[key]
  67. hit = True
  68. if not hit:
  69. raise KeyError(key)
  70. # Despite the fact that LimitedDict _seems_ no longer necessary, I'm very nervous to actually remove it
  71. # and discover later on that we need to re-add all this infrastructure...
  72. # _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict())
  73. _GENERIC_TYPES_CACHE: ContextVar[GenericTypesCache | None] = ContextVar('_GENERIC_TYPES_CACHE', default=None)
  74. class PydanticGenericMetadata(TypedDict):
  75. origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__
  76. args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__
  77. parameters: tuple[TypeVar, ...] # analogous to typing.Generic.__parameters__
  78. def create_generic_submodel(
  79. model_name: str, origin: type[BaseModel], args: tuple[Any, ...], params: tuple[Any, ...]
  80. ) -> type[BaseModel]:
  81. """Dynamically create a submodel of a provided (generic) BaseModel.
  82. This is used when producing concrete parametrizations of generic models. This function
  83. only *creates* the new subclass; the schema/validators/serialization must be updated to
  84. reflect a concrete parametrization elsewhere.
  85. Args:
  86. model_name: The name of the newly created model.
  87. origin: The base class for the new model to inherit from.
  88. args: A tuple of generic metadata arguments.
  89. params: A tuple of generic metadata parameters.
  90. Returns:
  91. The created submodel.
  92. """
  93. namespace: dict[str, Any] = {'__module__': origin.__module__}
  94. bases = (origin,)
  95. meta, ns, kwds = prepare_class(model_name, bases)
  96. namespace.update(ns)
  97. created_model = meta(
  98. model_name,
  99. bases,
  100. namespace,
  101. __pydantic_generic_metadata__={
  102. 'origin': origin,
  103. 'args': args,
  104. 'parameters': params,
  105. },
  106. __pydantic_reset_parent_namespace__=False,
  107. **kwds,
  108. )
  109. model_module, called_globally = _get_caller_frame_info(depth=3)
  110. if called_globally: # create global reference and therefore allow pickling
  111. object_by_reference = None
  112. reference_name = model_name
  113. reference_module_globals = sys.modules[created_model.__module__].__dict__
  114. while object_by_reference is not created_model:
  115. object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
  116. reference_name += '_'
  117. return created_model
  118. def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
  119. """Used inside a function to check whether it was called globally.
  120. Args:
  121. depth: The depth to get the frame.
  122. Returns:
  123. A tuple contains `module_name` and `called_globally`.
  124. Raises:
  125. RuntimeError: If the function is not called inside a function.
  126. """
  127. try:
  128. previous_caller_frame = sys._getframe(depth)
  129. except ValueError as e:
  130. raise RuntimeError('This function must be used inside another function') from e
  131. except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
  132. return None, False
  133. frame_globals = previous_caller_frame.f_globals
  134. return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
  135. DictValues: type[Any] = {}.values().__class__
  136. def iter_contained_typevars(v: Any) -> Iterator[TypeVar]:
  137. """Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.
  138. This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias,
  139. since __parameters__ of (nested) generic BaseModel subclasses won't show up in that list.
  140. """
  141. if isinstance(v, TypeVar):
  142. yield v
  143. elif is_model_class(v):
  144. yield from v.__pydantic_generic_metadata__['parameters']
  145. elif isinstance(v, (DictValues, list)):
  146. for var in v:
  147. yield from iter_contained_typevars(var)
  148. else:
  149. args = get_args(v)
  150. for arg in args:
  151. yield from iter_contained_typevars(arg)
  152. def get_args(v: Any) -> Any:
  153. pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
  154. if pydantic_generic_metadata:
  155. return pydantic_generic_metadata.get('args')
  156. return typing_extensions.get_args(v)
  157. def get_origin(v: Any) -> Any:
  158. pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
  159. if pydantic_generic_metadata:
  160. return pydantic_generic_metadata.get('origin')
  161. return typing_extensions.get_origin(v)
  162. def get_standard_typevars_map(cls: Any) -> dict[TypeVar, Any] | None:
  163. """Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the
  164. `replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias.
  165. """
  166. origin = get_origin(cls)
  167. if origin is None:
  168. return None
  169. if not hasattr(origin, '__parameters__'):
  170. return None
  171. # In this case, we know that cls is a _GenericAlias, and origin is the generic type
  172. # So it is safe to access cls.__args__ and origin.__parameters__
  173. args: tuple[Any, ...] = cls.__args__ # type: ignore
  174. parameters: tuple[TypeVar, ...] = origin.__parameters__
  175. return dict(zip(parameters, args))
  176. def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVar, Any]:
  177. """Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible
  178. with the `replace_types` function.
  179. Since BaseModel.__class_getitem__ does not produce a typing._GenericAlias, and the BaseModel generic info is
  180. stored in the __pydantic_generic_metadata__ attribute, we need special handling here.
  181. """
  182. # TODO: This could be unified with `get_standard_typevars_map` if we stored the generic metadata
  183. # in the __origin__, __args__, and __parameters__ attributes of the model.
  184. generic_metadata = cls.__pydantic_generic_metadata__
  185. origin = generic_metadata['origin']
  186. args = generic_metadata['args']
  187. if not args:
  188. # No need to go into `iter_contained_typevars`:
  189. return {}
  190. return dict(zip(iter_contained_typevars(origin), args))
  191. def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
  192. """Return type with all occurrences of `type_map` keys recursively replaced with their values.
  193. Args:
  194. type_: The class or generic alias.
  195. type_map: Mapping from `TypeVar` instance to concrete types.
  196. Returns:
  197. A new type representing the basic structure of `type_` with all
  198. `typevar_map` keys recursively replaced.
  199. Example:
  200. ```python
  201. from typing import Union
  202. from pydantic._internal._generics import replace_types
  203. replace_types(tuple[str, Union[list[str], float]], {str: int})
  204. #> tuple[int, Union[list[int], float]]
  205. ```
  206. """
  207. if not type_map:
  208. return type_
  209. type_args = get_args(type_)
  210. origin_type = get_origin(type_)
  211. if typing_objects.is_annotated(origin_type):
  212. annotated_type, *annotations = type_args
  213. annotated_type = replace_types(annotated_type, type_map)
  214. # TODO remove parentheses when we drop support for Python 3.10:
  215. return Annotated[(annotated_type, *annotations)]
  216. # Having type args is a good indicator that this is a typing special form
  217. # instance or a generic alias of some sort.
  218. if type_args:
  219. resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
  220. if all_identical(type_args, resolved_type_args):
  221. # If all arguments are the same, there is no need to modify the
  222. # type or create a new object at all
  223. return type_
  224. if (
  225. origin_type is not None
  226. and isinstance(type_, _typing_extra.typing_base)
  227. and not isinstance(origin_type, _typing_extra.typing_base)
  228. and getattr(type_, '_name', None) is not None
  229. ):
  230. # In python < 3.9 generic aliases don't exist so any of these like `list`,
  231. # `type` or `collections.abc.Callable` need to be translated.
  232. # See: https://www.python.org/dev/peps/pep-0585
  233. origin_type = getattr(typing, type_._name)
  234. assert origin_type is not None
  235. if is_union_origin(origin_type):
  236. if any(typing_objects.is_any(arg) for arg in resolved_type_args):
  237. # `Any | T` ~ `Any`:
  238. resolved_type_args = (Any,)
  239. # `Never | T` ~ `T`:
  240. resolved_type_args = tuple(
  241. arg
  242. for arg in resolved_type_args
  243. if not (typing_objects.is_noreturn(arg) or typing_objects.is_never(arg))
  244. )
  245. # PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
  246. # We also cannot use isinstance() since we have to compare types.
  247. if sys.version_info >= (3, 10) and origin_type is types.UnionType:
  248. return reduce(operator.or_, resolved_type_args)
  249. # NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below
  250. return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args]
  251. # We handle pydantic generic models separately as they don't have the same
  252. # semantics as "typing" classes or generic aliases
  253. if not origin_type and is_model_class(type_):
  254. parameters = type_.__pydantic_generic_metadata__['parameters']
  255. if not parameters:
  256. return type_
  257. resolved_type_args = tuple(replace_types(t, type_map) for t in parameters)
  258. if all_identical(parameters, resolved_type_args):
  259. return type_
  260. return type_[resolved_type_args]
  261. # Handle special case for typehints that can have lists as arguments.
  262. # `typing.Callable[[int, str], int]` is an example for this.
  263. if isinstance(type_, list):
  264. resolved_list = [replace_types(element, type_map) for element in type_]
  265. if all_identical(type_, resolved_list):
  266. return type_
  267. return resolved_list
  268. # If all else fails, we try to resolve the type directly and otherwise just
  269. # return the input with no modifications.
  270. return type_map.get(type_, type_)
  271. def map_generic_model_arguments(cls: type[BaseModel], args: tuple[Any, ...]) -> dict[TypeVar, Any]:
  272. """Return a mapping between the parameters of a generic model and the provided arguments during parameterization.
  273. Raises:
  274. TypeError: If the number of arguments does not match the parameters (i.e. if providing too few or too many arguments).
  275. Example:
  276. ```python {test="skip" lint="skip"}
  277. class Model[T, U, V = int](BaseModel): ...
  278. map_generic_model_arguments(Model, (str, bytes))
  279. #> {T: str, U: bytes, V: int}
  280. map_generic_model_arguments(Model, (str,))
  281. #> TypeError: Too few arguments for <class '__main__.Model'>; actual 1, expected at least 2
  282. map_generic_model_arguments(Model, (str, bytes, int, complex))
  283. #> TypeError: Too many arguments for <class '__main__.Model'>; actual 4, expected 3
  284. ```
  285. Note:
  286. This function is analogous to the private `typing._check_generic_specialization` function.
  287. """
  288. parameters = cls.__pydantic_generic_metadata__['parameters']
  289. expected_len = len(parameters)
  290. typevars_map: dict[TypeVar, Any] = {}
  291. _missing = object()
  292. for parameter, argument in zip_longest(parameters, args, fillvalue=_missing):
  293. if parameter is _missing:
  294. raise TypeError(f'Too many arguments for {cls}; actual {len(args)}, expected {expected_len}')
  295. if argument is _missing:
  296. param = cast(TypeVar, parameter)
  297. try:
  298. has_default = param.has_default() # pyright: ignore[reportAttributeAccessIssue]
  299. except AttributeError:
  300. # Happens if using `typing.TypeVar` (and not `typing_extensions`) on Python < 3.13.
  301. has_default = False
  302. if has_default:
  303. # The default might refer to other type parameters. For an example, see:
  304. # https://typing.python.org/en/latest/spec/generics.html#type-parameters-as-parameters-to-generics
  305. typevars_map[param] = replace_types(param.__default__, typevars_map) # pyright: ignore[reportAttributeAccessIssue]
  306. else:
  307. expected_len -= sum(hasattr(p, 'has_default') and p.has_default() for p in parameters) # pyright: ignore[reportAttributeAccessIssue]
  308. raise TypeError(f'Too few arguments for {cls}; actual {len(args)}, expected at least {expected_len}')
  309. else:
  310. param = cast(TypeVar, parameter)
  311. typevars_map[param] = argument
  312. return typevars_map
  313. _generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)
  314. @contextmanager
  315. def generic_recursion_self_type(
  316. origin: type[BaseModel], args: tuple[Any, ...]
  317. ) -> Iterator[PydanticRecursiveRef | None]:
  318. """This contextmanager should be placed around the recursive calls used to build a generic type,
  319. and accept as arguments the generic origin type and the type arguments being passed to it.
  320. If the same origin and arguments are observed twice, it implies that a self-reference placeholder
  321. can be used while building the core schema, and will produce a schema_ref that will be valid in the
  322. final parent schema.
  323. """
  324. previously_seen_type_refs = _generic_recursion_cache.get()
  325. if previously_seen_type_refs is None:
  326. previously_seen_type_refs = set()
  327. token = _generic_recursion_cache.set(previously_seen_type_refs)
  328. else:
  329. token = None
  330. try:
  331. type_ref = get_type_ref(origin, args_override=args)
  332. if type_ref in previously_seen_type_refs:
  333. self_type = PydanticRecursiveRef(type_ref=type_ref)
  334. yield self_type
  335. else:
  336. previously_seen_type_refs.add(type_ref)
  337. yield
  338. previously_seen_type_refs.remove(type_ref)
  339. finally:
  340. if token:
  341. _generic_recursion_cache.reset(token)
  342. def recursively_defined_type_refs() -> set[str]:
  343. visited = _generic_recursion_cache.get()
  344. if not visited:
  345. return set() # not in a generic recursion, so there are no types
  346. return visited.copy() # don't allow modifications
  347. def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any) -> type[BaseModel] | None:
  348. """The use of a two-stage cache lookup approach was necessary to have the highest performance possible for
  349. repeated calls to `__class_getitem__` on generic types (which may happen in tighter loops during runtime),
  350. while still ensuring that certain alternative parametrizations ultimately resolve to the same type.
  351. As a concrete example, this approach was necessary to make Model[List[T]][int] equal to Model[List[int]].
  352. The approach could be modified to not use two different cache keys at different points, but the
  353. _early_cache_key is optimized to be as quick to compute as possible (for repeated-access speed), and the
  354. _late_cache_key is optimized to be as "correct" as possible, so that two types that will ultimately be the
  355. same after resolving the type arguments will always produce cache hits.
  356. If we wanted to move to only using a single cache key per type, we would either need to always use the
  357. slower/more computationally intensive logic associated with _late_cache_key, or would need to accept
  358. that Model[List[T]][int] is a different type than Model[List[T]][int]. Because we rely on subclass relationships
  359. during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually
  360. equal.
  361. """
  362. generic_types_cache = _GENERIC_TYPES_CACHE.get()
  363. if generic_types_cache is None:
  364. generic_types_cache = GenericTypesCache()
  365. _GENERIC_TYPES_CACHE.set(generic_types_cache)
  366. return generic_types_cache.get(_early_cache_key(parent, typevar_values))
  367. def get_cached_generic_type_late(
  368. parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...]
  369. ) -> type[BaseModel] | None:
  370. """See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup."""
  371. generic_types_cache = _GENERIC_TYPES_CACHE.get()
  372. if (
  373. generic_types_cache is None
  374. ): # pragma: no cover (early cache is guaranteed to run first and initialize the cache)
  375. generic_types_cache = GenericTypesCache()
  376. _GENERIC_TYPES_CACHE.set(generic_types_cache)
  377. cached = generic_types_cache.get(_late_cache_key(origin, args, typevar_values))
  378. if cached is not None:
  379. set_cached_generic_type(parent, typevar_values, cached, origin, args)
  380. return cached
  381. def set_cached_generic_type(
  382. parent: type[BaseModel],
  383. typevar_values: tuple[Any, ...],
  384. type_: type[BaseModel],
  385. origin: type[BaseModel] | None = None,
  386. args: tuple[Any, ...] | None = None,
  387. ) -> None:
  388. """See the docstring of `get_cached_generic_type_early` for more information about why items are cached with
  389. two different keys.
  390. """
  391. generic_types_cache = _GENERIC_TYPES_CACHE.get()
  392. if (
  393. generic_types_cache is None
  394. ): # pragma: no cover (cache lookup is guaranteed to run first and initialize the cache)
  395. generic_types_cache = GenericTypesCache()
  396. _GENERIC_TYPES_CACHE.set(generic_types_cache)
  397. generic_types_cache[_early_cache_key(parent, typevar_values)] = type_
  398. if len(typevar_values) == 1:
  399. generic_types_cache[_early_cache_key(parent, typevar_values[0])] = type_
  400. if origin and args:
  401. generic_types_cache[_late_cache_key(origin, args, typevar_values)] = type_
  402. def _union_orderings_key(typevar_values: Any) -> Any:
  403. """This is intended to help differentiate between Union types with the same arguments in different order.
  404. Thanks to caching internal to the `typing` module, it is not possible to distinguish between
  405. List[Union[int, float]] and List[Union[float, int]] (and similarly for other "parent" origins besides List)
  406. because `typing` considers Union[int, float] to be equal to Union[float, int].
  407. However, you _can_ distinguish between (top-level) Union[int, float] vs. Union[float, int].
  408. Because we parse items as the first Union type that is successful, we get slightly more consistent behavior
  409. if we make an effort to distinguish the ordering of items in a union. It would be best if we could _always_
  410. get the exact-correct order of items in the union, but that would require a change to the `typing` module itself.
  411. (See https://github.com/python/cpython/issues/86483 for reference.)
  412. """
  413. if isinstance(typevar_values, tuple):
  414. return tuple(_union_orderings_key(value) for value in typevar_values)
  415. elif typing_objects.is_union(typing_extensions.get_origin(typevar_values)):
  416. return get_args(typevar_values)
  417. else:
  418. return ()
  419. def _early_cache_key(cls: type[BaseModel], typevar_values: Any) -> GenericTypesCacheKey:
  420. """This is intended for minimal computational overhead during lookups of cached types.
  421. Note that this is overly simplistic, and it's possible that two different cls/typevar_values
  422. inputs would ultimately result in the same type being created in BaseModel.__class_getitem__.
  423. To handle this, we have a fallback _late_cache_key that is checked later if the _early_cache_key
  424. lookup fails, and should result in a cache hit _precisely_ when the inputs to __class_getitem__
  425. would result in the same type.
  426. """
  427. return cls, typevar_values, _union_orderings_key(typevar_values)
  428. def _late_cache_key(origin: type[BaseModel], args: tuple[Any, ...], typevar_values: Any) -> GenericTypesCacheKey:
  429. """This is intended for use later in the process of creating a new type, when we have more information
  430. about the exact args that will be passed. If it turns out that a different set of inputs to
  431. __class_getitem__ resulted in the same inputs to the generic type creation process, we can still
  432. return the cached type, and update the cache with the _early_cache_key as well.
  433. """
  434. # The _union_orderings_key is placed at the start here to ensure there cannot be a collision with an
  435. # _early_cache_key, as that function will always produce a BaseModel subclass as the first item in the key,
  436. # whereas this function will always produce a tuple as the first item in the key.
  437. return _union_orderings_key(typevar_values), origin, args