query.py 60 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. # Copyright (C) 2003-2017 Nominum, Inc.
  3. #
  4. # Permission to use, copy, modify, and distribute this software and its
  5. # documentation for any purpose with or without fee is hereby granted,
  6. # provided that the above copyright notice and this permission notice
  7. # appear in all copies.
  8. #
  9. # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
  10. # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  11. # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
  12. # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  13. # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  14. # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
  15. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  16. """Talk to a DNS server."""
  17. import base64
  18. import contextlib
  19. import enum
  20. import errno
  21. import os
  22. import random
  23. import selectors
  24. import socket
  25. import struct
  26. import time
  27. import urllib.parse
  28. from typing import Any, Callable, Dict, Optional, Tuple, cast
  29. import dns._features
  30. import dns._tls_util
  31. import dns.exception
  32. import dns.inet
  33. import dns.message
  34. import dns.name
  35. import dns.quic
  36. import dns.rdata
  37. import dns.rdataclass
  38. import dns.rdatatype
  39. import dns.transaction
  40. import dns.tsig
  41. import dns.xfr
  42. try:
  43. import ssl
  44. except ImportError:
  45. import dns._no_ssl as ssl # type: ignore
  46. def _remaining(expiration):
  47. if expiration is None:
  48. return None
  49. timeout = expiration - time.time()
  50. if timeout <= 0.0:
  51. raise dns.exception.Timeout
  52. return timeout
  53. def _expiration_for_this_attempt(timeout, expiration):
  54. if expiration is None:
  55. return None
  56. return min(time.time() + timeout, expiration)
  57. _have_httpx = dns._features.have("doh")
  58. if _have_httpx:
  59. import httpcore._backends.sync
  60. import httpx
  61. _CoreNetworkBackend = httpcore.NetworkBackend
  62. _CoreSyncStream = httpcore._backends.sync.SyncStream
  63. class _NetworkBackend(_CoreNetworkBackend):
  64. def __init__(self, resolver, local_port, bootstrap_address, family):
  65. super().__init__()
  66. self._local_port = local_port
  67. self._resolver = resolver
  68. self._bootstrap_address = bootstrap_address
  69. self._family = family
  70. def connect_tcp(
  71. self, host, port, timeout=None, local_address=None, socket_options=None
  72. ): # pylint: disable=signature-differs
  73. addresses = []
  74. _, expiration = _compute_times(timeout)
  75. if dns.inet.is_address(host):
  76. addresses.append(host)
  77. elif self._bootstrap_address is not None:
  78. addresses.append(self._bootstrap_address)
  79. else:
  80. timeout = _remaining(expiration)
  81. family = self._family
  82. if local_address:
  83. family = dns.inet.af_for_address(local_address)
  84. answers = self._resolver.resolve_name(
  85. host, family=family, lifetime=timeout
  86. )
  87. addresses = answers.addresses()
  88. for address in addresses:
  89. af = dns.inet.af_for_address(address)
  90. if local_address is not None or self._local_port != 0:
  91. if local_address is None:
  92. local_address = "0.0.0.0"
  93. source = dns.inet.low_level_address_tuple(
  94. (local_address, self._local_port), af
  95. )
  96. else:
  97. source = None
  98. try:
  99. sock = make_socket(af, socket.SOCK_STREAM, source)
  100. attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
  101. _connect(
  102. sock,
  103. dns.inet.low_level_address_tuple((address, port), af),
  104. attempt_expiration,
  105. )
  106. return _CoreSyncStream(sock)
  107. except Exception:
  108. pass
  109. raise httpcore.ConnectError
  110. def connect_unix_socket(
  111. self, path, timeout=None, socket_options=None
  112. ): # pylint: disable=signature-differs
  113. raise NotImplementedError
  114. class _HTTPTransport(httpx.HTTPTransport): # pyright: ignore
  115. def __init__(
  116. self,
  117. *args,
  118. local_port=0,
  119. bootstrap_address=None,
  120. resolver=None,
  121. family=socket.AF_UNSPEC,
  122. **kwargs,
  123. ):
  124. if resolver is None and bootstrap_address is None:
  125. # pylint: disable=import-outside-toplevel,redefined-outer-name
  126. import dns.resolver
  127. resolver = dns.resolver.Resolver()
  128. super().__init__(*args, **kwargs)
  129. self._pool._network_backend = _NetworkBackend(
  130. resolver, local_port, bootstrap_address, family
  131. )
  132. else:
  133. class _HTTPTransport: # type: ignore
  134. def __init__(
  135. self,
  136. *args,
  137. local_port=0,
  138. bootstrap_address=None,
  139. resolver=None,
  140. family=socket.AF_UNSPEC,
  141. **kwargs,
  142. ):
  143. pass
  144. def connect_tcp(self, host, port, timeout, local_address):
  145. raise NotImplementedError
  146. have_doh = _have_httpx
  147. def default_socket_factory(
  148. af: socket.AddressFamily | int,
  149. kind: socket.SocketKind,
  150. proto: int,
  151. ) -> socket.socket:
  152. return socket.socket(af, kind, proto)
  153. # Function used to create a socket. Can be overridden if needed in special
  154. # situations.
  155. socket_factory: Callable[
  156. [socket.AddressFamily | int, socket.SocketKind, int], socket.socket
  157. ] = default_socket_factory
  158. class UnexpectedSource(dns.exception.DNSException):
  159. """A DNS query response came from an unexpected address or port."""
  160. class BadResponse(dns.exception.FormError):
  161. """A DNS query response does not respond to the question asked."""
  162. class NoDOH(dns.exception.DNSException):
  163. """DNS over HTTPS (DOH) was requested but the httpx module is not
  164. available."""
  165. class NoDOQ(dns.exception.DNSException):
  166. """DNS over QUIC (DOQ) was requested but the aioquic module is not
  167. available."""
  168. # for backwards compatibility
  169. TransferError = dns.xfr.TransferError
  170. def _compute_times(timeout):
  171. now = time.time()
  172. if timeout is None:
  173. return (now, None)
  174. else:
  175. return (now, now + timeout)
  176. def _wait_for(fd, readable, writable, _, expiration):
  177. # Use the selected selector class to wait for any of the specified
  178. # events. An "expiration" absolute time is converted into a relative
  179. # timeout.
  180. #
  181. # The unused parameter is 'error', which is always set when
  182. # selecting for read or write, and we have no error-only selects.
  183. if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
  184. return True
  185. with selectors.DefaultSelector() as sel:
  186. events = 0
  187. if readable:
  188. events |= selectors.EVENT_READ
  189. if writable:
  190. events |= selectors.EVENT_WRITE
  191. if events:
  192. sel.register(fd, events) # pyright: ignore
  193. if expiration is None:
  194. timeout = None
  195. else:
  196. timeout = expiration - time.time()
  197. if timeout <= 0.0:
  198. raise dns.exception.Timeout
  199. if not sel.select(timeout):
  200. raise dns.exception.Timeout
  201. def _wait_for_readable(s, expiration):
  202. _wait_for(s, True, False, True, expiration)
  203. def _wait_for_writable(s, expiration):
  204. _wait_for(s, False, True, True, expiration)
  205. def _addresses_equal(af, a1, a2):
  206. # Convert the first value of the tuple, which is a textual format
  207. # address into binary form, so that we are not confused by different
  208. # textual representations of the same address
  209. try:
  210. n1 = dns.inet.inet_pton(af, a1[0])
  211. n2 = dns.inet.inet_pton(af, a2[0])
  212. except dns.exception.SyntaxError:
  213. return False
  214. return n1 == n2 and a1[1:] == a2[1:]
  215. def _matches_destination(af, from_address, destination, ignore_unexpected):
  216. # Check that from_address is appropriate for a response to a query
  217. # sent to destination.
  218. if not destination:
  219. return True
  220. if _addresses_equal(af, from_address, destination) or (
  221. dns.inet.is_multicast(destination[0]) and from_address[1:] == destination[1:]
  222. ):
  223. return True
  224. elif ignore_unexpected:
  225. return False
  226. raise UnexpectedSource(
  227. f"got a response from {from_address} instead of " f"{destination}"
  228. )
  229. def _destination_and_source(
  230. where, port, source, source_port, where_must_be_address=True
  231. ):
  232. # Apply defaults and compute destination and source tuples
  233. # suitable for use in connect(), sendto(), or bind().
  234. af = None
  235. destination = None
  236. try:
  237. af = dns.inet.af_for_address(where)
  238. destination = where
  239. except Exception:
  240. if where_must_be_address:
  241. raise
  242. # URLs are ok so eat the exception
  243. if source:
  244. saf = dns.inet.af_for_address(source)
  245. if af:
  246. # We know the destination af, so source had better agree!
  247. if saf != af:
  248. raise ValueError(
  249. "different address families for source and destination"
  250. )
  251. else:
  252. # We didn't know the destination af, but we know the source,
  253. # so that's our af.
  254. af = saf
  255. if source_port and not source:
  256. # Caller has specified a source_port but not an address, so we
  257. # need to return a source, and we need to use the appropriate
  258. # wildcard address as the address.
  259. try:
  260. source = dns.inet.any_for_af(af)
  261. except Exception:
  262. # we catch this and raise ValueError for backwards compatibility
  263. raise ValueError("source_port specified but address family is unknown")
  264. # Convert high-level (address, port) tuples into low-level address
  265. # tuples.
  266. if destination:
  267. destination = dns.inet.low_level_address_tuple((destination, port), af)
  268. if source:
  269. source = dns.inet.low_level_address_tuple((source, source_port), af)
  270. return (af, destination, source)
  271. def make_socket(
  272. af: socket.AddressFamily | int,
  273. type: socket.SocketKind,
  274. source: Any | None = None,
  275. ) -> socket.socket:
  276. """Make a socket.
  277. This function uses the module's ``socket_factory`` to make a socket of the
  278. specified address family and type.
  279. *af*, a ``socket.AddressFamily`` or ``int`` is the address family, either
  280. ``socket.AF_INET`` or ``socket.AF_INET6``.
  281. *type*, a ``socket.SocketKind`` is the type of socket, e.g. ``socket.SOCK_DGRAM``,
  282. a datagram socket, or ``socket.SOCK_STREAM``, a stream socket. Note that the
  283. ``proto`` attribute of a socket is always zero with this API, so a datagram socket
  284. will always be a UDP socket, and a stream socket will always be a TCP socket.
  285. *source* is the source address and port to bind to, if any. The default is
  286. ``None`` which will bind to the wildcard address and a randomly chosen port.
  287. If not ``None``, it should be a (low-level) address tuple appropriate for *af*.
  288. """
  289. s = socket_factory(af, type, 0)
  290. try:
  291. s.setblocking(False)
  292. if source is not None:
  293. s.bind(source)
  294. return s
  295. except Exception:
  296. s.close()
  297. raise
  298. def make_ssl_socket(
  299. af: socket.AddressFamily | int,
  300. type: socket.SocketKind,
  301. ssl_context: ssl.SSLContext,
  302. server_hostname: dns.name.Name | str | None = None,
  303. source: Any | None = None,
  304. ) -> ssl.SSLSocket:
  305. """Make a socket.
  306. This function uses the module's ``socket_factory`` to make a socket of the
  307. specified address family and type.
  308. *af*, a ``socket.AddressFamily`` or ``int`` is the address family, either
  309. ``socket.AF_INET`` or ``socket.AF_INET6``.
  310. *type*, a ``socket.SocketKind`` is the type of socket, e.g. ``socket.SOCK_DGRAM``,
  311. a datagram socket, or ``socket.SOCK_STREAM``, a stream socket. Note that the
  312. ``proto`` attribute of a socket is always zero with this API, so a datagram socket
  313. will always be a UDP socket, and a stream socket will always be a TCP socket.
  314. If *ssl_context* is not ``None``, then it specifies the SSL context to use,
  315. typically created with ``make_ssl_context()``.
  316. If *server_hostname* is not ``None``, then it is the hostname to use for server
  317. certificate validation. A valid hostname must be supplied if *ssl_context*
  318. requires hostname checking.
  319. *source* is the source address and port to bind to, if any. The default is
  320. ``None`` which will bind to the wildcard address and a randomly chosen port.
  321. If not ``None``, it should be a (low-level) address tuple appropriate for *af*.
  322. """
  323. sock = make_socket(af, type, source)
  324. if isinstance(server_hostname, dns.name.Name):
  325. server_hostname = server_hostname.to_text()
  326. # LGTM gets a false positive here, as our default context is OK
  327. return ssl_context.wrap_socket(
  328. sock,
  329. do_handshake_on_connect=False, # lgtm[py/insecure-protocol]
  330. server_hostname=server_hostname,
  331. )
  332. # for backwards compatibility
  333. def _make_socket(
  334. af,
  335. type,
  336. source,
  337. ssl_context,
  338. server_hostname,
  339. ):
  340. if ssl_context is not None:
  341. return make_ssl_socket(af, type, ssl_context, server_hostname, source)
  342. else:
  343. return make_socket(af, type, source)
  344. def _maybe_get_resolver(
  345. resolver: Optional["dns.resolver.Resolver"], # pyright: ignore
  346. ) -> "dns.resolver.Resolver": # pyright: ignore
  347. # We need a separate method for this to avoid overriding the global
  348. # variable "dns" with the as-yet undefined local variable "dns"
  349. # in https().
  350. if resolver is None:
  351. # pylint: disable=import-outside-toplevel,redefined-outer-name
  352. import dns.resolver
  353. resolver = dns.resolver.Resolver()
  354. return resolver
  355. class HTTPVersion(enum.IntEnum):
  356. """Which version of HTTP should be used?
  357. DEFAULT will select the first version from the list [2, 1.1, 3] that
  358. is available.
  359. """
  360. DEFAULT = 0
  361. HTTP_1 = 1
  362. H1 = 1
  363. HTTP_2 = 2
  364. H2 = 2
  365. HTTP_3 = 3
  366. H3 = 3
  367. def https(
  368. q: dns.message.Message,
  369. where: str,
  370. timeout: float | None = None,
  371. port: int = 443,
  372. source: str | None = None,
  373. source_port: int = 0,
  374. one_rr_per_rrset: bool = False,
  375. ignore_trailing: bool = False,
  376. session: Any | None = None,
  377. path: str = "/dns-query",
  378. post: bool = True,
  379. bootstrap_address: str | None = None,
  380. verify: bool | str | ssl.SSLContext = True,
  381. resolver: Optional["dns.resolver.Resolver"] = None, # pyright: ignore
  382. family: int = socket.AF_UNSPEC,
  383. http_version: HTTPVersion = HTTPVersion.DEFAULT,
  384. ) -> dns.message.Message:
  385. """Return the response obtained after sending a query via DNS-over-HTTPS.
  386. *q*, a ``dns.message.Message``, the query to send.
  387. *where*, a ``str``, the nameserver IP address or the full URL. If an IP address is
  388. given, the URL will be constructed using the following schema:
  389. https://<IP-address>:<port>/<path>.
  390. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
  391. times out. If ``None``, the default, wait forever.
  392. *port*, a ``int``, the port to send the query to. The default is 443.
  393. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
  394. address. The default is the wildcard address.
  395. *source_port*, an ``int``, the port from which to send the message. The default is
  396. 0.
  397. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
  398. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
  399. received message.
  400. *session*, an ``httpx.Client``. If provided, the client session to use to send the
  401. queries.
  402. *path*, a ``str``. If *where* is an IP address, then *path* will be used to
  403. construct the URL to send the DNS query to.
  404. *post*, a ``bool``. If ``True``, the default, POST method will be used.
  405. *bootstrap_address*, a ``str``, the IP address to use to bypass resolution.
  406. *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
  407. of the server is done using the default CA bundle; if ``False``, then no
  408. verification is done; if a `str` then it specifies the path to a certificate file or
  409. directory which will be used for verification.
  410. *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for
  411. resolution of hostnames in URLs. If not specified, a new resolver with a default
  412. configuration will be used; note this is *not* the default resolver as that resolver
  413. might have been configured to use DoH causing a chicken-and-egg problem. This
  414. parameter only has an effect if the HTTP library is httpx.
  415. *family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
  416. and AAAA records will be retrieved.
  417. *http_version*, a ``dns.query.HTTPVersion``, indicating which HTTP version to use.
  418. Returns a ``dns.message.Message``.
  419. """
  420. (af, _, the_source) = _destination_and_source(
  421. where, port, source, source_port, False
  422. )
  423. # we bind url and then override as pyright can't figure out all paths bind.
  424. url = where
  425. if af is not None and dns.inet.is_address(where):
  426. if af == socket.AF_INET:
  427. url = f"https://{where}:{port}{path}"
  428. elif af == socket.AF_INET6:
  429. url = f"https://[{where}]:{port}{path}"
  430. extensions = {}
  431. if bootstrap_address is None:
  432. # pylint: disable=possibly-used-before-assignment
  433. parsed = urllib.parse.urlparse(url)
  434. if parsed.hostname is None:
  435. raise ValueError("no hostname in URL")
  436. if dns.inet.is_address(parsed.hostname):
  437. bootstrap_address = parsed.hostname
  438. extensions["sni_hostname"] = parsed.hostname
  439. if parsed.port is not None:
  440. port = parsed.port
  441. if http_version == HTTPVersion.H3 or (
  442. http_version == HTTPVersion.DEFAULT and not have_doh
  443. ):
  444. if bootstrap_address is None:
  445. resolver = _maybe_get_resolver(resolver)
  446. assert parsed.hostname is not None # pyright: ignore
  447. answers = resolver.resolve_name(parsed.hostname, family) # pyright: ignore
  448. bootstrap_address = random.choice(list(answers.addresses()))
  449. if session and not isinstance(
  450. session, dns.quic.SyncQuicConnection
  451. ): # pyright: ignore
  452. raise ValueError("session parameter must be a dns.quic.SyncQuicConnection.")
  453. return _http3(
  454. q,
  455. bootstrap_address,
  456. url, # pyright: ignore
  457. timeout,
  458. port,
  459. source,
  460. source_port,
  461. one_rr_per_rrset,
  462. ignore_trailing,
  463. verify=verify,
  464. post=post,
  465. connection=session,
  466. )
  467. if not have_doh:
  468. raise NoDOH # pragma: no cover
  469. if session and not isinstance(session, httpx.Client): # pyright: ignore
  470. raise ValueError("session parameter must be an httpx.Client")
  471. wire = q.to_wire()
  472. headers = {"accept": "application/dns-message"}
  473. h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
  474. h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
  475. # set source port and source address
  476. if the_source is None:
  477. local_address = None
  478. local_port = 0
  479. else:
  480. local_address = the_source[0]
  481. local_port = the_source[1]
  482. if session:
  483. cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
  484. else:
  485. transport = _HTTPTransport(
  486. local_address=local_address,
  487. http1=h1,
  488. http2=h2,
  489. verify=verify,
  490. local_port=local_port,
  491. bootstrap_address=bootstrap_address,
  492. resolver=resolver,
  493. family=family, # pyright: ignore
  494. )
  495. cm = httpx.Client( # type: ignore
  496. http1=h1, http2=h2, verify=verify, transport=transport # type: ignore
  497. )
  498. with cm as session:
  499. # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
  500. # GET and POST examples
  501. assert session is not None
  502. if post:
  503. headers.update(
  504. {
  505. "content-type": "application/dns-message",
  506. "content-length": str(len(wire)),
  507. }
  508. )
  509. response = session.post(
  510. url,
  511. headers=headers,
  512. content=wire,
  513. timeout=timeout,
  514. extensions=extensions,
  515. )
  516. else:
  517. wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
  518. twire = wire.decode() # httpx does a repr() if we give it bytes
  519. response = session.get(
  520. url,
  521. headers=headers,
  522. timeout=timeout,
  523. params={"dns": twire},
  524. extensions=extensions,
  525. )
  526. # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
  527. # status codes
  528. if response.status_code < 200 or response.status_code > 299:
  529. raise ValueError(
  530. f"{where} responded with status code {response.status_code}"
  531. f"\nResponse body: {response.content}"
  532. )
  533. r = dns.message.from_wire(
  534. response.content,
  535. keyring=q.keyring,
  536. request_mac=q.request_mac,
  537. one_rr_per_rrset=one_rr_per_rrset,
  538. ignore_trailing=ignore_trailing,
  539. )
  540. r.time = response.elapsed.total_seconds()
  541. if not q.is_response(r):
  542. raise BadResponse
  543. return r
  544. def _find_header(headers: dns.quic.Headers, name: bytes) -> bytes:
  545. if headers is None:
  546. raise KeyError
  547. for header, value in headers:
  548. if header == name:
  549. return value
  550. raise KeyError
  551. def _check_status(headers: dns.quic.Headers, peer: str, wire: bytes) -> None:
  552. value = _find_header(headers, b":status")
  553. if value is None:
  554. raise SyntaxError("no :status header in response")
  555. status = int(value)
  556. if status < 0:
  557. raise SyntaxError("status is negative")
  558. if status < 200 or status > 299:
  559. error = ""
  560. if len(wire) > 0:
  561. try:
  562. error = ": " + wire.decode()
  563. except Exception:
  564. pass
  565. raise ValueError(f"{peer} responded with status code {status}{error}")
  566. def _http3(
  567. q: dns.message.Message,
  568. where: str,
  569. url: str,
  570. timeout: float | None = None,
  571. port: int = 443,
  572. source: str | None = None,
  573. source_port: int = 0,
  574. one_rr_per_rrset: bool = False,
  575. ignore_trailing: bool = False,
  576. verify: bool | str | ssl.SSLContext = True,
  577. post: bool = True,
  578. connection: dns.quic.SyncQuicConnection | None = None,
  579. ) -> dns.message.Message:
  580. if not dns.quic.have_quic:
  581. raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
  582. url_parts = urllib.parse.urlparse(url)
  583. hostname = url_parts.hostname
  584. assert hostname is not None
  585. if url_parts.port is not None:
  586. port = url_parts.port
  587. q.id = 0
  588. wire = q.to_wire()
  589. the_connection: dns.quic.SyncQuicConnection
  590. the_manager: dns.quic.SyncQuicManager
  591. if connection:
  592. manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
  593. else:
  594. manager = dns.quic.SyncQuicManager(
  595. verify_mode=verify, server_name=hostname, h3=True # pyright: ignore
  596. )
  597. the_manager = manager # for type checking happiness
  598. with manager:
  599. if connection:
  600. the_connection = connection
  601. else:
  602. the_connection = the_manager.connect( # pyright: ignore
  603. where, port, source, source_port
  604. )
  605. (start, expiration) = _compute_times(timeout)
  606. with the_connection.make_stream(timeout) as stream: # pyright: ignore
  607. stream.send_h3(url, wire, post)
  608. wire = stream.receive(_remaining(expiration))
  609. _check_status(stream.headers(), where, wire)
  610. finish = time.time()
  611. r = dns.message.from_wire(
  612. wire,
  613. keyring=q.keyring,
  614. request_mac=q.request_mac,
  615. one_rr_per_rrset=one_rr_per_rrset,
  616. ignore_trailing=ignore_trailing,
  617. )
  618. r.time = max(finish - start, 0.0)
  619. if not q.is_response(r):
  620. raise BadResponse
  621. return r
  622. def _udp_recv(sock, max_size, expiration):
  623. """Reads a datagram from the socket.
  624. A Timeout exception will be raised if the operation is not completed
  625. by the expiration time.
  626. """
  627. while True:
  628. try:
  629. return sock.recvfrom(max_size)
  630. except BlockingIOError:
  631. _wait_for_readable(sock, expiration)
  632. def _udp_send(sock, data, destination, expiration):
  633. """Sends the specified datagram to destination over the socket.
  634. A Timeout exception will be raised if the operation is not completed
  635. by the expiration time.
  636. """
  637. while True:
  638. try:
  639. if destination:
  640. return sock.sendto(data, destination)
  641. else:
  642. return sock.send(data)
  643. except BlockingIOError: # pragma: no cover
  644. _wait_for_writable(sock, expiration)
  645. def send_udp(
  646. sock: Any,
  647. what: dns.message.Message | bytes,
  648. destination: Any,
  649. expiration: float | None = None,
  650. ) -> Tuple[int, float]:
  651. """Send a DNS message to the specified UDP socket.
  652. *sock*, a ``socket``.
  653. *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
  654. *destination*, a destination tuple appropriate for the address family
  655. of the socket, specifying where to send the query.
  656. *expiration*, a ``float`` or ``None``, the absolute time at which
  657. a timeout exception should be raised. If ``None``, no timeout will
  658. occur.
  659. Returns an ``(int, float)`` tuple of bytes sent and the sent time.
  660. """
  661. if isinstance(what, dns.message.Message):
  662. what = what.to_wire()
  663. sent_time = time.time()
  664. n = _udp_send(sock, what, destination, expiration)
  665. return (n, sent_time)
  666. def receive_udp(
  667. sock: Any,
  668. destination: Any | None = None,
  669. expiration: float | None = None,
  670. ignore_unexpected: bool = False,
  671. one_rr_per_rrset: bool = False,
  672. keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
  673. request_mac: bytes | None = b"",
  674. ignore_trailing: bool = False,
  675. raise_on_truncation: bool = False,
  676. ignore_errors: bool = False,
  677. query: dns.message.Message | None = None,
  678. ) -> Any:
  679. """Read a DNS message from a UDP socket.
  680. *sock*, a ``socket``.
  681. *destination*, a destination tuple appropriate for the address family
  682. of the socket, specifying where the message is expected to arrive from.
  683. When receiving a response, this would be where the associated query was
  684. sent.
  685. *expiration*, a ``float`` or ``None``, the absolute time at which
  686. a timeout exception should be raised. If ``None``, no timeout will
  687. occur.
  688. *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
  689. unexpected sources.
  690. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  691. RRset.
  692. *keyring*, a ``dict``, the keyring to use for TSIG.
  693. *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG).
  694. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  695. junk at end of the received message.
  696. *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
  697. the TC bit is set.
  698. Raises if the message is malformed, if network errors occur, of if
  699. there is a timeout.
  700. If *destination* is not ``None``, returns a ``(dns.message.Message, float)``
  701. tuple of the received message and the received time.
  702. If *destination* is ``None``, returns a
  703. ``(dns.message.Message, float, tuple)``
  704. tuple of the received message, the received time, and the address where
  705. the message arrived from.
  706. *ignore_errors*, a ``bool``. If various format errors or response
  707. mismatches occur, ignore them and keep listening for a valid response.
  708. The default is ``False``.
  709. *query*, a ``dns.message.Message`` or ``None``. If not ``None`` and
  710. *ignore_errors* is ``True``, check that the received message is a response
  711. to this query, and if not keep listening for a valid response.
  712. """
  713. wire = b""
  714. while True:
  715. (wire, from_address) = _udp_recv(sock, 65535, expiration)
  716. if not _matches_destination(
  717. sock.family, from_address, destination, ignore_unexpected
  718. ):
  719. continue
  720. received_time = time.time()
  721. try:
  722. r = dns.message.from_wire(
  723. wire,
  724. keyring=keyring,
  725. request_mac=request_mac,
  726. one_rr_per_rrset=one_rr_per_rrset,
  727. ignore_trailing=ignore_trailing,
  728. raise_on_truncation=raise_on_truncation,
  729. )
  730. except dns.message.Truncated as e:
  731. # If we got Truncated and not FORMERR, we at least got the header with TC
  732. # set, and very likely the question section, so we'll re-raise if the
  733. # message seems to be a response as we need to know when truncation happens.
  734. # We need to check that it seems to be a response as we don't want a random
  735. # injected message with TC set to cause us to bail out.
  736. if (
  737. ignore_errors
  738. and query is not None
  739. and not query.is_response(e.message())
  740. ):
  741. continue
  742. else:
  743. raise
  744. except Exception:
  745. if ignore_errors:
  746. continue
  747. else:
  748. raise
  749. if ignore_errors and query is not None and not query.is_response(r):
  750. continue
  751. if destination:
  752. return (r, received_time)
  753. else:
  754. return (r, received_time, from_address)
  755. def udp(
  756. q: dns.message.Message,
  757. where: str,
  758. timeout: float | None = None,
  759. port: int = 53,
  760. source: str | None = None,
  761. source_port: int = 0,
  762. ignore_unexpected: bool = False,
  763. one_rr_per_rrset: bool = False,
  764. ignore_trailing: bool = False,
  765. raise_on_truncation: bool = False,
  766. sock: Any | None = None,
  767. ignore_errors: bool = False,
  768. ) -> dns.message.Message:
  769. """Return the response obtained after sending a query via UDP.
  770. *q*, a ``dns.message.Message``, the query to send
  771. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  772. to send the message.
  773. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
  774. query times out. If ``None``, the default, wait forever.
  775. *port*, an ``int``, the port send the message to. The default is 53.
  776. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  777. the source address. The default is the wildcard address.
  778. *source_port*, an ``int``, the port from which to send the message.
  779. The default is 0.
  780. *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
  781. unexpected sources.
  782. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  783. RRset.
  784. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  785. junk at end of the received message.
  786. *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
  787. the TC bit is set.
  788. *sock*, a ``socket.socket``, or ``None``, the socket to use for the
  789. query. If ``None``, the default, a socket is created. Note that
  790. if a socket is provided, it must be a nonblocking datagram socket,
  791. and the *source* and *source_port* are ignored.
  792. *ignore_errors*, a ``bool``. If various format errors or response
  793. mismatches occur, ignore them and keep listening for a valid response.
  794. The default is ``False``.
  795. Returns a ``dns.message.Message``.
  796. """
  797. wire = q.to_wire()
  798. (af, destination, source) = _destination_and_source(
  799. where, port, source, source_port, True
  800. )
  801. (begin_time, expiration) = _compute_times(timeout)
  802. if sock:
  803. cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
  804. else:
  805. assert af is not None
  806. cm = make_socket(af, socket.SOCK_DGRAM, source)
  807. with cm as s:
  808. send_udp(s, wire, destination, expiration)
  809. (r, received_time) = receive_udp(
  810. s,
  811. destination,
  812. expiration,
  813. ignore_unexpected,
  814. one_rr_per_rrset,
  815. q.keyring,
  816. q.mac,
  817. ignore_trailing,
  818. raise_on_truncation,
  819. ignore_errors,
  820. q,
  821. )
  822. r.time = received_time - begin_time
  823. # We don't need to check q.is_response() if we are in ignore_errors mode
  824. # as receive_udp() will have checked it.
  825. if not (ignore_errors or q.is_response(r)):
  826. raise BadResponse
  827. return r
  828. assert (
  829. False # help mypy figure out we can't get here lgtm[py/unreachable-statement]
  830. )
  831. def udp_with_fallback(
  832. q: dns.message.Message,
  833. where: str,
  834. timeout: float | None = None,
  835. port: int = 53,
  836. source: str | None = None,
  837. source_port: int = 0,
  838. ignore_unexpected: bool = False,
  839. one_rr_per_rrset: bool = False,
  840. ignore_trailing: bool = False,
  841. udp_sock: Any | None = None,
  842. tcp_sock: Any | None = None,
  843. ignore_errors: bool = False,
  844. ) -> Tuple[dns.message.Message, bool]:
  845. """Return the response to the query, trying UDP first and falling back
  846. to TCP if UDP results in a truncated response.
  847. *q*, a ``dns.message.Message``, the query to send
  848. *where*, a ``str`` containing an IPv4 or IPv6 address, where to send the message.
  849. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
  850. times out. If ``None``, the default, wait forever.
  851. *port*, an ``int``, the port send the message to. The default is 53.
  852. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
  853. address. The default is the wildcard address.
  854. *source_port*, an ``int``, the port from which to send the message. The default is
  855. 0.
  856. *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from unexpected
  857. sources.
  858. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
  859. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
  860. received message.
  861. *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the UDP query.
  862. If ``None``, the default, a socket is created. Note that if a socket is provided,
  863. it must be a nonblocking datagram socket, and the *source* and *source_port* are
  864. ignored for the UDP query.
  865. *tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
  866. TCP query. If ``None``, the default, a socket is created. Note that if a socket is
  867. provided, it must be a nonblocking connected stream socket, and *where*, *source*
  868. and *source_port* are ignored for the TCP query.
  869. *ignore_errors*, a ``bool``. If various format errors or response mismatches occur
  870. while listening for UDP, ignore them and keep listening for a valid response. The
  871. default is ``False``.
  872. Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` if and only if
  873. TCP was used.
  874. """
  875. try:
  876. response = udp(
  877. q,
  878. where,
  879. timeout,
  880. port,
  881. source,
  882. source_port,
  883. ignore_unexpected,
  884. one_rr_per_rrset,
  885. ignore_trailing,
  886. True,
  887. udp_sock,
  888. ignore_errors,
  889. )
  890. return (response, False)
  891. except dns.message.Truncated:
  892. response = tcp(
  893. q,
  894. where,
  895. timeout,
  896. port,
  897. source,
  898. source_port,
  899. one_rr_per_rrset,
  900. ignore_trailing,
  901. tcp_sock,
  902. )
  903. return (response, True)
  904. def _net_read(sock, count, expiration):
  905. """Read the specified number of bytes from sock. Keep trying until we
  906. either get the desired amount, or we hit EOF.
  907. A Timeout exception will be raised if the operation is not completed
  908. by the expiration time.
  909. """
  910. s = b""
  911. while count > 0:
  912. try:
  913. n = sock.recv(count)
  914. if n == b"":
  915. raise EOFError("EOF")
  916. count -= len(n)
  917. s += n
  918. except (BlockingIOError, ssl.SSLWantReadError):
  919. _wait_for_readable(sock, expiration)
  920. except ssl.SSLWantWriteError: # pragma: no cover
  921. _wait_for_writable(sock, expiration)
  922. return s
  923. def _net_write(sock, data, expiration):
  924. """Write the specified data to the socket.
  925. A Timeout exception will be raised if the operation is not completed
  926. by the expiration time.
  927. """
  928. current = 0
  929. l = len(data)
  930. while current < l:
  931. try:
  932. current += sock.send(data[current:])
  933. except (BlockingIOError, ssl.SSLWantWriteError):
  934. _wait_for_writable(sock, expiration)
  935. except ssl.SSLWantReadError: # pragma: no cover
  936. _wait_for_readable(sock, expiration)
  937. def send_tcp(
  938. sock: Any,
  939. what: dns.message.Message | bytes,
  940. expiration: float | None = None,
  941. ) -> Tuple[int, float]:
  942. """Send a DNS message to the specified TCP socket.
  943. *sock*, a ``socket``.
  944. *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
  945. *expiration*, a ``float`` or ``None``, the absolute time at which
  946. a timeout exception should be raised. If ``None``, no timeout will
  947. occur.
  948. Returns an ``(int, float)`` tuple of bytes sent and the sent time.
  949. """
  950. if isinstance(what, dns.message.Message):
  951. tcpmsg = what.to_wire(prepend_length=True)
  952. else:
  953. # copying the wire into tcpmsg is inefficient, but lets us
  954. # avoid writev() or doing a short write that would get pushed
  955. # onto the net
  956. tcpmsg = len(what).to_bytes(2, "big") + what
  957. sent_time = time.time()
  958. _net_write(sock, tcpmsg, expiration)
  959. return (len(tcpmsg), sent_time)
  960. def receive_tcp(
  961. sock: Any,
  962. expiration: float | None = None,
  963. one_rr_per_rrset: bool = False,
  964. keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
  965. request_mac: bytes | None = b"",
  966. ignore_trailing: bool = False,
  967. ) -> Tuple[dns.message.Message, float]:
  968. """Read a DNS message from a TCP socket.
  969. *sock*, a ``socket``.
  970. *expiration*, a ``float`` or ``None``, the absolute time at which
  971. a timeout exception should be raised. If ``None``, no timeout will
  972. occur.
  973. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  974. RRset.
  975. *keyring*, a ``dict``, the keyring to use for TSIG.
  976. *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG).
  977. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  978. junk at end of the received message.
  979. Raises if the message is malformed, if network errors occur, of if
  980. there is a timeout.
  981. Returns a ``(dns.message.Message, float)`` tuple of the received message
  982. and the received time.
  983. """
  984. ldata = _net_read(sock, 2, expiration)
  985. (l,) = struct.unpack("!H", ldata)
  986. wire = _net_read(sock, l, expiration)
  987. received_time = time.time()
  988. r = dns.message.from_wire(
  989. wire,
  990. keyring=keyring,
  991. request_mac=request_mac,
  992. one_rr_per_rrset=one_rr_per_rrset,
  993. ignore_trailing=ignore_trailing,
  994. )
  995. return (r, received_time)
  996. def _connect(s, address, expiration):
  997. err = s.connect_ex(address)
  998. if err == 0:
  999. return
  1000. if err in (errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY):
  1001. _wait_for_writable(s, expiration)
  1002. err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
  1003. if err != 0:
  1004. raise OSError(err, os.strerror(err))
  1005. def tcp(
  1006. q: dns.message.Message,
  1007. where: str,
  1008. timeout: float | None = None,
  1009. port: int = 53,
  1010. source: str | None = None,
  1011. source_port: int = 0,
  1012. one_rr_per_rrset: bool = False,
  1013. ignore_trailing: bool = False,
  1014. sock: Any | None = None,
  1015. ) -> dns.message.Message:
  1016. """Return the response obtained after sending a query via TCP.
  1017. *q*, a ``dns.message.Message``, the query to send
  1018. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  1019. to send the message.
  1020. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
  1021. query times out. If ``None``, the default, wait forever.
  1022. *port*, an ``int``, the port send the message to. The default is 53.
  1023. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  1024. the source address. The default is the wildcard address.
  1025. *source_port*, an ``int``, the port from which to send the message.
  1026. The default is 0.
  1027. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  1028. RRset.
  1029. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  1030. junk at end of the received message.
  1031. *sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
  1032. query. If ``None``, the default, a socket is created. Note that
  1033. if a socket is provided, it must be a nonblocking connected stream
  1034. socket, and *where*, *port*, *source* and *source_port* are ignored.
  1035. Returns a ``dns.message.Message``.
  1036. """
  1037. wire = q.to_wire()
  1038. (begin_time, expiration) = _compute_times(timeout)
  1039. if sock:
  1040. cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
  1041. else:
  1042. (af, destination, source) = _destination_and_source(
  1043. where, port, source, source_port, True
  1044. )
  1045. assert af is not None
  1046. cm = make_socket(af, socket.SOCK_STREAM, source)
  1047. with cm as s:
  1048. if not sock:
  1049. # pylint: disable=possibly-used-before-assignment
  1050. _connect(s, destination, expiration) # pyright: ignore
  1051. send_tcp(s, wire, expiration)
  1052. (r, received_time) = receive_tcp(
  1053. s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
  1054. )
  1055. r.time = received_time - begin_time
  1056. if not q.is_response(r):
  1057. raise BadResponse
  1058. return r
  1059. assert (
  1060. False # help mypy figure out we can't get here lgtm[py/unreachable-statement]
  1061. )
  1062. def _tls_handshake(s, expiration):
  1063. while True:
  1064. try:
  1065. s.do_handshake()
  1066. return
  1067. except ssl.SSLWantReadError:
  1068. _wait_for_readable(s, expiration)
  1069. except ssl.SSLWantWriteError: # pragma: no cover
  1070. _wait_for_writable(s, expiration)
  1071. def make_ssl_context(
  1072. verify: bool | str = True,
  1073. check_hostname: bool = True,
  1074. alpns: list[str] | None = None,
  1075. ) -> ssl.SSLContext:
  1076. """Make an SSL context
  1077. If *verify* is ``True``, the default, then certificate verification will occur using
  1078. the standard CA roots. If *verify* is ``False``, then certificate verification will
  1079. be disabled. If *verify* is a string which is a valid pathname, then if the
  1080. pathname is a regular file, the CA roots will be taken from the file, otherwise if
  1081. the pathname is a directory roots will be taken from the directory.
  1082. If *check_hostname* is ``True``, the default, then the hostname of the server must
  1083. be specified when connecting and the server's certificate must authorize the
  1084. hostname. If ``False``, then hostname checking is disabled.
  1085. *aplns* is ``None`` or a list of TLS ALPN (Application Layer Protocol Negotiation)
  1086. strings to use in negotiation. For DNS-over-TLS, the right value is `["dot"]`.
  1087. """
  1088. cafile, capath = dns._tls_util.convert_verify_to_cafile_and_capath(verify)
  1089. ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
  1090. # the pyright ignores below are because it gets confused between the
  1091. # _no_ssl compatibility types and the real ones.
  1092. ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 # type: ignore
  1093. ssl_context.check_hostname = check_hostname
  1094. if verify is False:
  1095. ssl_context.verify_mode = ssl.CERT_NONE # type: ignore
  1096. if alpns is not None:
  1097. ssl_context.set_alpn_protocols(alpns)
  1098. return ssl_context # type: ignore
  1099. # for backwards compatibility
  1100. def _make_dot_ssl_context(
  1101. server_hostname: str | None, verify: bool | str
  1102. ) -> ssl.SSLContext:
  1103. return make_ssl_context(verify, server_hostname is not None, ["dot"])
  1104. def tls(
  1105. q: dns.message.Message,
  1106. where: str,
  1107. timeout: float | None = None,
  1108. port: int = 853,
  1109. source: str | None = None,
  1110. source_port: int = 0,
  1111. one_rr_per_rrset: bool = False,
  1112. ignore_trailing: bool = False,
  1113. sock: ssl.SSLSocket | None = None,
  1114. ssl_context: ssl.SSLContext | None = None,
  1115. server_hostname: str | None = None,
  1116. verify: bool | str = True,
  1117. ) -> dns.message.Message:
  1118. """Return the response obtained after sending a query via TLS.
  1119. *q*, a ``dns.message.Message``, the query to send
  1120. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  1121. to send the message.
  1122. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
  1123. query times out. If ``None``, the default, wait forever.
  1124. *port*, an ``int``, the port send the message to. The default is 853.
  1125. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  1126. the source address. The default is the wildcard address.
  1127. *source_port*, an ``int``, the port from which to send the message.
  1128. The default is 0.
  1129. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
  1130. RRset.
  1131. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
  1132. junk at end of the received message.
  1133. *sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for
  1134. the query. If ``None``, the default, a socket is created. Note
  1135. that if a socket is provided, it must be a nonblocking connected
  1136. SSL stream socket, and *where*, *port*, *source*, *source_port*,
  1137. and *ssl_context* are ignored.
  1138. *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
  1139. a TLS connection. If ``None``, the default, creates one with the default
  1140. configuration.
  1141. *server_hostname*, a ``str`` containing the server's hostname. The
  1142. default is ``None``, which means that no hostname is known, and if an
  1143. SSL context is created, hostname checking will be disabled.
  1144. *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
  1145. of the server is done using the default CA bundle; if ``False``, then no
  1146. verification is done; if a `str` then it specifies the path to a certificate file or
  1147. directory which will be used for verification.
  1148. Returns a ``dns.message.Message``.
  1149. """
  1150. if sock:
  1151. #
  1152. # If a socket was provided, there's no special TLS handling needed.
  1153. #
  1154. return tcp(
  1155. q,
  1156. where,
  1157. timeout,
  1158. port,
  1159. source,
  1160. source_port,
  1161. one_rr_per_rrset,
  1162. ignore_trailing,
  1163. sock,
  1164. )
  1165. wire = q.to_wire()
  1166. (begin_time, expiration) = _compute_times(timeout)
  1167. (af, destination, source) = _destination_and_source(
  1168. where, port, source, source_port, True
  1169. )
  1170. assert af is not None # where must be an address
  1171. if ssl_context is None:
  1172. ssl_context = make_ssl_context(verify, server_hostname is not None, ["dot"])
  1173. with make_ssl_socket(
  1174. af,
  1175. socket.SOCK_STREAM,
  1176. ssl_context=ssl_context,
  1177. server_hostname=server_hostname,
  1178. source=source,
  1179. ) as s:
  1180. _connect(s, destination, expiration)
  1181. _tls_handshake(s, expiration)
  1182. send_tcp(s, wire, expiration)
  1183. (r, received_time) = receive_tcp(
  1184. s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
  1185. )
  1186. r.time = received_time - begin_time
  1187. if not q.is_response(r):
  1188. raise BadResponse
  1189. return r
  1190. assert (
  1191. False # help mypy figure out we can't get here lgtm[py/unreachable-statement]
  1192. )
  1193. def quic(
  1194. q: dns.message.Message,
  1195. where: str,
  1196. timeout: float | None = None,
  1197. port: int = 853,
  1198. source: str | None = None,
  1199. source_port: int = 0,
  1200. one_rr_per_rrset: bool = False,
  1201. ignore_trailing: bool = False,
  1202. connection: dns.quic.SyncQuicConnection | None = None,
  1203. verify: bool | str = True,
  1204. hostname: str | None = None,
  1205. server_hostname: str | None = None,
  1206. ) -> dns.message.Message:
  1207. """Return the response obtained after sending a query via DNS-over-QUIC.
  1208. *q*, a ``dns.message.Message``, the query to send.
  1209. *where*, a ``str``, the nameserver IP address.
  1210. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
  1211. times out. If ``None``, the default, wait forever.
  1212. *port*, a ``int``, the port to send the query to. The default is 853.
  1213. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
  1214. address. The default is the wildcard address.
  1215. *source_port*, an ``int``, the port from which to send the message. The default is
  1216. 0.
  1217. *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
  1218. *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
  1219. received message.
  1220. *connection*, a ``dns.quic.SyncQuicConnection``. If provided, the connection to use
  1221. to send the query.
  1222. *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
  1223. of the server is done using the default CA bundle; if ``False``, then no
  1224. verification is done; if a `str` then it specifies the path to a certificate file or
  1225. directory which will be used for verification.
  1226. *hostname*, a ``str`` containing the server's hostname or ``None``. The default is
  1227. ``None``, which means that no hostname is known, and if an SSL context is created,
  1228. hostname checking will be disabled. This value is ignored if *url* is not
  1229. ``None``.
  1230. *server_hostname*, a ``str`` or ``None``. This item is for backwards compatibility
  1231. only, and has the same meaning as *hostname*.
  1232. Returns a ``dns.message.Message``.
  1233. """
  1234. if not dns.quic.have_quic:
  1235. raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
  1236. if server_hostname is not None and hostname is None:
  1237. hostname = server_hostname
  1238. q.id = 0
  1239. wire = q.to_wire()
  1240. the_connection: dns.quic.SyncQuicConnection
  1241. the_manager: dns.quic.SyncQuicManager
  1242. if connection:
  1243. manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
  1244. the_connection = connection
  1245. else:
  1246. manager = dns.quic.SyncQuicManager(
  1247. verify_mode=verify, server_name=hostname # pyright: ignore
  1248. )
  1249. the_manager = manager # for type checking happiness
  1250. with manager:
  1251. if not connection:
  1252. the_connection = the_manager.connect( # pyright: ignore
  1253. where, port, source, source_port
  1254. )
  1255. (start, expiration) = _compute_times(timeout)
  1256. with the_connection.make_stream(timeout) as stream: # pyright: ignore
  1257. stream.send(wire, True)
  1258. wire = stream.receive(_remaining(expiration))
  1259. finish = time.time()
  1260. r = dns.message.from_wire(
  1261. wire,
  1262. keyring=q.keyring,
  1263. request_mac=q.request_mac,
  1264. one_rr_per_rrset=one_rr_per_rrset,
  1265. ignore_trailing=ignore_trailing,
  1266. )
  1267. r.time = max(finish - start, 0.0)
  1268. if not q.is_response(r):
  1269. raise BadResponse
  1270. return r
  1271. class UDPMode(enum.IntEnum):
  1272. """How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?
  1273. NEVER means "never use UDP; always use TCP"
  1274. TRY_FIRST means "try to use UDP but fall back to TCP if needed"
  1275. ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
  1276. """
  1277. NEVER = 0
  1278. TRY_FIRST = 1
  1279. ONLY = 2
  1280. def _inbound_xfr(
  1281. txn_manager: dns.transaction.TransactionManager,
  1282. s: socket.socket | ssl.SSLSocket,
  1283. query: dns.message.Message,
  1284. serial: int | None,
  1285. timeout: float | None,
  1286. expiration: float | None,
  1287. ) -> Any:
  1288. """Given a socket, does the zone transfer."""
  1289. rdtype = query.question[0].rdtype
  1290. is_ixfr = rdtype == dns.rdatatype.IXFR
  1291. origin = txn_manager.from_wire_origin()
  1292. wire = query.to_wire()
  1293. is_udp = isinstance(s, socket.socket) and s.type == socket.SOCK_DGRAM
  1294. if is_udp:
  1295. _udp_send(s, wire, None, expiration)
  1296. else:
  1297. tcpmsg = struct.pack("!H", len(wire)) + wire
  1298. _net_write(s, tcpmsg, expiration)
  1299. with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
  1300. done = False
  1301. tsig_ctx = None
  1302. r: dns.message.Message | None = None
  1303. while not done:
  1304. (_, mexpiration) = _compute_times(timeout)
  1305. if mexpiration is None or (
  1306. expiration is not None and mexpiration > expiration
  1307. ):
  1308. mexpiration = expiration
  1309. if is_udp:
  1310. (rwire, _) = _udp_recv(s, 65535, mexpiration)
  1311. else:
  1312. ldata = _net_read(s, 2, mexpiration)
  1313. (l,) = struct.unpack("!H", ldata)
  1314. rwire = _net_read(s, l, mexpiration)
  1315. r = dns.message.from_wire(
  1316. rwire,
  1317. keyring=query.keyring,
  1318. request_mac=query.mac,
  1319. xfr=True,
  1320. origin=origin,
  1321. tsig_ctx=tsig_ctx,
  1322. multi=(not is_udp),
  1323. one_rr_per_rrset=is_ixfr,
  1324. )
  1325. done = inbound.process_message(r)
  1326. yield r
  1327. tsig_ctx = r.tsig_ctx
  1328. if query.keyring and r is not None and not r.had_tsig:
  1329. raise dns.exception.FormError("missing TSIG")
  1330. def xfr(
  1331. where: str,
  1332. zone: dns.name.Name | str,
  1333. rdtype: dns.rdatatype.RdataType | str = dns.rdatatype.AXFR,
  1334. rdclass: dns.rdataclass.RdataClass | str = dns.rdataclass.IN,
  1335. timeout: float | None = None,
  1336. port: int = 53,
  1337. keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
  1338. keyname: dns.name.Name | str | None = None,
  1339. relativize: bool = True,
  1340. lifetime: float | None = None,
  1341. source: str | None = None,
  1342. source_port: int = 0,
  1343. serial: int = 0,
  1344. use_udp: bool = False,
  1345. keyalgorithm: dns.name.Name | str = dns.tsig.default_algorithm,
  1346. ) -> Any:
  1347. """Return a generator for the responses to a zone transfer.
  1348. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  1349. to send the message.
  1350. *zone*, a ``dns.name.Name`` or ``str``, the name of the zone to transfer.
  1351. *rdtype*, an ``int`` or ``str``, the type of zone transfer. The
  1352. default is ``dns.rdatatype.AXFR``. ``dns.rdatatype.IXFR`` can be
  1353. used to do an incremental transfer instead.
  1354. *rdclass*, an ``int`` or ``str``, the class of the zone transfer.
  1355. The default is ``dns.rdataclass.IN``.
  1356. *timeout*, a ``float``, the number of seconds to wait for each
  1357. response message. If None, the default, wait forever.
  1358. *port*, an ``int``, the port send the message to. The default is 53.
  1359. *keyring*, a ``dict``, the keyring to use for TSIG.
  1360. *keyname*, a ``dns.name.Name`` or ``str``, the name of the TSIG
  1361. key to use.
  1362. *relativize*, a ``bool``. If ``True``, all names in the zone will be
  1363. relativized to the zone origin. It is essential that the
  1364. relativize setting matches the one specified to
  1365. ``dns.zone.from_xfr()`` if using this generator to make a zone.
  1366. *lifetime*, a ``float``, the total number of seconds to spend
  1367. doing the transfer. If ``None``, the default, then there is no
  1368. limit on the time the transfer may take.
  1369. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  1370. the source address. The default is the wildcard address.
  1371. *source_port*, an ``int``, the port from which to send the message.
  1372. The default is 0.
  1373. *serial*, an ``int``, the SOA serial number to use as the base for
  1374. an IXFR diff sequence (only meaningful if *rdtype* is
  1375. ``dns.rdatatype.IXFR``).
  1376. *use_udp*, a ``bool``. If ``True``, use UDP (only meaningful for IXFR).
  1377. *keyalgorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use.
  1378. Raises on errors, and so does the generator.
  1379. Returns a generator of ``dns.message.Message`` objects.
  1380. """
  1381. class DummyTransactionManager(dns.transaction.TransactionManager):
  1382. def __init__(self, origin, relativize):
  1383. self.info = (origin, relativize, dns.name.empty if relativize else origin)
  1384. def origin_information(self):
  1385. return self.info
  1386. def get_class(self) -> dns.rdataclass.RdataClass:
  1387. raise NotImplementedError # pragma: no cover
  1388. def reader(self):
  1389. raise NotImplementedError # pragma: no cover
  1390. def writer(self, replacement: bool = False) -> dns.transaction.Transaction:
  1391. class DummyTransaction:
  1392. def nop(self, *args, **kw):
  1393. pass
  1394. def __getattr__(self, _):
  1395. return self.nop
  1396. return cast(dns.transaction.Transaction, DummyTransaction())
  1397. if isinstance(zone, str):
  1398. zone = dns.name.from_text(zone)
  1399. rdtype = dns.rdatatype.RdataType.make(rdtype)
  1400. q = dns.message.make_query(zone, rdtype, rdclass)
  1401. if rdtype == dns.rdatatype.IXFR:
  1402. rrset = q.find_rrset(
  1403. q.authority, zone, dns.rdataclass.IN, dns.rdatatype.SOA, create=True
  1404. )
  1405. soa = dns.rdata.from_text("IN", "SOA", f". . {serial} 0 0 0 0")
  1406. rrset.add(soa, 0)
  1407. if keyring is not None:
  1408. q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
  1409. (af, destination, source) = _destination_and_source(
  1410. where, port, source, source_port, True
  1411. )
  1412. assert af is not None
  1413. (_, expiration) = _compute_times(lifetime)
  1414. tm = DummyTransactionManager(zone, relativize)
  1415. if use_udp and rdtype != dns.rdatatype.IXFR:
  1416. raise ValueError("cannot do a UDP AXFR")
  1417. sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
  1418. with make_socket(af, sock_type, source) as s:
  1419. _connect(s, destination, expiration)
  1420. yield from _inbound_xfr(tm, s, q, serial, timeout, expiration)
  1421. def inbound_xfr(
  1422. where: str,
  1423. txn_manager: dns.transaction.TransactionManager,
  1424. query: dns.message.Message | None = None,
  1425. port: int = 53,
  1426. timeout: float | None = None,
  1427. lifetime: float | None = None,
  1428. source: str | None = None,
  1429. source_port: int = 0,
  1430. udp_mode: UDPMode = UDPMode.NEVER,
  1431. ) -> None:
  1432. """Conduct an inbound transfer and apply it via a transaction from the
  1433. txn_manager.
  1434. *where*, a ``str`` containing an IPv4 or IPv6 address, where
  1435. to send the message.
  1436. *txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager
  1437. for this transfer (typically a ``dns.zone.Zone``).
  1438. *query*, the query to send. If not supplied, a default query is
  1439. constructed using information from the *txn_manager*.
  1440. *port*, an ``int``, the port send the message to. The default is 53.
  1441. *timeout*, a ``float``, the number of seconds to wait for each
  1442. response message. If None, the default, wait forever.
  1443. *lifetime*, a ``float``, the total number of seconds to spend
  1444. doing the transfer. If ``None``, the default, then there is no
  1445. limit on the time the transfer may take.
  1446. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
  1447. the source address. The default is the wildcard address.
  1448. *source_port*, an ``int``, the port from which to send the message.
  1449. The default is 0.
  1450. *udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used
  1451. for IXFRs. The default is ``dns.query.UDPMode.NEVER``, i.e. only use
  1452. TCP. Other possibilities are ``dns.query.UDPMode.TRY_FIRST``, which
  1453. means "try UDP but fallback to TCP if needed", and
  1454. ``dns.query.UDPMode.ONLY``, which means "try UDP and raise
  1455. ``dns.xfr.UseTCP`` if it does not succeed.
  1456. Raises on errors.
  1457. """
  1458. if query is None:
  1459. (query, serial) = dns.xfr.make_query(txn_manager)
  1460. else:
  1461. serial = dns.xfr.extract_serial_from_query(query)
  1462. (af, destination, source) = _destination_and_source(
  1463. where, port, source, source_port, True
  1464. )
  1465. assert af is not None
  1466. (_, expiration) = _compute_times(lifetime)
  1467. if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
  1468. with make_socket(af, socket.SOCK_DGRAM, source) as s:
  1469. _connect(s, destination, expiration)
  1470. try:
  1471. for _ in _inbound_xfr(
  1472. txn_manager, s, query, serial, timeout, expiration
  1473. ):
  1474. pass
  1475. return
  1476. except dns.xfr.UseTCP:
  1477. if udp_mode == UDPMode.ONLY:
  1478. raise
  1479. with make_socket(af, socket.SOCK_STREAM, source) as s:
  1480. _connect(s, destination, expiration)
  1481. for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
  1482. pass