transaction.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. import collections
  3. from typing import Any, Callable, Iterator, List, Tuple
  4. import dns.exception
  5. import dns.name
  6. import dns.node
  7. import dns.rdata
  8. import dns.rdataclass
  9. import dns.rdataset
  10. import dns.rdatatype
  11. import dns.rrset
  12. import dns.serial
  13. import dns.ttl
  14. class TransactionManager:
  15. def reader(self) -> "Transaction":
  16. """Begin a read-only transaction."""
  17. raise NotImplementedError # pragma: no cover
  18. def writer(self, replacement: bool = False) -> "Transaction":
  19. """Begin a writable transaction.
  20. *replacement*, a ``bool``. If `True`, the content of the
  21. transaction completely replaces any prior content. If False,
  22. the default, then the content of the transaction updates the
  23. existing content.
  24. """
  25. raise NotImplementedError # pragma: no cover
  26. def origin_information(
  27. self,
  28. ) -> Tuple[dns.name.Name | None, bool, dns.name.Name | None]:
  29. """Returns a tuple
  30. (absolute_origin, relativize, effective_origin)
  31. giving the absolute name of the default origin for any
  32. relative domain names, the "effective origin", and whether
  33. names should be relativized. The "effective origin" is the
  34. absolute origin if relativize is False, and the empty name if
  35. relativize is true. (The effective origin is provided even
  36. though it can be computed from the absolute_origin and
  37. relativize setting because it avoids a lot of code
  38. duplication.)
  39. If the returned names are `None`, then no origin information is
  40. available.
  41. This information is used by code working with transactions to
  42. allow it to coordinate relativization. The transaction code
  43. itself takes what it gets (i.e. does not change name
  44. relativity).
  45. """
  46. raise NotImplementedError # pragma: no cover
  47. def get_class(self) -> dns.rdataclass.RdataClass:
  48. """The class of the transaction manager."""
  49. raise NotImplementedError # pragma: no cover
  50. def from_wire_origin(self) -> dns.name.Name | None:
  51. """Origin to use in from_wire() calls."""
  52. (absolute_origin, relativize, _) = self.origin_information()
  53. if relativize:
  54. return absolute_origin
  55. else:
  56. return None
  57. class DeleteNotExact(dns.exception.DNSException):
  58. """Existing data did not match data specified by an exact delete."""
  59. class ReadOnly(dns.exception.DNSException):
  60. """Tried to write to a read-only transaction."""
  61. class AlreadyEnded(dns.exception.DNSException):
  62. """Tried to use an already-ended transaction."""
  63. def _ensure_immutable_rdataset(rdataset):
  64. if rdataset is None or isinstance(rdataset, dns.rdataset.ImmutableRdataset):
  65. return rdataset
  66. return dns.rdataset.ImmutableRdataset(rdataset)
  67. def _ensure_immutable_node(node):
  68. if node is None or node.is_immutable():
  69. return node
  70. return dns.node.ImmutableNode(node)
  71. CheckPutRdatasetType = Callable[
  72. ["Transaction", dns.name.Name, dns.rdataset.Rdataset], None
  73. ]
  74. CheckDeleteRdatasetType = Callable[
  75. ["Transaction", dns.name.Name, dns.rdatatype.RdataType, dns.rdatatype.RdataType],
  76. None,
  77. ]
  78. CheckDeleteNameType = Callable[["Transaction", dns.name.Name], None]
  79. class Transaction:
  80. def __init__(
  81. self,
  82. manager: TransactionManager,
  83. replacement: bool = False,
  84. read_only: bool = False,
  85. ):
  86. self.manager = manager
  87. self.replacement = replacement
  88. self.read_only = read_only
  89. self._ended = False
  90. self._check_put_rdataset: List[CheckPutRdatasetType] = []
  91. self._check_delete_rdataset: List[CheckDeleteRdatasetType] = []
  92. self._check_delete_name: List[CheckDeleteNameType] = []
  93. #
  94. # This is the high level API
  95. #
  96. # Note that we currently use non-immutable types in the return type signature to
  97. # avoid covariance problems, e.g. if the caller has a List[Rdataset], mypy will be
  98. # unhappy if we return an ImmutableRdataset.
  99. def get(
  100. self,
  101. name: dns.name.Name | str | None,
  102. rdtype: dns.rdatatype.RdataType | str,
  103. covers: dns.rdatatype.RdataType | str = dns.rdatatype.NONE,
  104. ) -> dns.rdataset.Rdataset:
  105. """Return the rdataset associated with *name*, *rdtype*, and *covers*,
  106. or `None` if not found.
  107. Note that the returned rdataset is immutable.
  108. """
  109. self._check_ended()
  110. if isinstance(name, str):
  111. name = dns.name.from_text(name, None)
  112. rdtype = dns.rdatatype.RdataType.make(rdtype)
  113. covers = dns.rdatatype.RdataType.make(covers)
  114. rdataset = self._get_rdataset(name, rdtype, covers)
  115. return _ensure_immutable_rdataset(rdataset)
  116. def get_node(self, name: dns.name.Name) -> dns.node.Node | None:
  117. """Return the node at *name*, if any.
  118. Returns an immutable node or ``None``.
  119. """
  120. return _ensure_immutable_node(self._get_node(name))
  121. def _check_read_only(self) -> None:
  122. if self.read_only:
  123. raise ReadOnly
  124. def add(self, *args: Any) -> None:
  125. """Add records.
  126. The arguments may be:
  127. - rrset
  128. - name, rdataset...
  129. - name, ttl, rdata...
  130. """
  131. self._check_ended()
  132. self._check_read_only()
  133. self._add(False, args)
  134. def replace(self, *args: Any) -> None:
  135. """Replace the existing rdataset at the name with the specified
  136. rdataset, or add the specified rdataset if there was no existing
  137. rdataset.
  138. The arguments may be:
  139. - rrset
  140. - name, rdataset...
  141. - name, ttl, rdata...
  142. Note that if you want to replace the entire node, you should do
  143. a delete of the name followed by one or more calls to add() or
  144. replace().
  145. """
  146. self._check_ended()
  147. self._check_read_only()
  148. self._add(True, args)
  149. def delete(self, *args: Any) -> None:
  150. """Delete records.
  151. It is not an error if some of the records are not in the existing
  152. set.
  153. The arguments may be:
  154. - rrset
  155. - name
  156. - name, rdatatype, [covers]
  157. - name, rdataset...
  158. - name, rdata...
  159. """
  160. self._check_ended()
  161. self._check_read_only()
  162. self._delete(False, args)
  163. def delete_exact(self, *args: Any) -> None:
  164. """Delete records.
  165. The arguments may be:
  166. - rrset
  167. - name
  168. - name, rdatatype, [covers]
  169. - name, rdataset...
  170. - name, rdata...
  171. Raises dns.transaction.DeleteNotExact if some of the records
  172. are not in the existing set.
  173. """
  174. self._check_ended()
  175. self._check_read_only()
  176. self._delete(True, args)
  177. def name_exists(self, name: dns.name.Name | str) -> bool:
  178. """Does the specified name exist?"""
  179. self._check_ended()
  180. if isinstance(name, str):
  181. name = dns.name.from_text(name, None)
  182. return self._name_exists(name)
  183. def update_serial(
  184. self,
  185. value: int = 1,
  186. relative: bool = True,
  187. name: dns.name.Name = dns.name.empty,
  188. ) -> None:
  189. """Update the serial number.
  190. *value*, an `int`, is an increment if *relative* is `True`, or the
  191. actual value to set if *relative* is `False`.
  192. Raises `KeyError` if there is no SOA rdataset at *name*.
  193. Raises `ValueError` if *value* is negative or if the increment is
  194. so large that it would cause the new serial to be less than the
  195. prior value.
  196. """
  197. self._check_ended()
  198. if value < 0:
  199. raise ValueError("negative update_serial() value")
  200. if isinstance(name, str):
  201. name = dns.name.from_text(name, None)
  202. rdataset = self._get_rdataset(name, dns.rdatatype.SOA, dns.rdatatype.NONE)
  203. if rdataset is None or len(rdataset) == 0:
  204. raise KeyError
  205. if relative:
  206. serial = dns.serial.Serial(rdataset[0].serial) + value
  207. else:
  208. serial = dns.serial.Serial(value)
  209. serial = serial.value # convert back to int
  210. if serial == 0:
  211. serial = 1
  212. rdata = rdataset[0].replace(serial=serial)
  213. new_rdataset = dns.rdataset.from_rdata(rdataset.ttl, rdata)
  214. self.replace(name, new_rdataset)
  215. def __iter__(self):
  216. self._check_ended()
  217. return self._iterate_rdatasets()
  218. def changed(self) -> bool:
  219. """Has this transaction changed anything?
  220. For read-only transactions, the result is always `False`.
  221. For writable transactions, the result is `True` if at some time
  222. during the life of the transaction, the content was changed.
  223. """
  224. self._check_ended()
  225. return self._changed()
  226. def commit(self) -> None:
  227. """Commit the transaction.
  228. Normally transactions are used as context managers and commit
  229. or rollback automatically, but it may be done explicitly if needed.
  230. A ``dns.transaction.Ended`` exception will be raised if you try
  231. to use a transaction after it has been committed or rolled back.
  232. Raises an exception if the commit fails (in which case the transaction
  233. is also rolled back.
  234. """
  235. self._end(True)
  236. def rollback(self) -> None:
  237. """Rollback the transaction.
  238. Normally transactions are used as context managers and commit
  239. or rollback automatically, but it may be done explicitly if needed.
  240. A ``dns.transaction.AlreadyEnded`` exception will be raised if you try
  241. to use a transaction after it has been committed or rolled back.
  242. Rollback cannot otherwise fail.
  243. """
  244. self._end(False)
  245. def check_put_rdataset(self, check: CheckPutRdatasetType) -> None:
  246. """Call *check* before putting (storing) an rdataset.
  247. The function is called with the transaction, the name, and the rdataset.
  248. The check function may safely make non-mutating transaction method
  249. calls, but behavior is undefined if mutating transaction methods are
  250. called. The check function should raise an exception if it objects to
  251. the put, and otherwise should return ``None``.
  252. """
  253. self._check_put_rdataset.append(check)
  254. def check_delete_rdataset(self, check: CheckDeleteRdatasetType) -> None:
  255. """Call *check* before deleting an rdataset.
  256. The function is called with the transaction, the name, the rdatatype,
  257. and the covered rdatatype.
  258. The check function may safely make non-mutating transaction method
  259. calls, but behavior is undefined if mutating transaction methods are
  260. called. The check function should raise an exception if it objects to
  261. the put, and otherwise should return ``None``.
  262. """
  263. self._check_delete_rdataset.append(check)
  264. def check_delete_name(self, check: CheckDeleteNameType) -> None:
  265. """Call *check* before putting (storing) an rdataset.
  266. The function is called with the transaction and the name.
  267. The check function may safely make non-mutating transaction method
  268. calls, but behavior is undefined if mutating transaction methods are
  269. called. The check function should raise an exception if it objects to
  270. the put, and otherwise should return ``None``.
  271. """
  272. self._check_delete_name.append(check)
  273. def iterate_rdatasets(
  274. self,
  275. ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]:
  276. """Iterate all the rdatasets in the transaction, returning
  277. (`dns.name.Name`, `dns.rdataset.Rdataset`) tuples.
  278. Note that as is usual with python iterators, adding or removing items
  279. while iterating will invalidate the iterator and may raise `RuntimeError`
  280. or fail to iterate over all entries."""
  281. self._check_ended()
  282. return self._iterate_rdatasets()
  283. def iterate_names(self) -> Iterator[dns.name.Name]:
  284. """Iterate all the names in the transaction.
  285. Note that as is usual with python iterators, adding or removing names
  286. while iterating will invalidate the iterator and may raise `RuntimeError`
  287. or fail to iterate over all entries."""
  288. self._check_ended()
  289. return self._iterate_names()
  290. #
  291. # Helper methods
  292. #
  293. def _raise_if_not_empty(self, method, args):
  294. if len(args) != 0:
  295. raise TypeError(f"extra parameters to {method}")
  296. def _rdataset_from_args(self, method, deleting, args):
  297. try:
  298. arg = args.popleft()
  299. if isinstance(arg, dns.rrset.RRset):
  300. rdataset = arg.to_rdataset()
  301. elif isinstance(arg, dns.rdataset.Rdataset):
  302. rdataset = arg
  303. else:
  304. if deleting:
  305. ttl = 0
  306. else:
  307. if isinstance(arg, int):
  308. ttl = arg
  309. if ttl > dns.ttl.MAX_TTL:
  310. raise ValueError(f"{method}: TTL value too big")
  311. else:
  312. raise TypeError(f"{method}: expected a TTL")
  313. arg = args.popleft()
  314. if isinstance(arg, dns.rdata.Rdata):
  315. rdataset = dns.rdataset.from_rdata(ttl, arg)
  316. else:
  317. raise TypeError(f"{method}: expected an Rdata")
  318. return rdataset
  319. except IndexError:
  320. if deleting:
  321. return None
  322. else:
  323. # reraise
  324. raise TypeError(f"{method}: expected more arguments")
  325. def _add(self, replace, args):
  326. if replace:
  327. method = "replace()"
  328. else:
  329. method = "add()"
  330. try:
  331. args = collections.deque(args)
  332. arg = args.popleft()
  333. if isinstance(arg, str):
  334. arg = dns.name.from_text(arg, None)
  335. if isinstance(arg, dns.name.Name):
  336. name = arg
  337. rdataset = self._rdataset_from_args(method, False, args)
  338. elif isinstance(arg, dns.rrset.RRset):
  339. rrset = arg
  340. name = rrset.name
  341. # rrsets are also rdatasets, but they don't print the
  342. # same and can't be stored in nodes, so convert.
  343. rdataset = rrset.to_rdataset()
  344. else:
  345. raise TypeError(
  346. f"{method} requires a name or RRset as the first argument"
  347. )
  348. assert rdataset is not None # for type checkers
  349. if rdataset.rdclass != self.manager.get_class():
  350. raise ValueError(f"{method} has objects of wrong RdataClass")
  351. if rdataset.rdtype == dns.rdatatype.SOA:
  352. (_, _, origin) = self._origin_information()
  353. if name != origin:
  354. raise ValueError(f"{method} has non-origin SOA")
  355. self._raise_if_not_empty(method, args)
  356. if not replace:
  357. existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers)
  358. if existing is not None:
  359. if isinstance(existing, dns.rdataset.ImmutableRdataset):
  360. trds = dns.rdataset.Rdataset(
  361. existing.rdclass, existing.rdtype, existing.covers
  362. )
  363. trds.update(existing)
  364. existing = trds
  365. rdataset = existing.union(rdataset)
  366. self._checked_put_rdataset(name, rdataset)
  367. except IndexError:
  368. raise TypeError(f"not enough parameters to {method}")
  369. def _delete(self, exact, args):
  370. if exact:
  371. method = "delete_exact()"
  372. else:
  373. method = "delete()"
  374. try:
  375. args = collections.deque(args)
  376. arg = args.popleft()
  377. if isinstance(arg, str):
  378. arg = dns.name.from_text(arg, None)
  379. if isinstance(arg, dns.name.Name):
  380. name = arg
  381. if len(args) > 0 and (
  382. isinstance(args[0], int) or isinstance(args[0], str)
  383. ):
  384. # deleting by type and (optionally) covers
  385. rdtype = dns.rdatatype.RdataType.make(args.popleft())
  386. if len(args) > 0:
  387. covers = dns.rdatatype.RdataType.make(args.popleft())
  388. else:
  389. covers = dns.rdatatype.NONE
  390. self._raise_if_not_empty(method, args)
  391. existing = self._get_rdataset(name, rdtype, covers)
  392. if existing is None:
  393. if exact:
  394. raise DeleteNotExact(f"{method}: missing rdataset")
  395. else:
  396. self._checked_delete_rdataset(name, rdtype, covers)
  397. return
  398. else:
  399. rdataset = self._rdataset_from_args(method, True, args)
  400. elif isinstance(arg, dns.rrset.RRset):
  401. rdataset = arg # rrsets are also rdatasets
  402. name = rdataset.name
  403. else:
  404. raise TypeError(
  405. f"{method} requires a name or RRset as the first argument"
  406. )
  407. self._raise_if_not_empty(method, args)
  408. if rdataset:
  409. if rdataset.rdclass != self.manager.get_class():
  410. raise ValueError(f"{method} has objects of wrong RdataClass")
  411. existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers)
  412. if existing is not None:
  413. if exact:
  414. intersection = existing.intersection(rdataset)
  415. if intersection != rdataset:
  416. raise DeleteNotExact(f"{method}: missing rdatas")
  417. rdataset = existing.difference(rdataset)
  418. if len(rdataset) == 0:
  419. self._checked_delete_rdataset(
  420. name, rdataset.rdtype, rdataset.covers
  421. )
  422. else:
  423. self._checked_put_rdataset(name, rdataset)
  424. elif exact:
  425. raise DeleteNotExact(f"{method}: missing rdataset")
  426. else:
  427. if exact and not self._name_exists(name):
  428. raise DeleteNotExact(f"{method}: name not known")
  429. self._checked_delete_name(name)
  430. except IndexError:
  431. raise TypeError(f"not enough parameters to {method}")
  432. def _check_ended(self):
  433. if self._ended:
  434. raise AlreadyEnded
  435. def _end(self, commit):
  436. self._check_ended()
  437. try:
  438. self._end_transaction(commit)
  439. finally:
  440. self._ended = True
  441. def _checked_put_rdataset(self, name, rdataset):
  442. for check in self._check_put_rdataset:
  443. check(self, name, rdataset)
  444. self._put_rdataset(name, rdataset)
  445. def _checked_delete_rdataset(self, name, rdtype, covers):
  446. for check in self._check_delete_rdataset:
  447. check(self, name, rdtype, covers)
  448. self._delete_rdataset(name, rdtype, covers)
  449. def _checked_delete_name(self, name):
  450. for check in self._check_delete_name:
  451. check(self, name)
  452. self._delete_name(name)
  453. #
  454. # Transactions are context managers.
  455. #
  456. def __enter__(self):
  457. return self
  458. def __exit__(self, exc_type, exc_val, exc_tb):
  459. if not self._ended:
  460. if exc_type is None:
  461. self.commit()
  462. else:
  463. self.rollback()
  464. return False
  465. #
  466. # This is the low level API, which must be implemented by subclasses
  467. # of Transaction.
  468. #
  469. def _get_rdataset(self, name, rdtype, covers):
  470. """Return the rdataset associated with *name*, *rdtype*, and *covers*,
  471. or `None` if not found.
  472. """
  473. raise NotImplementedError # pragma: no cover
  474. def _put_rdataset(self, name, rdataset):
  475. """Store the rdataset."""
  476. raise NotImplementedError # pragma: no cover
  477. def _delete_name(self, name):
  478. """Delete all data associated with *name*.
  479. It is not an error if the name does not exist.
  480. """
  481. raise NotImplementedError # pragma: no cover
  482. def _delete_rdataset(self, name, rdtype, covers):
  483. """Delete all data associated with *name*, *rdtype*, and *covers*.
  484. It is not an error if the rdataset does not exist.
  485. """
  486. raise NotImplementedError # pragma: no cover
  487. def _name_exists(self, name):
  488. """Does name exist?
  489. Returns a bool.
  490. """
  491. raise NotImplementedError # pragma: no cover
  492. def _changed(self):
  493. """Has this transaction changed anything?"""
  494. raise NotImplementedError # pragma: no cover
  495. def _end_transaction(self, commit):
  496. """End the transaction.
  497. *commit*, a bool. If ``True``, commit the transaction, otherwise
  498. roll it back.
  499. If committing and the commit fails, then roll back and raise an
  500. exception.
  501. """
  502. raise NotImplementedError # pragma: no cover
  503. def _set_origin(self, origin):
  504. """Set the origin.
  505. This method is called when reading a possibly relativized
  506. source, and an origin setting operation occurs (e.g. $ORIGIN
  507. in a zone file).
  508. """
  509. raise NotImplementedError # pragma: no cover
  510. def _iterate_rdatasets(self):
  511. """Return an iterator that yields (name, rdataset) tuples."""
  512. raise NotImplementedError # pragma: no cover
  513. def _iterate_names(self):
  514. """Return an iterator that yields a name."""
  515. raise NotImplementedError # pragma: no cover
  516. def _get_node(self, name):
  517. """Return the node at *name*, if any.
  518. Returns a node or ``None``.
  519. """
  520. raise NotImplementedError # pragma: no cover
  521. #
  522. # Low-level API with a default implementation, in case a subclass needs
  523. # to override.
  524. #
  525. def _origin_information(self):
  526. # This is only used by _add()
  527. return self.manager.origin_information()