_common.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. import base64
  3. import copy
  4. import functools
  5. import socket
  6. import struct
  7. import time
  8. import urllib.parse
  9. from typing import Any
  10. import aioquic.h3.connection # type: ignore
  11. import aioquic.quic.configuration # type: ignore
  12. import aioquic.quic.connection # type: ignore
  13. import dns._tls_util
  14. import dns.inet
  15. QUIC_MAX_DATAGRAM = 2048
  16. MAX_SESSION_TICKETS = 8
  17. # If we hit the max sessions limit we will delete this many of the oldest connections.
  18. # The value must be a integer > 0 and <= MAX_SESSION_TICKETS.
  19. SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4
  20. class UnexpectedEOF(Exception):
  21. pass
  22. class Buffer:
  23. def __init__(self):
  24. self._buffer = b""
  25. self._seen_end = False
  26. def put(self, data, is_end):
  27. if self._seen_end:
  28. return
  29. self._buffer += data
  30. if is_end:
  31. self._seen_end = True
  32. def have(self, amount):
  33. if len(self._buffer) >= amount:
  34. return True
  35. if self._seen_end:
  36. raise UnexpectedEOF
  37. return False
  38. def seen_end(self):
  39. return self._seen_end
  40. def get(self, amount):
  41. assert self.have(amount)
  42. data = self._buffer[:amount]
  43. self._buffer = self._buffer[amount:]
  44. return data
  45. def get_all(self):
  46. assert self.seen_end()
  47. data = self._buffer
  48. self._buffer = b""
  49. return data
  50. class BaseQuicStream:
  51. def __init__(self, connection, stream_id):
  52. self._connection = connection
  53. self._stream_id = stream_id
  54. self._buffer = Buffer()
  55. self._expecting = 0
  56. self._headers = None
  57. self._trailers = None
  58. def id(self):
  59. return self._stream_id
  60. def headers(self):
  61. return self._headers
  62. def trailers(self):
  63. return self._trailers
  64. def _expiration_from_timeout(self, timeout):
  65. if timeout is not None:
  66. expiration = time.time() + timeout
  67. else:
  68. expiration = None
  69. return expiration
  70. def _timeout_from_expiration(self, expiration):
  71. if expiration is not None:
  72. timeout = max(expiration - time.time(), 0.0)
  73. else:
  74. timeout = None
  75. return timeout
  76. # Subclass must implement receive() as sync / async and which returns a message
  77. # or raises.
  78. # Subclass must implement send() as sync / async and which takes a message and
  79. # an EOF indicator.
  80. def send_h3(self, url, datagram, post=True):
  81. if not self._connection.is_h3():
  82. raise SyntaxError("cannot send H3 to a non-H3 connection")
  83. url_parts = urllib.parse.urlparse(url)
  84. path = url_parts.path.encode()
  85. if post:
  86. method = b"POST"
  87. else:
  88. method = b"GET"
  89. path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=")
  90. headers = [
  91. (b":method", method),
  92. (b":scheme", url_parts.scheme.encode()),
  93. (b":authority", url_parts.netloc.encode()),
  94. (b":path", path),
  95. (b"accept", b"application/dns-message"),
  96. ]
  97. if post:
  98. headers.extend(
  99. [
  100. (b"content-type", b"application/dns-message"),
  101. (b"content-length", str(len(datagram)).encode()),
  102. ]
  103. )
  104. self._connection.send_headers(self._stream_id, headers, not post)
  105. if post:
  106. self._connection.send_data(self._stream_id, datagram, True)
  107. def _encapsulate(self, datagram):
  108. if self._connection.is_h3():
  109. return datagram
  110. l = len(datagram)
  111. return struct.pack("!H", l) + datagram
  112. def _common_add_input(self, data, is_end):
  113. self._buffer.put(data, is_end)
  114. try:
  115. return (
  116. self._expecting > 0 and self._buffer.have(self._expecting)
  117. ) or self._buffer.seen_end
  118. except UnexpectedEOF:
  119. return True
  120. def _close(self):
  121. self._connection.close_stream(self._stream_id)
  122. self._buffer.put(b"", True) # send EOF in case we haven't seen it.
  123. class BaseQuicConnection:
  124. def __init__(
  125. self,
  126. connection,
  127. address,
  128. port,
  129. source=None,
  130. source_port=0,
  131. manager=None,
  132. ):
  133. self._done = False
  134. self._connection = connection
  135. self._address = address
  136. self._port = port
  137. self._closed = False
  138. self._manager = manager
  139. self._streams = {}
  140. if manager is not None and manager.is_h3():
  141. self._h3_conn = aioquic.h3.connection.H3Connection(connection, False)
  142. else:
  143. self._h3_conn = None
  144. self._af = dns.inet.af_for_address(address)
  145. self._peer = dns.inet.low_level_address_tuple((address, port))
  146. if source is None and source_port != 0:
  147. if self._af == socket.AF_INET:
  148. source = "0.0.0.0"
  149. elif self._af == socket.AF_INET6:
  150. source = "::"
  151. else:
  152. raise NotImplementedError
  153. if source:
  154. self._source = (source, source_port)
  155. else:
  156. self._source = None
  157. def is_h3(self):
  158. return self._h3_conn is not None
  159. def close_stream(self, stream_id):
  160. del self._streams[stream_id]
  161. def send_headers(self, stream_id, headers, is_end=False):
  162. assert self._h3_conn is not None
  163. self._h3_conn.send_headers(stream_id, headers, is_end)
  164. def send_data(self, stream_id, data, is_end=False):
  165. assert self._h3_conn is not None
  166. self._h3_conn.send_data(stream_id, data, is_end)
  167. def _get_timer_values(self, closed_is_special=True):
  168. now = time.time()
  169. expiration = self._connection.get_timer()
  170. if expiration is None:
  171. expiration = now + 3600 # arbitrary "big" value
  172. interval = max(expiration - now, 0)
  173. if self._closed and closed_is_special:
  174. # lower sleep interval to avoid a race in the closing process
  175. # which can lead to higher latency closing due to sleeping when
  176. # we have events.
  177. interval = min(interval, 0.05)
  178. return (expiration, interval)
  179. def _handle_timer(self, expiration):
  180. now = time.time()
  181. if expiration <= now:
  182. self._connection.handle_timer(now)
  183. class AsyncQuicConnection(BaseQuicConnection):
  184. async def make_stream(self, timeout: float | None = None) -> Any:
  185. pass
  186. class BaseQuicManager:
  187. def __init__(
  188. self, conf, verify_mode, connection_factory, server_name=None, h3=False
  189. ):
  190. self._connections = {}
  191. self._connection_factory = connection_factory
  192. self._session_tickets = {}
  193. self._tokens = {}
  194. self._h3 = h3
  195. if conf is None:
  196. verify_path = None
  197. if isinstance(verify_mode, str):
  198. verify_path = verify_mode
  199. verify_mode = True
  200. if h3:
  201. alpn_protocols = ["h3"]
  202. else:
  203. alpn_protocols = ["doq", "doq-i03"]
  204. conf = aioquic.quic.configuration.QuicConfiguration(
  205. alpn_protocols=alpn_protocols,
  206. verify_mode=verify_mode,
  207. server_name=server_name,
  208. )
  209. if verify_path is not None:
  210. cafile, capath = dns._tls_util.convert_verify_to_cafile_and_capath(
  211. verify_path
  212. )
  213. conf.load_verify_locations(cafile=cafile, capath=capath)
  214. self._conf = conf
  215. def _connect(
  216. self,
  217. address,
  218. port=853,
  219. source=None,
  220. source_port=0,
  221. want_session_ticket=True,
  222. want_token=True,
  223. ):
  224. connection = self._connections.get((address, port))
  225. if connection is not None:
  226. return (connection, False)
  227. conf = self._conf
  228. if want_session_ticket:
  229. try:
  230. session_ticket = self._session_tickets.pop((address, port))
  231. # We found a session ticket, so make a configuration that uses it.
  232. conf = copy.copy(conf)
  233. conf.session_ticket = session_ticket
  234. except KeyError:
  235. # No session ticket.
  236. pass
  237. # Whether or not we found a session ticket, we want a handler to save
  238. # one.
  239. session_ticket_handler = functools.partial(
  240. self.save_session_ticket, address, port
  241. )
  242. else:
  243. session_ticket_handler = None
  244. if want_token:
  245. try:
  246. token = self._tokens.pop((address, port))
  247. # We found a token, so make a configuration that uses it.
  248. conf = copy.copy(conf)
  249. conf.token = token
  250. except KeyError:
  251. # No token
  252. pass
  253. # Whether or not we found a token, we want a handler to save # one.
  254. token_handler = functools.partial(self.save_token, address, port)
  255. else:
  256. token_handler = None
  257. qconn = aioquic.quic.connection.QuicConnection(
  258. configuration=conf,
  259. session_ticket_handler=session_ticket_handler,
  260. token_handler=token_handler,
  261. )
  262. lladdress = dns.inet.low_level_address_tuple((address, port))
  263. qconn.connect(lladdress, time.time())
  264. connection = self._connection_factory(
  265. qconn, address, port, source, source_port, self
  266. )
  267. self._connections[(address, port)] = connection
  268. return (connection, True)
  269. def closed(self, address, port):
  270. try:
  271. del self._connections[(address, port)]
  272. except KeyError:
  273. pass
  274. def is_h3(self):
  275. return self._h3
  276. def save_session_ticket(self, address, port, ticket):
  277. # We rely on dictionaries keys() being in insertion order here. We
  278. # can't just popitem() as that would be LIFO which is the opposite of
  279. # what we want.
  280. l = len(self._session_tickets)
  281. if l >= MAX_SESSION_TICKETS:
  282. keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE]
  283. for key in keys_to_delete:
  284. del self._session_tickets[key]
  285. self._session_tickets[(address, port)] = ticket
  286. def save_token(self, address, port, token):
  287. # We rely on dictionaries keys() being in insertion order here. We
  288. # can't just popitem() as that would be LIFO which is the opposite of
  289. # what we want.
  290. l = len(self._tokens)
  291. if l >= MAX_SESSION_TICKETS:
  292. keys_to_delete = list(self._tokens.keys())[0:SESSIONS_TO_DELETE]
  293. for key in keys_to_delete:
  294. del self._tokens[key]
  295. self._tokens[(address, port)] = token
  296. class AsyncQuicManager(BaseQuicManager):
  297. def connect(self, address, port=853, source=None, source_port=0):
  298. raise NotImplementedError