wire.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. import contextlib
  3. import struct
  4. from typing import Iterator, Optional, Tuple
  5. import dns.exception
  6. import dns.name
  7. class Parser:
  8. """Helper class for parsing DNS wire format."""
  9. def __init__(self, wire: bytes, current: int = 0):
  10. """Initialize a Parser
  11. *wire*, a ``bytes`` contains the data to be parsed, and possibly other data.
  12. Typically it is the whole message or a slice of it.
  13. *current*, an `int`, the offset within *wire* where parsing should begin.
  14. """
  15. self.wire = wire
  16. self.current = 0
  17. self.end = len(self.wire)
  18. if current:
  19. self.seek(current)
  20. self.furthest = current
  21. def remaining(self) -> int:
  22. return self.end - self.current
  23. def get_bytes(self, size: int) -> bytes:
  24. assert size >= 0
  25. if size > self.remaining():
  26. raise dns.exception.FormError
  27. output = self.wire[self.current : self.current + size]
  28. self.current += size
  29. self.furthest = max(self.furthest, self.current)
  30. return output
  31. def get_counted_bytes(self, length_size: int = 1) -> bytes:
  32. length = int.from_bytes(self.get_bytes(length_size), "big")
  33. return self.get_bytes(length)
  34. def get_remaining(self) -> bytes:
  35. return self.get_bytes(self.remaining())
  36. def get_uint8(self) -> int:
  37. return struct.unpack("!B", self.get_bytes(1))[0]
  38. def get_uint16(self) -> int:
  39. return struct.unpack("!H", self.get_bytes(2))[0]
  40. def get_uint32(self) -> int:
  41. return struct.unpack("!I", self.get_bytes(4))[0]
  42. def get_uint48(self) -> int:
  43. return int.from_bytes(self.get_bytes(6), "big")
  44. def get_struct(self, format: str) -> Tuple:
  45. return struct.unpack(format, self.get_bytes(struct.calcsize(format)))
  46. def get_name(self, origin: Optional["dns.name.Name"] = None) -> "dns.name.Name":
  47. name = dns.name.from_wire_parser(self)
  48. if origin:
  49. name = name.relativize(origin)
  50. return name
  51. def seek(self, where: int) -> None:
  52. # Note that seeking to the end is OK! (If you try to read
  53. # after such a seek, you'll get an exception as expected.)
  54. if where < 0 or where > self.end:
  55. raise dns.exception.FormError
  56. self.current = where
  57. @contextlib.contextmanager
  58. def restrict_to(self, size: int) -> Iterator:
  59. assert size >= 0
  60. if size > self.remaining():
  61. raise dns.exception.FormError
  62. saved_end = self.end
  63. try:
  64. self.end = self.current + size
  65. yield
  66. # We make this check here and not in the finally as we
  67. # don't want to raise if we're already raising for some
  68. # other reason.
  69. if self.current != self.end:
  70. raise dns.exception.FormError
  71. finally:
  72. self.end = saved_end
  73. @contextlib.contextmanager
  74. def restore_furthest(self) -> Iterator:
  75. try:
  76. yield None
  77. finally:
  78. self.current = self.furthest