aws.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from __future__ import annotations as _annotations # important for BaseSettings import to work
  2. import json
  3. from collections.abc import Mapping
  4. from typing import TYPE_CHECKING, Optional
  5. from ..utils import parse_env_vars
  6. from .env import EnvSettingsSource
  7. if TYPE_CHECKING:
  8. from pydantic_settings.main import BaseSettings
  9. boto3_client = None
  10. SecretsManagerClient = None
  11. def import_aws_secrets_manager() -> None:
  12. global boto3_client
  13. global SecretsManagerClient
  14. try:
  15. from boto3 import client as boto3_client
  16. from mypy_boto3_secretsmanager.client import SecretsManagerClient
  17. except ImportError as e: # pragma: no cover
  18. raise ImportError(
  19. 'AWS Secrets Manager dependencies are not installed, run `pip install pydantic-settings[aws-secrets-manager]`'
  20. ) from e
  21. class AWSSecretsManagerSettingsSource(EnvSettingsSource):
  22. _secret_id: str
  23. _secretsmanager_client: SecretsManagerClient # type: ignore
  24. def __init__(
  25. self,
  26. settings_cls: type[BaseSettings],
  27. secret_id: str,
  28. region_name: str | None = None,
  29. endpoint_url: str | None = None,
  30. case_sensitive: bool | None = True,
  31. env_prefix: str | None = None,
  32. env_nested_delimiter: str | None = '--',
  33. env_parse_none_str: str | None = None,
  34. env_parse_enums: bool | None = None,
  35. ) -> None:
  36. import_aws_secrets_manager()
  37. self._secretsmanager_client = boto3_client('secretsmanager', region_name=region_name, endpoint_url=endpoint_url) # type: ignore
  38. self._secret_id = secret_id
  39. super().__init__(
  40. settings_cls,
  41. case_sensitive=case_sensitive,
  42. env_prefix=env_prefix,
  43. env_nested_delimiter=env_nested_delimiter,
  44. env_ignore_empty=False,
  45. env_parse_none_str=env_parse_none_str,
  46. env_parse_enums=env_parse_enums,
  47. )
  48. def _load_env_vars(self) -> Mapping[str, Optional[str]]:
  49. response = self._secretsmanager_client.get_secret_value(SecretId=self._secret_id) # type: ignore
  50. return parse_env_vars(
  51. json.loads(response['SecretString']),
  52. self.case_sensitive,
  53. self.env_ignore_empty,
  54. self.env_parse_none_str,
  55. )
  56. def __repr__(self) -> str:
  57. return (
  58. f'{self.__class__.__name__}(secret_id={self._secret_id!r}, '
  59. f'env_nested_delimiter={self.env_nested_delimiter!r})'
  60. )
  61. __all__ = [
  62. 'AWSSecretsManagerSettingsSource',
  63. ]