__init__.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. from typing import Any, Dict, List, Tuple
  3. import dns._features
  4. import dns.asyncbackend
  5. if dns._features.have("doq"):
  6. from dns._asyncbackend import NullContext
  7. from dns.quic._asyncio import AsyncioQuicConnection as AsyncioQuicConnection
  8. from dns.quic._asyncio import AsyncioQuicManager
  9. from dns.quic._asyncio import AsyncioQuicStream as AsyncioQuicStream
  10. from dns.quic._common import AsyncQuicConnection # pyright: ignore
  11. from dns.quic._common import AsyncQuicManager as AsyncQuicManager
  12. from dns.quic._sync import SyncQuicConnection # pyright: ignore
  13. from dns.quic._sync import SyncQuicStream # pyright: ignore
  14. from dns.quic._sync import SyncQuicManager as SyncQuicManager
  15. have_quic = True
  16. def null_factory(
  17. *args, # pylint: disable=unused-argument
  18. **kwargs, # pylint: disable=unused-argument
  19. ):
  20. return NullContext(None)
  21. def _asyncio_manager_factory(
  22. context, *args, **kwargs # pylint: disable=unused-argument
  23. ):
  24. return AsyncioQuicManager(*args, **kwargs)
  25. # We have a context factory and a manager factory as for trio we need to have
  26. # a nursery.
  27. _async_factories: Dict[str, Tuple[Any, Any]] = {
  28. "asyncio": (null_factory, _asyncio_manager_factory)
  29. }
  30. if dns._features.have("trio"):
  31. import trio
  32. # pylint: disable=ungrouped-imports
  33. from dns.quic._trio import TrioQuicConnection as TrioQuicConnection
  34. from dns.quic._trio import TrioQuicManager
  35. from dns.quic._trio import TrioQuicStream as TrioQuicStream
  36. def _trio_context_factory():
  37. return trio.open_nursery()
  38. def _trio_manager_factory(context, *args, **kwargs):
  39. return TrioQuicManager(context, *args, **kwargs)
  40. _async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
  41. def factories_for_backend(backend=None):
  42. if backend is None:
  43. backend = dns.asyncbackend.get_default_backend()
  44. return _async_factories[backend.name()]
  45. else: # pragma: no cover
  46. have_quic = False
  47. class AsyncQuicStream: # type: ignore
  48. pass
  49. class AsyncQuicConnection: # type: ignore
  50. async def make_stream(self) -> Any:
  51. raise NotImplementedError
  52. class SyncQuicStream: # type: ignore
  53. pass
  54. class SyncQuicConnection: # type: ignore
  55. def make_stream(self) -> Any:
  56. raise NotImplementedError
  57. Headers = List[Tuple[bytes, bytes]]