env.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. from __future__ import annotations as _annotations
  2. import os
  3. from collections.abc import Mapping
  4. from typing import (
  5. TYPE_CHECKING,
  6. Any,
  7. )
  8. from pydantic._internal._utils import deep_update, is_model_class
  9. from pydantic.dataclasses import is_pydantic_dataclass
  10. from pydantic.fields import FieldInfo
  11. from typing_extensions import get_args, get_origin
  12. from typing_inspection.introspection import is_union_origin
  13. from ...utils import _lenient_issubclass
  14. from ..base import PydanticBaseEnvSettingsSource
  15. from ..types import EnvNoneType
  16. from ..utils import (
  17. _annotation_enum_name_to_val,
  18. _get_model_fields,
  19. _union_is_complex,
  20. parse_env_vars,
  21. )
  22. if TYPE_CHECKING:
  23. from pydantic_settings.main import BaseSettings
  24. class EnvSettingsSource(PydanticBaseEnvSettingsSource):
  25. """
  26. Source class for loading settings values from environment variables.
  27. """
  28. def __init__(
  29. self,
  30. settings_cls: type[BaseSettings],
  31. case_sensitive: bool | None = None,
  32. env_prefix: str | None = None,
  33. env_nested_delimiter: str | None = None,
  34. env_nested_max_split: int | None = None,
  35. env_ignore_empty: bool | None = None,
  36. env_parse_none_str: str | None = None,
  37. env_parse_enums: bool | None = None,
  38. ) -> None:
  39. super().__init__(
  40. settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
  41. )
  42. self.env_nested_delimiter = (
  43. env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter')
  44. )
  45. self.env_nested_max_split = (
  46. env_nested_max_split if env_nested_max_split is not None else self.config.get('env_nested_max_split')
  47. )
  48. self.maxsplit = (self.env_nested_max_split or 0) - 1
  49. self.env_prefix_len = len(self.env_prefix)
  50. self.env_vars = self._load_env_vars()
  51. def _load_env_vars(self) -> Mapping[str, str | None]:
  52. return parse_env_vars(os.environ, self.case_sensitive, self.env_ignore_empty, self.env_parse_none_str)
  53. def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
  54. """
  55. Gets the value for field from environment variables and a flag to determine whether value is complex.
  56. Args:
  57. field: The field.
  58. field_name: The field name.
  59. Returns:
  60. A tuple that contains the value (`None` if not found), key, and
  61. a flag to determine whether value is complex.
  62. """
  63. env_val: str | None = None
  64. for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
  65. env_val = self.env_vars.get(env_name)
  66. if env_val is not None:
  67. break
  68. return env_val, field_key, value_is_complex
  69. def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
  70. """
  71. Prepare value for the field.
  72. * Extract value for nested field.
  73. * Deserialize value to python object for complex field.
  74. Args:
  75. field: The field.
  76. field_name: The field name.
  77. Returns:
  78. A tuple contains prepared value for the field.
  79. Raises:
  80. ValuesError: When There is an error in deserializing value for complex field.
  81. """
  82. is_complex, allow_parse_failure = self._field_is_complex(field)
  83. if self.env_parse_enums:
  84. enum_val = _annotation_enum_name_to_val(field.annotation, value)
  85. value = value if enum_val is None else enum_val
  86. if is_complex or value_is_complex:
  87. if isinstance(value, EnvNoneType):
  88. return value
  89. elif value is None:
  90. # field is complex but no value found so far, try explode_env_vars
  91. env_val_built = self.explode_env_vars(field_name, field, self.env_vars)
  92. if env_val_built:
  93. return env_val_built
  94. else:
  95. # field is complex and there's a value, decode that as JSON, then add explode_env_vars
  96. try:
  97. value = self.decode_complex_value(field_name, field, value)
  98. except ValueError as e:
  99. if not allow_parse_failure:
  100. raise e
  101. if isinstance(value, dict):
  102. return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars))
  103. else:
  104. return value
  105. elif value is not None:
  106. # simplest case, field is not complex, we only need to add the value if it was found
  107. return value
  108. def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
  109. """
  110. Find out if a field is complex, and if so whether JSON errors should be ignored
  111. """
  112. if self.field_is_complex(field):
  113. allow_parse_failure = False
  114. elif is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
  115. allow_parse_failure = True
  116. else:
  117. return False, False
  118. return True, allow_parse_failure
  119. # Default value of `case_sensitive` is `None`, because we don't want to break existing behavior.
  120. # We have to change the method to a non-static method and use
  121. # `self.case_sensitive` instead in V3.
  122. def next_field(
  123. self, field: FieldInfo | Any | None, key: str, case_sensitive: bool | None = None
  124. ) -> FieldInfo | None:
  125. """
  126. Find the field in a sub model by key(env name)
  127. By having the following models:
  128. ```py
  129. class SubSubModel(BaseSettings):
  130. dvals: Dict
  131. class SubModel(BaseSettings):
  132. vals: list[str]
  133. sub_sub_model: SubSubModel
  134. class Cfg(BaseSettings):
  135. sub_model: SubModel
  136. ```
  137. Then:
  138. next_field(sub_model, 'vals') Returns the `vals` field of `SubModel` class
  139. next_field(sub_model, 'sub_sub_model') Returns `sub_sub_model` field of `SubModel` class
  140. Args:
  141. field: The field.
  142. key: The key (env name).
  143. case_sensitive: Whether to search for key case sensitively.
  144. Returns:
  145. Field if it finds the next field otherwise `None`.
  146. """
  147. if not field:
  148. return None
  149. annotation = field.annotation if isinstance(field, FieldInfo) else field
  150. for type_ in get_args(annotation):
  151. type_has_key = self.next_field(type_, key, case_sensitive)
  152. if type_has_key:
  153. return type_has_key
  154. if is_model_class(annotation) or is_pydantic_dataclass(annotation): # type: ignore[arg-type]
  155. fields = _get_model_fields(annotation)
  156. # `case_sensitive is None` is here to be compatible with the old behavior.
  157. # Has to be removed in V3.
  158. for field_name, f in fields.items():
  159. for _, env_name, _ in self._extract_field_info(f, field_name):
  160. if case_sensitive is None or case_sensitive:
  161. if field_name == key or env_name == key:
  162. return f
  163. elif field_name.lower() == key.lower() or env_name.lower() == key.lower():
  164. return f
  165. return None
  166. def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]:
  167. """
  168. Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
  169. This is applied to a single field, hence filtering by env_var prefix.
  170. Args:
  171. field_name: The field name.
  172. field: The field.
  173. env_vars: Environment variables.
  174. Returns:
  175. A dictionary contains extracted values from nested env values.
  176. """
  177. if not self.env_nested_delimiter:
  178. return {}
  179. ann = field.annotation
  180. is_dict = ann is dict or _lenient_issubclass(get_origin(ann), dict)
  181. prefixes = [
  182. f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
  183. ]
  184. result: dict[str, Any] = {}
  185. for env_name, env_val in env_vars.items():
  186. try:
  187. prefix = next(prefix for prefix in prefixes if env_name.startswith(prefix))
  188. except StopIteration:
  189. continue
  190. # we remove the prefix before splitting in case the prefix has characters in common with the delimiter
  191. env_name_without_prefix = env_name[len(prefix) :]
  192. *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter, self.maxsplit)
  193. env_var = result
  194. target_field: FieldInfo | None = field
  195. for key in keys:
  196. target_field = self.next_field(target_field, key, self.case_sensitive)
  197. if isinstance(env_var, dict):
  198. env_var = env_var.setdefault(key, {})
  199. # get proper field with last_key
  200. target_field = self.next_field(target_field, last_key, self.case_sensitive)
  201. # check if env_val maps to a complex field and if so, parse the env_val
  202. if (target_field or is_dict) and env_val:
  203. if target_field:
  204. is_complex, allow_json_failure = self._field_is_complex(target_field)
  205. if self.env_parse_enums:
  206. enum_val = _annotation_enum_name_to_val(target_field.annotation, env_val)
  207. env_val = env_val if enum_val is None else enum_val
  208. else:
  209. # nested field type is dict
  210. is_complex, allow_json_failure = True, True
  211. if is_complex:
  212. try:
  213. env_val = self.decode_complex_value(last_key, target_field, env_val) # type: ignore
  214. except ValueError as e:
  215. if not allow_json_failure:
  216. raise e
  217. if isinstance(env_var, dict):
  218. if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}:
  219. env_var[last_key] = env_val
  220. return result
  221. def __repr__(self) -> str:
  222. return (
  223. f'{self.__class__.__name__}(env_nested_delimiter={self.env_nested_delimiter!r}, '
  224. f'env_prefix_len={self.env_prefix_len!r})'
  225. )
  226. __all__ = ['EnvSettingsSource']