ecdsa.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. """
  2. Functions lifted from https://github.com/vbuterin/pybitcointools
  3. """
  4. import hashlib
  5. import hmac
  6. from typing import (
  7. Any,
  8. Callable,
  9. Tuple,
  10. )
  11. from eth_utils import (
  12. big_endian_to_int,
  13. int_to_big_endian,
  14. )
  15. from eth_keys.constants import (
  16. SECPK1_A as A,
  17. SECPK1_B as B,
  18. SECPK1_G as G,
  19. SECPK1_N as N,
  20. SECPK1_P as P,
  21. SECPK1_Gx as Gx,
  22. SECPK1_Gy as Gy,
  23. )
  24. from eth_keys.exceptions import (
  25. BadSignature,
  26. )
  27. from eth_keys.utils.padding import (
  28. pad32,
  29. )
  30. from .jacobian import (
  31. fast_add,
  32. fast_multiply,
  33. from_jacobian,
  34. inv,
  35. is_identity,
  36. jacobian_add,
  37. jacobian_multiply,
  38. )
  39. def decode_public_key(public_key_bytes: bytes) -> Tuple[int, int]:
  40. left = big_endian_to_int(public_key_bytes[0:32])
  41. right = big_endian_to_int(public_key_bytes[32:64])
  42. return left, right
  43. def encode_raw_public_key(raw_public_key: Tuple[int, int]) -> bytes:
  44. left, right = raw_public_key
  45. return b"".join(
  46. (
  47. pad32(int_to_big_endian(left)),
  48. pad32(int_to_big_endian(right)),
  49. )
  50. )
  51. def private_key_to_public_key(private_key_bytes: bytes) -> bytes:
  52. private_key_as_num = big_endian_to_int(private_key_bytes)
  53. if private_key_as_num >= N:
  54. raise Exception("Invalid privkey")
  55. raw_public_key = fast_multiply(G, private_key_as_num)
  56. public_key_bytes = encode_raw_public_key(raw_public_key)
  57. return public_key_bytes
  58. def compress_public_key(uncompressed_public_key_bytes: bytes) -> bytes:
  59. x, y = decode_public_key(uncompressed_public_key_bytes)
  60. if y % 2 == 0:
  61. prefix = b"\x02"
  62. else:
  63. prefix = b"\x03"
  64. return prefix + pad32(int_to_big_endian(x))
  65. def decompress_public_key(compressed_public_key_bytes: bytes) -> bytes:
  66. if len(compressed_public_key_bytes) != 33:
  67. raise ValueError("Invalid compressed public key")
  68. prefix = compressed_public_key_bytes[0]
  69. if prefix not in (2, 3):
  70. raise ValueError("Invalid compressed public key")
  71. x = big_endian_to_int(compressed_public_key_bytes[1:])
  72. y_squared = (x**3 + A * x + B) % P
  73. y_abs = pow(y_squared, ((P + 1) // 4), P)
  74. if (prefix == 2 and y_abs & 1 == 1) or (prefix == 3 and y_abs & 1 == 0):
  75. y = (-y_abs) % P
  76. else:
  77. y = y_abs
  78. return encode_raw_public_key((x, y))
  79. def deterministic_generate_k(
  80. msg_hash: bytes,
  81. private_key_bytes: bytes,
  82. digest_fn: Callable[[], Any] = hashlib.sha256,
  83. ) -> int:
  84. v_0 = b"\x01" * 32
  85. k_0 = b"\x00" * 32
  86. k_1 = hmac.new(
  87. k_0, v_0 + b"\x00" + private_key_bytes + msg_hash, digest_fn
  88. ).digest()
  89. v_1 = hmac.new(k_1, v_0, digest_fn).digest()
  90. k_2 = hmac.new(
  91. k_1, v_1 + b"\x01" + private_key_bytes + msg_hash, digest_fn
  92. ).digest()
  93. v_2 = hmac.new(k_2, v_1, digest_fn).digest()
  94. kb = hmac.new(k_2, v_2, digest_fn).digest()
  95. k = big_endian_to_int(kb)
  96. return k
  97. def ecdsa_raw_sign(msg_hash: bytes, private_key_bytes: bytes) -> Tuple[int, int, int]:
  98. z = big_endian_to_int(msg_hash)
  99. msg_hash_mod_n = pad32(int_to_big_endian(z % N))
  100. k = deterministic_generate_k(msg_hash_mod_n, private_key_bytes)
  101. r, y = fast_multiply(G, k)
  102. s_raw = inv(k, N) * (z + r * big_endian_to_int(private_key_bytes)) % N
  103. v = 27 + ((y % 2) ^ (0 if s_raw * 2 < N else 1))
  104. s = s_raw if s_raw * 2 < N else N - s_raw
  105. return v - 27, r, s
  106. def ecdsa_raw_verify(
  107. msg_hash: bytes, rs: Tuple[int, int], public_key_bytes: bytes
  108. ) -> bool:
  109. raw_public_key = decode_public_key(public_key_bytes)
  110. r, s = rs
  111. w = inv(s, N)
  112. z = big_endian_to_int(msg_hash)
  113. u1, u2 = z * w % N, r * w % N
  114. x, y = fast_add(
  115. fast_multiply(G, u1),
  116. fast_multiply(raw_public_key, u2),
  117. )
  118. return bool(r == x and (r % N) and (s % N))
  119. def ecdsa_raw_recover(msg_hash: bytes, vrs: Tuple[int, int, int]) -> bytes:
  120. v, r, s = vrs
  121. if v not in (0, 1):
  122. raise BadSignature(f"value of v, aka y-parity, was {v}, must be either 0 or 1")
  123. v += 27
  124. x = r
  125. xcubedaxb = (x * x * x + A * x + B) % P
  126. beta = pow(xcubedaxb, (P + 1) // 4, P)
  127. y = beta if v % 2 ^ beta % 2 else (P - beta)
  128. # If xcubedaxb is not a quadratic residue, then r cannot be the x coord
  129. # for a point on the curve, and so the sig is invalid
  130. if (xcubedaxb - y * y) % P != 0 or not (r % N) or not (s % N):
  131. raise BadSignature("Invalid signature")
  132. z = big_endian_to_int(msg_hash)
  133. Gz = jacobian_multiply((Gx, Gy, 1), (N - z) % N)
  134. XY = jacobian_multiply((x, y, 1), s)
  135. Qr = jacobian_add(Gz, XY)
  136. Q = jacobian_multiply(Qr, inv(r, N))
  137. if is_identity(Q):
  138. raise BadSignature("InvalidSignature")
  139. raw_public_key = from_jacobian(Q)
  140. return encode_raw_public_key(raw_public_key)