serializable.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. import abc
  2. import collections
  3. import copy
  4. import enum
  5. import re
  6. from eth_utils import (
  7. to_dict,
  8. to_set,
  9. to_tuple,
  10. )
  11. from rlp.exceptions import (
  12. ListDeserializationError,
  13. ListSerializationError,
  14. ObjectDeserializationError,
  15. ObjectSerializationError,
  16. )
  17. from .lists import (
  18. List,
  19. )
  20. class MetaBase:
  21. fields = None
  22. field_names = None
  23. field_attrs = None
  24. sedes = None
  25. def _get_duplicates(values):
  26. counts = collections.Counter(values)
  27. return tuple(item for item, num in counts.items() if num > 1)
  28. def validate_args_and_kwargs(args, kwargs, arg_names, allow_missing=False):
  29. duplicate_arg_names = _get_duplicates(arg_names)
  30. if duplicate_arg_names:
  31. raise TypeError(f"Duplicate argument names: {sorted(duplicate_arg_names)}")
  32. needed_kwargs = arg_names[len(args) :]
  33. used_kwargs = set(arg_names[: len(args)])
  34. duplicate_kwargs = used_kwargs.intersection(kwargs.keys())
  35. if duplicate_kwargs:
  36. raise TypeError(f"Duplicate kwargs: {sorted(duplicate_kwargs)}")
  37. unknown_kwargs = set(kwargs.keys()).difference(arg_names)
  38. if unknown_kwargs:
  39. raise TypeError(f"Unknown kwargs: {sorted(unknown_kwargs)}")
  40. missing_kwargs = set(needed_kwargs).difference(kwargs.keys())
  41. if not allow_missing and missing_kwargs:
  42. raise TypeError(f"Missing kwargs: {sorted(missing_kwargs)}")
  43. @to_tuple
  44. def merge_kwargs_to_args(args, kwargs, arg_names, allow_missing=False):
  45. validate_args_and_kwargs(args, kwargs, arg_names, allow_missing=allow_missing)
  46. needed_kwargs = arg_names[len(args) :]
  47. yield from args
  48. for arg_name in needed_kwargs:
  49. yield kwargs[arg_name]
  50. @to_dict
  51. def merge_args_to_kwargs(args, kwargs, arg_names, allow_missing=False):
  52. validate_args_and_kwargs(args, kwargs, arg_names, allow_missing=allow_missing)
  53. yield from kwargs.items()
  54. for value, name in zip(args, arg_names):
  55. yield name, value
  56. def _eq(left, right):
  57. """
  58. Equality comparison that allows for equality between tuple and list types with
  59. equivalent elements.
  60. """
  61. if isinstance(left, (tuple, list)) and isinstance(right, (tuple, list)):
  62. return len(left) == len(right) and all(_eq(*pair) for pair in zip(left, right))
  63. else:
  64. return left == right
  65. class ChangesetState(enum.Enum):
  66. INITIALIZED = "INITIALIZED"
  67. OPEN = "OPEN"
  68. CLOSED = "CLOSED"
  69. class ChangesetField:
  70. field = None
  71. def __init__(self, field):
  72. self.field = field
  73. def __get__(self, instance, type=None):
  74. if instance is None:
  75. return self
  76. elif instance.__state__ is not ChangesetState.OPEN:
  77. raise AttributeError(
  78. "Changeset is not active. Attribute access not allowed"
  79. )
  80. else:
  81. try:
  82. return instance.__diff__[self.field]
  83. except KeyError:
  84. return getattr(instance.__original__, self.field)
  85. def __set__(self, instance, value):
  86. if instance.__state__ is not ChangesetState.OPEN:
  87. raise AttributeError(
  88. "Changeset is not active. Attribute access not allowed"
  89. )
  90. instance.__diff__[self.field] = value
  91. class BaseChangeset:
  92. # reference to the original Serializable instance.
  93. __original__ = None
  94. # the state of this fieldset. Initialized -> Open -> Closed
  95. __state__ = None
  96. # the field changes that have been made in this change
  97. __diff__ = None
  98. def __init__(self, obj, changes=None):
  99. self.__original__ = obj
  100. self.__state__ = ChangesetState.INITIALIZED
  101. self.__diff__ = changes or {}
  102. def commit(self):
  103. obj = self.build_rlp()
  104. self.close()
  105. return obj
  106. def build_rlp(self):
  107. if self.__state__ == ChangesetState.OPEN:
  108. field_kwargs = {
  109. name: self.__diff__.get(name, self.__original__[name])
  110. for name in self.__original__._meta.field_names
  111. }
  112. return type(self.__original__)(**field_kwargs)
  113. else:
  114. raise ValueError("Cannot open Changeset which is not in the OPEN state")
  115. def open(self):
  116. if self.__state__ == ChangesetState.INITIALIZED:
  117. self.__state__ = ChangesetState.OPEN
  118. else:
  119. raise ValueError(
  120. "Cannot open Changeset which is not in the INITIALIZED state"
  121. )
  122. def close(self):
  123. if self.__state__ == ChangesetState.OPEN:
  124. self.__state__ = ChangesetState.CLOSED
  125. else:
  126. raise ValueError("Cannot close Changeset which is not in the OPEN state")
  127. def __enter__(self):
  128. if self.__state__ == ChangesetState.INITIALIZED:
  129. self.open()
  130. return self
  131. else:
  132. raise ValueError(
  133. "Cannot open Changeset which is not in the INITIALIZED state"
  134. )
  135. def __exit__(self, exc_type, exc_value, traceback):
  136. if self.__state__ == ChangesetState.OPEN:
  137. self.close()
  138. def Changeset(obj, changes):
  139. namespace = {name: ChangesetField(name) for name in obj._meta.field_names}
  140. cls = type(
  141. f"{obj.__class__.__name__}Changeset",
  142. (BaseChangeset,),
  143. namespace,
  144. )
  145. return cls(obj, changes)
  146. class BaseSerializable(collections.abc.Sequence):
  147. def __init__(self, *args, **kwargs):
  148. if kwargs:
  149. field_values = merge_kwargs_to_args(args, kwargs, self._meta.field_names)
  150. else:
  151. field_values = args
  152. if len(field_values) != len(self._meta.field_names):
  153. raise TypeError(
  154. f"Argument count mismatch. expected {len(self._meta.field_names)} - "
  155. f"got {len(field_values)} - "
  156. f"missing {','.join(self._meta.field_names[len(field_values) :])}"
  157. )
  158. for value, attr in zip(field_values, self._meta.field_attrs):
  159. setattr(self, attr, make_immutable(value))
  160. _cached_rlp = None
  161. def as_dict(self):
  162. return {field: value for field, value in zip(self._meta.field_names, self)}
  163. def __iter__(self):
  164. for attr in self._meta.field_attrs:
  165. yield getattr(self, attr)
  166. def __getitem__(self, idx):
  167. if isinstance(idx, int):
  168. attr = self._meta.field_attrs[idx]
  169. return getattr(self, attr)
  170. elif isinstance(idx, slice):
  171. field_slice = self._meta.field_attrs[idx]
  172. return tuple(getattr(self, field) for field in field_slice)
  173. elif isinstance(idx, str):
  174. return getattr(self, idx)
  175. else:
  176. raise IndexError(f"Unsupported type for __getitem__: {type(idx)}")
  177. def __len__(self):
  178. return len(self._meta.fields)
  179. def __eq__(self, other):
  180. return isinstance(other, Serializable) and hash(self) == hash(other)
  181. def __getstate__(self):
  182. state = self.__dict__.copy()
  183. # The hash() builtin is not stable across processes
  184. # (https://docs.python.org/3/reference/datamodel.html#object.__hash__), so we do
  185. # this here to ensure pickled instances don't carry the cached hash() as that
  186. # may cause issues like https://github.com/ethereum/py-evm/issues/1318
  187. state["_hash_cache"] = None
  188. return state
  189. _hash_cache = None
  190. def __hash__(self):
  191. if self._hash_cache is None:
  192. self._hash_cache = hash(tuple(self))
  193. return self._hash_cache
  194. def __repr__(self):
  195. keyword_args = tuple(f"{k}={v!r}" for k, v in self.as_dict().items())
  196. return f"{type(self).__name__}({', '.join(keyword_args)})"
  197. @classmethod
  198. def serialize(cls, obj):
  199. try:
  200. return cls._meta.sedes.serialize(obj)
  201. except ListSerializationError as e:
  202. raise ObjectSerializationError(obj=obj, sedes=cls, list_exception=e)
  203. @classmethod
  204. def deserialize(cls, serial, **extra_kwargs):
  205. try:
  206. values = cls._meta.sedes.deserialize(serial)
  207. except ListDeserializationError as e:
  208. raise ObjectDeserializationError(serial=serial, sedes=cls, list_exception=e)
  209. args_as_kwargs = merge_args_to_kwargs(values, {}, cls._meta.field_names)
  210. return cls(**args_as_kwargs, **extra_kwargs)
  211. def copy(self, *args, **kwargs):
  212. missing_overrides = (
  213. set(self._meta.field_names)
  214. .difference(kwargs.keys())
  215. .difference(self._meta.field_names[: len(args)])
  216. )
  217. unchanged_kwargs = {
  218. key: copy.deepcopy(value)
  219. for key, value in self.as_dict().items()
  220. if key in missing_overrides
  221. }
  222. combined_kwargs = dict(**unchanged_kwargs, **kwargs)
  223. all_kwargs = merge_args_to_kwargs(args, combined_kwargs, self._meta.field_names)
  224. return type(self)(**all_kwargs)
  225. def __copy__(self):
  226. return self.copy()
  227. def __deepcopy__(self, *args):
  228. return self.copy()
  229. _in_mutable_context = False
  230. def build_changeset(self, *args, **kwargs):
  231. args_as_kwargs = merge_args_to_kwargs(
  232. args,
  233. kwargs,
  234. self._meta.field_names,
  235. allow_missing=True,
  236. )
  237. return Changeset(self, changes=args_as_kwargs)
  238. def make_immutable(value):
  239. if isinstance(value, list):
  240. return tuple(make_immutable(item) for item in value)
  241. else:
  242. return value
  243. @to_tuple
  244. def _mk_field_attrs(field_names, extra_namespace):
  245. namespace = set(field_names).union(extra_namespace)
  246. for field in field_names:
  247. while True:
  248. field = "_" + field
  249. if field not in namespace:
  250. namespace.add(field)
  251. yield field
  252. break
  253. def _mk_field_property(field, attr):
  254. def field_fn_getter(self):
  255. return getattr(self, attr)
  256. def field_fn_setter(self, value):
  257. if not self._in_mutable_context:
  258. raise AttributeError("can't set attribute")
  259. setattr(self, attr, value)
  260. return property(field_fn_getter, field_fn_setter)
  261. IDENTIFIER_REGEX = re.compile(r"^[^\d\W]\w*\Z", re.UNICODE)
  262. def _is_valid_identifier(value):
  263. # Source: https://stackoverflow.com/questions/5474008/regular-expression-to-confirm-whether-a-string-is-a-valid-identifier-in-python # noqa: E501
  264. if not isinstance(value, str):
  265. return False
  266. return bool(IDENTIFIER_REGEX.match(value))
  267. @to_set
  268. def _get_class_namespace(cls):
  269. if hasattr(cls, "__dict__"):
  270. yield from cls.__dict__.keys()
  271. if hasattr(cls, "__slots__"):
  272. yield from cls.__slots__
  273. class SerializableBase(abc.ABCMeta):
  274. def __new__(cls, name, bases, attrs):
  275. super_new = super().__new__
  276. serializable_bases = tuple(b for b in bases if isinstance(b, SerializableBase))
  277. has_multiple_serializable_parents = len(serializable_bases) > 1
  278. is_serializable_subclass = any(serializable_bases)
  279. declares_fields = "fields" in attrs
  280. if not is_serializable_subclass:
  281. # If this is the original creation of the `Serializable` class,
  282. # just create the class.
  283. return super_new(cls, name, bases, attrs)
  284. elif not declares_fields:
  285. if has_multiple_serializable_parents:
  286. raise TypeError(
  287. "Cannot create subclass from multiple parent `Serializable` "
  288. "classes without explicit `fields` declaration."
  289. )
  290. else:
  291. # This is just a vanilla subclass of a `Serializable` parent class.
  292. parent_serializable = serializable_bases[0]
  293. if hasattr(parent_serializable, "_meta"):
  294. fields = parent_serializable._meta.fields
  295. else:
  296. # This is a subclass of `Serializable` which has no
  297. # `fields`, likely intended for further subclassing.
  298. fields = ()
  299. else:
  300. # ensure that the `fields` property is a tuple of tuples to ensure
  301. # immutability.
  302. fields = tuple(tuple(field) for field in attrs.pop("fields"))
  303. # split the fields into names and sedes
  304. if fields:
  305. field_names, sedes = zip(*fields)
  306. else:
  307. field_names, sedes = (), ()
  308. # check that field names are unique
  309. duplicate_field_names = _get_duplicates(field_names)
  310. if duplicate_field_names:
  311. raise TypeError(
  312. "The following fields are duplicated in the `fields` "
  313. f"declaration: {','.join(sorted(duplicate_field_names))}"
  314. )
  315. # check that field names are valid identifiers
  316. invalid_field_names = {
  317. field_name
  318. for field_name in field_names
  319. if not _is_valid_identifier(field_name)
  320. }
  321. if invalid_field_names:
  322. raise TypeError(
  323. "The following field names are not valid python identifiers: "
  324. f"{','.join(f'`{item}`' for item in sorted(invalid_field_names))}"
  325. )
  326. # extract all of the fields from parent `Serializable` classes.
  327. parent_field_names = {
  328. field_name
  329. for base in serializable_bases
  330. if hasattr(base, "_meta")
  331. for field_name in base._meta.field_names
  332. }
  333. # check that all fields from parent serializable classes are
  334. # represented on this class.
  335. missing_fields = parent_field_names.difference(field_names)
  336. if missing_fields:
  337. raise TypeError(
  338. "Subclasses of `Serializable` **must** contain a full superset "
  339. "of the fields defined in their parent classes. The following "
  340. f"fields are missing: {','.join(sorted(missing_fields))}"
  341. )
  342. # the actual field values are stored in separate *private* attributes.
  343. # This computes attribute names that don't conflict with other
  344. # attributes already present on the class.
  345. reserved_namespace = set(attrs.keys()).union(
  346. attr
  347. for base in bases
  348. for parent_cls in base.__mro__
  349. for attr in _get_class_namespace(parent_cls)
  350. )
  351. field_attrs = _mk_field_attrs(field_names, reserved_namespace)
  352. # construct the Meta object to store field information for the class
  353. meta_namespace = {
  354. "fields": fields,
  355. "field_attrs": field_attrs,
  356. "field_names": field_names,
  357. "sedes": List(sedes),
  358. }
  359. meta_base = attrs.pop("_meta", MetaBase)
  360. meta = type(
  361. "Meta",
  362. (meta_base,),
  363. meta_namespace,
  364. )
  365. attrs["_meta"] = meta
  366. # construct `property` attributes for read only access to the fields.
  367. field_props = tuple(
  368. (field, _mk_field_property(field, attr))
  369. for field, attr in zip(meta.field_names, meta.field_attrs)
  370. )
  371. return super_new(
  372. cls,
  373. name,
  374. bases,
  375. dict(field_props + tuple(attrs.items())),
  376. )
  377. class Serializable(BaseSerializable, metaclass=SerializableBase):
  378. """
  379. The base class for serializable objects.
  380. """