versioned.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. """DNS Versioned Zones."""
  3. import collections
  4. import threading
  5. from typing import Callable, Deque, Set, cast
  6. import dns.exception
  7. import dns.name
  8. import dns.node
  9. import dns.rdataclass
  10. import dns.rdataset
  11. import dns.rdatatype
  12. import dns.rdtypes.ANY.SOA
  13. import dns.zone
  14. class UseTransaction(dns.exception.DNSException):
  15. """To alter a versioned zone, use a transaction."""
  16. # Backwards compatibility
  17. Node = dns.zone.VersionedNode
  18. ImmutableNode = dns.zone.ImmutableVersionedNode
  19. Version = dns.zone.Version
  20. WritableVersion = dns.zone.WritableVersion
  21. ImmutableVersion = dns.zone.ImmutableVersion
  22. Transaction = dns.zone.Transaction
  23. class Zone(dns.zone.Zone): # lgtm[py/missing-equals]
  24. __slots__ = [
  25. "_versions",
  26. "_versions_lock",
  27. "_write_txn",
  28. "_write_waiters",
  29. "_write_event",
  30. "_pruning_policy",
  31. "_readers",
  32. ]
  33. node_factory: Callable[[], dns.node.Node] = Node
  34. def __init__(
  35. self,
  36. origin: dns.name.Name | str | None,
  37. rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
  38. relativize: bool = True,
  39. pruning_policy: Callable[["Zone", Version], bool | None] | None = None,
  40. ):
  41. """Initialize a versioned zone object.
  42. *origin* is the origin of the zone. It may be a ``dns.name.Name``,
  43. a ``str``, or ``None``. If ``None``, then the zone's origin will
  44. be set by the first ``$ORIGIN`` line in a zone file.
  45. *rdclass*, an ``int``, the zone's rdata class; the default is class IN.
  46. *relativize*, a ``bool``, determine's whether domain names are
  47. relativized to the zone's origin. The default is ``True``.
  48. *pruning policy*, a function taking a ``Zone`` and a ``Version`` and returning
  49. a ``bool``, or ``None``. Should the version be pruned? If ``None``,
  50. the default policy, which retains one version is used.
  51. """
  52. super().__init__(origin, rdclass, relativize)
  53. self._versions: Deque[Version] = collections.deque()
  54. self._version_lock = threading.Lock()
  55. if pruning_policy is None:
  56. self._pruning_policy = self._default_pruning_policy
  57. else:
  58. self._pruning_policy = pruning_policy
  59. self._write_txn: Transaction | None = None
  60. self._write_event: threading.Event | None = None
  61. self._write_waiters: Deque[threading.Event] = collections.deque()
  62. self._readers: Set[Transaction] = set()
  63. self._commit_version_unlocked(
  64. None, WritableVersion(self, replacement=True), origin
  65. )
  66. def reader(
  67. self, id: int | None = None, serial: int | None = None
  68. ) -> Transaction: # pylint: disable=arguments-differ
  69. if id is not None and serial is not None:
  70. raise ValueError("cannot specify both id and serial")
  71. with self._version_lock:
  72. if id is not None:
  73. version = None
  74. for v in reversed(self._versions):
  75. if v.id == id:
  76. version = v
  77. break
  78. if version is None:
  79. raise KeyError("version not found")
  80. elif serial is not None:
  81. if self.relativize:
  82. oname = dns.name.empty
  83. else:
  84. assert self.origin is not None
  85. oname = self.origin
  86. version = None
  87. for v in reversed(self._versions):
  88. n = v.nodes.get(oname)
  89. if n:
  90. rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA)
  91. if rds is None:
  92. continue
  93. soa = cast(dns.rdtypes.ANY.SOA.SOA, rds[0])
  94. if rds and soa.serial == serial:
  95. version = v
  96. break
  97. if version is None:
  98. raise KeyError("serial not found")
  99. else:
  100. version = self._versions[-1]
  101. txn = Transaction(self, False, version)
  102. self._readers.add(txn)
  103. return txn
  104. def writer(self, replacement: bool = False) -> Transaction:
  105. event = None
  106. while True:
  107. with self._version_lock:
  108. # Checking event == self._write_event ensures that either
  109. # no one was waiting before we got lucky and found no write
  110. # txn, or we were the one who was waiting and got woken up.
  111. # This prevents "taking cuts" when creating a write txn.
  112. if self._write_txn is None and event == self._write_event:
  113. # Creating the transaction defers version setup
  114. # (i.e. copying the nodes dictionary) until we
  115. # give up the lock, so that we hold the lock as
  116. # short a time as possible. This is why we call
  117. # _setup_version() below.
  118. self._write_txn = Transaction(
  119. self, replacement, make_immutable=True
  120. )
  121. # give up our exclusive right to make a Transaction
  122. self._write_event = None
  123. break
  124. # Someone else is writing already, so we will have to
  125. # wait, but we want to do the actual wait outside the
  126. # lock.
  127. event = threading.Event()
  128. self._write_waiters.append(event)
  129. # wait (note we gave up the lock!)
  130. #
  131. # We only wake one sleeper at a time, so it's important
  132. # that no event waiter can exit this method (e.g. via
  133. # cancellation) without returning a transaction or waking
  134. # someone else up.
  135. #
  136. # This is not a problem with Threading module threads as
  137. # they cannot be canceled, but could be an issue with trio
  138. # tasks when we do the async version of writer().
  139. # I.e. we'd need to do something like:
  140. #
  141. # try:
  142. # event.wait()
  143. # except trio.Cancelled:
  144. # with self._version_lock:
  145. # self._maybe_wakeup_one_waiter_unlocked()
  146. # raise
  147. #
  148. event.wait()
  149. # Do the deferred version setup.
  150. self._write_txn._setup_version()
  151. return self._write_txn
  152. def _maybe_wakeup_one_waiter_unlocked(self):
  153. if len(self._write_waiters) > 0:
  154. self._write_event = self._write_waiters.popleft()
  155. self._write_event.set()
  156. # pylint: disable=unused-argument
  157. def _default_pruning_policy(self, zone, version):
  158. return True
  159. # pylint: enable=unused-argument
  160. def _prune_versions_unlocked(self):
  161. assert len(self._versions) > 0
  162. # Don't ever prune a version greater than or equal to one that
  163. # a reader has open. This pins versions in memory while the
  164. # reader is open, and importantly lets the reader open a txn on
  165. # a successor version (e.g. if generating an IXFR).
  166. #
  167. # Note our definition of least_kept also ensures we do not try to
  168. # delete the greatest version.
  169. if len(self._readers) > 0:
  170. least_kept = min(txn.version.id for txn in self._readers) # pyright: ignore
  171. else:
  172. least_kept = self._versions[-1].id
  173. while self._versions[0].id < least_kept and self._pruning_policy(
  174. self, self._versions[0]
  175. ):
  176. self._versions.popleft()
  177. def set_max_versions(self, max_versions: int | None) -> None:
  178. """Set a pruning policy that retains up to the specified number
  179. of versions
  180. """
  181. if max_versions is not None and max_versions < 1:
  182. raise ValueError("max versions must be at least 1")
  183. if max_versions is None:
  184. # pylint: disable=unused-argument
  185. def policy(zone, _): # pyright: ignore
  186. return False
  187. else:
  188. def policy(zone, _):
  189. return len(zone._versions) > max_versions
  190. self.set_pruning_policy(policy)
  191. def set_pruning_policy(
  192. self, policy: Callable[["Zone", Version], bool | None] | None
  193. ) -> None:
  194. """Set the pruning policy for the zone.
  195. The *policy* function takes a `Version` and returns `True` if
  196. the version should be pruned, and `False` otherwise. `None`
  197. may also be specified for policy, in which case the default policy
  198. is used.
  199. Pruning checking proceeds from the least version and the first
  200. time the function returns `False`, the checking stops. I.e. the
  201. retained versions are always a consecutive sequence.
  202. """
  203. if policy is None:
  204. policy = self._default_pruning_policy
  205. with self._version_lock:
  206. self._pruning_policy = policy
  207. self._prune_versions_unlocked()
  208. def _end_read(self, txn):
  209. with self._version_lock:
  210. self._readers.remove(txn)
  211. self._prune_versions_unlocked()
  212. def _end_write_unlocked(self, txn):
  213. assert self._write_txn == txn
  214. self._write_txn = None
  215. self._maybe_wakeup_one_waiter_unlocked()
  216. def _end_write(self, txn):
  217. with self._version_lock:
  218. self._end_write_unlocked(txn)
  219. def _commit_version_unlocked(self, txn, version, origin):
  220. self._versions.append(version)
  221. self._prune_versions_unlocked()
  222. self.nodes = version.nodes
  223. if self.origin is None:
  224. self.origin = origin
  225. # txn can be None in __init__ when we make the empty version.
  226. if txn is not None:
  227. self._end_write_unlocked(txn)
  228. def _commit_version(self, txn, version, origin):
  229. with self._version_lock:
  230. self._commit_version_unlocked(txn, version, origin)
  231. def _get_next_version_id(self):
  232. if len(self._versions) > 0:
  233. id = self._versions[-1].id + 1
  234. else:
  235. id = 1
  236. return id
  237. def find_node(
  238. self, name: dns.name.Name | str, create: bool = False
  239. ) -> dns.node.Node:
  240. if create:
  241. raise UseTransaction
  242. return super().find_node(name)
  243. def delete_node(self, name: dns.name.Name | str) -> None:
  244. raise UseTransaction
  245. def find_rdataset(
  246. self,
  247. name: dns.name.Name | str,
  248. rdtype: dns.rdatatype.RdataType | str,
  249. covers: dns.rdatatype.RdataType | str = dns.rdatatype.NONE,
  250. create: bool = False,
  251. ) -> dns.rdataset.Rdataset:
  252. if create:
  253. raise UseTransaction
  254. rdataset = super().find_rdataset(name, rdtype, covers)
  255. return dns.rdataset.ImmutableRdataset(rdataset)
  256. def get_rdataset(
  257. self,
  258. name: dns.name.Name | str,
  259. rdtype: dns.rdatatype.RdataType | str,
  260. covers: dns.rdatatype.RdataType | str = dns.rdatatype.NONE,
  261. create: bool = False,
  262. ) -> dns.rdataset.Rdataset | None:
  263. if create:
  264. raise UseTransaction
  265. rdataset = super().get_rdataset(name, rdtype, covers)
  266. if rdataset is not None:
  267. return dns.rdataset.ImmutableRdataset(rdataset)
  268. else:
  269. return None
  270. def delete_rdataset(
  271. self,
  272. name: dns.name.Name | str,
  273. rdtype: dns.rdatatype.RdataType | str,
  274. covers: dns.rdatatype.RdataType | str = dns.rdatatype.NONE,
  275. ) -> None:
  276. raise UseTransaction
  277. def replace_rdataset(
  278. self, name: dns.name.Name | str, replacement: dns.rdataset.Rdataset
  279. ) -> None:
  280. raise UseTransaction