ruler.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. """
  2. class Ruler
  3. Helper class, used by [[MarkdownIt#core]], [[MarkdownIt#block]] and
  4. [[MarkdownIt#inline]] to manage sequences of functions (rules):
  5. - keep rules in defined order
  6. - assign the name to each rule
  7. - enable/disable rules
  8. - add/replace rules
  9. - allow assign rules to additional named chains (in the same)
  10. - caching lists of active rules
  11. You will not need use this class directly until write plugins. For simple
  12. rules control use [[MarkdownIt.disable]], [[MarkdownIt.enable]] and
  13. [[MarkdownIt.use]].
  14. """
  15. from __future__ import annotations
  16. from collections.abc import Iterable
  17. from dataclasses import dataclass, field
  18. from typing import TYPE_CHECKING, Generic, TypedDict, TypeVar
  19. import warnings
  20. from .utils import EnvType
  21. if TYPE_CHECKING:
  22. from markdown_it import MarkdownIt
  23. class StateBase:
  24. def __init__(self, src: str, md: MarkdownIt, env: EnvType):
  25. self.src = src
  26. self.env = env
  27. self.md = md
  28. @property
  29. def src(self) -> str:
  30. return self._src
  31. @src.setter
  32. def src(self, value: str) -> None:
  33. self._src = value
  34. self._srcCharCode: tuple[int, ...] | None = None
  35. @property
  36. def srcCharCode(self) -> tuple[int, ...]:
  37. warnings.warn(
  38. "StateBase.srcCharCode is deprecated. Use StateBase.src instead.",
  39. DeprecationWarning,
  40. stacklevel=2,
  41. )
  42. if self._srcCharCode is None:
  43. self._srcCharCode = tuple(ord(c) for c in self._src)
  44. return self._srcCharCode
  45. class RuleOptionsType(TypedDict, total=False):
  46. alt: list[str]
  47. RuleFuncTv = TypeVar("RuleFuncTv")
  48. """A rule function, whose signature is dependent on the state type."""
  49. @dataclass(slots=True)
  50. class Rule(Generic[RuleFuncTv]):
  51. name: str
  52. enabled: bool
  53. fn: RuleFuncTv = field(repr=False)
  54. alt: list[str]
  55. class Ruler(Generic[RuleFuncTv]):
  56. def __init__(self) -> None:
  57. # List of added rules.
  58. self.__rules__: list[Rule[RuleFuncTv]] = []
  59. # Cached rule chains.
  60. # First level - chain name, '' for default.
  61. # Second level - diginal anchor for fast filtering by charcodes.
  62. self.__cache__: dict[str, list[RuleFuncTv]] | None = None
  63. def __find__(self, name: str) -> int:
  64. """Find rule index by name"""
  65. for i, rule in enumerate(self.__rules__):
  66. if rule.name == name:
  67. return i
  68. return -1
  69. def __compile__(self) -> None:
  70. """Build rules lookup cache"""
  71. chains = {""}
  72. # collect unique names
  73. for rule in self.__rules__:
  74. if not rule.enabled:
  75. continue
  76. for name in rule.alt:
  77. chains.add(name)
  78. self.__cache__ = {}
  79. for chain in chains:
  80. self.__cache__[chain] = []
  81. for rule in self.__rules__:
  82. if not rule.enabled:
  83. continue
  84. if chain and (chain not in rule.alt):
  85. continue
  86. self.__cache__[chain].append(rule.fn)
  87. def at(
  88. self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
  89. ) -> None:
  90. """Replace rule by name with new function & options.
  91. :param ruleName: rule name to replace.
  92. :param fn: new rule function.
  93. :param options: new rule options (not mandatory).
  94. :raises: KeyError if name not found
  95. """
  96. index = self.__find__(ruleName)
  97. options = options or {}
  98. if index == -1:
  99. raise KeyError(f"Parser rule not found: {ruleName}")
  100. self.__rules__[index].fn = fn
  101. self.__rules__[index].alt = options.get("alt", [])
  102. self.__cache__ = None
  103. def before(
  104. self,
  105. beforeName: str,
  106. ruleName: str,
  107. fn: RuleFuncTv,
  108. options: RuleOptionsType | None = None,
  109. ) -> None:
  110. """Add new rule to chain before one with given name.
  111. :param beforeName: new rule will be added before this one.
  112. :param ruleName: new rule will be added before this one.
  113. :param fn: new rule function.
  114. :param options: new rule options (not mandatory).
  115. :raises: KeyError if name not found
  116. """
  117. index = self.__find__(beforeName)
  118. options = options or {}
  119. if index == -1:
  120. raise KeyError(f"Parser rule not found: {beforeName}")
  121. self.__rules__.insert(
  122. index, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
  123. )
  124. self.__cache__ = None
  125. def after(
  126. self,
  127. afterName: str,
  128. ruleName: str,
  129. fn: RuleFuncTv,
  130. options: RuleOptionsType | None = None,
  131. ) -> None:
  132. """Add new rule to chain after one with given name.
  133. :param afterName: new rule will be added after this one.
  134. :param ruleName: new rule will be added after this one.
  135. :param fn: new rule function.
  136. :param options: new rule options (not mandatory).
  137. :raises: KeyError if name not found
  138. """
  139. index = self.__find__(afterName)
  140. options = options or {}
  141. if index == -1:
  142. raise KeyError(f"Parser rule not found: {afterName}")
  143. self.__rules__.insert(
  144. index + 1, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
  145. )
  146. self.__cache__ = None
  147. def push(
  148. self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
  149. ) -> None:
  150. """Push new rule to the end of chain.
  151. :param ruleName: new rule will be added to the end of chain.
  152. :param fn: new rule function.
  153. :param options: new rule options (not mandatory).
  154. """
  155. self.__rules__.append(
  156. Rule[RuleFuncTv](ruleName, True, fn, (options or {}).get("alt", []))
  157. )
  158. self.__cache__ = None
  159. def enable(
  160. self, names: str | Iterable[str], ignoreInvalid: bool = False
  161. ) -> list[str]:
  162. """Enable rules with given names.
  163. :param names: name or list of rule names to enable.
  164. :param ignoreInvalid: ignore errors when rule not found
  165. :raises: KeyError if name not found and not ignoreInvalid
  166. :return: list of found rule names
  167. """
  168. if isinstance(names, str):
  169. names = [names]
  170. result: list[str] = []
  171. for name in names:
  172. idx = self.__find__(name)
  173. if (idx < 0) and ignoreInvalid:
  174. continue
  175. if (idx < 0) and not ignoreInvalid:
  176. raise KeyError(f"Rules manager: invalid rule name {name}")
  177. self.__rules__[idx].enabled = True
  178. result.append(name)
  179. self.__cache__ = None
  180. return result
  181. def enableOnly(
  182. self, names: str | Iterable[str], ignoreInvalid: bool = False
  183. ) -> list[str]:
  184. """Enable rules with given names, and disable everything else.
  185. :param names: name or list of rule names to enable.
  186. :param ignoreInvalid: ignore errors when rule not found
  187. :raises: KeyError if name not found and not ignoreInvalid
  188. :return: list of found rule names
  189. """
  190. if isinstance(names, str):
  191. names = [names]
  192. for rule in self.__rules__:
  193. rule.enabled = False
  194. return self.enable(names, ignoreInvalid)
  195. def disable(
  196. self, names: str | Iterable[str], ignoreInvalid: bool = False
  197. ) -> list[str]:
  198. """Disable rules with given names.
  199. :param names: name or list of rule names to enable.
  200. :param ignoreInvalid: ignore errors when rule not found
  201. :raises: KeyError if name not found and not ignoreInvalid
  202. :return: list of found rule names
  203. """
  204. if isinstance(names, str):
  205. names = [names]
  206. result = []
  207. for name in names:
  208. idx = self.__find__(name)
  209. if (idx < 0) and ignoreInvalid:
  210. continue
  211. if (idx < 0) and not ignoreInvalid:
  212. raise KeyError(f"Rules manager: invalid rule name {name}")
  213. self.__rules__[idx].enabled = False
  214. result.append(name)
  215. self.__cache__ = None
  216. return result
  217. def getRules(self, chainName: str = "") -> list[RuleFuncTv]:
  218. """Return array of active functions (rules) for given chain name.
  219. It analyzes rules configuration, compiles caches if not exists and returns result.
  220. Default chain name is `''` (empty string). It can't be skipped.
  221. That's done intentionally, to keep signature monomorphic for high speed.
  222. """
  223. if self.__cache__ is None:
  224. self.__compile__()
  225. assert self.__cache__ is not None
  226. # Chain can be empty, if rules disabled. But we still have to return Array.
  227. return self.__cache__.get(chainName, []) or []
  228. def get_all_rules(self) -> list[str]:
  229. """Return all available rule names."""
  230. return [r.name for r in self.__rules__]
  231. def get_active_rules(self) -> list[str]:
  232. """Return the active rule names."""
  233. return [r.name for r in self.__rules__ if r.enabled]