| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493 |
- # This file is licensed under the BSD 2-Clause License.
- # See https://opensource.org/licenses/BSD-2-Clause for details.
- import threading
- from Crypto.Util.number import bytes_to_long, long_to_bytes
- from Crypto.Util._raw_api import (VoidPointer, null_pointer,
- SmartPointer, c_size_t, c_uint8_ptr,
- c_ulonglong)
- from Crypto.Math.Numbers import Integer
- from Crypto.Random.random import getrandbits
- class CurveID(object):
- P192 = 1
- P224 = 2
- P256 = 3
- P384 = 4
- P521 = 5
- ED25519 = 6
- ED448 = 7
- CURVE25519 = 8
- CURVE448 = 9
- class _Curves(object):
- curves = {}
- curves_lock = threading.RLock()
- p192_names = ["p192", "NIST P-192", "P-192", "prime192v1", "secp192r1",
- "nistp192"]
- p224_names = ["p224", "NIST P-224", "P-224", "prime224v1", "secp224r1",
- "nistp224"]
- p256_names = ["p256", "NIST P-256", "P-256", "prime256v1", "secp256r1",
- "nistp256"]
- p384_names = ["p384", "NIST P-384", "P-384", "prime384v1", "secp384r1",
- "nistp384"]
- p521_names = ["p521", "NIST P-521", "P-521", "prime521v1", "secp521r1",
- "nistp521"]
- ed25519_names = ["ed25519", "Ed25519"]
- ed448_names = ["ed448", "Ed448"]
- curve25519_names = ["curve25519", "Curve25519", "X25519"]
- curve448_names = ["curve448", "Curve448", "X448"]
- all_names = p192_names + p224_names + p256_names + p384_names + p521_names + \
- ed25519_names + ed448_names + curve25519_names + curve448_names
- def __contains__(self, item):
- return item in self.all_names
- def __dir__(self):
- return self.all_names
- def load(self, name):
- if name in self.p192_names:
- from . import _nist_ecc
- p192 = _nist_ecc.p192_curve()
- p192.id = CurveID.P192
- self.curves.update(dict.fromkeys(self.p192_names, p192))
- elif name in self.p224_names:
- from . import _nist_ecc
- p224 = _nist_ecc.p224_curve()
- p224.id = CurveID.P224
- self.curves.update(dict.fromkeys(self.p224_names, p224))
- elif name in self.p256_names:
- from . import _nist_ecc
- p256 = _nist_ecc.p256_curve()
- p256.id = CurveID.P256
- self.curves.update(dict.fromkeys(self.p256_names, p256))
- elif name in self.p384_names:
- from . import _nist_ecc
- p384 = _nist_ecc.p384_curve()
- p384.id = CurveID.P384
- self.curves.update(dict.fromkeys(self.p384_names, p384))
- elif name in self.p521_names:
- from . import _nist_ecc
- p521 = _nist_ecc.p521_curve()
- p521.id = CurveID.P521
- self.curves.update(dict.fromkeys(self.p521_names, p521))
- elif name in self.ed25519_names:
- from . import _edwards
- ed25519 = _edwards.ed25519_curve()
- ed25519.id = CurveID.ED25519
- self.curves.update(dict.fromkeys(self.ed25519_names, ed25519))
- elif name in self.ed448_names:
- from . import _edwards
- ed448 = _edwards.ed448_curve()
- ed448.id = CurveID.ED448
- self.curves.update(dict.fromkeys(self.ed448_names, ed448))
- elif name in self.curve25519_names:
- from . import _montgomery
- curve25519 = _montgomery.curve25519_curve()
- curve25519.id = CurveID.CURVE25519
- self.curves.update(dict.fromkeys(self.curve25519_names, curve25519))
- elif name in self.curve448_names:
- from . import _montgomery
- curve448 = _montgomery.curve448_curve()
- curve448.id = CurveID.CURVE448
- self.curves.update(dict.fromkeys(self.curve448_names, curve448))
- else:
- raise ValueError("Unsupported curve '%s'" % name)
- return self.curves[name]
- def __getitem__(self, name):
- with self.curves_lock:
- curve = self.curves.get(name)
- if curve is None:
- curve = self.load(name)
- if name in self.curve25519_names or name in self.curve448_names:
- curve.G = EccXPoint(curve.Gx, name)
- else:
- curve.G = EccPoint(curve.Gx, curve.Gy, name)
- curve.is_edwards = curve.id in (CurveID.ED25519, CurveID.ED448)
- curve.is_montgomery = curve.id in (CurveID.CURVE25519,
- CurveID.CURVE448)
- curve.is_weierstrass = not (curve.is_edwards or
- curve.is_montgomery)
- return curve
- def items(self):
- # Load all curves
- for name in self.all_names:
- _ = self[name]
- return self.curves.items()
- _curves = _Curves()
- class EccPoint(object):
- """A class to model a point on an Elliptic Curve.
- The class supports operators for:
- * Adding two points: ``R = S + T``
- * In-place addition: ``S += T``
- * Negating a point: ``R = -T``
- * Comparing two points: ``if S == T: ...`` or ``if S != T: ...``
- * Multiplying a point by a scalar: ``R = S*k``
- * In-place multiplication by a scalar: ``T *= k``
- :ivar curve: The **canonical** name of the curve as defined in the `ECC table`_.
- :vartype curve: string
- :ivar x: The affine X-coordinate of the ECC point
- :vartype x: integer
- :ivar y: The affine Y-coordinate of the ECC point
- :vartype y: integer
- :ivar xy: The tuple with affine X- and Y- coordinates
- """
- def __init__(self, x, y, curve="p256"):
- try:
- self._curve = _curves[curve]
- except KeyError:
- raise ValueError("Unknown curve name %s" % str(curve))
- self.curve = self._curve.canonical
- if self._curve.id == CurveID.CURVE25519:
- raise ValueError("EccPoint cannot be created for Curve25519")
- modulus_bytes = self.size_in_bytes()
- xb = long_to_bytes(x, modulus_bytes)
- yb = long_to_bytes(y, modulus_bytes)
- if len(xb) != modulus_bytes or len(yb) != modulus_bytes:
- raise ValueError("Incorrect coordinate length")
- new_point = self._curve.rawlib.new_point
- free_func = self._curve.rawlib.free_point
- self._point = VoidPointer()
- try:
- context = self._curve.context.get()
- except AttributeError:
- context = null_pointer
- result = new_point(self._point.address_of(),
- c_uint8_ptr(xb),
- c_uint8_ptr(yb),
- c_size_t(modulus_bytes),
- context)
- if result:
- if result == 15:
- raise ValueError("The EC point does not belong to the curve")
- raise ValueError("Error %d while instantiating an EC point" % result)
- # Ensure that object disposal of this Python object will (eventually)
- # free the memory allocated by the raw library for the EC point
- self._point = SmartPointer(self._point.get(), free_func)
- def set(self, point):
- clone = self._curve.rawlib.clone
- free_func = self._curve.rawlib.free_point
- self._point = VoidPointer()
- result = clone(self._point.address_of(),
- point._point.get())
- if result:
- raise ValueError("Error %d while cloning an EC point" % result)
- self._point = SmartPointer(self._point.get(), free_func)
- return self
- def __eq__(self, point):
- if not isinstance(point, EccPoint):
- return False
- cmp_func = self._curve.rawlib.cmp
- return 0 == cmp_func(self._point.get(), point._point.get())
- # Only needed for Python 2
- def __ne__(self, point):
- return not self == point
- def __neg__(self):
- neg_func = self._curve.rawlib.neg
- np = self.copy()
- result = neg_func(np._point.get())
- if result:
- raise ValueError("Error %d while inverting an EC point" % result)
- return np
- def copy(self):
- """Return a copy of this point."""
- x, y = self.xy
- np = EccPoint(x, y, self.curve)
- return np
- def is_point_at_infinity(self):
- """``True`` if this is the *point-at-infinity*."""
- if self._curve.is_edwards:
- return self.x == 0
- else:
- return self.xy == (0, 0)
- def point_at_infinity(self):
- """Return the *point-at-infinity* for the curve."""
- if self._curve.is_edwards:
- return EccPoint(0, 1, self.curve)
- else:
- return EccPoint(0, 0, self.curve)
- @property
- def x(self):
- return self.xy[0]
- @property
- def y(self):
- return self.xy[1]
- @property
- def xy(self):
- modulus_bytes = self.size_in_bytes()
- xb = bytearray(modulus_bytes)
- yb = bytearray(modulus_bytes)
- get_xy = self._curve.rawlib.get_xy
- result = get_xy(c_uint8_ptr(xb),
- c_uint8_ptr(yb),
- c_size_t(modulus_bytes),
- self._point.get())
- if result:
- raise ValueError("Error %d while encoding an EC point" % result)
- return (Integer(bytes_to_long(xb)), Integer(bytes_to_long(yb)))
- def size_in_bytes(self):
- """Size of each coordinate, in bytes."""
- return (self.size_in_bits() + 7) // 8
- def size_in_bits(self):
- """Size of each coordinate, in bits."""
- return self._curve.modulus_bits
- def double(self):
- """Double this point (in-place operation).
- Returns:
- This same object (to enable chaining).
- """
- double_func = self._curve.rawlib.double
- result = double_func(self._point.get())
- if result:
- raise ValueError("Error %d while doubling an EC point" % result)
- return self
- def __iadd__(self, point):
- """Add a second point to this one"""
- add_func = self._curve.rawlib.add
- result = add_func(self._point.get(), point._point.get())
- if result:
- if result == 16:
- raise ValueError("EC points are not on the same curve")
- raise ValueError("Error %d while adding two EC points" % result)
- return self
- def __add__(self, point):
- """Return a new point, the addition of this one and another"""
- np = self.copy()
- np += point
- return np
- def __imul__(self, scalar):
- """Multiply this point by a scalar"""
- scalar_func = self._curve.rawlib.scalar
- if scalar < 0:
- raise ValueError("Scalar multiplication is only defined for non-negative integers")
- sb = long_to_bytes(scalar)
- result = scalar_func(self._point.get(),
- c_uint8_ptr(sb),
- c_size_t(len(sb)),
- c_ulonglong(getrandbits(64)))
- if result:
- raise ValueError("Error %d during scalar multiplication" % result)
- return self
- def __mul__(self, scalar):
- """Return a new point, the scalar product of this one"""
- np = self.copy()
- np *= scalar
- return np
- def __rmul__(self, left_hand):
- return self.__mul__(left_hand)
- class EccXPoint(object):
- """A class to model a point on an Elliptic Curve,
- where only the X-coordinate is exposed.
- The class supports operators for:
- * Multiplying a point by a scalar: ``R = S*k``
- * In-place multiplication by a scalar: ``T *= k``
- :ivar curve: The **canonical** name of the curve as defined in the `ECC table`_.
- :vartype curve: string
- :ivar x: The affine X-coordinate of the ECC point
- :vartype x: integer
- """
- def __init__(self, x, curve):
- # Once encoded, x must not exceed the length of the modulus,
- # but its value may match or exceed the modulus itself
- # (i.e., non-canonical value)
- try:
- self._curve = _curves[curve]
- except KeyError:
- raise ValueError("Unknown curve name %s" % str(curve))
- self.curve = self._curve.canonical
- if self._curve.id not in (CurveID.CURVE25519, CurveID.CURVE448):
- raise ValueError("EccXPoint can only be created for Curve25519/Curve448")
- new_point = self._curve.rawlib.new_point
- free_func = self._curve.rawlib.free_point
- self._point = VoidPointer()
- try:
- context = self._curve.context.get()
- except AttributeError:
- context = null_pointer
- modulus_bytes = self.size_in_bytes()
- if x is None:
- xb = null_pointer
- else:
- xb = c_uint8_ptr(long_to_bytes(x, modulus_bytes))
- if len(xb) != modulus_bytes:
- raise ValueError("Incorrect coordinate length")
- self._point = VoidPointer()
- result = new_point(self._point.address_of(),
- xb,
- c_size_t(modulus_bytes),
- context)
- if result == 15:
- raise ValueError("The EC point does not belong to the curve")
- if result:
- raise ValueError("Error %d while instantiating an EC point" % result)
- # Ensure that object disposal of this Python object will (eventually)
- # free the memory allocated by the raw library for the EC point
- self._point = SmartPointer(self._point.get(), free_func)
- def set(self, point):
- clone = self._curve.rawlib.clone
- free_func = self._curve.rawlib.free_point
- self._point = VoidPointer()
- result = clone(self._point.address_of(),
- point._point.get())
- if result:
- raise ValueError("Error %d while cloning an EC point" % result)
- self._point = SmartPointer(self._point.get(), free_func)
- return self
- def __eq__(self, point):
- if not isinstance(point, EccXPoint):
- return False
- cmp_func = self._curve.rawlib.cmp
- p1 = self._point.get()
- p2 = point._point.get()
- res = cmp_func(p1, p2)
- return 0 == res
- def copy(self):
- """Return a copy of this point."""
- try:
- x = self.x
- except ValueError:
- return self.point_at_infinity()
- return EccXPoint(x, self.curve)
- def is_point_at_infinity(self):
- """``True`` if this is the *point-at-infinity*."""
- try:
- _ = self.x
- except ValueError:
- return True
- return False
- def point_at_infinity(self):
- """Return the *point-at-infinity* for the curve."""
- return EccXPoint(None, self.curve)
- @property
- def x(self):
- modulus_bytes = self.size_in_bytes()
- xb = bytearray(modulus_bytes)
- get_x = self._curve.rawlib.get_x
- result = get_x(c_uint8_ptr(xb),
- c_size_t(modulus_bytes),
- self._point.get())
- if result == 19: # ERR_ECC_PAI
- raise ValueError("No X coordinate for the point at infinity")
- if result:
- raise ValueError("Error %d while getting X of an EC point" % result)
- return Integer(bytes_to_long(xb))
- def size_in_bytes(self):
- """Size of each coordinate, in bytes."""
- return (self.size_in_bits() + 7) // 8
- def size_in_bits(self):
- """Size of each coordinate, in bits."""
- return self._curve.modulus_bits
- def __imul__(self, scalar):
- """Multiply this point by a scalar"""
- scalar_func = self._curve.rawlib.scalar
- if scalar < 0:
- raise ValueError("Scalar multiplication is only defined for non-negative integers")
- sb = long_to_bytes(scalar)
- result = scalar_func(self._point.get(),
- c_uint8_ptr(sb),
- c_size_t(len(sb)),
- c_ulonglong(getrandbits(64)))
- if result:
- raise ValueError("Error %d during scalar multiplication" % result)
- return self
- def __mul__(self, scalar):
- """Return a new point, the scalar product of this one"""
- np = self.copy()
- np *= scalar
- return np
- def __rmul__(self, left_hand):
- return self.__mul__(left_hand)
|