base.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. """Base classes and core functionality for pydantic-settings sources."""
  2. from __future__ import annotations as _annotations
  3. import json
  4. import os
  5. from abc import ABC, abstractmethod
  6. from dataclasses import asdict, is_dataclass
  7. from pathlib import Path
  8. from typing import TYPE_CHECKING, Any, Optional, cast
  9. from pydantic import AliasChoices, AliasPath, BaseModel, TypeAdapter
  10. from pydantic._internal._typing_extra import ( # type: ignore[attr-defined]
  11. get_origin,
  12. )
  13. from pydantic._internal._utils import is_model_class
  14. from pydantic.fields import FieldInfo
  15. from typing_extensions import get_args
  16. from typing_inspection import typing_objects
  17. from typing_inspection.introspection import is_union_origin
  18. from ..exceptions import SettingsError
  19. from ..utils import _lenient_issubclass
  20. from .types import EnvNoneType, ForceDecode, NoDecode, PathType, PydanticModel, _CliSubCommand
  21. from .utils import (
  22. _annotation_is_complex,
  23. _get_alias_names,
  24. _get_model_fields,
  25. _strip_annotated,
  26. _union_is_complex,
  27. )
  28. if TYPE_CHECKING:
  29. from pydantic_settings.main import BaseSettings
  30. def get_subcommand(
  31. model: PydanticModel, is_required: bool = True, cli_exit_on_error: bool | None = None
  32. ) -> Optional[PydanticModel]:
  33. """
  34. Get the subcommand from a model.
  35. Args:
  36. model: The model to get the subcommand from.
  37. is_required: Determines whether a model must have subcommand set and raises error if not
  38. found. Defaults to `True`.
  39. cli_exit_on_error: Determines whether this function exits with error if no subcommand is found.
  40. Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`.
  41. Returns:
  42. The subcommand model if found, otherwise `None`.
  43. Raises:
  44. SystemExit: When no subcommand is found and is_required=`True` and cli_exit_on_error=`True`
  45. (the default).
  46. SettingsError: When no subcommand is found and is_required=`True` and
  47. cli_exit_on_error=`False`.
  48. """
  49. model_cls = type(model)
  50. if cli_exit_on_error is None and is_model_class(model_cls):
  51. model_default = model_cls.model_config.get('cli_exit_on_error')
  52. if isinstance(model_default, bool):
  53. cli_exit_on_error = model_default
  54. if cli_exit_on_error is None:
  55. cli_exit_on_error = True
  56. subcommands: list[str] = []
  57. for field_name, field_info in _get_model_fields(model_cls).items():
  58. if _CliSubCommand in field_info.metadata:
  59. if getattr(model, field_name) is not None:
  60. return getattr(model, field_name)
  61. subcommands.append(field_name)
  62. if is_required:
  63. error_message = (
  64. f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}'
  65. if subcommands
  66. else 'Error: CLI subcommand is required but no subcommands were found.'
  67. )
  68. raise SystemExit(error_message) if cli_exit_on_error else SettingsError(error_message)
  69. return None
  70. class PydanticBaseSettingsSource(ABC):
  71. """
  72. Abstract base class for settings sources, every settings source classes should inherit from it.
  73. """
  74. def __init__(self, settings_cls: type[BaseSettings]):
  75. self.settings_cls = settings_cls
  76. self.config = settings_cls.model_config
  77. self._current_state: dict[str, Any] = {}
  78. self._settings_sources_data: dict[str, dict[str, Any]] = {}
  79. def _set_current_state(self, state: dict[str, Any]) -> None:
  80. """
  81. Record the state of settings from the previous settings sources. This should
  82. be called right before __call__.
  83. """
  84. self._current_state = state
  85. def _set_settings_sources_data(self, states: dict[str, dict[str, Any]]) -> None:
  86. """
  87. Record the state of settings from all previous settings sources. This should
  88. be called right before __call__.
  89. """
  90. self._settings_sources_data = states
  91. @property
  92. def current_state(self) -> dict[str, Any]:
  93. """
  94. The current state of the settings, populated by the previous settings sources.
  95. """
  96. return self._current_state
  97. @property
  98. def settings_sources_data(self) -> dict[str, dict[str, Any]]:
  99. """
  100. The state of all previous settings sources.
  101. """
  102. return self._settings_sources_data
  103. @abstractmethod
  104. def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
  105. """
  106. Gets the value, the key for model creation, and a flag to determine whether value is complex.
  107. This is an abstract method that should be overridden in every settings source classes.
  108. Args:
  109. field: The field.
  110. field_name: The field name.
  111. Returns:
  112. A tuple that contains the value, key and a flag to determine whether value is complex.
  113. """
  114. pass
  115. def field_is_complex(self, field: FieldInfo) -> bool:
  116. """
  117. Checks whether a field is complex, in which case it will attempt to be parsed as JSON.
  118. Args:
  119. field: The field.
  120. Returns:
  121. Whether the field is complex.
  122. """
  123. return _annotation_is_complex(field.annotation, field.metadata)
  124. def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
  125. """
  126. Prepares the value of a field.
  127. Args:
  128. field_name: The field name.
  129. field: The field.
  130. value: The value of the field that has to be prepared.
  131. value_is_complex: A flag to determine whether value is complex.
  132. Returns:
  133. The prepared value.
  134. """
  135. if value is not None and (self.field_is_complex(field) or value_is_complex):
  136. return self.decode_complex_value(field_name, field, value)
  137. return value
  138. def decode_complex_value(self, field_name: str, field: FieldInfo, value: Any) -> Any:
  139. """
  140. Decode the value for a complex field
  141. Args:
  142. field_name: The field name.
  143. field: The field.
  144. value: The value of the field that has to be prepared.
  145. Returns:
  146. The decoded value for further preparation
  147. """
  148. if field and (
  149. NoDecode in field.metadata
  150. or (self.config.get('enable_decoding') is False and ForceDecode not in field.metadata)
  151. ):
  152. return value
  153. return json.loads(value)
  154. @abstractmethod
  155. def __call__(self) -> dict[str, Any]:
  156. pass
  157. class ConfigFileSourceMixin(ABC):
  158. def _read_files(self, files: PathType | None) -> dict[str, Any]:
  159. if files is None:
  160. return {}
  161. if isinstance(files, (str, os.PathLike)):
  162. files = [files]
  163. vars: dict[str, Any] = {}
  164. for file in files:
  165. file_path = Path(file).expanduser()
  166. if file_path.is_file():
  167. vars.update(self._read_file(file_path))
  168. return vars
  169. @abstractmethod
  170. def _read_file(self, path: Path) -> dict[str, Any]:
  171. pass
  172. class DefaultSettingsSource(PydanticBaseSettingsSource):
  173. """
  174. Source class for loading default object values.
  175. Args:
  176. settings_cls: The Settings class.
  177. nested_model_default_partial_update: Whether to allow partial updates on nested model default object fields.
  178. Defaults to `False`.
  179. """
  180. def __init__(self, settings_cls: type[BaseSettings], nested_model_default_partial_update: bool | None = None):
  181. super().__init__(settings_cls)
  182. self.defaults: dict[str, Any] = {}
  183. self.nested_model_default_partial_update = (
  184. nested_model_default_partial_update
  185. if nested_model_default_partial_update is not None
  186. else self.config.get('nested_model_default_partial_update', False)
  187. )
  188. if self.nested_model_default_partial_update:
  189. for field_name, field_info in settings_cls.model_fields.items():
  190. alias_names, *_ = _get_alias_names(field_name, field_info)
  191. preferred_alias = alias_names[0]
  192. if is_dataclass(type(field_info.default)):
  193. self.defaults[preferred_alias] = asdict(field_info.default)
  194. elif is_model_class(type(field_info.default)):
  195. self.defaults[preferred_alias] = field_info.default.model_dump()
  196. def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
  197. # Nothing to do here. Only implement the return statement to make mypy happy
  198. return None, '', False
  199. def __call__(self) -> dict[str, Any]:
  200. return self.defaults
  201. def __repr__(self) -> str:
  202. return (
  203. f'{self.__class__.__name__}(nested_model_default_partial_update={self.nested_model_default_partial_update})'
  204. )
  205. class InitSettingsSource(PydanticBaseSettingsSource):
  206. """
  207. Source class for loading values provided during settings class initialization.
  208. """
  209. def __init__(
  210. self,
  211. settings_cls: type[BaseSettings],
  212. init_kwargs: dict[str, Any],
  213. nested_model_default_partial_update: bool | None = None,
  214. ):
  215. self.init_kwargs = {}
  216. init_kwarg_names = set(init_kwargs.keys())
  217. for field_name, field_info in settings_cls.model_fields.items():
  218. alias_names, *_ = _get_alias_names(field_name, field_info)
  219. init_kwarg_name = init_kwarg_names & set(alias_names)
  220. if init_kwarg_name:
  221. preferred_alias = alias_names[0]
  222. preferred_set_alias = next(alias for alias in alias_names if alias in init_kwarg_name)
  223. init_kwarg_names -= init_kwarg_name
  224. self.init_kwargs[preferred_alias] = init_kwargs[preferred_set_alias]
  225. self.init_kwargs.update({key: val for key, val in init_kwargs.items() if key in init_kwarg_names})
  226. super().__init__(settings_cls)
  227. self.nested_model_default_partial_update = (
  228. nested_model_default_partial_update
  229. if nested_model_default_partial_update is not None
  230. else self.config.get('nested_model_default_partial_update', False)
  231. )
  232. def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
  233. # Nothing to do here. Only implement the return statement to make mypy happy
  234. return None, '', False
  235. def __call__(self) -> dict[str, Any]:
  236. return (
  237. TypeAdapter(dict[str, Any]).dump_python(self.init_kwargs)
  238. if self.nested_model_default_partial_update
  239. else self.init_kwargs
  240. )
  241. def __repr__(self) -> str:
  242. return f'{self.__class__.__name__}(init_kwargs={self.init_kwargs!r})'
  243. class PydanticBaseEnvSettingsSource(PydanticBaseSettingsSource):
  244. def __init__(
  245. self,
  246. settings_cls: type[BaseSettings],
  247. case_sensitive: bool | None = None,
  248. env_prefix: str | None = None,
  249. env_ignore_empty: bool | None = None,
  250. env_parse_none_str: str | None = None,
  251. env_parse_enums: bool | None = None,
  252. ) -> None:
  253. super().__init__(settings_cls)
  254. self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False)
  255. self.env_prefix = env_prefix if env_prefix is not None else self.config.get('env_prefix', '')
  256. self.env_ignore_empty = (
  257. env_ignore_empty if env_ignore_empty is not None else self.config.get('env_ignore_empty', False)
  258. )
  259. self.env_parse_none_str = (
  260. env_parse_none_str if env_parse_none_str is not None else self.config.get('env_parse_none_str')
  261. )
  262. self.env_parse_enums = env_parse_enums if env_parse_enums is not None else self.config.get('env_parse_enums')
  263. def _apply_case_sensitive(self, value: str) -> str:
  264. return value.lower() if not self.case_sensitive else value
  265. def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
  266. """
  267. Extracts field info. This info is used to get the value of field from environment variables.
  268. It returns a list of tuples, each tuple contains:
  269. * field_key: The key of field that has to be used in model creation.
  270. * env_name: The environment variable name of the field.
  271. * value_is_complex: A flag to determine whether the value from environment variable
  272. is complex and has to be parsed.
  273. Args:
  274. field (FieldInfo): The field.
  275. field_name (str): The field name.
  276. Returns:
  277. list[tuple[str, str, bool]]: List of tuples, each tuple contains field_key, env_name, and value_is_complex.
  278. """
  279. field_info: list[tuple[str, str, bool]] = []
  280. if isinstance(field.validation_alias, (AliasChoices, AliasPath)):
  281. v_alias: str | list[str | int] | list[list[str | int]] | None = field.validation_alias.convert_to_aliases()
  282. else:
  283. v_alias = field.validation_alias
  284. if v_alias:
  285. if isinstance(v_alias, list): # AliasChoices, AliasPath
  286. for alias in v_alias:
  287. if isinstance(alias, str): # AliasPath
  288. field_info.append((alias, self._apply_case_sensitive(alias), True if len(alias) > 1 else False))
  289. elif isinstance(alias, list): # AliasChoices
  290. first_arg = cast(str, alias[0]) # first item of an AliasChoices must be a str
  291. field_info.append(
  292. (first_arg, self._apply_case_sensitive(first_arg), True if len(alias) > 1 else False)
  293. )
  294. else: # string validation alias
  295. field_info.append((v_alias, self._apply_case_sensitive(v_alias), False))
  296. if not v_alias or self.config.get('populate_by_name', False):
  297. annotation = field.annotation
  298. if typing_objects.is_typealiastype(annotation) or typing_objects.is_typealiastype(get_origin(annotation)):
  299. annotation = _strip_annotated(annotation.__value__) # type: ignore[union-attr]
  300. if is_union_origin(get_origin(annotation)) and _union_is_complex(annotation, field.metadata):
  301. field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True))
  302. else:
  303. field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False))
  304. return field_info
  305. def _replace_field_names_case_insensitively(self, field: FieldInfo, field_values: dict[str, Any]) -> dict[str, Any]:
  306. """
  307. Replace field names in values dict by looking in models fields insensitively.
  308. By having the following models:
  309. ```py
  310. class SubSubSub(BaseModel):
  311. VaL3: str
  312. class SubSub(BaseModel):
  313. Val2: str
  314. SUB_sub_SuB: SubSubSub
  315. class Sub(BaseModel):
  316. VAL1: str
  317. SUB_sub: SubSub
  318. class Settings(BaseSettings):
  319. nested: Sub
  320. model_config = SettingsConfigDict(env_nested_delimiter='__')
  321. ```
  322. Then:
  323. _replace_field_names_case_insensitively(
  324. field,
  325. {"val1": "v1", "sub_SUB": {"VAL2": "v2", "sub_SUB_sUb": {"vAl3": "v3"}}}
  326. )
  327. Returns {'VAL1': 'v1', 'SUB_sub': {'Val2': 'v2', 'SUB_sub_SuB': {'VaL3': 'v3'}}}
  328. """
  329. values: dict[str, Any] = {}
  330. for name, value in field_values.items():
  331. sub_model_field: FieldInfo | None = None
  332. annotation = field.annotation
  333. # If field is Optional, we need to find the actual type
  334. if is_union_origin(get_origin(field.annotation)):
  335. args = get_args(annotation)
  336. if len(args) == 2 and type(None) in args:
  337. for arg in args:
  338. if arg is not None:
  339. annotation = arg
  340. break
  341. # This is here to make mypy happy
  342. # Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
  343. if not annotation or not hasattr(annotation, 'model_fields'):
  344. values[name] = value
  345. continue
  346. else:
  347. model_fields: dict[str, FieldInfo] = annotation.model_fields
  348. # Find field in sub model by looking in fields case insensitively
  349. field_key: str | None = None
  350. for sub_model_field_name, sub_model_field in model_fields.items():
  351. aliases, _ = _get_alias_names(sub_model_field_name, sub_model_field)
  352. _search = (alias for alias in aliases if alias.lower() == name.lower())
  353. if field_key := next(_search, None):
  354. break
  355. if not field_key:
  356. values[name] = value
  357. continue
  358. if (
  359. sub_model_field is not None
  360. and _lenient_issubclass(sub_model_field.annotation, BaseModel)
  361. and isinstance(value, dict)
  362. ):
  363. values[field_key] = self._replace_field_names_case_insensitively(sub_model_field, value)
  364. else:
  365. values[field_key] = value
  366. return values
  367. def _replace_env_none_type_values(self, field_value: dict[str, Any]) -> dict[str, Any]:
  368. """
  369. Recursively parse values that are of "None" type(EnvNoneType) to `None` type(None).
  370. """
  371. values: dict[str, Any] = {}
  372. for key, value in field_value.items():
  373. if not isinstance(value, EnvNoneType):
  374. values[key] = value if not isinstance(value, dict) else self._replace_env_none_type_values(value)
  375. else:
  376. values[key] = None
  377. return values
  378. def _get_resolved_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
  379. """
  380. Gets the value, the preferred alias key for model creation, and a flag to determine whether value
  381. is complex.
  382. Note:
  383. In V3, this method should either be made public, or, this method should be removed and the
  384. abstract method get_field_value should be updated to include a "use_preferred_alias" flag.
  385. Args:
  386. field: The field.
  387. field_name: The field name.
  388. Returns:
  389. A tuple that contains the value, preferred key and a flag to determine whether value is complex.
  390. """
  391. field_value, field_key, value_is_complex = self.get_field_value(field, field_name)
  392. if not (value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name))):
  393. field_infos = self._extract_field_info(field, field_name)
  394. preferred_key, *_ = field_infos[0]
  395. return field_value, preferred_key, value_is_complex
  396. return field_value, field_key, value_is_complex
  397. def __call__(self) -> dict[str, Any]:
  398. data: dict[str, Any] = {}
  399. for field_name, field in self.settings_cls.model_fields.items():
  400. try:
  401. field_value, field_key, value_is_complex = self._get_resolved_field_value(field, field_name)
  402. except Exception as e:
  403. raise SettingsError(
  404. f'error getting value for field "{field_name}" from source "{self.__class__.__name__}"'
  405. ) from e
  406. try:
  407. field_value = self.prepare_field_value(field_name, field, field_value, value_is_complex)
  408. except ValueError as e:
  409. raise SettingsError(
  410. f'error parsing value for field "{field_name}" from source "{self.__class__.__name__}"'
  411. ) from e
  412. if field_value is not None:
  413. if self.env_parse_none_str is not None:
  414. if isinstance(field_value, dict):
  415. field_value = self._replace_env_none_type_values(field_value)
  416. elif isinstance(field_value, EnvNoneType):
  417. field_value = None
  418. if (
  419. not self.case_sensitive
  420. # and _lenient_issubclass(field.annotation, BaseModel)
  421. and isinstance(field_value, dict)
  422. ):
  423. data[field_key] = self._replace_field_names_case_insensitively(field, field_value)
  424. else:
  425. data[field_key] = field_value
  426. return data
  427. __all__ = [
  428. 'ConfigFileSourceMixin',
  429. 'DefaultSettingsSource',
  430. 'InitSettingsSource',
  431. 'PydanticBaseEnvSettingsSource',
  432. 'PydanticBaseSettingsSource',
  433. 'SettingsError',
  434. ]