utils.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """Utility functions for pydantic-settings sources."""
  2. from __future__ import annotations as _annotations
  3. from collections import deque
  4. from collections.abc import Mapping, Sequence
  5. from dataclasses import is_dataclass
  6. from enum import Enum
  7. from typing import Any, Optional, cast
  8. from pydantic import BaseModel, Json, RootModel, Secret
  9. from pydantic._internal._utils import is_model_class
  10. from pydantic.dataclasses import is_pydantic_dataclass
  11. from typing_extensions import get_args, get_origin
  12. from typing_inspection import typing_objects
  13. from ..exceptions import SettingsError
  14. from ..utils import _lenient_issubclass
  15. from .types import EnvNoneType
  16. def _get_env_var_key(key: str, case_sensitive: bool = False) -> str:
  17. return key if case_sensitive else key.lower()
  18. def _parse_env_none_str(value: str | None, parse_none_str: str | None = None) -> str | None | EnvNoneType:
  19. return value if not (value == parse_none_str and parse_none_str is not None) else EnvNoneType(value)
  20. def parse_env_vars(
  21. env_vars: Mapping[str, str | None],
  22. case_sensitive: bool = False,
  23. ignore_empty: bool = False,
  24. parse_none_str: str | None = None,
  25. ) -> Mapping[str, str | None]:
  26. return {
  27. _get_env_var_key(k, case_sensitive): _parse_env_none_str(v, parse_none_str)
  28. for k, v in env_vars.items()
  29. if not (ignore_empty and v == '')
  30. }
  31. def _annotation_is_complex(annotation: Any, metadata: list[Any]) -> bool:
  32. # If the model is a root model, the root annotation should be used to
  33. # evaluate the complexity.
  34. if typing_objects.is_typealiastype(annotation) or typing_objects.is_typealiastype(get_origin(annotation)):
  35. annotation = annotation.__value__
  36. if annotation is not None and _lenient_issubclass(annotation, RootModel) and annotation is not RootModel:
  37. annotation = cast('type[RootModel[Any]]', annotation)
  38. root_annotation = annotation.model_fields['root'].annotation
  39. if root_annotation is not None: # pragma: no branch
  40. annotation = root_annotation
  41. if any(isinstance(md, Json) for md in metadata): # type: ignore[misc]
  42. return False
  43. origin = get_origin(annotation)
  44. # Check if annotation is of the form Annotated[type, metadata].
  45. if typing_objects.is_annotated(origin):
  46. # Return result of recursive call on inner type.
  47. inner, *meta = get_args(annotation)
  48. return _annotation_is_complex(inner, meta)
  49. if origin is Secret:
  50. return False
  51. return (
  52. _annotation_is_complex_inner(annotation)
  53. or _annotation_is_complex_inner(origin)
  54. or hasattr(origin, '__pydantic_core_schema__')
  55. or hasattr(origin, '__get_pydantic_core_schema__')
  56. )
  57. def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
  58. if _lenient_issubclass(annotation, (str, bytes)):
  59. return False
  60. return _lenient_issubclass(
  61. annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)
  62. ) or is_dataclass(annotation)
  63. def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
  64. """Check if a union type contains any complex types."""
  65. return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation))
  66. def _annotation_contains_types(
  67. annotation: type[Any] | None,
  68. types: tuple[Any, ...],
  69. is_include_origin: bool = True,
  70. is_strip_annotated: bool = False,
  71. ) -> bool:
  72. """Check if a type annotation contains any of the specified types."""
  73. if is_strip_annotated:
  74. annotation = _strip_annotated(annotation)
  75. if is_include_origin is True and get_origin(annotation) in types:
  76. return True
  77. for type_ in get_args(annotation):
  78. if _annotation_contains_types(type_, types, is_include_origin=True, is_strip_annotated=is_strip_annotated):
  79. return True
  80. return annotation in types
  81. def _strip_annotated(annotation: Any) -> Any:
  82. if typing_objects.is_annotated(get_origin(annotation)):
  83. return annotation.__origin__
  84. else:
  85. return annotation
  86. def _annotation_enum_val_to_name(annotation: type[Any] | None, value: Any) -> Optional[str]:
  87. for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
  88. if _lenient_issubclass(type_, Enum):
  89. if value in tuple(val.value for val in type_):
  90. return type_(value).name
  91. return None
  92. def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any:
  93. for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
  94. if _lenient_issubclass(type_, Enum):
  95. if name in tuple(val.name for val in type_):
  96. return type_[name]
  97. return None
  98. def _get_model_fields(model_cls: type[Any]) -> dict[str, Any]:
  99. """Get fields from a pydantic model or dataclass."""
  100. if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__'):
  101. return model_cls.__pydantic_fields__
  102. if is_model_class(model_cls):
  103. return model_cls.model_fields
  104. raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass')
  105. def _get_alias_names(
  106. field_name: str,
  107. field_info: Any,
  108. alias_path_args: Optional[dict[str, Optional[int]]] = None,
  109. case_sensitive: bool = True,
  110. ) -> tuple[tuple[str, ...], bool]:
  111. """Get alias names for a field, handling alias paths and case sensitivity."""
  112. from pydantic import AliasChoices, AliasPath
  113. alias_names: list[str] = []
  114. is_alias_path_only: bool = True
  115. if not any((field_info.alias, field_info.validation_alias)):
  116. alias_names += [field_name]
  117. is_alias_path_only = False
  118. else:
  119. new_alias_paths: list[AliasPath] = []
  120. for alias in (field_info.alias, field_info.validation_alias):
  121. if alias is None:
  122. continue
  123. elif isinstance(alias, str):
  124. alias_names.append(alias)
  125. is_alias_path_only = False
  126. elif isinstance(alias, AliasChoices):
  127. for name in alias.choices:
  128. if isinstance(name, str):
  129. alias_names.append(name)
  130. is_alias_path_only = False
  131. else:
  132. new_alias_paths.append(name)
  133. else:
  134. new_alias_paths.append(alias)
  135. for alias_path in new_alias_paths:
  136. name = cast(str, alias_path.path[0])
  137. name = name.lower() if not case_sensitive else name
  138. if alias_path_args is not None:
  139. alias_path_args[name] = (
  140. alias_path.path[1] if len(alias_path.path) > 1 and isinstance(alias_path.path[1], int) else None
  141. )
  142. if not alias_names and is_alias_path_only:
  143. alias_names.append(name)
  144. if not case_sensitive:
  145. alias_names = [alias_name.lower() for alias_name in alias_names]
  146. return tuple(dict.fromkeys(alias_names)), is_alias_path_only
  147. def _is_function(obj: Any) -> bool:
  148. """Check if an object is a function."""
  149. from types import BuiltinFunctionType, FunctionType
  150. return isinstance(obj, (FunctionType, BuiltinFunctionType))
  151. __all__ = [
  152. '_annotation_contains_types',
  153. '_annotation_enum_name_to_val',
  154. '_annotation_enum_val_to_name',
  155. '_annotation_is_complex',
  156. '_annotation_is_complex_inner',
  157. '_get_alias_names',
  158. '_get_env_var_key',
  159. '_get_model_fields',
  160. '_is_function',
  161. '_parse_env_none_str',
  162. '_strip_annotated',
  163. '_union_is_complex',
  164. 'parse_env_vars',
  165. ]