rsa.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import math
  2. import struct
  3. from cryptography.hazmat.backends import default_backend
  4. from cryptography.hazmat.primitives import hashes
  5. from cryptography.hazmat.primitives.asymmetric import padding, rsa
  6. from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
  7. from dns.dnssectypes import Algorithm
  8. from dns.rdtypes.ANY.DNSKEY import DNSKEY
  9. class PublicRSA(CryptographyPublicKey):
  10. key: rsa.RSAPublicKey
  11. key_cls = rsa.RSAPublicKey
  12. algorithm: Algorithm
  13. chosen_hash: hashes.HashAlgorithm
  14. def verify(self, signature: bytes, data: bytes) -> None:
  15. self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash)
  16. def encode_key_bytes(self) -> bytes:
  17. """Encode a public key per RFC 3110, section 2."""
  18. pn = self.key.public_numbers()
  19. _exp_len = math.ceil(int.bit_length(pn.e) / 8)
  20. exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
  21. if _exp_len > 255:
  22. exp_header = b"\0" + struct.pack("!H", _exp_len)
  23. else:
  24. exp_header = struct.pack("!B", _exp_len)
  25. if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
  26. raise ValueError("unsupported RSA key length")
  27. return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
  28. @classmethod
  29. def from_dnskey(cls, key: DNSKEY) -> "PublicRSA":
  30. cls._ensure_algorithm_key_combination(key)
  31. keyptr = key.key
  32. (bytes_,) = struct.unpack("!B", keyptr[0:1])
  33. keyptr = keyptr[1:]
  34. if bytes_ == 0:
  35. (bytes_,) = struct.unpack("!H", keyptr[0:2])
  36. keyptr = keyptr[2:]
  37. rsa_e = keyptr[0:bytes_]
  38. rsa_n = keyptr[bytes_:]
  39. return cls(
  40. key=rsa.RSAPublicNumbers(
  41. int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big")
  42. ).public_key(default_backend())
  43. )
  44. class PrivateRSA(CryptographyPrivateKey):
  45. key: rsa.RSAPrivateKey
  46. key_cls = rsa.RSAPrivateKey
  47. public_cls = PublicRSA
  48. default_public_exponent = 65537
  49. def sign(
  50. self,
  51. data: bytes,
  52. verify: bool = False,
  53. deterministic: bool = True,
  54. ) -> bytes:
  55. """Sign using a private key per RFC 3110, section 3."""
  56. signature = self.key.sign(
  57. data, padding.PKCS1v15(), self.public_cls.chosen_hash # pyright: ignore
  58. )
  59. if verify:
  60. self.public_key().verify(signature, data)
  61. return signature
  62. @classmethod
  63. def generate(cls, key_size: int) -> "PrivateRSA":
  64. return cls(
  65. key=rsa.generate_private_key(
  66. public_exponent=cls.default_public_exponent,
  67. key_size=key_size,
  68. backend=default_backend(),
  69. )
  70. )
  71. class PublicRSAMD5(PublicRSA):
  72. algorithm = Algorithm.RSAMD5
  73. chosen_hash = hashes.MD5()
  74. class PrivateRSAMD5(PrivateRSA):
  75. public_cls = PublicRSAMD5
  76. class PublicRSASHA1(PublicRSA):
  77. algorithm = Algorithm.RSASHA1
  78. chosen_hash = hashes.SHA1()
  79. class PrivateRSASHA1(PrivateRSA):
  80. public_cls = PublicRSASHA1
  81. class PublicRSASHA1NSEC3SHA1(PublicRSA):
  82. algorithm = Algorithm.RSASHA1NSEC3SHA1
  83. chosen_hash = hashes.SHA1()
  84. class PrivateRSASHA1NSEC3SHA1(PrivateRSA):
  85. public_cls = PublicRSASHA1NSEC3SHA1
  86. class PublicRSASHA256(PublicRSA):
  87. algorithm = Algorithm.RSASHA256
  88. chosen_hash = hashes.SHA256()
  89. class PrivateRSASHA256(PrivateRSA):
  90. public_cls = PublicRSASHA256
  91. class PublicRSASHA512(PublicRSA):
  92. algorithm = Algorithm.RSASHA512
  93. chosen_hash = hashes.SHA512()
  94. class PrivateRSASHA512(PrivateRSA):
  95. public_cls = PublicRSASHA512