test_HPKE.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. import os
  2. import json
  3. import unittest
  4. from binascii import unhexlify
  5. from Crypto.Protocol import HPKE
  6. from Crypto.Protocol.HPKE import DeserializeError
  7. from Crypto.PublicKey import ECC
  8. from Crypto.SelfTest.st_common import list_test_cases
  9. from Crypto.Protocol import DH
  10. from Crypto.Hash import SHA256, SHA384, SHA512
  11. class HPKE_Tests(unittest.TestCase):
  12. key1 = ECC.generate(curve='p256')
  13. key2 = ECC.generate(curve='p256')
  14. # name, size of enc
  15. curves = {
  16. 'p256': 65,
  17. 'p384': 97,
  18. 'p521': 133,
  19. 'curve25519': 32,
  20. 'curve448': 56,
  21. }
  22. def round_trip(self, curve, aead_id):
  23. key1 = ECC.generate(curve=curve)
  24. aead_id = aead_id
  25. encryptor = HPKE.new(receiver_key=key1.public_key(),
  26. aead_id=aead_id)
  27. self.assertEqual(len(encryptor.enc), self.curves[curve])
  28. # First message
  29. ct = encryptor.seal(b'ABC', auth_data=b'DEF')
  30. decryptor = HPKE.new(receiver_key=key1,
  31. aead_id=aead_id,
  32. enc=encryptor.enc)
  33. pt = decryptor.unseal(ct, auth_data=b'DEF')
  34. self.assertEqual(b'ABC', pt)
  35. # Second message
  36. ct2 = encryptor.seal(b'GHI')
  37. pt2 = decryptor.unseal(ct2)
  38. self.assertEqual(b'GHI', pt2)
  39. def test_round_trip(self):
  40. for curve in self.curves.keys():
  41. for aead_id in HPKE.AEAD:
  42. self.round_trip(curve, aead_id)
  43. def test_psk(self):
  44. aead_id = HPKE.AEAD.AES128_GCM
  45. HPKE.new(receiver_key=self.key1.public_key(),
  46. aead_id=aead_id,
  47. psk=(b'a', b'c' * 32))
  48. def test_info(self):
  49. aead_id = HPKE.AEAD.AES128_GCM
  50. HPKE.new(receiver_key=self.key1.public_key(),
  51. aead_id=aead_id,
  52. info=b'baba')
  53. def test_neg_unsupported_curve(self):
  54. key3 = ECC.generate(curve='p224')
  55. with self.assertRaises(ValueError) as cm:
  56. HPKE.new(receiver_key=key3.public_key(),
  57. aead_id=HPKE.AEAD.AES128_GCM)
  58. self.assertIn("Unsupported curve", str(cm.exception))
  59. def test_neg_too_many_private_keys(self):
  60. with self.assertRaises(ValueError) as cm:
  61. HPKE.new(receiver_key=self.key1,
  62. sender_key=self.key2,
  63. aead_id=HPKE.AEAD.AES128_GCM)
  64. self.assertIn("Exactly 1 private key", str(cm.exception))
  65. def test_neg_curve_mismatch(self):
  66. key3 = ECC.generate(curve='p384')
  67. with self.assertRaises(ValueError) as cm:
  68. HPKE.new(receiver_key=self.key1.public_key(),
  69. sender_key=key3,
  70. aead_id=HPKE.AEAD.AES128_GCM)
  71. self.assertIn("but recipient key", str(cm.exception))
  72. def test_neg_psk(self):
  73. with self.assertRaises(ValueError) as cm:
  74. HPKE.new(receiver_key=self.key1.public_key(),
  75. psk=(b'', b'G' * 32),
  76. aead_id=HPKE.AEAD.AES128_GCM)
  77. with self.assertRaises(ValueError) as cm:
  78. HPKE.new(receiver_key=self.key1.public_key(),
  79. psk=(b'JJJ', b''),
  80. aead_id=HPKE.AEAD.AES128_GCM)
  81. with self.assertRaises(ValueError) as cm:
  82. HPKE.new(receiver_key=self.key1.public_key(),
  83. psk=(b'JJJ', b'Y' * 31),
  84. aead_id=HPKE.AEAD.AES128_GCM)
  85. self.assertIn("at least 32", str(cm.exception))
  86. def test_neg_wrong_enc(self):
  87. wrong_enc = b'\xFF' + b'8' * 64
  88. with self.assertRaises(DeserializeError):
  89. HPKE.new(receiver_key=self.key1,
  90. aead_id=HPKE.AEAD.AES128_GCM,
  91. enc=wrong_enc)
  92. with self.assertRaises(ValueError) as cm:
  93. HPKE.new(receiver_key=self.key1.public_key(),
  94. enc=self.key1.public_key().export_key(format='raw'),
  95. aead_id=HPKE.AEAD.AES128_GCM)
  96. self.assertIn("'enc' cannot be an input", str(cm.exception))
  97. with self.assertRaises(ValueError) as cm:
  98. HPKE.new(receiver_key=self.key1,
  99. aead_id=HPKE.AEAD.AES128_GCM)
  100. self.assertIn("'enc' required", str(cm.exception))
  101. def test_neg_unseal_wrong_ct(self):
  102. decryptor = HPKE.new(receiver_key=self.key1,
  103. aead_id=HPKE.AEAD.CHACHA20_POLY1305,
  104. enc=self.key2.public_key().export_key(format='raw'))
  105. with self.assertRaises(ValueError):
  106. decryptor.unseal(b'XYZ' * 20)
  107. def test_neg_unseal_no_auth_data(self):
  108. aead_id = HPKE.AEAD.CHACHA20_POLY1305
  109. encryptor = HPKE.new(receiver_key=self.key1.public_key(),
  110. aead_id=aead_id)
  111. ct = encryptor.seal(b'ABC', auth_data=b'DEF')
  112. decryptor = HPKE.new(receiver_key=self.key1,
  113. aead_id=aead_id,
  114. enc=encryptor.enc)
  115. with self.assertRaises(ValueError):
  116. decryptor.unseal(ct)
  117. def test_x25519_mode_0(self):
  118. # RFC x9180, A.1.1.1, seq 0 and 1
  119. keyR_hex = "4612c550263fc8ad58375df3f557aac531d26850903e55a9f23f21d8534e8ac8"
  120. keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex))
  121. pt_hex = "4265617574792069732074727574682c20747275746820626561757479"
  122. pt = bytes.fromhex(pt_hex)
  123. ct0_hex = "f938558b5d72f1a23810b4be2ab4f84331acc02fc97babc53a52ae8218a355a96d8770ac83d07bea87e13c512a"
  124. ct0 = bytes.fromhex(ct0_hex)
  125. enc_hex = "37fda3567bdbd628e88668c3c8d7e97d1d1253b6d4ea6d44c150f741f1bf4431"
  126. enc = bytes.fromhex(enc_hex)
  127. aad0_hex = "436f756e742d30"
  128. aad0 = bytes.fromhex(aad0_hex)
  129. aad1_hex = "436f756e742d31"
  130. aad1 = bytes.fromhex(aad1_hex)
  131. info_hex = "4f6465206f6e2061204772656369616e2055726e"
  132. info = bytes.fromhex(info_hex)
  133. ct1_hex = "af2d7e9ac9ae7e270f46ba1f975be53c09f8d875bdc8535458c2494e8a6eab251c03d0c22a56b8ca42c2063b84"
  134. ct1 = bytes.fromhex(ct1_hex)
  135. aead_id = HPKE.AEAD.AES128_GCM
  136. decryptor = HPKE.new(receiver_key=keyR,
  137. aead_id=aead_id,
  138. info=info,
  139. enc=enc)
  140. pt_X0 = decryptor.unseal(ct0, aad0)
  141. self.assertEqual(pt_X0, pt)
  142. pt_X1 = decryptor.unseal(ct1, aad1)
  143. self.assertEqual(pt_X1, pt)
  144. def test_x25519_mode_1(self):
  145. # RFC x9180, A.1.2.1, seq 0 and 1
  146. keyR_hex = "c5eb01eb457fe6c6f57577c5413b931550a162c71a03ac8d196babbd4e5ce0fd"
  147. keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex))
  148. psk_id_hex = "456e6e796e20447572696e206172616e204d6f726961"
  149. psk_id = bytes.fromhex(psk_id_hex)
  150. psk_hex = "0247fd33b913760fa1fa51e1892d9f307fbe65eb171e8132c2af18555a738b82"
  151. psk = bytes.fromhex(psk_hex)
  152. pt_hex = "4265617574792069732074727574682c20747275746820626561757479"
  153. pt = bytes.fromhex(pt_hex)
  154. ct0_hex = "e52c6fed7f758d0cf7145689f21bc1be6ec9ea097fef4e959440012f4feb73fb611b946199e681f4cfc34db8ea"
  155. ct0 = bytes.fromhex(ct0_hex)
  156. enc_hex = "0ad0950d9fb9588e59690b74f1237ecdf1d775cd60be2eca57af5a4b0471c91b"
  157. enc = bytes.fromhex(enc_hex)
  158. aad0_hex = "436f756e742d30"
  159. aad0 = bytes.fromhex(aad0_hex)
  160. aad1_hex = "436f756e742d31"
  161. aad1 = bytes.fromhex(aad1_hex)
  162. info_hex = "4f6465206f6e2061204772656369616e2055726e"
  163. info = bytes.fromhex(info_hex)
  164. ct1_hex = "49f3b19b28a9ea9f43e8c71204c00d4a490ee7f61387b6719db765e948123b45b61633ef059ba22cd62437c8ba"
  165. ct1 = bytes.fromhex(ct1_hex)
  166. aead_id = HPKE.AEAD.AES128_GCM
  167. decryptor = HPKE.new(receiver_key=keyR,
  168. aead_id=aead_id,
  169. info=info,
  170. psk=(psk_id, psk),
  171. enc=enc)
  172. pt_X0 = decryptor.unseal(ct0, aad0)
  173. self.assertEqual(pt_X0, pt)
  174. pt_X1 = decryptor.unseal(ct1, aad1)
  175. self.assertEqual(pt_X1, pt)
  176. def test_x25519_mode_2(self):
  177. # RFC x9180, A.1.3.1, seq 0 and 1
  178. keyR_hex = "fdea67cf831f1ca98d8e27b1f6abeb5b7745e9d35348b80fa407ff6958f9137e"
  179. keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex))
  180. keyS_hex = "dc4a146313cce60a278a5323d321f051c5707e9c45ba21a3479fecdf76fc69dd"
  181. keyS = DH.import_x25519_private_key(bytes.fromhex(keyS_hex))
  182. pt_hex = "4265617574792069732074727574682c20747275746820626561757479"
  183. pt = bytes.fromhex(pt_hex)
  184. ct0_hex = "5fd92cc9d46dbf8943e72a07e42f363ed5f721212cd90bcfd072bfd9f44e06b80fd17824947496e21b680c141b"
  185. ct0 = bytes.fromhex(ct0_hex)
  186. enc_hex = "23fb952571a14a25e3d678140cd0e5eb47a0961bb18afcf85896e5453c312e76"
  187. enc = bytes.fromhex(enc_hex)
  188. aad0_hex = "436f756e742d30"
  189. aad0 = bytes.fromhex(aad0_hex)
  190. aad1_hex = "436f756e742d31"
  191. aad1 = bytes.fromhex(aad1_hex)
  192. info_hex = "4f6465206f6e2061204772656369616e2055726e"
  193. info = bytes.fromhex(info_hex)
  194. ct1_hex = "d3736bb256c19bfa93d79e8f80b7971262cb7c887e35c26370cfed62254369a1b52e3d505b79dd699f002bc8ed"
  195. ct1 = bytes.fromhex(ct1_hex)
  196. aead_id = HPKE.AEAD.AES128_GCM
  197. decryptor = HPKE.new(receiver_key=keyR,
  198. sender_key=keyS.public_key(),
  199. aead_id=aead_id,
  200. info=info,
  201. enc=enc)
  202. pt_X0 = decryptor.unseal(ct0, aad0)
  203. self.assertEqual(pt_X0, pt)
  204. pt_X1 = decryptor.unseal(ct1, aad1)
  205. self.assertEqual(pt_X1, pt)
  206. def test_x25519_mode_3(self):
  207. # RFC x9180, A.1.4.1, seq 0 and 1
  208. keyR_hex = "cb29a95649dc5656c2d054c1aa0d3df0493155e9d5da6d7e344ed8b6a64a9423"
  209. keyR = DH.import_x25519_private_key(bytes.fromhex(keyR_hex))
  210. keyS_hex = "fc1c87d2f3832adb178b431fce2ac77c7ca2fd680f3406c77b5ecdf818b119f4"
  211. keyS = DH.import_x25519_private_key(bytes.fromhex(keyS_hex))
  212. psk_id_hex = "456e6e796e20447572696e206172616e204d6f726961"
  213. psk_id = bytes.fromhex(psk_id_hex)
  214. psk_hex = "0247fd33b913760fa1fa51e1892d9f307fbe65eb171e8132c2af18555a738b82"
  215. psk = bytes.fromhex(psk_hex)
  216. pt_hex = "4265617574792069732074727574682c20747275746820626561757479"
  217. pt = bytes.fromhex(pt_hex)
  218. ct0_hex = "a84c64df1e11d8fd11450039d4fe64ff0c8a99fca0bd72c2d4c3e0400bc14a40f27e45e141a24001697737533e"
  219. ct0 = bytes.fromhex(ct0_hex)
  220. enc_hex = "820818d3c23993492cc5623ab437a48a0a7ca3e9639c140fe1e33811eb844b7c"
  221. enc = bytes.fromhex(enc_hex)
  222. aad0_hex = "436f756e742d30"
  223. aad0 = bytes.fromhex(aad0_hex)
  224. aad1_hex = "436f756e742d31"
  225. aad1 = bytes.fromhex(aad1_hex)
  226. info_hex = "4f6465206f6e2061204772656369616e2055726e"
  227. info = bytes.fromhex(info_hex)
  228. ct1_hex = "4d19303b848f424fc3c3beca249b2c6de0a34083b8e909b6aa4c3688505c05ffe0c8f57a0a4c5ab9da127435d9"
  229. ct1 = bytes.fromhex(ct1_hex)
  230. aead_id = HPKE.AEAD.AES128_GCM
  231. decryptor = HPKE.new(receiver_key=keyR,
  232. sender_key=keyS.public_key(),
  233. aead_id=aead_id,
  234. psk=(psk_id, psk),
  235. info=info,
  236. enc=enc)
  237. pt_X0 = decryptor.unseal(ct0, aad0)
  238. self.assertEqual(pt_X0, pt)
  239. pt_X1 = decryptor.unseal(ct1, aad1)
  240. self.assertEqual(pt_X1, pt)
  241. class HPKE_TestVectors(unittest.TestCase):
  242. def setUp(self):
  243. self.vectors = []
  244. try:
  245. import pycryptodome_test_vectors # type: ignore
  246. init_dir = os.path.dirname(pycryptodome_test_vectors.__file__)
  247. full_file_name = os.path.join(init_dir, "Protocol", "wycheproof", "HPKE-test-vectors.json")
  248. with open(full_file_name, "r") as f:
  249. self.vectors = json.load(f)
  250. except (FileNotFoundError, ImportError):
  251. print("\nWarning: skipping extended tests for HPKE (install pycryptodome-test-vectors)")
  252. def import_private_key(self, key_hex, kem_id):
  253. key_bin = unhexlify(key_hex)
  254. if kem_id == 0x0010:
  255. return ECC.construct(curve='p256', d=int.from_bytes(key_bin,
  256. byteorder="big"))
  257. elif kem_id == 0x0011:
  258. return ECC.construct(curve='p384', d=int.from_bytes(key_bin,
  259. byteorder="big"))
  260. elif kem_id == 0x0012:
  261. return ECC.construct(curve='p521', d=int.from_bytes(key_bin,
  262. byteorder="big"))
  263. elif kem_id == 0x0020:
  264. return DH.import_x25519_private_key(key_bin)
  265. elif kem_id == 0x0021:
  266. return DH.import_x448_private_key(key_bin)
  267. def test_hpke_encap(self):
  268. """Test HPKE encapsulation using test vectors."""
  269. if not self.vectors:
  270. self.skipTest("No test vectors available")
  271. for idx, vector in enumerate(self.vectors):
  272. kem_id = vector["kem_id"]
  273. kdf_id = vector["kdf_id"]
  274. aead_id = vector["aead_id"]
  275. # No export-only pseudo-cipher
  276. if aead_id == 0xffff:
  277. continue
  278. # We support only one KDF per curve
  279. supported_combi = {
  280. (0x10, 0x1): SHA256,
  281. (0x11, 0x2): SHA384,
  282. (0x12, 0x3): SHA512,
  283. (0x20, 0x1): SHA256,
  284. (0x21, 0x3): SHA512,
  285. }
  286. hashmod = supported_combi.get((kem_id, kdf_id))
  287. if hashmod is None:
  288. continue
  289. with self.subTest(idx=idx, kem_id=kem_id, aead_id=aead_id):
  290. receiver_pub = self.import_private_key(vector["skRm"],
  291. kem_id).public_key()
  292. sender_priv = None
  293. if "skSm" in vector:
  294. sender_priv = self.import_private_key(vector["skSm"],
  295. kem_id)
  296. encap_key = self.import_private_key(vector["skEm"], kem_id)
  297. shared_secret, enc = HPKE.HPKE_Cipher._encap(receiver_pub,
  298. kem_id,
  299. hashmod,
  300. sender_priv,
  301. encap_key)
  302. self.assertEqual(enc.hex(), vector["enc"])
  303. self.assertEqual(shared_secret,
  304. unhexlify(vector["shared_secret"]))
  305. print(".", end="", flush=True)
  306. def test_hpke_unseal(self):
  307. """Test HPKE encryption and decryption using test vectors."""
  308. if not self.vectors:
  309. self.skipTest("No test vectors available")
  310. for idx, vector in enumerate(self.vectors):
  311. kem_id = vector["kem_id"]
  312. kdf_id = vector["kdf_id"]
  313. aead_id = vector["aead_id"]
  314. # No export-only pseudo-cipher
  315. if aead_id == 0xffff:
  316. continue
  317. # We support only one KDF per curve
  318. supported_combi = (
  319. (0x10, 0x1),
  320. (0x11, 0x2),
  321. (0x12, 0x3),
  322. (0x20, 0x1),
  323. (0x21, 0x3),
  324. )
  325. if (kem_id, kdf_id) not in supported_combi:
  326. continue
  327. with self.subTest(idx=idx, kem_id=kem_id, aead_id=aead_id):
  328. receiver_priv = self.import_private_key(vector["skRm"],
  329. kem_id)
  330. sender_pub = None
  331. if "skSm" in vector:
  332. sender_priv = self.import_private_key(vector["skSm"],
  333. kem_id)
  334. sender_pub = sender_priv.public_key()
  335. encap_key = unhexlify(vector["enc"])
  336. psk = None
  337. if "psk_id" in vector:
  338. psk = unhexlify(vector["psk_id"]), unhexlify(vector["psk"])
  339. receiver_hpke = HPKE.new(receiver_key=receiver_priv,
  340. aead_id=HPKE.AEAD(aead_id),
  341. enc=encap_key,
  342. sender_key=sender_pub,
  343. psk=psk,
  344. info=unhexlify(vector["info"]))
  345. for encryption in vector['encryptions']:
  346. plaintext = unhexlify(encryption["pt"])
  347. ciphertext = unhexlify(encryption["ct"])
  348. aad = unhexlify(encryption["aad"])
  349. # Decrypt (unseal)
  350. decrypted = receiver_hpke.unseal(ciphertext, aad)
  351. self.assertEqual(decrypted, plaintext, "Decryption failed")
  352. print(".", end="", flush=True)
  353. def get_tests(config={}):
  354. tests = []
  355. tests += list_test_cases(HPKE_Tests)
  356. if config.get('slow_tests'):
  357. tests += list_test_cases(HPKE_TestVectors)
  358. return tests
  359. if __name__ == '__main__':
  360. def suite():
  361. return unittest.TestSuite(get_tests())
  362. unittest.main(defaultTest='suite')