yaml.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. """YAML file settings source."""
  2. from __future__ import annotations as _annotations
  3. from pathlib import Path
  4. from typing import (
  5. TYPE_CHECKING,
  6. Any,
  7. )
  8. from ..base import ConfigFileSourceMixin, InitSettingsSource
  9. from ..types import DEFAULT_PATH, PathType
  10. if TYPE_CHECKING:
  11. import yaml
  12. from pydantic_settings.main import BaseSettings
  13. else:
  14. yaml = None
  15. def import_yaml() -> None:
  16. global yaml
  17. if yaml is not None:
  18. return
  19. try:
  20. import yaml
  21. except ImportError as e:
  22. raise ImportError('PyYAML is not installed, run `pip install pydantic-settings[yaml]`') from e
  23. class YamlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
  24. """
  25. A source class that loads variables from a yaml file
  26. """
  27. def __init__(
  28. self,
  29. settings_cls: type[BaseSettings],
  30. yaml_file: PathType | None = DEFAULT_PATH,
  31. yaml_file_encoding: str | None = None,
  32. yaml_config_section: str | None = None,
  33. ):
  34. self.yaml_file_path = yaml_file if yaml_file != DEFAULT_PATH else settings_cls.model_config.get('yaml_file')
  35. self.yaml_file_encoding = (
  36. yaml_file_encoding
  37. if yaml_file_encoding is not None
  38. else settings_cls.model_config.get('yaml_file_encoding')
  39. )
  40. self.yaml_config_section = (
  41. yaml_config_section
  42. if yaml_config_section is not None
  43. else settings_cls.model_config.get('yaml_config_section')
  44. )
  45. self.yaml_data = self._read_files(self.yaml_file_path)
  46. if self.yaml_config_section:
  47. try:
  48. self.yaml_data = self.yaml_data[self.yaml_config_section]
  49. except KeyError:
  50. raise KeyError(
  51. f'yaml_config_section key "{self.yaml_config_section}" not found in {self.yaml_file_path}'
  52. )
  53. super().__init__(settings_cls, self.yaml_data)
  54. def _read_file(self, file_path: Path) -> dict[str, Any]:
  55. import_yaml()
  56. with open(file_path, encoding=self.yaml_file_encoding) as yaml_file:
  57. return yaml.safe_load(yaml_file) or {}
  58. def __repr__(self) -> str:
  59. return f'{self.__class__.__name__}(yaml_file={self.yaml_file_path})'
  60. __all__ = ['YamlConfigSettingsSource']