| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483 |
- import struct
- from enum import IntEnum
- from types import ModuleType
- from typing import Optional
- from .KDF import _HKDF_extract, _HKDF_expand
- from .DH import key_agreement, import_x25519_public_key, import_x448_public_key
- from Crypto.Util.strxor import strxor
- from Crypto.PublicKey import ECC
- from Crypto.PublicKey.ECC import EccKey
- from Crypto.Hash import SHA256, SHA384, SHA512
- from Crypto.Cipher import AES, ChaCha20_Poly1305
- class MODE(IntEnum):
- """HPKE modes"""
- BASE = 0x00
- PSK = 0x01
- AUTH = 0x02
- AUTH_PSK = 0x03
- class AEAD(IntEnum):
- """Authenticated Encryption with Associated Data (AEAD) Functions"""
- AES128_GCM = 0x0001
- AES256_GCM = 0x0002
- CHACHA20_POLY1305 = 0x0003
- class DeserializeError(ValueError):
- pass
- class MessageLimitReachedError(ValueError):
- pass
- # CURVE to (KEM ID, KDF ID, HASH)
- _Curve_Config = {
- "NIST P-256": (0x0010, 0x0001, SHA256),
- "NIST P-384": (0x0011, 0x0002, SHA384),
- "NIST P-521": (0x0012, 0x0003, SHA512),
- "Curve25519": (0x0020, 0x0001, SHA256),
- "Curve448": (0x0021, 0x0003, SHA512),
- }
- def _labeled_extract(salt: bytes,
- label: bytes,
- ikm: bytes,
- suite_id: bytes,
- hashmod: ModuleType):
- labeled_ikm = b"HPKE-v1" + suite_id + label + ikm
- return _HKDF_extract(salt, labeled_ikm, hashmod)
- def _labeled_expand(prk: bytes,
- label: bytes,
- info: bytes,
- L: int,
- suite_id: bytes,
- hashmod: ModuleType):
- labeled_info = struct.pack('>H', L) + b"HPKE-v1" + suite_id + \
- label + info
- return _HKDF_expand(prk, labeled_info, L, hashmod)
- def _extract_and_expand(dh: bytes,
- kem_context: bytes,
- suite_id: bytes,
- hashmod: ModuleType):
- Nsecret = hashmod.digest_size
- eae_prk = _labeled_extract(b"",
- b"eae_prk",
- dh,
- suite_id,
- hashmod)
- shared_secret = _labeled_expand(eae_prk,
- b"shared_secret",
- kem_context,
- Nsecret,
- suite_id,
- hashmod)
- return shared_secret
- class HPKE_Cipher:
- def __init__(self,
- receiver_key: EccKey,
- enc: Optional[bytes],
- sender_key: Optional[EccKey],
- psk_pair: tuple[bytes, bytes],
- info: bytes,
- aead_id: AEAD,
- mode: MODE):
- self.enc: bytes = b'' if enc is None else enc
- """The encapsulated session key."""
- self._verify_psk_inputs(mode, psk_pair)
- self._curve = receiver_key.curve
- self._aead_id = aead_id
- self._mode = mode
- try:
- self._kem_id, \
- self._kdf_id, \
- self._hashmod = _Curve_Config[self._curve]
- except KeyError as ke:
- raise ValueError("Curve {} is not supported by HPKE".format(self._curve)) from ke
- self._Nk = 16 if self._aead_id == AEAD.AES128_GCM else 32
- self._Nn = 12
- self._Nt = 16
- self._Nh = self._hashmod.digest_size
- self._encrypt = not receiver_key.has_private()
- if self._encrypt:
- # SetupBaseS (encryption)
- if enc is not None:
- raise ValueError("Parameter 'enc' cannot be an input when sealing")
- shared_secret, self.enc = self._encap(receiver_key,
- self._kem_id,
- self._hashmod,
- sender_key)
- else:
- # SetupBaseR (decryption)
- if enc is None:
- raise ValueError("Parameter 'enc' required when unsealing")
- shared_secret = self._decap(enc,
- receiver_key,
- self._kem_id,
- self._hashmod,
- sender_key)
- self._sequence = 0
- self._max_sequence = (1 << (8 * self._Nn)) - 1
- self._key, \
- self._base_nonce, \
- self._export_secret = self._key_schedule(shared_secret,
- info,
- *psk_pair)
- @staticmethod
- def _encap(receiver_key: EccKey,
- kem_id: int,
- hashmod: ModuleType,
- sender_key: Optional[EccKey] = None,
- eph_key: Optional[EccKey] = None):
- assert (sender_key is None) or sender_key.has_private()
- assert (eph_key is None) or eph_key.has_private()
- if eph_key is None:
- eph_key = ECC.generate(curve=receiver_key.curve)
- enc = eph_key.public_key().export_key(format='raw')
- pkRm = receiver_key.public_key().export_key(format='raw')
- kem_context = enc + pkRm
- extra_param = {}
- if sender_key:
- kem_context += sender_key.public_key().export_key(format='raw')
- extra_param = {'static_priv': sender_key}
- suite_id = b"KEM" + struct.pack('>H', kem_id)
- def kdf(dh,
- kem_context=kem_context,
- suite_id=suite_id,
- hashmod=hashmod):
- return _extract_and_expand(dh, kem_context, suite_id, hashmod)
- shared_secret = key_agreement(eph_priv=eph_key,
- static_pub=receiver_key,
- kdf=kdf,
- **extra_param)
- return shared_secret, enc
- @staticmethod
- def _decap(enc: bytes,
- receiver_key: EccKey,
- kem_id: int,
- hashmod: ModuleType,
- sender_key: Optional[EccKey] = None):
- assert receiver_key.has_private()
- try:
- if receiver_key.curve == 'Curve25519':
- pkE = import_x25519_public_key(enc)
- elif receiver_key.curve == 'Curve448':
- pkE = import_x448_public_key(enc)
- else:
- pkE = ECC.import_key(enc, curve_name=receiver_key.curve)
- except ValueError as ve:
- raise DeserializeError("'enc' is not a valid encapsulated HPKE key") from ve
- pkRm = receiver_key.public_key().export_key(format='raw')
- kem_context = enc + pkRm
- extra_param = {}
- if sender_key:
- kem_context += sender_key.public_key().export_key(format='raw')
- extra_param = {'static_pub': sender_key}
- suite_id = b"KEM" + struct.pack('>H', kem_id)
- def kdf(dh,
- kem_context=kem_context,
- suite_id=suite_id,
- hashmod=hashmod):
- return _extract_and_expand(dh, kem_context, suite_id, hashmod)
- shared_secret = key_agreement(eph_pub=pkE,
- static_priv=receiver_key,
- kdf=kdf,
- **extra_param)
- return shared_secret
- @staticmethod
- def _verify_psk_inputs(mode: MODE, psk_pair: tuple[bytes, bytes]):
- psk_id, psk = psk_pair
- if (psk == b'') ^ (psk_id == b''):
- raise ValueError("Inconsistent PSK inputs")
- if (psk == b''):
- if mode in (MODE.PSK, MODE.AUTH_PSK):
- raise ValueError(f"PSK is required with mode {mode.name}")
- else:
- if len(psk) < 32:
- raise ValueError("PSK must be at least 32 byte long")
- if mode in (MODE.BASE, MODE.AUTH):
- raise ValueError("PSK is not compatible with this mode")
- def _key_schedule(self,
- shared_secret: bytes,
- info: bytes,
- psk_id: bytes,
- psk: bytes):
- suite_id = b"HPKE" + struct.pack('>HHH',
- self._kem_id,
- self._kdf_id,
- self._aead_id)
- psk_id_hash = _labeled_extract(b'',
- b'psk_id_hash',
- psk_id,
- suite_id,
- self._hashmod)
- info_hash = _labeled_extract(b'',
- b'info_hash',
- info,
- suite_id,
- self._hashmod)
- key_schedule_context = self._mode.to_bytes(1, 'big') + psk_id_hash + info_hash
- secret = _labeled_extract(shared_secret,
- b'secret',
- psk,
- suite_id,
- self._hashmod)
- key = _labeled_expand(secret,
- b'key',
- key_schedule_context,
- self._Nk,
- suite_id,
- self._hashmod)
- base_nonce = _labeled_expand(secret,
- b'base_nonce',
- key_schedule_context,
- self._Nn,
- suite_id,
- self._hashmod)
- exporter_secret = _labeled_expand(secret,
- b'exp',
- key_schedule_context,
- self._Nh,
- suite_id,
- self._hashmod)
- return key, base_nonce, exporter_secret
- def _new_cipher(self):
- nonce = strxor(self._base_nonce, self._sequence.to_bytes(self._Nn, 'big'))
- if self._aead_id in (AEAD.AES128_GCM, AEAD.AES256_GCM):
- cipher = AES.new(self._key, AES.MODE_GCM, nonce=nonce, mac_len=self._Nt)
- elif self._aead_id == AEAD.CHACHA20_POLY1305:
- cipher = ChaCha20_Poly1305.new(key=self._key, nonce=nonce)
- else:
- raise ValueError(f"Unknown AEAD cipher ID {self._aead_id:#x}")
- if self._sequence >= self._max_sequence:
- raise MessageLimitReachedError()
- self._sequence += 1
- return cipher
- def seal(self, plaintext: bytes, auth_data: Optional[bytes] = None):
- """Encrypt and authenticate a message.
- This method can be invoked multiple times
- to seal an ordered sequence of messages.
- Arguments:
- plaintext: bytes
- The message to seal.
- auth_data: bytes
- Optional. Additional Authenticated data (AAD) that is not encrypted
- but that will be also covered by the authentication tag.
- Returns:
- The ciphertext concatenated with the authentication tag.
- """
- if not self._encrypt:
- raise ValueError("This cipher can only be used to seal")
- cipher = self._new_cipher()
- if auth_data:
- cipher.update(auth_data)
- ct, tag = cipher.encrypt_and_digest(plaintext)
- return ct + tag
- def unseal(self, ciphertext: bytes, auth_data: Optional[bytes] = None):
- """Decrypt a message and validate its authenticity.
- This method can be invoked multiple times
- to unseal an ordered sequence of messages.
- Arguments:
- cipertext: bytes
- The message to unseal.
- auth_data: bytes
- Optional. Additional Authenticated data (AAD) that
- was also covered by the authentication tag.
- Returns:
- The original plaintext.
- Raises: ValueError
- If the ciphertext (in combination with the AAD) is not valid.
- But if it is the first time you call ``unseal()`` this
- exception may also mean that any of the parameters or keys
- used to establish the session is wrong or that one is missing.
- """
- if self._encrypt:
- raise ValueError("This cipher can only be used to unseal")
- if len(ciphertext) < self._Nt:
- raise ValueError("Ciphertext is too small")
- cipher = self._new_cipher()
- if auth_data:
- cipher.update(auth_data)
- try:
- pt = cipher.decrypt_and_verify(ciphertext[:-self._Nt],
- ciphertext[-self._Nt:])
- except ValueError:
- if self._sequence == 1:
- raise ValueError("Incorrect HPKE keys/parameters or invalid message (wrong MAC tag)")
- raise ValueError("Invalid message (wrong MAC tag)")
- return pt
- def new(*, receiver_key: EccKey,
- aead_id: AEAD,
- enc: Optional[bytes] = None,
- sender_key: Optional[EccKey] = None,
- psk: Optional[tuple[bytes, bytes]] = None,
- info: Optional[bytes] = None) -> HPKE_Cipher:
- """Create an HPKE context which can be used:
- - by the sender to seal (encrypt) a message or
- - by the receiver to unseal (decrypt) it.
- As a minimum, the two parties agree on the receiver's asymmetric key
- (of which the sender will only know the public half).
- Additionally, for authentication purposes, they may also agree on:
- * the sender's asymmetric key (of which the receiver will only know the public half)
- * a shared secret (e.g., a symmetric key derived from a password)
- Args:
- receiver_key:
- The ECC key of the receiver.
- It must be on one of the following curves: ``NIST P-256``,
- ``NIST P-384``, ``NIST P-521``, ``X25519`` or ``X448``.
- If this is a **public** key, the HPKE context can only be used to
- **seal** (**encrypt**).
- If this is a **private** key, the HPKE context can only be used to
- **unseal** (**decrypt**).
- aead_id:
- The HPKE identifier of the symmetric cipher.
- The possible values are:
- * ``HPKE.AEAD.AES128_GCM``
- * ``HPKE.AEAD.AES256_GCM``
- * ``HPKE.AEAD.CHACHA20_POLY1305``
- enc:
- The encapsulated session key (i.e., the KEM shared secret).
- The receiver must always specify this parameter.
- The sender must always omit this parameter.
- sender_key:
- The ECC key of the sender.
- It must be on the same curve as the ``receiver_key``.
- If the ``receiver_key`` is a public key, ``sender_key`` must be a
- private key, and vice versa.
- psk:
- A Pre-Shared Key (PSK) as a 2-tuple of non-empty
- byte strings: the identifier and the actual secret value.
- Sender and receiver must use the same PSK (or none).
- The secret value must be at least 32 bytes long,
- but it must not be a low-entropy password
- (use a KDF like PBKDF2 or scrypt to derive a secret
- from a password).
- info:
- A non-secret parameter that contributes
- to the generation of all session keys.
- Sender and receive must use the same **info** parameter (or none).
- Returns:
- An object that can be used for
- sealing (if ``receiver_key`` is a public key) or
- unsealing (if ``receiver_key`` is a private key).
- In the latter case,
- correctness of all the keys and parameters will only
- be assessed with the first call to ``unseal()``.
- """
- if aead_id not in AEAD:
- raise ValueError(f"Unknown AEAD cipher ID {aead_id:#x}")
- curve = receiver_key.curve
- if curve not in ('NIST P-256', 'NIST P-384', 'NIST P-521',
- 'Curve25519', 'Curve448'):
- raise ValueError(f"Unsupported curve {curve}")
- if sender_key:
- count_private_keys = int(receiver_key.has_private()) + \
- int(sender_key.has_private())
- if count_private_keys != 1:
- raise ValueError("Exactly 1 private key required")
- if sender_key.curve != curve:
- raise ValueError("Sender key uses {} but recipient key {}".
- format(sender_key.curve, curve))
- mode = MODE.AUTH if psk is None else MODE.AUTH_PSK
- else:
- mode = MODE.BASE if psk is None else MODE.PSK
- if psk is None:
- psk = b'', b''
- if info is None:
- info = b''
- return HPKE_Cipher(receiver_key,
- enc,
- sender_key,
- psk,
- info,
- aead_id,
- mode)
|