_asyncio.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. import asyncio
  3. import socket
  4. import ssl
  5. import struct
  6. import time
  7. import aioquic.h3.connection # type: ignore
  8. import aioquic.h3.events # type: ignore
  9. import aioquic.quic.configuration # type: ignore
  10. import aioquic.quic.connection # type: ignore
  11. import aioquic.quic.events # type: ignore
  12. import dns.asyncbackend
  13. import dns.exception
  14. import dns.inet
  15. from dns.quic._common import (
  16. QUIC_MAX_DATAGRAM,
  17. AsyncQuicConnection,
  18. AsyncQuicManager,
  19. BaseQuicStream,
  20. UnexpectedEOF,
  21. )
  22. class AsyncioQuicStream(BaseQuicStream):
  23. def __init__(self, connection, stream_id):
  24. super().__init__(connection, stream_id)
  25. self._wake_up = asyncio.Condition()
  26. async def _wait_for_wake_up(self):
  27. async with self._wake_up:
  28. await self._wake_up.wait()
  29. async def wait_for(self, amount, expiration):
  30. while True:
  31. timeout = self._timeout_from_expiration(expiration)
  32. if self._buffer.have(amount):
  33. return
  34. self._expecting = amount
  35. try:
  36. await asyncio.wait_for(self._wait_for_wake_up(), timeout)
  37. except TimeoutError:
  38. raise dns.exception.Timeout
  39. self._expecting = 0
  40. async def wait_for_end(self, expiration):
  41. while True:
  42. timeout = self._timeout_from_expiration(expiration)
  43. if self._buffer.seen_end():
  44. return
  45. try:
  46. await asyncio.wait_for(self._wait_for_wake_up(), timeout)
  47. except TimeoutError:
  48. raise dns.exception.Timeout
  49. async def receive(self, timeout=None):
  50. expiration = self._expiration_from_timeout(timeout)
  51. if self._connection.is_h3():
  52. await self.wait_for_end(expiration)
  53. return self._buffer.get_all()
  54. else:
  55. await self.wait_for(2, expiration)
  56. (size,) = struct.unpack("!H", self._buffer.get(2))
  57. await self.wait_for(size, expiration)
  58. return self._buffer.get(size)
  59. async def send(self, datagram, is_end=False):
  60. data = self._encapsulate(datagram)
  61. await self._connection.write(self._stream_id, data, is_end)
  62. async def _add_input(self, data, is_end):
  63. if self._common_add_input(data, is_end):
  64. async with self._wake_up:
  65. self._wake_up.notify()
  66. async def close(self):
  67. self._close()
  68. # Streams are async context managers
  69. async def __aenter__(self):
  70. return self
  71. async def __aexit__(self, exc_type, exc_val, exc_tb):
  72. await self.close()
  73. async with self._wake_up:
  74. self._wake_up.notify()
  75. return False
  76. class AsyncioQuicConnection(AsyncQuicConnection):
  77. def __init__(self, connection, address, port, source, source_port, manager=None):
  78. super().__init__(connection, address, port, source, source_port, manager)
  79. self._socket = None
  80. self._handshake_complete = asyncio.Event()
  81. self._socket_created = asyncio.Event()
  82. self._wake_timer = asyncio.Condition()
  83. self._receiver_task = None
  84. self._sender_task = None
  85. self._wake_pending = False
  86. async def _receiver(self):
  87. try:
  88. af = dns.inet.af_for_address(self._address)
  89. backend = dns.asyncbackend.get_backend("asyncio")
  90. # Note that peer is a low-level address tuple, but make_socket() wants
  91. # a high-level address tuple, so we convert.
  92. self._socket = await backend.make_socket(
  93. af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1])
  94. )
  95. self._socket_created.set()
  96. async with self._socket:
  97. while not self._done:
  98. (datagram, address) = await self._socket.recvfrom(
  99. QUIC_MAX_DATAGRAM, None
  100. )
  101. if address[0] != self._peer[0] or address[1] != self._peer[1]:
  102. continue
  103. self._connection.receive_datagram(datagram, address, time.time())
  104. # Wake up the timer in case the sender is sleeping, as there may be
  105. # stuff to send now.
  106. await self._wakeup()
  107. except Exception:
  108. pass
  109. finally:
  110. self._done = True
  111. await self._wakeup()
  112. self._handshake_complete.set()
  113. async def _wakeup(self):
  114. self._wake_pending = True
  115. async with self._wake_timer:
  116. self._wake_timer.notify_all()
  117. async def _wait_for_wake_timer(self):
  118. async with self._wake_timer:
  119. if not self._wake_pending:
  120. await self._wake_timer.wait()
  121. self._wake_pending = False
  122. async def _sender(self):
  123. await self._socket_created.wait()
  124. while not self._done:
  125. datagrams = self._connection.datagrams_to_send(time.time())
  126. for datagram, address in datagrams:
  127. assert address == self._peer
  128. assert self._socket is not None
  129. await self._socket.sendto(datagram, self._peer, None)
  130. (expiration, interval) = self._get_timer_values()
  131. try:
  132. await asyncio.wait_for(self._wait_for_wake_timer(), interval)
  133. except Exception:
  134. pass
  135. self._handle_timer(expiration)
  136. await self._handle_events()
  137. async def _handle_events(self):
  138. count = 0
  139. while True:
  140. event = self._connection.next_event()
  141. if event is None:
  142. return
  143. if isinstance(event, aioquic.quic.events.StreamDataReceived):
  144. if self.is_h3():
  145. assert self._h3_conn is not None
  146. h3_events = self._h3_conn.handle_event(event)
  147. for h3_event in h3_events:
  148. if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
  149. stream = self._streams.get(event.stream_id)
  150. if stream:
  151. if stream._headers is None:
  152. stream._headers = h3_event.headers
  153. elif stream._trailers is None:
  154. stream._trailers = h3_event.headers
  155. if h3_event.stream_ended:
  156. await stream._add_input(b"", True)
  157. elif isinstance(h3_event, aioquic.h3.events.DataReceived):
  158. stream = self._streams.get(event.stream_id)
  159. if stream:
  160. await stream._add_input(
  161. h3_event.data, h3_event.stream_ended
  162. )
  163. else:
  164. stream = self._streams.get(event.stream_id)
  165. if stream:
  166. await stream._add_input(event.data, event.end_stream)
  167. elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
  168. self._handshake_complete.set()
  169. elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
  170. self._done = True
  171. if self._receiver_task is not None:
  172. self._receiver_task.cancel()
  173. elif isinstance(event, aioquic.quic.events.StreamReset):
  174. stream = self._streams.get(event.stream_id)
  175. if stream:
  176. await stream._add_input(b"", True)
  177. count += 1
  178. if count > 10:
  179. # yield
  180. count = 0
  181. await asyncio.sleep(0)
  182. async def write(self, stream, data, is_end=False):
  183. self._connection.send_stream_data(stream, data, is_end)
  184. await self._wakeup()
  185. def run(self):
  186. if self._closed:
  187. return
  188. self._receiver_task = asyncio.Task(self._receiver())
  189. self._sender_task = asyncio.Task(self._sender())
  190. async def make_stream(self, timeout=None):
  191. try:
  192. await asyncio.wait_for(self._handshake_complete.wait(), timeout)
  193. except TimeoutError:
  194. raise dns.exception.Timeout
  195. if self._done:
  196. raise UnexpectedEOF
  197. stream_id = self._connection.get_next_available_stream_id(False)
  198. stream = AsyncioQuicStream(self, stream_id)
  199. self._streams[stream_id] = stream
  200. return stream
  201. async def close(self):
  202. if not self._closed:
  203. if self._manager is not None:
  204. self._manager.closed(self._peer[0], self._peer[1])
  205. self._closed = True
  206. self._connection.close()
  207. # sender might be blocked on this, so set it
  208. self._socket_created.set()
  209. await self._wakeup()
  210. try:
  211. if self._receiver_task is not None:
  212. await self._receiver_task
  213. except asyncio.CancelledError:
  214. pass
  215. try:
  216. if self._sender_task is not None:
  217. await self._sender_task
  218. except asyncio.CancelledError:
  219. pass
  220. if self._socket is not None:
  221. await self._socket.close()
  222. class AsyncioQuicManager(AsyncQuicManager):
  223. def __init__(
  224. self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
  225. ):
  226. super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name, h3)
  227. def connect(
  228. self, address, port=853, source=None, source_port=0, want_session_ticket=True
  229. ):
  230. (connection, start) = self._connect(
  231. address, port, source, source_port, want_session_ticket
  232. )
  233. if start:
  234. connection.run()
  235. return connection
  236. async def __aenter__(self):
  237. return self
  238. async def __aexit__(self, exc_type, exc_val, exc_tb):
  239. # Copy the iterator into a list as exiting things will mutate the connections
  240. # table.
  241. connections = list(self._connections.values())
  242. for connection in connections:
  243. await connection.close()
  244. return False