HPKE.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. import struct
  2. from enum import IntEnum
  3. from types import ModuleType
  4. from typing import Optional
  5. from .KDF import _HKDF_extract, _HKDF_expand
  6. from .DH import key_agreement, import_x25519_public_key, import_x448_public_key
  7. from Crypto.Util.strxor import strxor
  8. from Crypto.PublicKey import ECC
  9. from Crypto.PublicKey.ECC import EccKey
  10. from Crypto.Hash import SHA256, SHA384, SHA512
  11. from Crypto.Cipher import AES, ChaCha20_Poly1305
  12. class MODE(IntEnum):
  13. """HPKE modes"""
  14. BASE = 0x00
  15. PSK = 0x01
  16. AUTH = 0x02
  17. AUTH_PSK = 0x03
  18. class AEAD(IntEnum):
  19. """Authenticated Encryption with Associated Data (AEAD) Functions"""
  20. AES128_GCM = 0x0001
  21. AES256_GCM = 0x0002
  22. CHACHA20_POLY1305 = 0x0003
  23. class DeserializeError(ValueError):
  24. pass
  25. class MessageLimitReachedError(ValueError):
  26. pass
  27. # CURVE to (KEM ID, KDF ID, HASH)
  28. _Curve_Config = {
  29. "NIST P-256": (0x0010, 0x0001, SHA256),
  30. "NIST P-384": (0x0011, 0x0002, SHA384),
  31. "NIST P-521": (0x0012, 0x0003, SHA512),
  32. "Curve25519": (0x0020, 0x0001, SHA256),
  33. "Curve448": (0x0021, 0x0003, SHA512),
  34. }
  35. def _labeled_extract(salt: bytes,
  36. label: bytes,
  37. ikm: bytes,
  38. suite_id: bytes,
  39. hashmod: ModuleType):
  40. labeled_ikm = b"HPKE-v1" + suite_id + label + ikm
  41. return _HKDF_extract(salt, labeled_ikm, hashmod)
  42. def _labeled_expand(prk: bytes,
  43. label: bytes,
  44. info: bytes,
  45. L: int,
  46. suite_id: bytes,
  47. hashmod: ModuleType):
  48. labeled_info = struct.pack('>H', L) + b"HPKE-v1" + suite_id + \
  49. label + info
  50. return _HKDF_expand(prk, labeled_info, L, hashmod)
  51. def _extract_and_expand(dh: bytes,
  52. kem_context: bytes,
  53. suite_id: bytes,
  54. hashmod: ModuleType):
  55. Nsecret = hashmod.digest_size
  56. eae_prk = _labeled_extract(b"",
  57. b"eae_prk",
  58. dh,
  59. suite_id,
  60. hashmod)
  61. shared_secret = _labeled_expand(eae_prk,
  62. b"shared_secret",
  63. kem_context,
  64. Nsecret,
  65. suite_id,
  66. hashmod)
  67. return shared_secret
  68. class HPKE_Cipher:
  69. def __init__(self,
  70. receiver_key: EccKey,
  71. enc: Optional[bytes],
  72. sender_key: Optional[EccKey],
  73. psk_pair: tuple[bytes, bytes],
  74. info: bytes,
  75. aead_id: AEAD,
  76. mode: MODE):
  77. self.enc: bytes = b'' if enc is None else enc
  78. """The encapsulated session key."""
  79. self._verify_psk_inputs(mode, psk_pair)
  80. self._curve = receiver_key.curve
  81. self._aead_id = aead_id
  82. self._mode = mode
  83. try:
  84. self._kem_id, \
  85. self._kdf_id, \
  86. self._hashmod = _Curve_Config[self._curve]
  87. except KeyError as ke:
  88. raise ValueError("Curve {} is not supported by HPKE".format(self._curve)) from ke
  89. self._Nk = 16 if self._aead_id == AEAD.AES128_GCM else 32
  90. self._Nn = 12
  91. self._Nt = 16
  92. self._Nh = self._hashmod.digest_size
  93. self._encrypt = not receiver_key.has_private()
  94. if self._encrypt:
  95. # SetupBaseS (encryption)
  96. if enc is not None:
  97. raise ValueError("Parameter 'enc' cannot be an input when sealing")
  98. shared_secret, self.enc = self._encap(receiver_key,
  99. self._kem_id,
  100. self._hashmod,
  101. sender_key)
  102. else:
  103. # SetupBaseR (decryption)
  104. if enc is None:
  105. raise ValueError("Parameter 'enc' required when unsealing")
  106. shared_secret = self._decap(enc,
  107. receiver_key,
  108. self._kem_id,
  109. self._hashmod,
  110. sender_key)
  111. self._sequence = 0
  112. self._max_sequence = (1 << (8 * self._Nn)) - 1
  113. self._key, \
  114. self._base_nonce, \
  115. self._export_secret = self._key_schedule(shared_secret,
  116. info,
  117. *psk_pair)
  118. @staticmethod
  119. def _encap(receiver_key: EccKey,
  120. kem_id: int,
  121. hashmod: ModuleType,
  122. sender_key: Optional[EccKey] = None,
  123. eph_key: Optional[EccKey] = None):
  124. assert (sender_key is None) or sender_key.has_private()
  125. assert (eph_key is None) or eph_key.has_private()
  126. if eph_key is None:
  127. eph_key = ECC.generate(curve=receiver_key.curve)
  128. enc = eph_key.public_key().export_key(format='raw')
  129. pkRm = receiver_key.public_key().export_key(format='raw')
  130. kem_context = enc + pkRm
  131. extra_param = {}
  132. if sender_key:
  133. kem_context += sender_key.public_key().export_key(format='raw')
  134. extra_param = {'static_priv': sender_key}
  135. suite_id = b"KEM" + struct.pack('>H', kem_id)
  136. def kdf(dh,
  137. kem_context=kem_context,
  138. suite_id=suite_id,
  139. hashmod=hashmod):
  140. return _extract_and_expand(dh, kem_context, suite_id, hashmod)
  141. shared_secret = key_agreement(eph_priv=eph_key,
  142. static_pub=receiver_key,
  143. kdf=kdf,
  144. **extra_param)
  145. return shared_secret, enc
  146. @staticmethod
  147. def _decap(enc: bytes,
  148. receiver_key: EccKey,
  149. kem_id: int,
  150. hashmod: ModuleType,
  151. sender_key: Optional[EccKey] = None):
  152. assert receiver_key.has_private()
  153. try:
  154. if receiver_key.curve == 'Curve25519':
  155. pkE = import_x25519_public_key(enc)
  156. elif receiver_key.curve == 'Curve448':
  157. pkE = import_x448_public_key(enc)
  158. else:
  159. pkE = ECC.import_key(enc, curve_name=receiver_key.curve)
  160. except ValueError as ve:
  161. raise DeserializeError("'enc' is not a valid encapsulated HPKE key") from ve
  162. pkRm = receiver_key.public_key().export_key(format='raw')
  163. kem_context = enc + pkRm
  164. extra_param = {}
  165. if sender_key:
  166. kem_context += sender_key.public_key().export_key(format='raw')
  167. extra_param = {'static_pub': sender_key}
  168. suite_id = b"KEM" + struct.pack('>H', kem_id)
  169. def kdf(dh,
  170. kem_context=kem_context,
  171. suite_id=suite_id,
  172. hashmod=hashmod):
  173. return _extract_and_expand(dh, kem_context, suite_id, hashmod)
  174. shared_secret = key_agreement(eph_pub=pkE,
  175. static_priv=receiver_key,
  176. kdf=kdf,
  177. **extra_param)
  178. return shared_secret
  179. @staticmethod
  180. def _verify_psk_inputs(mode: MODE, psk_pair: tuple[bytes, bytes]):
  181. psk_id, psk = psk_pair
  182. if (psk == b'') ^ (psk_id == b''):
  183. raise ValueError("Inconsistent PSK inputs")
  184. if (psk == b''):
  185. if mode in (MODE.PSK, MODE.AUTH_PSK):
  186. raise ValueError(f"PSK is required with mode {mode.name}")
  187. else:
  188. if len(psk) < 32:
  189. raise ValueError("PSK must be at least 32 byte long")
  190. if mode in (MODE.BASE, MODE.AUTH):
  191. raise ValueError("PSK is not compatible with this mode")
  192. def _key_schedule(self,
  193. shared_secret: bytes,
  194. info: bytes,
  195. psk_id: bytes,
  196. psk: bytes):
  197. suite_id = b"HPKE" + struct.pack('>HHH',
  198. self._kem_id,
  199. self._kdf_id,
  200. self._aead_id)
  201. psk_id_hash = _labeled_extract(b'',
  202. b'psk_id_hash',
  203. psk_id,
  204. suite_id,
  205. self._hashmod)
  206. info_hash = _labeled_extract(b'',
  207. b'info_hash',
  208. info,
  209. suite_id,
  210. self._hashmod)
  211. key_schedule_context = self._mode.to_bytes(1, 'big') + psk_id_hash + info_hash
  212. secret = _labeled_extract(shared_secret,
  213. b'secret',
  214. psk,
  215. suite_id,
  216. self._hashmod)
  217. key = _labeled_expand(secret,
  218. b'key',
  219. key_schedule_context,
  220. self._Nk,
  221. suite_id,
  222. self._hashmod)
  223. base_nonce = _labeled_expand(secret,
  224. b'base_nonce',
  225. key_schedule_context,
  226. self._Nn,
  227. suite_id,
  228. self._hashmod)
  229. exporter_secret = _labeled_expand(secret,
  230. b'exp',
  231. key_schedule_context,
  232. self._Nh,
  233. suite_id,
  234. self._hashmod)
  235. return key, base_nonce, exporter_secret
  236. def _new_cipher(self):
  237. nonce = strxor(self._base_nonce, self._sequence.to_bytes(self._Nn, 'big'))
  238. if self._aead_id in (AEAD.AES128_GCM, AEAD.AES256_GCM):
  239. cipher = AES.new(self._key, AES.MODE_GCM, nonce=nonce, mac_len=self._Nt)
  240. elif self._aead_id == AEAD.CHACHA20_POLY1305:
  241. cipher = ChaCha20_Poly1305.new(key=self._key, nonce=nonce)
  242. else:
  243. raise ValueError(f"Unknown AEAD cipher ID {self._aead_id:#x}")
  244. if self._sequence >= self._max_sequence:
  245. raise MessageLimitReachedError()
  246. self._sequence += 1
  247. return cipher
  248. def seal(self, plaintext: bytes, auth_data: Optional[bytes] = None):
  249. """Encrypt and authenticate a message.
  250. This method can be invoked multiple times
  251. to seal an ordered sequence of messages.
  252. Arguments:
  253. plaintext: bytes
  254. The message to seal.
  255. auth_data: bytes
  256. Optional. Additional Authenticated data (AAD) that is not encrypted
  257. but that will be also covered by the authentication tag.
  258. Returns:
  259. The ciphertext concatenated with the authentication tag.
  260. """
  261. if not self._encrypt:
  262. raise ValueError("This cipher can only be used to seal")
  263. cipher = self._new_cipher()
  264. if auth_data:
  265. cipher.update(auth_data)
  266. ct, tag = cipher.encrypt_and_digest(plaintext)
  267. return ct + tag
  268. def unseal(self, ciphertext: bytes, auth_data: Optional[bytes] = None):
  269. """Decrypt a message and validate its authenticity.
  270. This method can be invoked multiple times
  271. to unseal an ordered sequence of messages.
  272. Arguments:
  273. cipertext: bytes
  274. The message to unseal.
  275. auth_data: bytes
  276. Optional. Additional Authenticated data (AAD) that
  277. was also covered by the authentication tag.
  278. Returns:
  279. The original plaintext.
  280. Raises: ValueError
  281. If the ciphertext (in combination with the AAD) is not valid.
  282. But if it is the first time you call ``unseal()`` this
  283. exception may also mean that any of the parameters or keys
  284. used to establish the session is wrong or that one is missing.
  285. """
  286. if self._encrypt:
  287. raise ValueError("This cipher can only be used to unseal")
  288. if len(ciphertext) < self._Nt:
  289. raise ValueError("Ciphertext is too small")
  290. cipher = self._new_cipher()
  291. if auth_data:
  292. cipher.update(auth_data)
  293. try:
  294. pt = cipher.decrypt_and_verify(ciphertext[:-self._Nt],
  295. ciphertext[-self._Nt:])
  296. except ValueError:
  297. if self._sequence == 1:
  298. raise ValueError("Incorrect HPKE keys/parameters or invalid message (wrong MAC tag)")
  299. raise ValueError("Invalid message (wrong MAC tag)")
  300. return pt
  301. def new(*, receiver_key: EccKey,
  302. aead_id: AEAD,
  303. enc: Optional[bytes] = None,
  304. sender_key: Optional[EccKey] = None,
  305. psk: Optional[tuple[bytes, bytes]] = None,
  306. info: Optional[bytes] = None) -> HPKE_Cipher:
  307. """Create an HPKE context which can be used:
  308. - by the sender to seal (encrypt) a message or
  309. - by the receiver to unseal (decrypt) it.
  310. As a minimum, the two parties agree on the receiver's asymmetric key
  311. (of which the sender will only know the public half).
  312. Additionally, for authentication purposes, they may also agree on:
  313. * the sender's asymmetric key (of which the receiver will only know the public half)
  314. * a shared secret (e.g., a symmetric key derived from a password)
  315. Args:
  316. receiver_key:
  317. The ECC key of the receiver.
  318. It must be on one of the following curves: ``NIST P-256``,
  319. ``NIST P-384``, ``NIST P-521``, ``X25519`` or ``X448``.
  320. If this is a **public** key, the HPKE context can only be used to
  321. **seal** (**encrypt**).
  322. If this is a **private** key, the HPKE context can only be used to
  323. **unseal** (**decrypt**).
  324. aead_id:
  325. The HPKE identifier of the symmetric cipher.
  326. The possible values are:
  327. * ``HPKE.AEAD.AES128_GCM``
  328. * ``HPKE.AEAD.AES256_GCM``
  329. * ``HPKE.AEAD.CHACHA20_POLY1305``
  330. enc:
  331. The encapsulated session key (i.e., the KEM shared secret).
  332. The receiver must always specify this parameter.
  333. The sender must always omit this parameter.
  334. sender_key:
  335. The ECC key of the sender.
  336. It must be on the same curve as the ``receiver_key``.
  337. If the ``receiver_key`` is a public key, ``sender_key`` must be a
  338. private key, and vice versa.
  339. psk:
  340. A Pre-Shared Key (PSK) as a 2-tuple of non-empty
  341. byte strings: the identifier and the actual secret value.
  342. Sender and receiver must use the same PSK (or none).
  343. The secret value must be at least 32 bytes long,
  344. but it must not be a low-entropy password
  345. (use a KDF like PBKDF2 or scrypt to derive a secret
  346. from a password).
  347. info:
  348. A non-secret parameter that contributes
  349. to the generation of all session keys.
  350. Sender and receive must use the same **info** parameter (or none).
  351. Returns:
  352. An object that can be used for
  353. sealing (if ``receiver_key`` is a public key) or
  354. unsealing (if ``receiver_key`` is a private key).
  355. In the latter case,
  356. correctness of all the keys and parameters will only
  357. be assessed with the first call to ``unseal()``.
  358. """
  359. if aead_id not in AEAD:
  360. raise ValueError(f"Unknown AEAD cipher ID {aead_id:#x}")
  361. curve = receiver_key.curve
  362. if curve not in ('NIST P-256', 'NIST P-384', 'NIST P-521',
  363. 'Curve25519', 'Curve448'):
  364. raise ValueError(f"Unsupported curve {curve}")
  365. if sender_key:
  366. count_private_keys = int(receiver_key.has_private()) + \
  367. int(sender_key.has_private())
  368. if count_private_keys != 1:
  369. raise ValueError("Exactly 1 private key required")
  370. if sender_key.curve != curve:
  371. raise ValueError("Sender key uses {} but recipient key {}".
  372. format(sender_key.curve, curve))
  373. mode = MODE.AUTH if psk is None else MODE.AUTH_PSK
  374. else:
  375. mode = MODE.BASE if psk is None else MODE.PSK
  376. if psk is None:
  377. psk = b'', b''
  378. if info is None:
  379. info = b''
  380. return HPKE_Cipher(receiver_key,
  381. enc,
  382. sender_key,
  383. psk,
  384. info,
  385. aead_id,
  386. mode)