_asyncio_backend.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. """asyncio library query support"""
  3. import asyncio
  4. import socket
  5. import sys
  6. import dns._asyncbackend
  7. import dns._features
  8. import dns.exception
  9. import dns.inet
  10. _is_win32 = sys.platform == "win32"
  11. def _get_running_loop():
  12. try:
  13. return asyncio.get_running_loop()
  14. except AttributeError: # pragma: no cover
  15. return asyncio.get_event_loop()
  16. class _DatagramProtocol:
  17. def __init__(self):
  18. self.transport = None
  19. self.recvfrom = None
  20. def connection_made(self, transport):
  21. self.transport = transport
  22. def datagram_received(self, data, addr):
  23. if self.recvfrom and not self.recvfrom.done():
  24. self.recvfrom.set_result((data, addr))
  25. def error_received(self, exc): # pragma: no cover
  26. if self.recvfrom and not self.recvfrom.done():
  27. self.recvfrom.set_exception(exc)
  28. def connection_lost(self, exc):
  29. if self.recvfrom and not self.recvfrom.done():
  30. if exc is None:
  31. # EOF we triggered. Is there a better way to do this?
  32. try:
  33. raise EOFError("EOF")
  34. except EOFError as e:
  35. self.recvfrom.set_exception(e)
  36. else:
  37. self.recvfrom.set_exception(exc)
  38. def close(self):
  39. if self.transport is not None:
  40. self.transport.close()
  41. async def _maybe_wait_for(awaitable, timeout):
  42. if timeout is not None:
  43. try:
  44. return await asyncio.wait_for(awaitable, timeout)
  45. except asyncio.TimeoutError:
  46. raise dns.exception.Timeout(timeout=timeout)
  47. else:
  48. return await awaitable
  49. class DatagramSocket(dns._asyncbackend.DatagramSocket):
  50. def __init__(self, family, transport, protocol):
  51. super().__init__(family, socket.SOCK_DGRAM)
  52. self.transport = transport
  53. self.protocol = protocol
  54. async def sendto(self, what, destination, timeout): # pragma: no cover
  55. # no timeout for asyncio sendto
  56. self.transport.sendto(what, destination)
  57. return len(what)
  58. async def recvfrom(self, size, timeout):
  59. # ignore size as there's no way I know to tell protocol about it
  60. done = _get_running_loop().create_future()
  61. try:
  62. assert self.protocol.recvfrom is None
  63. self.protocol.recvfrom = done
  64. await _maybe_wait_for(done, timeout)
  65. return done.result()
  66. finally:
  67. self.protocol.recvfrom = None
  68. async def close(self):
  69. self.protocol.close()
  70. async def getpeername(self):
  71. return self.transport.get_extra_info("peername")
  72. async def getsockname(self):
  73. return self.transport.get_extra_info("sockname")
  74. async def getpeercert(self, timeout):
  75. raise NotImplementedError
  76. class StreamSocket(dns._asyncbackend.StreamSocket):
  77. def __init__(self, af, reader, writer):
  78. super().__init__(af, socket.SOCK_STREAM)
  79. self.reader = reader
  80. self.writer = writer
  81. async def sendall(self, what, timeout):
  82. self.writer.write(what)
  83. return await _maybe_wait_for(self.writer.drain(), timeout)
  84. async def recv(self, size, timeout):
  85. return await _maybe_wait_for(self.reader.read(size), timeout)
  86. async def close(self):
  87. self.writer.close()
  88. async def getpeername(self):
  89. return self.writer.get_extra_info("peername")
  90. async def getsockname(self):
  91. return self.writer.get_extra_info("sockname")
  92. async def getpeercert(self, timeout):
  93. return self.writer.get_extra_info("peercert")
  94. if dns._features.have("doh"):
  95. import anyio
  96. import httpcore
  97. import httpcore._backends.anyio
  98. import httpx
  99. _CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
  100. _CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream # pyright: ignore
  101. from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
  102. class _NetworkBackend(_CoreAsyncNetworkBackend):
  103. def __init__(self, resolver, local_port, bootstrap_address, family):
  104. super().__init__()
  105. self._local_port = local_port
  106. self._resolver = resolver
  107. self._bootstrap_address = bootstrap_address
  108. self._family = family
  109. if local_port != 0:
  110. raise NotImplementedError(
  111. "the asyncio transport for HTTPX cannot set the local port"
  112. )
  113. async def connect_tcp(
  114. self, host, port, timeout=None, local_address=None, socket_options=None
  115. ): # pylint: disable=signature-differs
  116. addresses = []
  117. _, expiration = _compute_times(timeout)
  118. if dns.inet.is_address(host):
  119. addresses.append(host)
  120. elif self._bootstrap_address is not None:
  121. addresses.append(self._bootstrap_address)
  122. else:
  123. timeout = _remaining(expiration)
  124. family = self._family
  125. if local_address:
  126. family = dns.inet.af_for_address(local_address)
  127. answers = await self._resolver.resolve_name(
  128. host, family=family, lifetime=timeout
  129. )
  130. addresses = answers.addresses()
  131. for address in addresses:
  132. try:
  133. attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
  134. timeout = _remaining(attempt_expiration)
  135. with anyio.fail_after(timeout):
  136. stream = await anyio.connect_tcp(
  137. remote_host=address,
  138. remote_port=port,
  139. local_host=local_address,
  140. )
  141. return _CoreAnyIOStream(stream)
  142. except Exception:
  143. pass
  144. raise httpcore.ConnectError
  145. async def connect_unix_socket(
  146. self, path, timeout=None, socket_options=None
  147. ): # pylint: disable=signature-differs
  148. raise NotImplementedError
  149. async def sleep(self, seconds): # pylint: disable=signature-differs
  150. await anyio.sleep(seconds)
  151. class _HTTPTransport(httpx.AsyncHTTPTransport):
  152. def __init__(
  153. self,
  154. *args,
  155. local_port=0,
  156. bootstrap_address=None,
  157. resolver=None,
  158. family=socket.AF_UNSPEC,
  159. **kwargs,
  160. ):
  161. if resolver is None and bootstrap_address is None:
  162. # pylint: disable=import-outside-toplevel,redefined-outer-name
  163. import dns.asyncresolver
  164. resolver = dns.asyncresolver.Resolver()
  165. super().__init__(*args, **kwargs)
  166. self._pool._network_backend = _NetworkBackend(
  167. resolver, local_port, bootstrap_address, family
  168. )
  169. else:
  170. _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
  171. class Backend(dns._asyncbackend.Backend):
  172. def name(self):
  173. return "asyncio"
  174. async def make_socket(
  175. self,
  176. af,
  177. socktype,
  178. proto=0,
  179. source=None,
  180. destination=None,
  181. timeout=None,
  182. ssl_context=None,
  183. server_hostname=None,
  184. ):
  185. loop = _get_running_loop()
  186. if socktype == socket.SOCK_DGRAM:
  187. if _is_win32 and source is None:
  188. # Win32 wants explicit binding before recvfrom(). This is the
  189. # proper fix for [#637].
  190. source = (dns.inet.any_for_af(af), 0)
  191. transport, protocol = await loop.create_datagram_endpoint(
  192. _DatagramProtocol, # pyright: ignore
  193. source,
  194. family=af,
  195. proto=proto,
  196. remote_addr=destination,
  197. )
  198. return DatagramSocket(af, transport, protocol)
  199. elif socktype == socket.SOCK_STREAM:
  200. if destination is None:
  201. # This shouldn't happen, but we check to make code analysis software
  202. # happier.
  203. raise ValueError("destination required for stream sockets")
  204. (r, w) = await _maybe_wait_for(
  205. asyncio.open_connection(
  206. destination[0],
  207. destination[1],
  208. ssl=ssl_context,
  209. family=af,
  210. proto=proto,
  211. local_addr=source,
  212. server_hostname=server_hostname,
  213. ),
  214. timeout,
  215. )
  216. return StreamSocket(af, r, w)
  217. raise NotImplementedError(
  218. "unsupported socket " + f"type {socktype}"
  219. ) # pragma: no cover
  220. async def sleep(self, interval):
  221. await asyncio.sleep(interval)
  222. def datagram_connection_required(self):
  223. return False
  224. def get_transport_class(self):
  225. return _HTTPTransport
  226. async def wait_for(self, awaitable, timeout):
  227. return await _maybe_wait_for(awaitable, timeout)