nameserver.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. from urllib.parse import urlparse
  2. import dns.asyncbackend
  3. import dns.asyncquery
  4. import dns.message
  5. import dns.query
  6. class Nameserver:
  7. def __init__(self):
  8. pass
  9. def __str__(self):
  10. raise NotImplementedError
  11. def kind(self) -> str:
  12. raise NotImplementedError
  13. def is_always_max_size(self) -> bool:
  14. raise NotImplementedError
  15. def answer_nameserver(self) -> str:
  16. raise NotImplementedError
  17. def answer_port(self) -> int:
  18. raise NotImplementedError
  19. def query(
  20. self,
  21. request: dns.message.QueryMessage,
  22. timeout: float,
  23. source: str | None,
  24. source_port: int,
  25. max_size: bool,
  26. one_rr_per_rrset: bool = False,
  27. ignore_trailing: bool = False,
  28. ) -> dns.message.Message:
  29. raise NotImplementedError
  30. async def async_query(
  31. self,
  32. request: dns.message.QueryMessage,
  33. timeout: float,
  34. source: str | None,
  35. source_port: int,
  36. max_size: bool,
  37. backend: dns.asyncbackend.Backend,
  38. one_rr_per_rrset: bool = False,
  39. ignore_trailing: bool = False,
  40. ) -> dns.message.Message:
  41. raise NotImplementedError
  42. class AddressAndPortNameserver(Nameserver):
  43. def __init__(self, address: str, port: int):
  44. super().__init__()
  45. self.address = address
  46. self.port = port
  47. def kind(self) -> str:
  48. raise NotImplementedError
  49. def is_always_max_size(self) -> bool:
  50. return False
  51. def __str__(self):
  52. ns_kind = self.kind()
  53. return f"{ns_kind}:{self.address}@{self.port}"
  54. def answer_nameserver(self) -> str:
  55. return self.address
  56. def answer_port(self) -> int:
  57. return self.port
  58. class Do53Nameserver(AddressAndPortNameserver):
  59. def __init__(self, address: str, port: int = 53):
  60. super().__init__(address, port)
  61. def kind(self):
  62. return "Do53"
  63. def query(
  64. self,
  65. request: dns.message.QueryMessage,
  66. timeout: float,
  67. source: str | None,
  68. source_port: int,
  69. max_size: bool,
  70. one_rr_per_rrset: bool = False,
  71. ignore_trailing: bool = False,
  72. ) -> dns.message.Message:
  73. if max_size:
  74. response = dns.query.tcp(
  75. request,
  76. self.address,
  77. timeout=timeout,
  78. port=self.port,
  79. source=source,
  80. source_port=source_port,
  81. one_rr_per_rrset=one_rr_per_rrset,
  82. ignore_trailing=ignore_trailing,
  83. )
  84. else:
  85. response = dns.query.udp(
  86. request,
  87. self.address,
  88. timeout=timeout,
  89. port=self.port,
  90. source=source,
  91. source_port=source_port,
  92. raise_on_truncation=True,
  93. one_rr_per_rrset=one_rr_per_rrset,
  94. ignore_trailing=ignore_trailing,
  95. ignore_errors=True,
  96. ignore_unexpected=True,
  97. )
  98. return response
  99. async def async_query(
  100. self,
  101. request: dns.message.QueryMessage,
  102. timeout: float,
  103. source: str | None,
  104. source_port: int,
  105. max_size: bool,
  106. backend: dns.asyncbackend.Backend,
  107. one_rr_per_rrset: bool = False,
  108. ignore_trailing: bool = False,
  109. ) -> dns.message.Message:
  110. if max_size:
  111. response = await dns.asyncquery.tcp(
  112. request,
  113. self.address,
  114. timeout=timeout,
  115. port=self.port,
  116. source=source,
  117. source_port=source_port,
  118. backend=backend,
  119. one_rr_per_rrset=one_rr_per_rrset,
  120. ignore_trailing=ignore_trailing,
  121. )
  122. else:
  123. response = await dns.asyncquery.udp(
  124. request,
  125. self.address,
  126. timeout=timeout,
  127. port=self.port,
  128. source=source,
  129. source_port=source_port,
  130. raise_on_truncation=True,
  131. backend=backend,
  132. one_rr_per_rrset=one_rr_per_rrset,
  133. ignore_trailing=ignore_trailing,
  134. ignore_errors=True,
  135. ignore_unexpected=True,
  136. )
  137. return response
  138. class DoHNameserver(Nameserver):
  139. def __init__(
  140. self,
  141. url: str,
  142. bootstrap_address: str | None = None,
  143. verify: bool | str = True,
  144. want_get: bool = False,
  145. http_version: dns.query.HTTPVersion = dns.query.HTTPVersion.DEFAULT,
  146. ):
  147. super().__init__()
  148. self.url = url
  149. self.bootstrap_address = bootstrap_address
  150. self.verify = verify
  151. self.want_get = want_get
  152. self.http_version = http_version
  153. def kind(self):
  154. return "DoH"
  155. def is_always_max_size(self) -> bool:
  156. return True
  157. def __str__(self):
  158. return self.url
  159. def answer_nameserver(self) -> str:
  160. return self.url
  161. def answer_port(self) -> int:
  162. port = urlparse(self.url).port
  163. if port is None:
  164. port = 443
  165. return port
  166. def query(
  167. self,
  168. request: dns.message.QueryMessage,
  169. timeout: float,
  170. source: str | None,
  171. source_port: int,
  172. max_size: bool = False,
  173. one_rr_per_rrset: bool = False,
  174. ignore_trailing: bool = False,
  175. ) -> dns.message.Message:
  176. return dns.query.https(
  177. request,
  178. self.url,
  179. timeout=timeout,
  180. source=source,
  181. source_port=source_port,
  182. bootstrap_address=self.bootstrap_address,
  183. one_rr_per_rrset=one_rr_per_rrset,
  184. ignore_trailing=ignore_trailing,
  185. verify=self.verify,
  186. post=(not self.want_get),
  187. http_version=self.http_version,
  188. )
  189. async def async_query(
  190. self,
  191. request: dns.message.QueryMessage,
  192. timeout: float,
  193. source: str | None,
  194. source_port: int,
  195. max_size: bool,
  196. backend: dns.asyncbackend.Backend,
  197. one_rr_per_rrset: bool = False,
  198. ignore_trailing: bool = False,
  199. ) -> dns.message.Message:
  200. return await dns.asyncquery.https(
  201. request,
  202. self.url,
  203. timeout=timeout,
  204. source=source,
  205. source_port=source_port,
  206. bootstrap_address=self.bootstrap_address,
  207. one_rr_per_rrset=one_rr_per_rrset,
  208. ignore_trailing=ignore_trailing,
  209. verify=self.verify,
  210. post=(not self.want_get),
  211. http_version=self.http_version,
  212. )
  213. class DoTNameserver(AddressAndPortNameserver):
  214. def __init__(
  215. self,
  216. address: str,
  217. port: int = 853,
  218. hostname: str | None = None,
  219. verify: bool | str = True,
  220. ):
  221. super().__init__(address, port)
  222. self.hostname = hostname
  223. self.verify = verify
  224. def kind(self):
  225. return "DoT"
  226. def query(
  227. self,
  228. request: dns.message.QueryMessage,
  229. timeout: float,
  230. source: str | None,
  231. source_port: int,
  232. max_size: bool = False,
  233. one_rr_per_rrset: bool = False,
  234. ignore_trailing: bool = False,
  235. ) -> dns.message.Message:
  236. return dns.query.tls(
  237. request,
  238. self.address,
  239. port=self.port,
  240. timeout=timeout,
  241. one_rr_per_rrset=one_rr_per_rrset,
  242. ignore_trailing=ignore_trailing,
  243. server_hostname=self.hostname,
  244. verify=self.verify,
  245. )
  246. async def async_query(
  247. self,
  248. request: dns.message.QueryMessage,
  249. timeout: float,
  250. source: str | None,
  251. source_port: int,
  252. max_size: bool,
  253. backend: dns.asyncbackend.Backend,
  254. one_rr_per_rrset: bool = False,
  255. ignore_trailing: bool = False,
  256. ) -> dns.message.Message:
  257. return await dns.asyncquery.tls(
  258. request,
  259. self.address,
  260. port=self.port,
  261. timeout=timeout,
  262. one_rr_per_rrset=one_rr_per_rrset,
  263. ignore_trailing=ignore_trailing,
  264. server_hostname=self.hostname,
  265. verify=self.verify,
  266. )
  267. class DoQNameserver(AddressAndPortNameserver):
  268. def __init__(
  269. self,
  270. address: str,
  271. port: int = 853,
  272. verify: bool | str = True,
  273. server_hostname: str | None = None,
  274. ):
  275. super().__init__(address, port)
  276. self.verify = verify
  277. self.server_hostname = server_hostname
  278. def kind(self):
  279. return "DoQ"
  280. def query(
  281. self,
  282. request: dns.message.QueryMessage,
  283. timeout: float,
  284. source: str | None,
  285. source_port: int,
  286. max_size: bool = False,
  287. one_rr_per_rrset: bool = False,
  288. ignore_trailing: bool = False,
  289. ) -> dns.message.Message:
  290. return dns.query.quic(
  291. request,
  292. self.address,
  293. port=self.port,
  294. timeout=timeout,
  295. one_rr_per_rrset=one_rr_per_rrset,
  296. ignore_trailing=ignore_trailing,
  297. verify=self.verify,
  298. server_hostname=self.server_hostname,
  299. )
  300. async def async_query(
  301. self,
  302. request: dns.message.QueryMessage,
  303. timeout: float,
  304. source: str | None,
  305. source_port: int,
  306. max_size: bool,
  307. backend: dns.asyncbackend.Backend,
  308. one_rr_per_rrset: bool = False,
  309. ignore_trailing: bool = False,
  310. ) -> dns.message.Message:
  311. return await dns.asyncquery.quic(
  312. request,
  313. self.address,
  314. port=self.port,
  315. timeout=timeout,
  316. one_rr_per_rrset=one_rr_per_rrset,
  317. ignore_trailing=ignore_trailing,
  318. verify=self.verify,
  319. server_hostname=self.server_hostname,
  320. )