azure.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. """Azure Key Vault settings source."""
  2. from __future__ import annotations as _annotations
  3. from collections.abc import Iterator, Mapping
  4. from typing import TYPE_CHECKING, Optional
  5. from pydantic.alias_generators import to_snake
  6. from pydantic.fields import FieldInfo
  7. from .env import EnvSettingsSource
  8. if TYPE_CHECKING:
  9. from azure.core.credentials import TokenCredential
  10. from azure.core.exceptions import ResourceNotFoundError
  11. from azure.keyvault.secrets import SecretClient
  12. from pydantic_settings.main import BaseSettings
  13. else:
  14. TokenCredential = None
  15. ResourceNotFoundError = None
  16. SecretClient = None
  17. def import_azure_key_vault() -> None:
  18. global TokenCredential
  19. global SecretClient
  20. global ResourceNotFoundError
  21. try:
  22. from azure.core.credentials import TokenCredential
  23. from azure.core.exceptions import ResourceNotFoundError
  24. from azure.keyvault.secrets import SecretClient
  25. except ImportError as e: # pragma: no cover
  26. raise ImportError(
  27. 'Azure Key Vault dependencies are not installed, run `pip install pydantic-settings[azure-key-vault]`'
  28. ) from e
  29. class AzureKeyVaultMapping(Mapping[str, Optional[str]]):
  30. _loaded_secrets: dict[str, str | None]
  31. _secret_client: SecretClient
  32. _secret_names: list[str]
  33. def __init__(
  34. self,
  35. secret_client: SecretClient,
  36. case_sensitive: bool,
  37. snake_case_conversion: bool,
  38. ) -> None:
  39. self._loaded_secrets = {}
  40. self._secret_client = secret_client
  41. self._case_sensitive = case_sensitive
  42. self._snake_case_conversion = snake_case_conversion
  43. self._secret_map: dict[str, str] = self._load_remote()
  44. def _load_remote(self) -> dict[str, str]:
  45. secret_names: Iterator[str] = (
  46. secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled
  47. )
  48. if self._snake_case_conversion:
  49. return {to_snake(name): name for name in secret_names}
  50. if self._case_sensitive:
  51. return {name: name for name in secret_names}
  52. return {name.lower(): name for name in secret_names}
  53. def __getitem__(self, key: str) -> str | None:
  54. new_key = key
  55. if self._snake_case_conversion:
  56. new_key = to_snake(key)
  57. elif not self._case_sensitive:
  58. new_key = key.lower()
  59. if new_key not in self._loaded_secrets:
  60. if new_key in self._secret_map:
  61. self._loaded_secrets[new_key] = self._secret_client.get_secret(self._secret_map[new_key]).value
  62. else:
  63. raise KeyError(key)
  64. return self._loaded_secrets[new_key]
  65. def __len__(self) -> int:
  66. return len(self._secret_map)
  67. def __iter__(self) -> Iterator[str]:
  68. return iter(self._secret_map.keys())
  69. class AzureKeyVaultSettingsSource(EnvSettingsSource):
  70. _url: str
  71. _credential: TokenCredential
  72. def __init__(
  73. self,
  74. settings_cls: type[BaseSettings],
  75. url: str,
  76. credential: TokenCredential,
  77. dash_to_underscore: bool = False,
  78. case_sensitive: bool | None = None,
  79. snake_case_conversion: bool = False,
  80. env_prefix: str | None = None,
  81. env_parse_none_str: str | None = None,
  82. env_parse_enums: bool | None = None,
  83. ) -> None:
  84. import_azure_key_vault()
  85. self._url = url
  86. self._credential = credential
  87. self._dash_to_underscore = dash_to_underscore
  88. self._snake_case_conversion = snake_case_conversion
  89. super().__init__(
  90. settings_cls,
  91. case_sensitive=False if snake_case_conversion else case_sensitive,
  92. env_prefix=env_prefix,
  93. env_nested_delimiter='__' if snake_case_conversion else '--',
  94. env_ignore_empty=False,
  95. env_parse_none_str=env_parse_none_str,
  96. env_parse_enums=env_parse_enums,
  97. )
  98. def _load_env_vars(self) -> Mapping[str, Optional[str]]:
  99. secret_client = SecretClient(vault_url=self._url, credential=self._credential)
  100. return AzureKeyVaultMapping(
  101. secret_client=secret_client,
  102. case_sensitive=self.case_sensitive,
  103. snake_case_conversion=self._snake_case_conversion,
  104. )
  105. def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
  106. if self._snake_case_conversion:
  107. return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name))
  108. if self._dash_to_underscore:
  109. return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name))
  110. return super()._extract_field_info(field, field_name)
  111. def __repr__(self) -> str:
  112. return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
  113. __all__ = ['AzureKeyVaultMapping', 'AzureKeyVaultSettingsSource']