keyfile.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. import hashlib
  2. import hmac
  3. import io
  4. import json
  5. from typing import (
  6. IO,
  7. Any,
  8. AnyStr,
  9. Callable,
  10. Dict,
  11. Iterable,
  12. Literal,
  13. Mapping,
  14. Tuple,
  15. TypeVar,
  16. Union,
  17. cast,
  18. )
  19. import uuid
  20. from Crypto import (
  21. Random,
  22. )
  23. from Crypto.Cipher import (
  24. AES,
  25. )
  26. from Crypto.Protocol.KDF import (
  27. scrypt,
  28. )
  29. from Crypto.Util import (
  30. Counter,
  31. )
  32. from eth_keys import (
  33. keys,
  34. )
  35. from eth_typing import (
  36. HexStr,
  37. )
  38. from eth_utils import (
  39. big_endian_to_int,
  40. decode_hex,
  41. encode_hex,
  42. int_to_big_endian,
  43. is_dict,
  44. is_string,
  45. keccak,
  46. remove_0x_prefix,
  47. to_dict,
  48. )
  49. KDFType = Literal["pbkdf2", "scrypt"]
  50. TKey = TypeVar("TKey")
  51. TVal = TypeVar("TVal")
  52. typed_to_dict = cast(
  53. Callable[
  54. [Callable[..., Iterable[Union[Mapping[TKey, TVal], Tuple[TKey, TVal]]]]],
  55. Callable[..., Dict[TKey, TVal]],
  56. ],
  57. to_dict,
  58. )
  59. def encode_hex_no_prefix(value: AnyStr) -> HexStr:
  60. return remove_0x_prefix(encode_hex(value))
  61. def load_keyfile(path_or_file_obj: Union[str, IO[str]]) -> Any:
  62. if is_string(path_or_file_obj):
  63. assert isinstance(path_or_file_obj, str)
  64. with open(path_or_file_obj) as keyfile_file:
  65. return json.load(keyfile_file)
  66. else:
  67. assert isinstance(path_or_file_obj, io.TextIOBase)
  68. return json.load(path_or_file_obj)
  69. def create_keyfile_json(
  70. private_key: Union[bytes, bytearray, memoryview],
  71. password: str,
  72. version: int = 3,
  73. kdf: KDFType = "pbkdf2",
  74. iterations: Union[int, None] = None,
  75. salt_size: int = 16,
  76. ) -> Dict[str, Any]:
  77. if version == 3:
  78. return _create_v3_keyfile_json(
  79. private_key, password, kdf, iterations, salt_size
  80. )
  81. else:
  82. raise NotImplementedError("Not yet implemented")
  83. def decode_keyfile_json(raw_keyfile_json: Dict[Any, Any], password: str) -> bytes:
  84. keyfile_json = normalize_keys(raw_keyfile_json)
  85. version = keyfile_json["version"]
  86. if version == 3:
  87. return _decode_keyfile_json_v3(keyfile_json, password)
  88. if version == 4:
  89. return _decode_keyfile_json_v4(keyfile_json, password)
  90. else:
  91. raise NotImplementedError("Not yet implemented")
  92. def extract_key_from_keyfile(
  93. path_or_file_obj: Union[str, IO[str]], password: str
  94. ) -> bytes:
  95. keyfile_json = load_keyfile(path_or_file_obj)
  96. private_key = decode_keyfile_json(keyfile_json, password)
  97. return private_key
  98. @typed_to_dict
  99. def normalize_keys(keyfile_json: Dict[Any, Any]) -> Any:
  100. for key, value in keyfile_json.items():
  101. if is_string(key):
  102. norm_key = key.lower()
  103. else:
  104. norm_key = key
  105. if is_dict(value):
  106. norm_value = normalize_keys(value)
  107. else:
  108. norm_value = value
  109. yield norm_key, norm_value
  110. #
  111. # Version 3 creators
  112. #
  113. DKLEN = 32
  114. SCRYPT_R = 8
  115. SCRYPT_P = 1
  116. def _create_v3_keyfile_json(
  117. private_key: Union[bytes, bytearray, memoryview],
  118. password: str,
  119. kdf: KDFType,
  120. work_factor: Union[int, None] = None,
  121. salt_size: int = 16,
  122. ) -> Dict[str, Any]:
  123. salt = Random.get_random_bytes(salt_size)
  124. if work_factor is None:
  125. work_factor = get_default_work_factor_for_kdf(kdf)
  126. if kdf == "pbkdf2":
  127. derived_key = _pbkdf2_hash(
  128. password,
  129. hash_name="sha256",
  130. salt=salt,
  131. iterations=work_factor,
  132. dklen=DKLEN,
  133. )
  134. kdfparams = {
  135. "c": work_factor,
  136. "dklen": DKLEN,
  137. "prf": "hmac-sha256",
  138. "salt": encode_hex_no_prefix(salt),
  139. }
  140. elif kdf == "scrypt":
  141. derived_key = _scrypt_hash(
  142. password,
  143. salt=salt,
  144. buflen=DKLEN,
  145. r=SCRYPT_R,
  146. p=SCRYPT_P,
  147. n=work_factor,
  148. )
  149. kdfparams = {
  150. "dklen": DKLEN,
  151. "n": work_factor,
  152. "r": SCRYPT_R,
  153. "p": SCRYPT_P,
  154. "salt": encode_hex_no_prefix(salt),
  155. }
  156. else:
  157. raise NotImplementedError(f"KDF not implemented: {kdf}")
  158. iv = big_endian_to_int(Random.get_random_bytes(16))
  159. encrypt_key = derived_key[:16]
  160. ciphertext = encrypt_aes_ctr(private_key, encrypt_key, iv)
  161. mac = keccak(derived_key[16:32] + ciphertext)
  162. address = keys.PrivateKey(private_key).public_key.to_checksum_address()
  163. return {
  164. "address": remove_0x_prefix(address),
  165. "crypto": {
  166. "cipher": "aes-128-ctr",
  167. "cipherparams": {
  168. "iv": encode_hex_no_prefix(int_to_big_endian(iv)),
  169. },
  170. "ciphertext": encode_hex_no_prefix(ciphertext),
  171. "kdf": kdf,
  172. "kdfparams": kdfparams,
  173. "mac": encode_hex_no_prefix(mac),
  174. },
  175. "id": str(uuid.uuid4()),
  176. "version": 3,
  177. }
  178. #
  179. # Verson 3 decoder
  180. #
  181. def _decode_keyfile_json_v3(keyfile_json: Dict[str, Any], password: str) -> bytes:
  182. crypto = keyfile_json["crypto"]
  183. kdf = crypto["kdf"]
  184. # Derive the encryption key from the password using the key derivation
  185. # function.
  186. if kdf == "pbkdf2":
  187. derived_key = _derive_pbkdf_key(crypto["kdfparams"], password)
  188. elif kdf == "scrypt":
  189. derived_key = _derive_scrypt_key(crypto["kdfparams"], password)
  190. else:
  191. raise TypeError(f"Unsupported key derivation function: {kdf}")
  192. # Validate that the derived key matchs the provided MAC
  193. ciphertext = decode_hex(crypto["ciphertext"])
  194. mac = keccak(derived_key[16:32] + ciphertext)
  195. expected_mac = decode_hex(crypto["mac"])
  196. if not hmac.compare_digest(mac, expected_mac):
  197. raise ValueError("MAC mismatch")
  198. # Decrypt the ciphertext using the derived encryption key to get the
  199. # private key.
  200. encrypt_key = derived_key[:16]
  201. cipherparams = crypto["cipherparams"]
  202. iv = big_endian_to_int(decode_hex(cipherparams["iv"]))
  203. private_key = decrypt_aes_ctr(ciphertext, encrypt_key, iv)
  204. return private_key
  205. #
  206. # Verson 4 decoder
  207. #
  208. def _decode_keyfile_json_v4(keyfile_json: Dict[str, Any], password: str) -> bytes:
  209. crypto = keyfile_json["crypto"]
  210. kdf = crypto["kdf"]["function"]
  211. # Derive the encryption key from the password using the key derivation
  212. # function.
  213. if kdf == "pbkdf2":
  214. derived_key = _derive_pbkdf_key(crypto["kdf"]["params"], password)
  215. elif kdf == "scrypt":
  216. derived_key = _derive_scrypt_key(crypto["kdf"]["params"], password)
  217. else:
  218. raise TypeError(f"Unsupported key derivation function: {kdf}")
  219. cipher_message = decode_hex(crypto["cipher"]["message"])
  220. checksum_message = crypto["checksum"]["message"]
  221. if (
  222. hashlib.sha256(derived_key[16:32] + cipher_message).hexdigest()
  223. != checksum_message
  224. ):
  225. raise ValueError("Checksum mismatch")
  226. # Decrypt the cipher message using the derived encryption key to get the
  227. # private key.
  228. encrypt_key = derived_key[:16]
  229. cipherparams = crypto["cipher"]["params"]
  230. iv = big_endian_to_int(decode_hex(cipherparams["iv"]))
  231. private_key = decrypt_aes_ctr(cipher_message, encrypt_key, iv)
  232. return private_key
  233. #
  234. # Key derivation
  235. #
  236. def _derive_pbkdf_key(kdf_params: Dict[str, Any], password: str) -> bytes:
  237. salt = decode_hex(kdf_params["salt"])
  238. dklen = kdf_params["dklen"]
  239. should_be_hmac, _, hash_name = kdf_params["prf"].partition("-")
  240. assert should_be_hmac == "hmac"
  241. iterations = kdf_params["c"]
  242. derive_pbkdf_key = _pbkdf2_hash(password, hash_name, salt, iterations, dklen)
  243. return derive_pbkdf_key
  244. def _derive_scrypt_key(kdf_params: Dict[str, Any], password: str) -> bytes:
  245. salt = decode_hex(kdf_params["salt"])
  246. p = kdf_params["p"]
  247. r = kdf_params["r"]
  248. n = kdf_params["n"]
  249. buflen = kdf_params["dklen"]
  250. derived_scrypt_key = _scrypt_hash(
  251. password,
  252. salt=salt,
  253. n=n,
  254. r=r,
  255. p=p,
  256. buflen=buflen,
  257. )
  258. return derived_scrypt_key
  259. def _scrypt_hash(
  260. password: str, salt: bytes, n: int, r: int, p: int, buflen: int
  261. ) -> bytes:
  262. derived_key = scrypt(
  263. password,
  264. salt=salt,
  265. key_len=buflen,
  266. N=n,
  267. r=r,
  268. p=p,
  269. num_keys=1,
  270. )
  271. return cast(bytes, derived_key)
  272. def _pbkdf2_hash(
  273. password: Any, hash_name: str, salt: bytes, iterations: int, dklen: int
  274. ) -> bytes:
  275. derived_key = hashlib.pbkdf2_hmac(
  276. hash_name=hash_name,
  277. password=password,
  278. salt=salt,
  279. iterations=iterations,
  280. dklen=dklen,
  281. )
  282. return derived_key
  283. #
  284. # Encryption and Decryption
  285. #
  286. def decrypt_aes_ctr(ciphertext: bytes, key: bytes, iv: int) -> bytes:
  287. ctr = Counter.new(128, initial_value=iv, allow_wraparound=True)
  288. encryptor = AES.new(key, AES.MODE_CTR, counter=ctr)
  289. return cast(bytes, encryptor.decrypt(ciphertext))
  290. def encrypt_aes_ctr(
  291. value: Union[bytes, bytearray, memoryview], key: bytes, iv: int
  292. ) -> bytes:
  293. ctr = Counter.new(128, initial_value=iv, allow_wraparound=True)
  294. encryptor = AES.new(key, AES.MODE_CTR, counter=ctr)
  295. ciphertext = encryptor.encrypt(value)
  296. return cast(bytes, ciphertext)
  297. #
  298. # Utility
  299. #
  300. def get_default_work_factor_for_kdf(kdf: KDFType) -> int:
  301. if kdf == "pbkdf2":
  302. return 1000000
  303. elif kdf == "scrypt":
  304. return 262144
  305. else:
  306. raise ValueError(f"Unsupported key derivation function: {kdf}")