| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527 |
- """Base classes and core functionality for pydantic-settings sources."""
- from __future__ import annotations as _annotations
- import json
- import os
- from abc import ABC, abstractmethod
- from dataclasses import asdict, is_dataclass
- from pathlib import Path
- from typing import TYPE_CHECKING, Any, Optional, cast
- from pydantic import AliasChoices, AliasPath, BaseModel, TypeAdapter
- from pydantic._internal._typing_extra import ( # type: ignore[attr-defined]
- get_origin,
- )
- from pydantic._internal._utils import is_model_class
- from pydantic.fields import FieldInfo
- from typing_extensions import get_args
- from typing_inspection import typing_objects
- from typing_inspection.introspection import is_union_origin
- from ..exceptions import SettingsError
- from ..utils import _lenient_issubclass
- from .types import EnvNoneType, ForceDecode, NoDecode, PathType, PydanticModel, _CliSubCommand
- from .utils import (
- _annotation_is_complex,
- _get_alias_names,
- _get_model_fields,
- _strip_annotated,
- _union_is_complex,
- )
- if TYPE_CHECKING:
- from pydantic_settings.main import BaseSettings
- def get_subcommand(
- model: PydanticModel, is_required: bool = True, cli_exit_on_error: bool | None = None
- ) -> Optional[PydanticModel]:
- """
- Get the subcommand from a model.
- Args:
- model: The model to get the subcommand from.
- is_required: Determines whether a model must have subcommand set and raises error if not
- found. Defaults to `True`.
- cli_exit_on_error: Determines whether this function exits with error if no subcommand is found.
- Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`.
- Returns:
- The subcommand model if found, otherwise `None`.
- Raises:
- SystemExit: When no subcommand is found and is_required=`True` and cli_exit_on_error=`True`
- (the default).
- SettingsError: When no subcommand is found and is_required=`True` and
- cli_exit_on_error=`False`.
- """
- model_cls = type(model)
- if cli_exit_on_error is None and is_model_class(model_cls):
- model_default = model_cls.model_config.get('cli_exit_on_error')
- if isinstance(model_default, bool):
- cli_exit_on_error = model_default
- if cli_exit_on_error is None:
- cli_exit_on_error = True
- subcommands: list[str] = []
- for field_name, field_info in _get_model_fields(model_cls).items():
- if _CliSubCommand in field_info.metadata:
- if getattr(model, field_name) is not None:
- return getattr(model, field_name)
- subcommands.append(field_name)
- if is_required:
- error_message = (
- f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}'
- if subcommands
- else 'Error: CLI subcommand is required but no subcommands were found.'
- )
- raise SystemExit(error_message) if cli_exit_on_error else SettingsError(error_message)
- return None
- class PydanticBaseSettingsSource(ABC):
- """
- Abstract base class for settings sources, every settings source classes should inherit from it.
- """
- def __init__(self, settings_cls: type[BaseSettings]):
- self.settings_cls = settings_cls
- self.config = settings_cls.model_config
- self._current_state: dict[str, Any] = {}
- self._settings_sources_data: dict[str, dict[str, Any]] = {}
- def _set_current_state(self, state: dict[str, Any]) -> None:
- """
- Record the state of settings from the previous settings sources. This should
- be called right before __call__.
- """
- self._current_state = state
- def _set_settings_sources_data(self, states: dict[str, dict[str, Any]]) -> None:
- """
- Record the state of settings from all previous settings sources. This should
- be called right before __call__.
- """
- self._settings_sources_data = states
- @property
- def current_state(self) -> dict[str, Any]:
- """
- The current state of the settings, populated by the previous settings sources.
- """
- return self._current_state
- @property
- def settings_sources_data(self) -> dict[str, dict[str, Any]]:
- """
- The state of all previous settings sources.
- """
- return self._settings_sources_data
- @abstractmethod
- def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
- """
- Gets the value, the key for model creation, and a flag to determine whether value is complex.
- This is an abstract method that should be overridden in every settings source classes.
- Args:
- field: The field.
- field_name: The field name.
- Returns:
- A tuple that contains the value, key and a flag to determine whether value is complex.
- """
- pass
- def field_is_complex(self, field: FieldInfo) -> bool:
- """
- Checks whether a field is complex, in which case it will attempt to be parsed as JSON.
- Args:
- field: The field.
- Returns:
- Whether the field is complex.
- """
- return _annotation_is_complex(field.annotation, field.metadata)
- def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
- """
- Prepares the value of a field.
- Args:
- field_name: The field name.
- field: The field.
- value: The value of the field that has to be prepared.
- value_is_complex: A flag to determine whether value is complex.
- Returns:
- The prepared value.
- """
- if value is not None and (self.field_is_complex(field) or value_is_complex):
- return self.decode_complex_value(field_name, field, value)
- return value
- def decode_complex_value(self, field_name: str, field: FieldInfo, value: Any) -> Any:
- """
- Decode the value for a complex field
- Args:
- field_name: The field name.
- field: The field.
- value: The value of the field that has to be prepared.
- Returns:
- The decoded value for further preparation
- """
- if field and (
- NoDecode in field.metadata
- or (self.config.get('enable_decoding') is False and ForceDecode not in field.metadata)
- ):
- return value
- return json.loads(value)
- @abstractmethod
- def __call__(self) -> dict[str, Any]:
- pass
- class ConfigFileSourceMixin(ABC):
- def _read_files(self, files: PathType | None) -> dict[str, Any]:
- if files is None:
- return {}
- if isinstance(files, (str, os.PathLike)):
- files = [files]
- vars: dict[str, Any] = {}
- for file in files:
- file_path = Path(file).expanduser()
- if file_path.is_file():
- vars.update(self._read_file(file_path))
- return vars
- @abstractmethod
- def _read_file(self, path: Path) -> dict[str, Any]:
- pass
- class DefaultSettingsSource(PydanticBaseSettingsSource):
- """
- Source class for loading default object values.
- Args:
- settings_cls: The Settings class.
- nested_model_default_partial_update: Whether to allow partial updates on nested model default object fields.
- Defaults to `False`.
- """
- def __init__(self, settings_cls: type[BaseSettings], nested_model_default_partial_update: bool | None = None):
- super().__init__(settings_cls)
- self.defaults: dict[str, Any] = {}
- self.nested_model_default_partial_update = (
- nested_model_default_partial_update
- if nested_model_default_partial_update is not None
- else self.config.get('nested_model_default_partial_update', False)
- )
- if self.nested_model_default_partial_update:
- for field_name, field_info in settings_cls.model_fields.items():
- alias_names, *_ = _get_alias_names(field_name, field_info)
- preferred_alias = alias_names[0]
- if is_dataclass(type(field_info.default)):
- self.defaults[preferred_alias] = asdict(field_info.default)
- elif is_model_class(type(field_info.default)):
- self.defaults[preferred_alias] = field_info.default.model_dump()
- def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
- # Nothing to do here. Only implement the return statement to make mypy happy
- return None, '', False
- def __call__(self) -> dict[str, Any]:
- return self.defaults
- def __repr__(self) -> str:
- return (
- f'{self.__class__.__name__}(nested_model_default_partial_update={self.nested_model_default_partial_update})'
- )
- class InitSettingsSource(PydanticBaseSettingsSource):
- """
- Source class for loading values provided during settings class initialization.
- """
- def __init__(
- self,
- settings_cls: type[BaseSettings],
- init_kwargs: dict[str, Any],
- nested_model_default_partial_update: bool | None = None,
- ):
- self.init_kwargs = {}
- init_kwarg_names = set(init_kwargs.keys())
- for field_name, field_info in settings_cls.model_fields.items():
- alias_names, *_ = _get_alias_names(field_name, field_info)
- init_kwarg_name = init_kwarg_names & set(alias_names)
- if init_kwarg_name:
- preferred_alias = alias_names[0]
- preferred_set_alias = next(alias for alias in alias_names if alias in init_kwarg_name)
- init_kwarg_names -= init_kwarg_name
- self.init_kwargs[preferred_alias] = init_kwargs[preferred_set_alias]
- self.init_kwargs.update({key: val for key, val in init_kwargs.items() if key in init_kwarg_names})
- super().__init__(settings_cls)
- self.nested_model_default_partial_update = (
- nested_model_default_partial_update
- if nested_model_default_partial_update is not None
- else self.config.get('nested_model_default_partial_update', False)
- )
- def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
- # Nothing to do here. Only implement the return statement to make mypy happy
- return None, '', False
- def __call__(self) -> dict[str, Any]:
- return (
- TypeAdapter(dict[str, Any]).dump_python(self.init_kwargs)
- if self.nested_model_default_partial_update
- else self.init_kwargs
- )
- def __repr__(self) -> str:
- return f'{self.__class__.__name__}(init_kwargs={self.init_kwargs!r})'
- class PydanticBaseEnvSettingsSource(PydanticBaseSettingsSource):
- def __init__(
- self,
- settings_cls: type[BaseSettings],
- case_sensitive: bool | None = None,
- env_prefix: str | None = None,
- env_ignore_empty: bool | None = None,
- env_parse_none_str: str | None = None,
- env_parse_enums: bool | None = None,
- ) -> None:
- super().__init__(settings_cls)
- self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False)
- self.env_prefix = env_prefix if env_prefix is not None else self.config.get('env_prefix', '')
- self.env_ignore_empty = (
- env_ignore_empty if env_ignore_empty is not None else self.config.get('env_ignore_empty', False)
- )
- self.env_parse_none_str = (
- env_parse_none_str if env_parse_none_str is not None else self.config.get('env_parse_none_str')
- )
- self.env_parse_enums = env_parse_enums if env_parse_enums is not None else self.config.get('env_parse_enums')
- def _apply_case_sensitive(self, value: str) -> str:
- return value.lower() if not self.case_sensitive else value
- def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
- """
- Extracts field info. This info is used to get the value of field from environment variables.
- It returns a list of tuples, each tuple contains:
- * field_key: The key of field that has to be used in model creation.
- * env_name: The environment variable name of the field.
- * value_is_complex: A flag to determine whether the value from environment variable
- is complex and has to be parsed.
- Args:
- field (FieldInfo): The field.
- field_name (str): The field name.
- Returns:
- list[tuple[str, str, bool]]: List of tuples, each tuple contains field_key, env_name, and value_is_complex.
- """
- field_info: list[tuple[str, str, bool]] = []
- if isinstance(field.validation_alias, (AliasChoices, AliasPath)):
- v_alias: str | list[str | int] | list[list[str | int]] | None = field.validation_alias.convert_to_aliases()
- else:
- v_alias = field.validation_alias
- if v_alias:
- if isinstance(v_alias, list): # AliasChoices, AliasPath
- for alias in v_alias:
- if isinstance(alias, str): # AliasPath
- field_info.append((alias, self._apply_case_sensitive(alias), True if len(alias) > 1 else False))
- elif isinstance(alias, list): # AliasChoices
- first_arg = cast(str, alias[0]) # first item of an AliasChoices must be a str
- field_info.append(
- (first_arg, self._apply_case_sensitive(first_arg), True if len(alias) > 1 else False)
- )
- else: # string validation alias
- field_info.append((v_alias, self._apply_case_sensitive(v_alias), False))
- if not v_alias or self.config.get('populate_by_name', False):
- annotation = field.annotation
- if typing_objects.is_typealiastype(annotation) or typing_objects.is_typealiastype(get_origin(annotation)):
- annotation = _strip_annotated(annotation.__value__) # type: ignore[union-attr]
- if is_union_origin(get_origin(annotation)) and _union_is_complex(annotation, field.metadata):
- field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True))
- else:
- field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False))
- return field_info
- def _replace_field_names_case_insensitively(self, field: FieldInfo, field_values: dict[str, Any]) -> dict[str, Any]:
- """
- Replace field names in values dict by looking in models fields insensitively.
- By having the following models:
- ```py
- class SubSubSub(BaseModel):
- VaL3: str
- class SubSub(BaseModel):
- Val2: str
- SUB_sub_SuB: SubSubSub
- class Sub(BaseModel):
- VAL1: str
- SUB_sub: SubSub
- class Settings(BaseSettings):
- nested: Sub
- model_config = SettingsConfigDict(env_nested_delimiter='__')
- ```
- Then:
- _replace_field_names_case_insensitively(
- field,
- {"val1": "v1", "sub_SUB": {"VAL2": "v2", "sub_SUB_sUb": {"vAl3": "v3"}}}
- )
- Returns {'VAL1': 'v1', 'SUB_sub': {'Val2': 'v2', 'SUB_sub_SuB': {'VaL3': 'v3'}}}
- """
- values: dict[str, Any] = {}
- for name, value in field_values.items():
- sub_model_field: FieldInfo | None = None
- annotation = field.annotation
- # If field is Optional, we need to find the actual type
- if is_union_origin(get_origin(field.annotation)):
- args = get_args(annotation)
- if len(args) == 2 and type(None) in args:
- for arg in args:
- if arg is not None:
- annotation = arg
- break
- # This is here to make mypy happy
- # Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
- if not annotation or not hasattr(annotation, 'model_fields'):
- values[name] = value
- continue
- else:
- model_fields: dict[str, FieldInfo] = annotation.model_fields
- # Find field in sub model by looking in fields case insensitively
- field_key: str | None = None
- for sub_model_field_name, sub_model_field in model_fields.items():
- aliases, _ = _get_alias_names(sub_model_field_name, sub_model_field)
- _search = (alias for alias in aliases if alias.lower() == name.lower())
- if field_key := next(_search, None):
- break
- if not field_key:
- values[name] = value
- continue
- if (
- sub_model_field is not None
- and _lenient_issubclass(sub_model_field.annotation, BaseModel)
- and isinstance(value, dict)
- ):
- values[field_key] = self._replace_field_names_case_insensitively(sub_model_field, value)
- else:
- values[field_key] = value
- return values
- def _replace_env_none_type_values(self, field_value: dict[str, Any]) -> dict[str, Any]:
- """
- Recursively parse values that are of "None" type(EnvNoneType) to `None` type(None).
- """
- values: dict[str, Any] = {}
- for key, value in field_value.items():
- if not isinstance(value, EnvNoneType):
- values[key] = value if not isinstance(value, dict) else self._replace_env_none_type_values(value)
- else:
- values[key] = None
- return values
- def _get_resolved_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
- """
- Gets the value, the preferred alias key for model creation, and a flag to determine whether value
- is complex.
- Note:
- In V3, this method should either be made public, or, this method should be removed and the
- abstract method get_field_value should be updated to include a "use_preferred_alias" flag.
- Args:
- field: The field.
- field_name: The field name.
- Returns:
- A tuple that contains the value, preferred key and a flag to determine whether value is complex.
- """
- field_value, field_key, value_is_complex = self.get_field_value(field, field_name)
- if not (value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name))):
- field_infos = self._extract_field_info(field, field_name)
- preferred_key, *_ = field_infos[0]
- return field_value, preferred_key, value_is_complex
- return field_value, field_key, value_is_complex
- def __call__(self) -> dict[str, Any]:
- data: dict[str, Any] = {}
- for field_name, field in self.settings_cls.model_fields.items():
- try:
- field_value, field_key, value_is_complex = self._get_resolved_field_value(field, field_name)
- except Exception as e:
- raise SettingsError(
- f'error getting value for field "{field_name}" from source "{self.__class__.__name__}"'
- ) from e
- try:
- field_value = self.prepare_field_value(field_name, field, field_value, value_is_complex)
- except ValueError as e:
- raise SettingsError(
- f'error parsing value for field "{field_name}" from source "{self.__class__.__name__}"'
- ) from e
- if field_value is not None:
- if self.env_parse_none_str is not None:
- if isinstance(field_value, dict):
- field_value = self._replace_env_none_type_values(field_value)
- elif isinstance(field_value, EnvNoneType):
- field_value = None
- if (
- not self.case_sensitive
- # and _lenient_issubclass(field.annotation, BaseModel)
- and isinstance(field_value, dict)
- ):
- data[field_key] = self._replace_field_names_case_insensitively(field, field_value)
- else:
- data[field_key] = field_value
- return data
- __all__ = [
- 'ConfigFileSourceMixin',
- 'DefaultSettingsSource',
- 'InitSettingsSource',
- 'PydanticBaseEnvSettingsSource',
- 'PydanticBaseSettingsSource',
- 'SettingsError',
- ]
|