_trio.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. import socket
  3. import ssl
  4. import struct
  5. import time
  6. import aioquic.h3.connection # type: ignore
  7. import aioquic.h3.events # type: ignore
  8. import aioquic.quic.configuration # type: ignore
  9. import aioquic.quic.connection # type: ignore
  10. import aioquic.quic.events # type: ignore
  11. import trio
  12. import dns.exception
  13. import dns.inet
  14. from dns._asyncbackend import NullContext
  15. from dns.quic._common import (
  16. QUIC_MAX_DATAGRAM,
  17. AsyncQuicConnection,
  18. AsyncQuicManager,
  19. BaseQuicStream,
  20. UnexpectedEOF,
  21. )
  22. class TrioQuicStream(BaseQuicStream):
  23. def __init__(self, connection, stream_id):
  24. super().__init__(connection, stream_id)
  25. self._wake_up = trio.Condition()
  26. async def wait_for(self, amount):
  27. while True:
  28. if self._buffer.have(amount):
  29. return
  30. self._expecting = amount
  31. async with self._wake_up:
  32. await self._wake_up.wait()
  33. self._expecting = 0
  34. async def wait_for_end(self):
  35. while True:
  36. if self._buffer.seen_end():
  37. return
  38. async with self._wake_up:
  39. await self._wake_up.wait()
  40. async def receive(self, timeout=None):
  41. if timeout is None:
  42. context = NullContext(None)
  43. else:
  44. context = trio.move_on_after(timeout)
  45. with context:
  46. if self._connection.is_h3():
  47. await self.wait_for_end()
  48. return self._buffer.get_all()
  49. else:
  50. await self.wait_for(2)
  51. (size,) = struct.unpack("!H", self._buffer.get(2))
  52. await self.wait_for(size)
  53. return self._buffer.get(size)
  54. raise dns.exception.Timeout
  55. async def send(self, datagram, is_end=False):
  56. data = self._encapsulate(datagram)
  57. await self._connection.write(self._stream_id, data, is_end)
  58. async def _add_input(self, data, is_end):
  59. if self._common_add_input(data, is_end):
  60. async with self._wake_up:
  61. self._wake_up.notify()
  62. async def close(self):
  63. self._close()
  64. # Streams are async context managers
  65. async def __aenter__(self):
  66. return self
  67. async def __aexit__(self, exc_type, exc_val, exc_tb):
  68. await self.close()
  69. async with self._wake_up:
  70. self._wake_up.notify()
  71. return False
  72. class TrioQuicConnection(AsyncQuicConnection):
  73. def __init__(self, connection, address, port, source, source_port, manager=None):
  74. super().__init__(connection, address, port, source, source_port, manager)
  75. self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
  76. self._handshake_complete = trio.Event()
  77. self._run_done = trio.Event()
  78. self._worker_scope = None
  79. self._send_pending = False
  80. async def _worker(self):
  81. try:
  82. if self._source:
  83. await self._socket.bind(
  84. dns.inet.low_level_address_tuple(self._source, self._af)
  85. )
  86. await self._socket.connect(self._peer)
  87. while not self._done:
  88. (expiration, interval) = self._get_timer_values(False)
  89. if self._send_pending:
  90. # Do not block forever if sends are pending. Even though we
  91. # have a wake-up mechanism if we've already started the blocking
  92. # read, the possibility of context switching in send means that
  93. # more writes can happen while we have no wake up context, so
  94. # we need self._send_pending to avoid (effectively) a "lost wakeup"
  95. # race.
  96. interval = 0.0
  97. with trio.CancelScope(
  98. deadline=trio.current_time() + interval # pyright: ignore
  99. ) as self._worker_scope:
  100. datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
  101. self._connection.receive_datagram(datagram, self._peer, time.time())
  102. self._worker_scope = None
  103. self._handle_timer(expiration)
  104. await self._handle_events()
  105. # We clear this now, before sending anything, as sending can cause
  106. # context switches that do more sends. We want to know if that
  107. # happens so we don't block a long time on the recv() above.
  108. self._send_pending = False
  109. datagrams = self._connection.datagrams_to_send(time.time())
  110. for datagram, _ in datagrams:
  111. await self._socket.send(datagram)
  112. finally:
  113. self._done = True
  114. self._socket.close()
  115. self._handshake_complete.set()
  116. async def _handle_events(self):
  117. count = 0
  118. while True:
  119. event = self._connection.next_event()
  120. if event is None:
  121. return
  122. if isinstance(event, aioquic.quic.events.StreamDataReceived):
  123. if self.is_h3():
  124. assert self._h3_conn is not None
  125. h3_events = self._h3_conn.handle_event(event)
  126. for h3_event in h3_events:
  127. if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
  128. stream = self._streams.get(event.stream_id)
  129. if stream:
  130. if stream._headers is None:
  131. stream._headers = h3_event.headers
  132. elif stream._trailers is None:
  133. stream._trailers = h3_event.headers
  134. if h3_event.stream_ended:
  135. await stream._add_input(b"", True)
  136. elif isinstance(h3_event, aioquic.h3.events.DataReceived):
  137. stream = self._streams.get(event.stream_id)
  138. if stream:
  139. await stream._add_input(
  140. h3_event.data, h3_event.stream_ended
  141. )
  142. else:
  143. stream = self._streams.get(event.stream_id)
  144. if stream:
  145. await stream._add_input(event.data, event.end_stream)
  146. elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
  147. self._handshake_complete.set()
  148. elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
  149. self._done = True
  150. self._socket.close()
  151. elif isinstance(event, aioquic.quic.events.StreamReset):
  152. stream = self._streams.get(event.stream_id)
  153. if stream:
  154. await stream._add_input(b"", True)
  155. count += 1
  156. if count > 10:
  157. # yield
  158. count = 0
  159. await trio.sleep(0)
  160. async def write(self, stream, data, is_end=False):
  161. self._connection.send_stream_data(stream, data, is_end)
  162. self._send_pending = True
  163. if self._worker_scope is not None:
  164. self._worker_scope.cancel()
  165. async def run(self):
  166. if self._closed:
  167. return
  168. async with trio.open_nursery() as nursery:
  169. nursery.start_soon(self._worker)
  170. self._run_done.set()
  171. async def make_stream(self, timeout=None):
  172. if timeout is None:
  173. context = NullContext(None)
  174. else:
  175. context = trio.move_on_after(timeout)
  176. with context:
  177. await self._handshake_complete.wait()
  178. if self._done:
  179. raise UnexpectedEOF
  180. stream_id = self._connection.get_next_available_stream_id(False)
  181. stream = TrioQuicStream(self, stream_id)
  182. self._streams[stream_id] = stream
  183. return stream
  184. raise dns.exception.Timeout
  185. async def close(self):
  186. if not self._closed:
  187. if self._manager is not None:
  188. self._manager.closed(self._peer[0], self._peer[1])
  189. self._closed = True
  190. self._connection.close()
  191. self._send_pending = True
  192. if self._worker_scope is not None:
  193. self._worker_scope.cancel()
  194. await self._run_done.wait()
  195. class TrioQuicManager(AsyncQuicManager):
  196. def __init__(
  197. self,
  198. nursery,
  199. conf=None,
  200. verify_mode=ssl.CERT_REQUIRED,
  201. server_name=None,
  202. h3=False,
  203. ):
  204. super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3)
  205. self._nursery = nursery
  206. def connect(
  207. self, address, port=853, source=None, source_port=0, want_session_ticket=True
  208. ):
  209. (connection, start) = self._connect(
  210. address, port, source, source_port, want_session_ticket
  211. )
  212. if start:
  213. self._nursery.start_soon(connection.run)
  214. return connection
  215. async def __aenter__(self):
  216. return self
  217. async def __aexit__(self, exc_type, exc_val, exc_tb):
  218. # Copy the iterator into a list as exiting things will mutate the connections
  219. # table.
  220. connections = list(self._connections.values())
  221. for connection in connections:
  222. await connection.close()
  223. return False