asyncquery.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. # Copyright (C) 2003-2017 Nominum, Inc.
  3. #
  4. # Permission to use, copy, modify, and distribute this software and its
  5. # documentation for any purpose with or without fee is hereby granted,
  6. # provided that the above copyright notice and this permission notice
  7. # appear in all copies.
  8. #
  9. # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
  10. # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  11. # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
  12. # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  13. # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  14. # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
  15. # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  16. """Talk to a DNS server."""
  17. import base64
  18. import contextlib
  19. import random
  20. import socket
  21. import struct
  22. import time
  23. import urllib.parse
  24. from typing import Any, Dict, Optional, Tuple, cast
  25. import dns.asyncbackend
  26. import dns.exception
  27. import dns.inet
  28. import dns.message
  29. import dns.name
  30. import dns.quic
  31. import dns.rdatatype
  32. import dns.transaction
  33. import dns.tsig
  34. import dns.xfr
  35. from dns._asyncbackend import NullContext
  36. from dns.query import (
  37. BadResponse,
  38. HTTPVersion,
  39. NoDOH,
  40. NoDOQ,
  41. UDPMode,
  42. _check_status,
  43. _compute_times,
  44. _matches_destination,
  45. _remaining,
  46. have_doh,
  47. make_ssl_context,
  48. )
  49. try:
  50. import ssl
  51. except ImportError:
  52. import dns._no_ssl as ssl # type: ignore
  53. if have_doh:
  54. import httpx
  55. # for brevity
  56. _lltuple = dns.inet.low_level_address_tuple
  57. def _source_tuple(af, address, port):
  58. # Make a high level source tuple, or return None if address and port
  59. # are both None
  60. if address or port:
  61. if address is None:
  62. if af == socket.AF_INET:
  63. address = "0.0.0.0"
  64. elif af == socket.AF_INET6:
  65. address = "::"
  66. else:
  67. raise NotImplementedError(f"unknown address family {af}")
  68. return (address, port)
  69. else:
  70. return None
  71. def _timeout(expiration, now=None):
  72. if expiration is not None:
  73. if not now:
  74. now = time.time()
  75. return max(expiration - now, 0)
  76. else:
  77. return None
  78. async def send_udp(
  79. sock: dns.asyncbackend.DatagramSocket,
  80. what: dns.message.Message | bytes,
  81. destination: Any,
  82. expiration: float | None = None,
  83. ) -> Tuple[int, float]:
  84. """Send a DNS message to the specified UDP socket.
  85. *sock*, a ``dns.asyncbackend.DatagramSocket``.
  86. *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
  87. *destination*, a destination tuple appropriate for the address family
  88. of the socket, specifying where to send the query.
  89. *expiration*, a ``float`` or ``None``, the absolute time at which
  90. a timeout exception should be raised. If ``None``, no timeout will
  91. occur. The expiration value is meaningless for the asyncio backend, as
  92. asyncio's transport sendto() never blocks.
  93. Returns an ``(int, float)`` tuple of bytes sent and the sent time.
  94. """
  95. if isinstance(what, dns.message.Message):
  96. what = what.to_wire()
  97. sent_time = time.time()
  98. n = await sock.sendto(what, destination, _timeout(expiration, sent_time))
  99. return (n, sent_time)
  100. async def receive_udp(
  101. sock: dns.asyncbackend.DatagramSocket,
  102. destination: Any | None = None,
  103. expiration: float | None = None,
  104. ignore_unexpected: bool = False,
  105. one_rr_per_rrset: bool = False,
  106. keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
  107. request_mac: bytes | None = b"",
  108. ignore_trailing: bool = False,
  109. raise_on_truncation: bool = False,
  110. ignore_errors: bool = False,
  111. query: dns.message.Message | None = None,
  112. ) -> Any:
  113. """Read a DNS message from a UDP socket.
  114. *sock*, a ``dns.asyncbackend.DatagramSocket``.
  115. See :py:func:`dns.query.receive_udp()` for the documentation of the other
  116. parameters, and exceptions.
  117. Returns a ``(dns.message.Message, float, tuple)`` tuple of the received message, the
  118. received time, and the address where the message arrived from.
  119. """
  120. wire = b""
  121. while True:
  122. (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
  123. if not _matches_destination(
  124. sock.family, from_address, destination, ignore_unexpected
  125. ):
  126. continue
  127. received_time = time.time()
  128. try:
  129. r = dns.message.from_wire(
  130. wire,
  131. keyring=keyring,
  132. request_mac=request_mac,
  133. one_rr_per_rrset=one_rr_per_rrset,
  134. ignore_trailing=ignore_trailing,
  135. raise_on_truncation=raise_on_truncation,
  136. )
  137. except dns.message.Truncated as e:
  138. # See the comment in query.py for details.
  139. if (
  140. ignore_errors
  141. and query is not None
  142. and not query.is_response(e.message())
  143. ):
  144. continue
  145. else:
  146. raise
  147. except Exception:
  148. if ignore_errors:
  149. continue
  150. else:
  151. raise
  152. if ignore_errors and query is not None and not query.is_response(r):
  153. continue
  154. return (r, received_time, from_address)
  155. async def udp(
  156. q: dns.message.Message,
  157. where: str,
  158. timeout: float | None = None,
  159. port: int = 53,
  160. source: str | None = None,
  161. source_port: int = 0,
  162. ignore_unexpected: bool = False,
  163. one_rr_per_rrset: bool = False,
  164. ignore_trailing: bool = False,
  165. raise_on_truncation: bool = False,
  166. sock: dns.asyncbackend.DatagramSocket | None = None,
  167. backend: dns.asyncbackend.Backend | None = None,
  168. ignore_errors: bool = False,
  169. ) -> dns.message.Message:
  170. """Return the response obtained after sending a query via UDP.
  171. *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
  172. the socket to use for the query. If ``None``, the default, a
  173. socket is created. Note that if a socket is provided, the
  174. *source*, *source_port*, and *backend* are ignored.
  175. *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
  176. the default, then dnspython will use the default backend.
  177. See :py:func:`dns.query.udp()` for the documentation of the other
  178. parameters, exceptions, and return type of this method.
  179. """
  180. wire = q.to_wire()
  181. (begin_time, expiration) = _compute_times(timeout)
  182. af = dns.inet.af_for_address(where)
  183. destination = _lltuple((where, port), af)
  184. if sock:
  185. cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
  186. else:
  187. if not backend:
  188. backend = dns.asyncbackend.get_default_backend()
  189. stuple = _source_tuple(af, source, source_port)
  190. if backend.datagram_connection_required():
  191. dtuple = (where, port)
  192. else:
  193. dtuple = None
  194. cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
  195. async with cm as s:
  196. await send_udp(s, wire, destination, expiration) # pyright: ignore
  197. (r, received_time, _) = await receive_udp(
  198. s, # pyright: ignore
  199. destination,
  200. expiration,
  201. ignore_unexpected,
  202. one_rr_per_rrset,
  203. q.keyring,
  204. q.mac,
  205. ignore_trailing,
  206. raise_on_truncation,
  207. ignore_errors,
  208. q,
  209. )
  210. r.time = received_time - begin_time
  211. # We don't need to check q.is_response() if we are in ignore_errors mode
  212. # as receive_udp() will have checked it.
  213. if not (ignore_errors or q.is_response(r)):
  214. raise BadResponse
  215. return r
  216. async def udp_with_fallback(
  217. q: dns.message.Message,
  218. where: str,
  219. timeout: float | None = None,
  220. port: int = 53,
  221. source: str | None = None,
  222. source_port: int = 0,
  223. ignore_unexpected: bool = False,
  224. one_rr_per_rrset: bool = False,
  225. ignore_trailing: bool = False,
  226. udp_sock: dns.asyncbackend.DatagramSocket | None = None,
  227. tcp_sock: dns.asyncbackend.StreamSocket | None = None,
  228. backend: dns.asyncbackend.Backend | None = None,
  229. ignore_errors: bool = False,
  230. ) -> Tuple[dns.message.Message, bool]:
  231. """Return the response to the query, trying UDP first and falling back
  232. to TCP if UDP results in a truncated response.
  233. *udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
  234. the socket to use for the UDP query. If ``None``, the default, a
  235. socket is created. Note that if a socket is provided the *source*,
  236. *source_port*, and *backend* are ignored for the UDP query.
  237. *tcp_sock*, a ``dns.asyncbackend.StreamSocket``, or ``None``, the
  238. socket to use for the TCP query. If ``None``, the default, a
  239. socket is created. Note that if a socket is provided *where*,
  240. *source*, *source_port*, and *backend* are ignored for the TCP query.
  241. *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
  242. the default, then dnspython will use the default backend.
  243. See :py:func:`dns.query.udp_with_fallback()` for the documentation
  244. of the other parameters, exceptions, and return type of this
  245. method.
  246. """
  247. try:
  248. response = await udp(
  249. q,
  250. where,
  251. timeout,
  252. port,
  253. source,
  254. source_port,
  255. ignore_unexpected,
  256. one_rr_per_rrset,
  257. ignore_trailing,
  258. True,
  259. udp_sock,
  260. backend,
  261. ignore_errors,
  262. )
  263. return (response, False)
  264. except dns.message.Truncated:
  265. response = await tcp(
  266. q,
  267. where,
  268. timeout,
  269. port,
  270. source,
  271. source_port,
  272. one_rr_per_rrset,
  273. ignore_trailing,
  274. tcp_sock,
  275. backend,
  276. )
  277. return (response, True)
  278. async def send_tcp(
  279. sock: dns.asyncbackend.StreamSocket,
  280. what: dns.message.Message | bytes,
  281. expiration: float | None = None,
  282. ) -> Tuple[int, float]:
  283. """Send a DNS message to the specified TCP socket.
  284. *sock*, a ``dns.asyncbackend.StreamSocket``.
  285. See :py:func:`dns.query.send_tcp()` for the documentation of the other
  286. parameters, exceptions, and return type of this method.
  287. """
  288. if isinstance(what, dns.message.Message):
  289. tcpmsg = what.to_wire(prepend_length=True)
  290. else:
  291. # copying the wire into tcpmsg is inefficient, but lets us
  292. # avoid writev() or doing a short write that would get pushed
  293. # onto the net
  294. tcpmsg = len(what).to_bytes(2, "big") + what
  295. sent_time = time.time()
  296. await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
  297. return (len(tcpmsg), sent_time)
  298. async def _read_exactly(sock, count, expiration):
  299. """Read the specified number of bytes from stream. Keep trying until we
  300. either get the desired amount, or we hit EOF.
  301. """
  302. s = b""
  303. while count > 0:
  304. n = await sock.recv(count, _timeout(expiration))
  305. if n == b"":
  306. raise EOFError("EOF")
  307. count = count - len(n)
  308. s = s + n
  309. return s
  310. async def receive_tcp(
  311. sock: dns.asyncbackend.StreamSocket,
  312. expiration: float | None = None,
  313. one_rr_per_rrset: bool = False,
  314. keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
  315. request_mac: bytes | None = b"",
  316. ignore_trailing: bool = False,
  317. ) -> Tuple[dns.message.Message, float]:
  318. """Read a DNS message from a TCP socket.
  319. *sock*, a ``dns.asyncbackend.StreamSocket``.
  320. See :py:func:`dns.query.receive_tcp()` for the documentation of the other
  321. parameters, exceptions, and return type of this method.
  322. """
  323. ldata = await _read_exactly(sock, 2, expiration)
  324. (l,) = struct.unpack("!H", ldata)
  325. wire = await _read_exactly(sock, l, expiration)
  326. received_time = time.time()
  327. r = dns.message.from_wire(
  328. wire,
  329. keyring=keyring,
  330. request_mac=request_mac,
  331. one_rr_per_rrset=one_rr_per_rrset,
  332. ignore_trailing=ignore_trailing,
  333. )
  334. return (r, received_time)
  335. async def tcp(
  336. q: dns.message.Message,
  337. where: str,
  338. timeout: float | None = None,
  339. port: int = 53,
  340. source: str | None = None,
  341. source_port: int = 0,
  342. one_rr_per_rrset: bool = False,
  343. ignore_trailing: bool = False,
  344. sock: dns.asyncbackend.StreamSocket | None = None,
  345. backend: dns.asyncbackend.Backend | None = None,
  346. ) -> dns.message.Message:
  347. """Return the response obtained after sending a query via TCP.
  348. *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
  349. socket to use for the query. If ``None``, the default, a socket
  350. is created. Note that if a socket is provided
  351. *where*, *port*, *source*, *source_port*, and *backend* are ignored.
  352. *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
  353. the default, then dnspython will use the default backend.
  354. See :py:func:`dns.query.tcp()` for the documentation of the other
  355. parameters, exceptions, and return type of this method.
  356. """
  357. wire = q.to_wire()
  358. (begin_time, expiration) = _compute_times(timeout)
  359. if sock:
  360. # Verify that the socket is connected, as if it's not connected,
  361. # it's not writable, and the polling in send_tcp() will time out or
  362. # hang forever.
  363. await sock.getpeername()
  364. cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
  365. else:
  366. # These are simple (address, port) pairs, not family-dependent tuples
  367. # you pass to low-level socket code.
  368. af = dns.inet.af_for_address(where)
  369. stuple = _source_tuple(af, source, source_port)
  370. dtuple = (where, port)
  371. if not backend:
  372. backend = dns.asyncbackend.get_default_backend()
  373. cm = await backend.make_socket(
  374. af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
  375. )
  376. async with cm as s:
  377. await send_tcp(s, wire, expiration) # pyright: ignore
  378. (r, received_time) = await receive_tcp(
  379. s, # pyright: ignore
  380. expiration,
  381. one_rr_per_rrset,
  382. q.keyring,
  383. q.mac,
  384. ignore_trailing,
  385. )
  386. r.time = received_time - begin_time
  387. if not q.is_response(r):
  388. raise BadResponse
  389. return r
  390. async def tls(
  391. q: dns.message.Message,
  392. where: str,
  393. timeout: float | None = None,
  394. port: int = 853,
  395. source: str | None = None,
  396. source_port: int = 0,
  397. one_rr_per_rrset: bool = False,
  398. ignore_trailing: bool = False,
  399. sock: dns.asyncbackend.StreamSocket | None = None,
  400. backend: dns.asyncbackend.Backend | None = None,
  401. ssl_context: ssl.SSLContext | None = None,
  402. server_hostname: str | None = None,
  403. verify: bool | str = True,
  404. ) -> dns.message.Message:
  405. """Return the response obtained after sending a query via TLS.
  406. *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
  407. to use for the query. If ``None``, the default, a socket is
  408. created. Note that if a socket is provided, it must be a
  409. connected SSL stream socket, and *where*, *port*,
  410. *source*, *source_port*, *backend*, *ssl_context*, and *server_hostname*
  411. are ignored.
  412. *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
  413. the default, then dnspython will use the default backend.
  414. See :py:func:`dns.query.tls()` for the documentation of the other
  415. parameters, exceptions, and return type of this method.
  416. """
  417. (begin_time, expiration) = _compute_times(timeout)
  418. if sock:
  419. cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
  420. else:
  421. if ssl_context is None:
  422. ssl_context = make_ssl_context(verify, server_hostname is not None, ["dot"])
  423. af = dns.inet.af_for_address(where)
  424. stuple = _source_tuple(af, source, source_port)
  425. dtuple = (where, port)
  426. if not backend:
  427. backend = dns.asyncbackend.get_default_backend()
  428. cm = await backend.make_socket(
  429. af,
  430. socket.SOCK_STREAM,
  431. 0,
  432. stuple,
  433. dtuple,
  434. timeout,
  435. ssl_context,
  436. server_hostname,
  437. )
  438. async with cm as s:
  439. timeout = _timeout(expiration)
  440. response = await tcp(
  441. q,
  442. where,
  443. timeout,
  444. port,
  445. source,
  446. source_port,
  447. one_rr_per_rrset,
  448. ignore_trailing,
  449. s,
  450. backend,
  451. )
  452. end_time = time.time()
  453. response.time = end_time - begin_time
  454. return response
  455. def _maybe_get_resolver(
  456. resolver: Optional["dns.asyncresolver.Resolver"], # pyright: ignore
  457. ) -> "dns.asyncresolver.Resolver": # pyright: ignore
  458. # We need a separate method for this to avoid overriding the global
  459. # variable "dns" with the as-yet undefined local variable "dns"
  460. # in https().
  461. if resolver is None:
  462. # pylint: disable=import-outside-toplevel,redefined-outer-name
  463. import dns.asyncresolver
  464. resolver = dns.asyncresolver.Resolver()
  465. return resolver
  466. async def https(
  467. q: dns.message.Message,
  468. where: str,
  469. timeout: float | None = None,
  470. port: int = 443,
  471. source: str | None = None,
  472. source_port: int = 0, # pylint: disable=W0613
  473. one_rr_per_rrset: bool = False,
  474. ignore_trailing: bool = False,
  475. client: Optional["httpx.AsyncClient|dns.quic.AsyncQuicConnection"] = None,
  476. path: str = "/dns-query",
  477. post: bool = True,
  478. verify: bool | str | ssl.SSLContext = True,
  479. bootstrap_address: str | None = None,
  480. resolver: Optional["dns.asyncresolver.Resolver"] = None, # pyright: ignore
  481. family: int = socket.AF_UNSPEC,
  482. http_version: HTTPVersion = HTTPVersion.DEFAULT,
  483. ) -> dns.message.Message:
  484. """Return the response obtained after sending a query via DNS-over-HTTPS.
  485. *client*, a ``httpx.AsyncClient``. If provided, the client to use for
  486. the query.
  487. Unlike the other dnspython async functions, a backend cannot be provided
  488. in this function because httpx always auto-detects the async backend.
  489. See :py:func:`dns.query.https()` for the documentation of the other
  490. parameters, exceptions, and return type of this method.
  491. """
  492. try:
  493. af = dns.inet.af_for_address(where)
  494. except ValueError:
  495. af = None
  496. # we bind url and then override as pyright can't figure out all paths bind.
  497. url = where
  498. if af is not None and dns.inet.is_address(where):
  499. if af == socket.AF_INET:
  500. url = f"https://{where}:{port}{path}"
  501. elif af == socket.AF_INET6:
  502. url = f"https://[{where}]:{port}{path}"
  503. extensions = {}
  504. if bootstrap_address is None:
  505. # pylint: disable=possibly-used-before-assignment
  506. parsed = urllib.parse.urlparse(url)
  507. if parsed.hostname is None:
  508. raise ValueError("no hostname in URL")
  509. if dns.inet.is_address(parsed.hostname):
  510. bootstrap_address = parsed.hostname
  511. extensions["sni_hostname"] = parsed.hostname
  512. if parsed.port is not None:
  513. port = parsed.port
  514. if http_version == HTTPVersion.H3 or (
  515. http_version == HTTPVersion.DEFAULT and not have_doh
  516. ):
  517. if bootstrap_address is None:
  518. resolver = _maybe_get_resolver(resolver)
  519. assert parsed.hostname is not None # pyright: ignore
  520. answers = await resolver.resolve_name( # pyright: ignore
  521. parsed.hostname, family # pyright: ignore
  522. )
  523. bootstrap_address = random.choice(list(answers.addresses()))
  524. if client and not isinstance(
  525. client, dns.quic.AsyncQuicConnection
  526. ): # pyright: ignore
  527. raise ValueError("client parameter must be a dns.quic.AsyncQuicConnection.")
  528. assert client is None or isinstance(client, dns.quic.AsyncQuicConnection)
  529. return await _http3(
  530. q,
  531. bootstrap_address,
  532. url,
  533. timeout,
  534. port,
  535. source,
  536. source_port,
  537. one_rr_per_rrset,
  538. ignore_trailing,
  539. verify=verify,
  540. post=post,
  541. connection=client,
  542. )
  543. if not have_doh:
  544. raise NoDOH # pragma: no cover
  545. # pylint: disable=possibly-used-before-assignment
  546. if client and not isinstance(client, httpx.AsyncClient): # pyright: ignore
  547. raise ValueError("client parameter must be an httpx.AsyncClient")
  548. # pylint: enable=possibly-used-before-assignment
  549. wire = q.to_wire()
  550. headers = {"accept": "application/dns-message"}
  551. h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
  552. h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)
  553. backend = dns.asyncbackend.get_default_backend()
  554. if source is None:
  555. local_address = None
  556. local_port = 0
  557. else:
  558. local_address = source
  559. local_port = source_port
  560. if client:
  561. cm: contextlib.AbstractAsyncContextManager = NullContext(client)
  562. else:
  563. transport = backend.get_transport_class()(
  564. local_address=local_address,
  565. http1=h1,
  566. http2=h2,
  567. verify=verify,
  568. local_port=local_port,
  569. bootstrap_address=bootstrap_address,
  570. resolver=resolver,
  571. family=family,
  572. )
  573. cm = httpx.AsyncClient( # pyright: ignore
  574. http1=h1, http2=h2, verify=verify, transport=transport # type: ignore
  575. )
  576. async with cm as the_client:
  577. # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
  578. # GET and POST examples
  579. if post:
  580. headers.update(
  581. {
  582. "content-type": "application/dns-message",
  583. "content-length": str(len(wire)),
  584. }
  585. )
  586. response = await backend.wait_for(
  587. the_client.post( # pyright: ignore
  588. url,
  589. headers=headers,
  590. content=wire,
  591. extensions=extensions,
  592. ),
  593. timeout,
  594. )
  595. else:
  596. wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
  597. twire = wire.decode() # httpx does a repr() if we give it bytes
  598. response = await backend.wait_for(
  599. the_client.get( # pyright: ignore
  600. url,
  601. headers=headers,
  602. params={"dns": twire},
  603. extensions=extensions,
  604. ),
  605. timeout,
  606. )
  607. # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
  608. # status codes
  609. if response.status_code < 200 or response.status_code > 299:
  610. raise ValueError(
  611. f"{where} responded with status code {response.status_code}"
  612. f"\nResponse body: {response.content!r}"
  613. )
  614. r = dns.message.from_wire(
  615. response.content,
  616. keyring=q.keyring,
  617. request_mac=q.request_mac,
  618. one_rr_per_rrset=one_rr_per_rrset,
  619. ignore_trailing=ignore_trailing,
  620. )
  621. r.time = response.elapsed.total_seconds()
  622. if not q.is_response(r):
  623. raise BadResponse
  624. return r
  625. async def _http3(
  626. q: dns.message.Message,
  627. where: str,
  628. url: str,
  629. timeout: float | None = None,
  630. port: int = 443,
  631. source: str | None = None,
  632. source_port: int = 0,
  633. one_rr_per_rrset: bool = False,
  634. ignore_trailing: bool = False,
  635. verify: bool | str | ssl.SSLContext = True,
  636. backend: dns.asyncbackend.Backend | None = None,
  637. post: bool = True,
  638. connection: dns.quic.AsyncQuicConnection | None = None,
  639. ) -> dns.message.Message:
  640. if not dns.quic.have_quic:
  641. raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
  642. url_parts = urllib.parse.urlparse(url)
  643. hostname = url_parts.hostname
  644. assert hostname is not None
  645. if url_parts.port is not None:
  646. port = url_parts.port
  647. q.id = 0
  648. wire = q.to_wire()
  649. the_connection: dns.quic.AsyncQuicConnection
  650. if connection:
  651. cfactory = dns.quic.null_factory
  652. mfactory = dns.quic.null_factory
  653. else:
  654. (cfactory, mfactory) = dns.quic.factories_for_backend(backend)
  655. async with cfactory() as context:
  656. async with mfactory(
  657. context, verify_mode=verify, server_name=hostname, h3=True
  658. ) as the_manager:
  659. if connection:
  660. the_connection = connection
  661. else:
  662. the_connection = the_manager.connect( # pyright: ignore
  663. where, port, source, source_port
  664. )
  665. (start, expiration) = _compute_times(timeout)
  666. stream = await the_connection.make_stream(timeout) # pyright: ignore
  667. async with stream:
  668. # note that send_h3() does not need await
  669. stream.send_h3(url, wire, post)
  670. wire = await stream.receive(_remaining(expiration))
  671. _check_status(stream.headers(), where, wire)
  672. finish = time.time()
  673. r = dns.message.from_wire(
  674. wire,
  675. keyring=q.keyring,
  676. request_mac=q.request_mac,
  677. one_rr_per_rrset=one_rr_per_rrset,
  678. ignore_trailing=ignore_trailing,
  679. )
  680. r.time = max(finish - start, 0.0)
  681. if not q.is_response(r):
  682. raise BadResponse
  683. return r
  684. async def quic(
  685. q: dns.message.Message,
  686. where: str,
  687. timeout: float | None = None,
  688. port: int = 853,
  689. source: str | None = None,
  690. source_port: int = 0,
  691. one_rr_per_rrset: bool = False,
  692. ignore_trailing: bool = False,
  693. connection: dns.quic.AsyncQuicConnection | None = None,
  694. verify: bool | str = True,
  695. backend: dns.asyncbackend.Backend | None = None,
  696. hostname: str | None = None,
  697. server_hostname: str | None = None,
  698. ) -> dns.message.Message:
  699. """Return the response obtained after sending an asynchronous query via
  700. DNS-over-QUIC.
  701. *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
  702. the default, then dnspython will use the default backend.
  703. See :py:func:`dns.query.quic()` for the documentation of the other
  704. parameters, exceptions, and return type of this method.
  705. """
  706. if not dns.quic.have_quic:
  707. raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
  708. if server_hostname is not None and hostname is None:
  709. hostname = server_hostname
  710. q.id = 0
  711. wire = q.to_wire()
  712. the_connection: dns.quic.AsyncQuicConnection
  713. if connection:
  714. cfactory = dns.quic.null_factory
  715. mfactory = dns.quic.null_factory
  716. the_connection = connection
  717. else:
  718. (cfactory, mfactory) = dns.quic.factories_for_backend(backend)
  719. async with cfactory() as context:
  720. async with mfactory(
  721. context,
  722. verify_mode=verify,
  723. server_name=server_hostname,
  724. ) as the_manager:
  725. if not connection:
  726. the_connection = the_manager.connect( # pyright: ignore
  727. where, port, source, source_port
  728. )
  729. (start, expiration) = _compute_times(timeout)
  730. stream = await the_connection.make_stream(timeout) # pyright: ignore
  731. async with stream:
  732. await stream.send(wire, True)
  733. wire = await stream.receive(_remaining(expiration))
  734. finish = time.time()
  735. r = dns.message.from_wire(
  736. wire,
  737. keyring=q.keyring,
  738. request_mac=q.request_mac,
  739. one_rr_per_rrset=one_rr_per_rrset,
  740. ignore_trailing=ignore_trailing,
  741. )
  742. r.time = max(finish - start, 0.0)
  743. if not q.is_response(r):
  744. raise BadResponse
  745. return r
  746. async def _inbound_xfr(
  747. txn_manager: dns.transaction.TransactionManager,
  748. s: dns.asyncbackend.Socket,
  749. query: dns.message.Message,
  750. serial: int | None,
  751. timeout: float | None,
  752. expiration: float,
  753. ) -> Any:
  754. """Given a socket, does the zone transfer."""
  755. rdtype = query.question[0].rdtype
  756. is_ixfr = rdtype == dns.rdatatype.IXFR
  757. origin = txn_manager.from_wire_origin()
  758. wire = query.to_wire()
  759. is_udp = s.type == socket.SOCK_DGRAM
  760. if is_udp:
  761. udp_sock = cast(dns.asyncbackend.DatagramSocket, s)
  762. await udp_sock.sendto(wire, None, _timeout(expiration))
  763. else:
  764. tcp_sock = cast(dns.asyncbackend.StreamSocket, s)
  765. tcpmsg = struct.pack("!H", len(wire)) + wire
  766. await tcp_sock.sendall(tcpmsg, expiration)
  767. with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
  768. done = False
  769. tsig_ctx = None
  770. r: dns.message.Message | None = None
  771. while not done:
  772. (_, mexpiration) = _compute_times(timeout)
  773. if mexpiration is None or (
  774. expiration is not None and mexpiration > expiration
  775. ):
  776. mexpiration = expiration
  777. if is_udp:
  778. timeout = _timeout(mexpiration)
  779. (rwire, _) = await udp_sock.recvfrom(65535, timeout) # pyright: ignore
  780. else:
  781. ldata = await _read_exactly(tcp_sock, 2, mexpiration) # pyright: ignore
  782. (l,) = struct.unpack("!H", ldata)
  783. rwire = await _read_exactly(tcp_sock, l, mexpiration) # pyright: ignore
  784. r = dns.message.from_wire(
  785. rwire,
  786. keyring=query.keyring,
  787. request_mac=query.mac,
  788. xfr=True,
  789. origin=origin,
  790. tsig_ctx=tsig_ctx,
  791. multi=(not is_udp),
  792. one_rr_per_rrset=is_ixfr,
  793. )
  794. done = inbound.process_message(r)
  795. yield r
  796. tsig_ctx = r.tsig_ctx
  797. if query.keyring and r is not None and not r.had_tsig:
  798. raise dns.exception.FormError("missing TSIG")
  799. async def inbound_xfr(
  800. where: str,
  801. txn_manager: dns.transaction.TransactionManager,
  802. query: dns.message.Message | None = None,
  803. port: int = 53,
  804. timeout: float | None = None,
  805. lifetime: float | None = None,
  806. source: str | None = None,
  807. source_port: int = 0,
  808. udp_mode: UDPMode = UDPMode.NEVER,
  809. backend: dns.asyncbackend.Backend | None = None,
  810. ) -> None:
  811. """Conduct an inbound transfer and apply it via a transaction from the
  812. txn_manager.
  813. *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
  814. the default, then dnspython will use the default backend.
  815. See :py:func:`dns.query.inbound_xfr()` for the documentation of
  816. the other parameters, exceptions, and return type of this method.
  817. """
  818. if query is None:
  819. (query, serial) = dns.xfr.make_query(txn_manager)
  820. else:
  821. serial = dns.xfr.extract_serial_from_query(query)
  822. af = dns.inet.af_for_address(where)
  823. stuple = _source_tuple(af, source, source_port)
  824. dtuple = (where, port)
  825. if not backend:
  826. backend = dns.asyncbackend.get_default_backend()
  827. (_, expiration) = _compute_times(lifetime)
  828. if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
  829. s = await backend.make_socket(
  830. af, socket.SOCK_DGRAM, 0, stuple, dtuple, _timeout(expiration)
  831. )
  832. async with s:
  833. try:
  834. async for _ in _inbound_xfr( # pyright: ignore
  835. txn_manager,
  836. s,
  837. query,
  838. serial,
  839. timeout,
  840. expiration, # pyright: ignore
  841. ):
  842. pass
  843. return
  844. except dns.xfr.UseTCP:
  845. if udp_mode == UDPMode.ONLY:
  846. raise
  847. s = await backend.make_socket(
  848. af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration)
  849. )
  850. async with s:
  851. async for _ in _inbound_xfr( # pyright: ignore
  852. txn_manager, s, query, serial, timeout, expiration # pyright: ignore
  853. ):
  854. pass