_point.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. # This file is licensed under the BSD 2-Clause License.
  2. # See https://opensource.org/licenses/BSD-2-Clause for details.
  3. import threading
  4. from Crypto.Util.number import bytes_to_long, long_to_bytes
  5. from Crypto.Util._raw_api import (VoidPointer, null_pointer,
  6. SmartPointer, c_size_t, c_uint8_ptr,
  7. c_ulonglong)
  8. from Crypto.Math.Numbers import Integer
  9. from Crypto.Random.random import getrandbits
  10. class CurveID(object):
  11. P192 = 1
  12. P224 = 2
  13. P256 = 3
  14. P384 = 4
  15. P521 = 5
  16. ED25519 = 6
  17. ED448 = 7
  18. CURVE25519 = 8
  19. CURVE448 = 9
  20. class _Curves(object):
  21. curves = {}
  22. curves_lock = threading.RLock()
  23. p192_names = ["p192", "NIST P-192", "P-192", "prime192v1", "secp192r1",
  24. "nistp192"]
  25. p224_names = ["p224", "NIST P-224", "P-224", "prime224v1", "secp224r1",
  26. "nistp224"]
  27. p256_names = ["p256", "NIST P-256", "P-256", "prime256v1", "secp256r1",
  28. "nistp256"]
  29. p384_names = ["p384", "NIST P-384", "P-384", "prime384v1", "secp384r1",
  30. "nistp384"]
  31. p521_names = ["p521", "NIST P-521", "P-521", "prime521v1", "secp521r1",
  32. "nistp521"]
  33. ed25519_names = ["ed25519", "Ed25519"]
  34. ed448_names = ["ed448", "Ed448"]
  35. curve25519_names = ["curve25519", "Curve25519", "X25519"]
  36. curve448_names = ["curve448", "Curve448", "X448"]
  37. all_names = p192_names + p224_names + p256_names + p384_names + p521_names + \
  38. ed25519_names + ed448_names + curve25519_names + curve448_names
  39. def __contains__(self, item):
  40. return item in self.all_names
  41. def __dir__(self):
  42. return self.all_names
  43. def load(self, name):
  44. if name in self.p192_names:
  45. from . import _nist_ecc
  46. p192 = _nist_ecc.p192_curve()
  47. p192.id = CurveID.P192
  48. self.curves.update(dict.fromkeys(self.p192_names, p192))
  49. elif name in self.p224_names:
  50. from . import _nist_ecc
  51. p224 = _nist_ecc.p224_curve()
  52. p224.id = CurveID.P224
  53. self.curves.update(dict.fromkeys(self.p224_names, p224))
  54. elif name in self.p256_names:
  55. from . import _nist_ecc
  56. p256 = _nist_ecc.p256_curve()
  57. p256.id = CurveID.P256
  58. self.curves.update(dict.fromkeys(self.p256_names, p256))
  59. elif name in self.p384_names:
  60. from . import _nist_ecc
  61. p384 = _nist_ecc.p384_curve()
  62. p384.id = CurveID.P384
  63. self.curves.update(dict.fromkeys(self.p384_names, p384))
  64. elif name in self.p521_names:
  65. from . import _nist_ecc
  66. p521 = _nist_ecc.p521_curve()
  67. p521.id = CurveID.P521
  68. self.curves.update(dict.fromkeys(self.p521_names, p521))
  69. elif name in self.ed25519_names:
  70. from . import _edwards
  71. ed25519 = _edwards.ed25519_curve()
  72. ed25519.id = CurveID.ED25519
  73. self.curves.update(dict.fromkeys(self.ed25519_names, ed25519))
  74. elif name in self.ed448_names:
  75. from . import _edwards
  76. ed448 = _edwards.ed448_curve()
  77. ed448.id = CurveID.ED448
  78. self.curves.update(dict.fromkeys(self.ed448_names, ed448))
  79. elif name in self.curve25519_names:
  80. from . import _montgomery
  81. curve25519 = _montgomery.curve25519_curve()
  82. curve25519.id = CurveID.CURVE25519
  83. self.curves.update(dict.fromkeys(self.curve25519_names, curve25519))
  84. elif name in self.curve448_names:
  85. from . import _montgomery
  86. curve448 = _montgomery.curve448_curve()
  87. curve448.id = CurveID.CURVE448
  88. self.curves.update(dict.fromkeys(self.curve448_names, curve448))
  89. else:
  90. raise ValueError("Unsupported curve '%s'" % name)
  91. return self.curves[name]
  92. def __getitem__(self, name):
  93. with self.curves_lock:
  94. curve = self.curves.get(name)
  95. if curve is None:
  96. curve = self.load(name)
  97. if name in self.curve25519_names or name in self.curve448_names:
  98. curve.G = EccXPoint(curve.Gx, name)
  99. else:
  100. curve.G = EccPoint(curve.Gx, curve.Gy, name)
  101. curve.is_edwards = curve.id in (CurveID.ED25519, CurveID.ED448)
  102. curve.is_montgomery = curve.id in (CurveID.CURVE25519,
  103. CurveID.CURVE448)
  104. curve.is_weierstrass = not (curve.is_edwards or
  105. curve.is_montgomery)
  106. return curve
  107. def items(self):
  108. # Load all curves
  109. for name in self.all_names:
  110. _ = self[name]
  111. return self.curves.items()
  112. _curves = _Curves()
  113. class EccPoint(object):
  114. """A class to model a point on an Elliptic Curve.
  115. The class supports operators for:
  116. * Adding two points: ``R = S + T``
  117. * In-place addition: ``S += T``
  118. * Negating a point: ``R = -T``
  119. * Comparing two points: ``if S == T: ...`` or ``if S != T: ...``
  120. * Multiplying a point by a scalar: ``R = S*k``
  121. * In-place multiplication by a scalar: ``T *= k``
  122. :ivar curve: The **canonical** name of the curve as defined in the `ECC table`_.
  123. :vartype curve: string
  124. :ivar x: The affine X-coordinate of the ECC point
  125. :vartype x: integer
  126. :ivar y: The affine Y-coordinate of the ECC point
  127. :vartype y: integer
  128. :ivar xy: The tuple with affine X- and Y- coordinates
  129. """
  130. def __init__(self, x, y, curve="p256"):
  131. try:
  132. self._curve = _curves[curve]
  133. except KeyError:
  134. raise ValueError("Unknown curve name %s" % str(curve))
  135. self.curve = self._curve.canonical
  136. if self._curve.id == CurveID.CURVE25519:
  137. raise ValueError("EccPoint cannot be created for Curve25519")
  138. modulus_bytes = self.size_in_bytes()
  139. xb = long_to_bytes(x, modulus_bytes)
  140. yb = long_to_bytes(y, modulus_bytes)
  141. if len(xb) != modulus_bytes or len(yb) != modulus_bytes:
  142. raise ValueError("Incorrect coordinate length")
  143. new_point = self._curve.rawlib.new_point
  144. free_func = self._curve.rawlib.free_point
  145. self._point = VoidPointer()
  146. try:
  147. context = self._curve.context.get()
  148. except AttributeError:
  149. context = null_pointer
  150. result = new_point(self._point.address_of(),
  151. c_uint8_ptr(xb),
  152. c_uint8_ptr(yb),
  153. c_size_t(modulus_bytes),
  154. context)
  155. if result:
  156. if result == 15:
  157. raise ValueError("The EC point does not belong to the curve")
  158. raise ValueError("Error %d while instantiating an EC point" % result)
  159. # Ensure that object disposal of this Python object will (eventually)
  160. # free the memory allocated by the raw library for the EC point
  161. self._point = SmartPointer(self._point.get(), free_func)
  162. def set(self, point):
  163. clone = self._curve.rawlib.clone
  164. free_func = self._curve.rawlib.free_point
  165. self._point = VoidPointer()
  166. result = clone(self._point.address_of(),
  167. point._point.get())
  168. if result:
  169. raise ValueError("Error %d while cloning an EC point" % result)
  170. self._point = SmartPointer(self._point.get(), free_func)
  171. return self
  172. def __eq__(self, point):
  173. if not isinstance(point, EccPoint):
  174. return False
  175. cmp_func = self._curve.rawlib.cmp
  176. return 0 == cmp_func(self._point.get(), point._point.get())
  177. # Only needed for Python 2
  178. def __ne__(self, point):
  179. return not self == point
  180. def __neg__(self):
  181. neg_func = self._curve.rawlib.neg
  182. np = self.copy()
  183. result = neg_func(np._point.get())
  184. if result:
  185. raise ValueError("Error %d while inverting an EC point" % result)
  186. return np
  187. def copy(self):
  188. """Return a copy of this point."""
  189. x, y = self.xy
  190. np = EccPoint(x, y, self.curve)
  191. return np
  192. def is_point_at_infinity(self):
  193. """``True`` if this is the *point-at-infinity*."""
  194. if self._curve.is_edwards:
  195. return self.x == 0
  196. else:
  197. return self.xy == (0, 0)
  198. def point_at_infinity(self):
  199. """Return the *point-at-infinity* for the curve."""
  200. if self._curve.is_edwards:
  201. return EccPoint(0, 1, self.curve)
  202. else:
  203. return EccPoint(0, 0, self.curve)
  204. @property
  205. def x(self):
  206. return self.xy[0]
  207. @property
  208. def y(self):
  209. return self.xy[1]
  210. @property
  211. def xy(self):
  212. modulus_bytes = self.size_in_bytes()
  213. xb = bytearray(modulus_bytes)
  214. yb = bytearray(modulus_bytes)
  215. get_xy = self._curve.rawlib.get_xy
  216. result = get_xy(c_uint8_ptr(xb),
  217. c_uint8_ptr(yb),
  218. c_size_t(modulus_bytes),
  219. self._point.get())
  220. if result:
  221. raise ValueError("Error %d while encoding an EC point" % result)
  222. return (Integer(bytes_to_long(xb)), Integer(bytes_to_long(yb)))
  223. def size_in_bytes(self):
  224. """Size of each coordinate, in bytes."""
  225. return (self.size_in_bits() + 7) // 8
  226. def size_in_bits(self):
  227. """Size of each coordinate, in bits."""
  228. return self._curve.modulus_bits
  229. def double(self):
  230. """Double this point (in-place operation).
  231. Returns:
  232. This same object (to enable chaining).
  233. """
  234. double_func = self._curve.rawlib.double
  235. result = double_func(self._point.get())
  236. if result:
  237. raise ValueError("Error %d while doubling an EC point" % result)
  238. return self
  239. def __iadd__(self, point):
  240. """Add a second point to this one"""
  241. add_func = self._curve.rawlib.add
  242. result = add_func(self._point.get(), point._point.get())
  243. if result:
  244. if result == 16:
  245. raise ValueError("EC points are not on the same curve")
  246. raise ValueError("Error %d while adding two EC points" % result)
  247. return self
  248. def __add__(self, point):
  249. """Return a new point, the addition of this one and another"""
  250. np = self.copy()
  251. np += point
  252. return np
  253. def __imul__(self, scalar):
  254. """Multiply this point by a scalar"""
  255. scalar_func = self._curve.rawlib.scalar
  256. if scalar < 0:
  257. raise ValueError("Scalar multiplication is only defined for non-negative integers")
  258. sb = long_to_bytes(scalar)
  259. result = scalar_func(self._point.get(),
  260. c_uint8_ptr(sb),
  261. c_size_t(len(sb)),
  262. c_ulonglong(getrandbits(64)))
  263. if result:
  264. raise ValueError("Error %d during scalar multiplication" % result)
  265. return self
  266. def __mul__(self, scalar):
  267. """Return a new point, the scalar product of this one"""
  268. np = self.copy()
  269. np *= scalar
  270. return np
  271. def __rmul__(self, left_hand):
  272. return self.__mul__(left_hand)
  273. class EccXPoint(object):
  274. """A class to model a point on an Elliptic Curve,
  275. where only the X-coordinate is exposed.
  276. The class supports operators for:
  277. * Multiplying a point by a scalar: ``R = S*k``
  278. * In-place multiplication by a scalar: ``T *= k``
  279. :ivar curve: The **canonical** name of the curve as defined in the `ECC table`_.
  280. :vartype curve: string
  281. :ivar x: The affine X-coordinate of the ECC point
  282. :vartype x: integer
  283. """
  284. def __init__(self, x, curve):
  285. # Once encoded, x must not exceed the length of the modulus,
  286. # but its value may match or exceed the modulus itself
  287. # (i.e., non-canonical value)
  288. try:
  289. self._curve = _curves[curve]
  290. except KeyError:
  291. raise ValueError("Unknown curve name %s" % str(curve))
  292. self.curve = self._curve.canonical
  293. if self._curve.id not in (CurveID.CURVE25519, CurveID.CURVE448):
  294. raise ValueError("EccXPoint can only be created for Curve25519/Curve448")
  295. new_point = self._curve.rawlib.new_point
  296. free_func = self._curve.rawlib.free_point
  297. self._point = VoidPointer()
  298. try:
  299. context = self._curve.context.get()
  300. except AttributeError:
  301. context = null_pointer
  302. modulus_bytes = self.size_in_bytes()
  303. if x is None:
  304. xb = null_pointer
  305. else:
  306. xb = c_uint8_ptr(long_to_bytes(x, modulus_bytes))
  307. if len(xb) != modulus_bytes:
  308. raise ValueError("Incorrect coordinate length")
  309. self._point = VoidPointer()
  310. result = new_point(self._point.address_of(),
  311. xb,
  312. c_size_t(modulus_bytes),
  313. context)
  314. if result == 15:
  315. raise ValueError("The EC point does not belong to the curve")
  316. if result:
  317. raise ValueError("Error %d while instantiating an EC point" % result)
  318. # Ensure that object disposal of this Python object will (eventually)
  319. # free the memory allocated by the raw library for the EC point
  320. self._point = SmartPointer(self._point.get(), free_func)
  321. def set(self, point):
  322. clone = self._curve.rawlib.clone
  323. free_func = self._curve.rawlib.free_point
  324. self._point = VoidPointer()
  325. result = clone(self._point.address_of(),
  326. point._point.get())
  327. if result:
  328. raise ValueError("Error %d while cloning an EC point" % result)
  329. self._point = SmartPointer(self._point.get(), free_func)
  330. return self
  331. def __eq__(self, point):
  332. if not isinstance(point, EccXPoint):
  333. return False
  334. cmp_func = self._curve.rawlib.cmp
  335. p1 = self._point.get()
  336. p2 = point._point.get()
  337. res = cmp_func(p1, p2)
  338. return 0 == res
  339. def copy(self):
  340. """Return a copy of this point."""
  341. try:
  342. x = self.x
  343. except ValueError:
  344. return self.point_at_infinity()
  345. return EccXPoint(x, self.curve)
  346. def is_point_at_infinity(self):
  347. """``True`` if this is the *point-at-infinity*."""
  348. try:
  349. _ = self.x
  350. except ValueError:
  351. return True
  352. return False
  353. def point_at_infinity(self):
  354. """Return the *point-at-infinity* for the curve."""
  355. return EccXPoint(None, self.curve)
  356. @property
  357. def x(self):
  358. modulus_bytes = self.size_in_bytes()
  359. xb = bytearray(modulus_bytes)
  360. get_x = self._curve.rawlib.get_x
  361. result = get_x(c_uint8_ptr(xb),
  362. c_size_t(modulus_bytes),
  363. self._point.get())
  364. if result == 19: # ERR_ECC_PAI
  365. raise ValueError("No X coordinate for the point at infinity")
  366. if result:
  367. raise ValueError("Error %d while getting X of an EC point" % result)
  368. return Integer(bytes_to_long(xb))
  369. def size_in_bytes(self):
  370. """Size of each coordinate, in bytes."""
  371. return (self.size_in_bits() + 7) // 8
  372. def size_in_bits(self):
  373. """Size of each coordinate, in bits."""
  374. return self._curve.modulus_bits
  375. def __imul__(self, scalar):
  376. """Multiply this point by a scalar"""
  377. scalar_func = self._curve.rawlib.scalar
  378. if scalar < 0:
  379. raise ValueError("Scalar multiplication is only defined for non-negative integers")
  380. sb = long_to_bytes(scalar)
  381. result = scalar_func(self._point.get(),
  382. c_uint8_ptr(sb),
  383. c_size_t(len(sb)),
  384. c_ulonglong(getrandbits(64)))
  385. if result:
  386. raise ValueError("Error %d during scalar multiplication" % result)
  387. return self
  388. def __mul__(self, scalar):
  389. """Return a new point, the scalar product of this one"""
  390. np = self.copy()
  391. np *= scalar
  392. return np
  393. def __rmul__(self, left_hand):
  394. return self.__mul__(left_hand)