_signature.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from __future__ import annotations
  2. import dataclasses
  3. from inspect import Parameter, Signature, signature
  4. from typing import TYPE_CHECKING, Any, Callable
  5. from pydantic_core import PydanticUndefined
  6. from ._utils import is_valid_identifier
  7. if TYPE_CHECKING:
  8. from ..config import ExtraValues
  9. from ..fields import FieldInfo
  10. # Copied over from stdlib dataclasses
  11. class _HAS_DEFAULT_FACTORY_CLASS:
  12. def __repr__(self):
  13. return '<factory>'
  14. _HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
  15. def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str:
  16. """Extract the correct name to use for the field when generating a signature.
  17. Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name.
  18. First priority is given to the alias, then the validation_alias, then the field name.
  19. Args:
  20. field_name: The name of the field
  21. field_info: The corresponding FieldInfo object.
  22. Returns:
  23. The correct name to use when generating a signature.
  24. """
  25. if isinstance(field_info.alias, str) and is_valid_identifier(field_info.alias):
  26. return field_info.alias
  27. if isinstance(field_info.validation_alias, str) and is_valid_identifier(field_info.validation_alias):
  28. return field_info.validation_alias
  29. return field_name
  30. def _process_param_defaults(param: Parameter) -> Parameter:
  31. """Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.
  32. Args:
  33. param (Parameter): The parameter
  34. Returns:
  35. Parameter: The custom processed parameter
  36. """
  37. from ..fields import FieldInfo
  38. param_default = param.default
  39. if isinstance(param_default, FieldInfo):
  40. annotation = param.annotation
  41. # Replace the annotation if appropriate
  42. # inspect does "clever" things to show annotations as strings because we have
  43. # `from __future__ import annotations` in main, we don't want that
  44. if annotation == 'Any':
  45. annotation = Any
  46. # Replace the field default
  47. default = param_default.default
  48. if default is PydanticUndefined:
  49. if param_default.default_factory is PydanticUndefined:
  50. default = Signature.empty
  51. else:
  52. # this is used by dataclasses to indicate a factory exists:
  53. default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
  54. return param.replace(
  55. annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default
  56. )
  57. return param
  58. def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor)
  59. init: Callable[..., None],
  60. fields: dict[str, FieldInfo],
  61. validate_by_name: bool,
  62. extra: ExtraValues | None,
  63. ) -> dict[str, Parameter]:
  64. """Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass."""
  65. from itertools import islice
  66. present_params = signature(init).parameters.values()
  67. merged_params: dict[str, Parameter] = {}
  68. var_kw = None
  69. use_var_kw = False
  70. for param in islice(present_params, 1, None): # skip self arg
  71. # inspect does "clever" things to show annotations as strings because we have
  72. # `from __future__ import annotations` in main, we don't want that
  73. if fields.get(param.name):
  74. # exclude params with init=False
  75. if getattr(fields[param.name], 'init', True) is False:
  76. continue
  77. param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
  78. if param.annotation == 'Any':
  79. param = param.replace(annotation=Any)
  80. if param.kind is param.VAR_KEYWORD:
  81. var_kw = param
  82. continue
  83. merged_params[param.name] = param
  84. if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
  85. allow_names = validate_by_name
  86. for field_name, field in fields.items():
  87. # when alias is a str it should be used for signature generation
  88. param_name = _field_name_for_signature(field_name, field)
  89. if field_name in merged_params or param_name in merged_params:
  90. continue
  91. if not is_valid_identifier(param_name):
  92. if allow_names:
  93. param_name = field_name
  94. else:
  95. use_var_kw = True
  96. continue
  97. if field.is_required():
  98. default = Parameter.empty
  99. elif field.default_factory is not None:
  100. # Mimics stdlib dataclasses:
  101. default = _HAS_DEFAULT_FACTORY
  102. else:
  103. default = field.default
  104. merged_params[param_name] = Parameter(
  105. param_name,
  106. Parameter.KEYWORD_ONLY,
  107. annotation=field.rebuild_annotation(),
  108. default=default,
  109. )
  110. if extra == 'allow':
  111. use_var_kw = True
  112. if var_kw and use_var_kw:
  113. # Make sure the parameter for extra kwargs
  114. # does not have the same name as a field
  115. default_model_signature = [
  116. ('self', Parameter.POSITIONAL_ONLY),
  117. ('data', Parameter.VAR_KEYWORD),
  118. ]
  119. if [(p.name, p.kind) for p in present_params] == default_model_signature:
  120. # if this is the standard model signature, use extra_data as the extra args name
  121. var_kw_name = 'extra_data'
  122. else:
  123. # else start from var_kw
  124. var_kw_name = var_kw.name
  125. # generate a name that's definitely unique
  126. while var_kw_name in fields:
  127. var_kw_name += '_'
  128. merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
  129. return merged_params
  130. def generate_pydantic_signature(
  131. init: Callable[..., None],
  132. fields: dict[str, FieldInfo],
  133. validate_by_name: bool,
  134. extra: ExtraValues | None,
  135. is_dataclass: bool = False,
  136. ) -> Signature:
  137. """Generate signature for a pydantic BaseModel or dataclass.
  138. Args:
  139. init: The class init.
  140. fields: The model fields.
  141. validate_by_name: The `validate_by_name` value of the config.
  142. extra: The `extra` value of the config.
  143. is_dataclass: Whether the model is a dataclass.
  144. Returns:
  145. The dataclass/BaseModel subclass signature.
  146. """
  147. merged_params = _generate_signature_parameters(init, fields, validate_by_name, extra)
  148. if is_dataclass:
  149. merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()}
  150. return Signature(parameters=list(merged_params.values()), return_annotation=None)