util.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. # Copyright (c) 2019 - 2025, Ilan Schnell; All Rights Reserved
  2. # bitarray is published under the PSF license.
  3. #
  4. # Author: Ilan Schnell
  5. """
  6. Useful utilities for working with bitarrays.
  7. """
  8. import os
  9. import sys
  10. import math
  11. import random
  12. from bitarray import bitarray, bits2bytes
  13. from bitarray._util import (
  14. zeros, ones, count_n, parity, _ssqi, xor_indices,
  15. count_and, count_or, count_xor, any_and, subset,
  16. correspond_all, byteswap,
  17. serialize, deserialize,
  18. ba2hex, hex2ba,
  19. ba2base, base2ba,
  20. sc_encode, sc_decode,
  21. vl_encode, vl_decode,
  22. canonical_decode,
  23. )
  24. __all__ = [
  25. 'zeros', 'ones', 'urandom', 'random_k', 'random_p', 'gen_primes',
  26. 'pprint', 'strip', 'count_n',
  27. 'parity', 'sum_indices', 'xor_indices',
  28. 'count_and', 'count_or', 'count_xor', 'any_and', 'subset',
  29. 'correspond_all', 'byteswap', 'intervals',
  30. 'ba2hex', 'hex2ba',
  31. 'ba2base', 'base2ba',
  32. 'ba2int', 'int2ba',
  33. 'serialize', 'deserialize',
  34. 'sc_encode', 'sc_decode',
  35. 'vl_encode', 'vl_decode',
  36. 'huffman_code', 'canonical_huffman', 'canonical_decode',
  37. ]
  38. def urandom(__length, endian=None):
  39. """urandom(n, /, endian=None) -> bitarray
  40. Return random bitarray of length `n` (uses `os.urandom()`).
  41. """
  42. a = bitarray(os.urandom(bits2bytes(__length)), endian)
  43. del a[__length:]
  44. return a
  45. def random_k(__n, k, endian=None):
  46. """random_k(n, /, k, endian=None) -> bitarray
  47. Return (pseudo-) random bitarray of length `n` with `k` elements
  48. set to one. Mathematically equivalent to setting (in a bitarray of
  49. length `n`) all bits at indices `random.sample(range(n), k)` to one.
  50. The random bitarrays are reproducible when giving Python's `random.seed()`
  51. a specific seed value.
  52. """
  53. r = _Random(__n, endian)
  54. if not isinstance(k, int):
  55. raise TypeError("int expected, got '%s'" % type(k).__name__)
  56. return r.random_k(k)
  57. def random_p(__n, p=0.5, endian=None):
  58. """random_p(n, /, p=0.5, endian=None) -> bitarray
  59. Return (pseudo-) random bitarray of length `n`, where each bit has
  60. probability `p` of being one (independent of any other bits). Mathematically
  61. equivalent to `bitarray((random() < p for _ in range(n)), endian)`, but much
  62. faster for large `n`. The random bitarrays are reproducible when giving
  63. Python's `random.seed()` with a specific seed value.
  64. This function requires Python 3.12 or higher, as it depends on the standard
  65. library function `random.binomialvariate()`. Raises `NotImplementedError`
  66. when Python version is too low.
  67. """
  68. if sys.version_info[:2] < (3, 12):
  69. raise NotImplementedError("bitarray.util.random_p() requires "
  70. "Python 3.12 or higher")
  71. r = _Random(__n, endian)
  72. return r.random_p(p)
  73. class _Random:
  74. # The main reason for this class it to enable testing functionality
  75. # individually in the test class Random_P_Tests in 'test_util.py'.
  76. # The test class also contains many comments and explanations.
  77. # To better understand how the algorithm works, see ./doc/random_p.rst
  78. # See also, VerificationTests in devel/test_random.py
  79. # maximal number of calls to .random_half() in .combine()
  80. M = 8
  81. # number of resulting probability intervals
  82. K = 1 << M
  83. # limit for setting individual bits randomly
  84. SMALL_P = 0.01
  85. def __init__(self, n=0, endian=None):
  86. self.n = n
  87. self.nbytes = bits2bytes(n)
  88. self.endian = endian
  89. def random_half(self):
  90. """
  91. Return bitarray with each bit having probability p = 1/2 of being 1.
  92. """
  93. nbytes = self.nbytes
  94. # use random module function for reproducibility (not urandom())
  95. b = random.getrandbits(8 * nbytes).to_bytes(nbytes, 'little')
  96. a = bitarray(b, self.endian)
  97. del a[self.n:]
  98. return a
  99. def op_seq(self, i):
  100. """
  101. Return bitarray containing operator sequence.
  102. Each item represents a bitwise operation: 0: AND 1: OR
  103. After applying the sequence (see .combine_half()), we
  104. obtain a bitarray with probability q = i / K
  105. """
  106. if not 0 < i < self.K:
  107. raise ValueError("0 < i < %d, got i = %d" % (self.K, i))
  108. # sequence of &, | operations - least significant operations first
  109. a = bitarray(i.to_bytes(2, byteorder="little"), "little")
  110. return a[a.index(1) + 1 : self.M]
  111. def combine_half(self, seq):
  112. """
  113. Combine random bitarrays with probability 1/2
  114. according to given operator sequence.
  115. """
  116. a = self.random_half()
  117. for k in seq:
  118. if k:
  119. a |= self.random_half()
  120. else:
  121. a &= self.random_half()
  122. return a
  123. def random_k(self, k):
  124. n = self.n
  125. # error check inputs and handle edge cases
  126. if k <= 0 or k >= n:
  127. if k == 0:
  128. return zeros(n, self.endian)
  129. if k == n:
  130. return ones(n, self.endian)
  131. raise ValueError("k must be in range 0 <= k <= n, got %s" % k)
  132. # exploit symmetry to establish: k <= n // 2
  133. if k > n // 2:
  134. a = self.random_k(n - k)
  135. a.invert() # use in-place to avoid copying
  136. return a
  137. # decide on sequence, see VerificationTests devel/test_random.py
  138. if k < 16 or k * self.K < 3 * n:
  139. i = 0
  140. else:
  141. p = k / n # p <= 0.5
  142. p -= (0.2 - 0.4 * p) / math.sqrt(n)
  143. i = int(p * (self.K + 1))
  144. # combine random bitarrays using bitwise AND and OR operations
  145. if i < 3:
  146. a = zeros(n, self.endian)
  147. diff = -k
  148. else:
  149. a = self.combine_half(self.op_seq(i))
  150. diff = a.count() - k
  151. randrange = random.randrange
  152. if diff < 0: # not enough bits 1 - increase count
  153. for _ in range(-diff):
  154. i = randrange(n)
  155. while a[i]:
  156. i = randrange(n)
  157. a[i] = 1
  158. elif diff > 0: # too many bits 1 - decrease count
  159. for _ in range(diff):
  160. i = randrange(n)
  161. while not a[i]:
  162. i = randrange(n)
  163. a[i] = 0
  164. return a
  165. def random_p(self, p):
  166. # error check inputs and handle edge cases
  167. if p <= 0.0 or p == 0.5 or p >= 1.0:
  168. if p == 0.0:
  169. return zeros(self.n, self.endian)
  170. if p == 0.5:
  171. return self.random_half()
  172. if p == 1.0:
  173. return ones(self.n, self.endian)
  174. raise ValueError("p must be in range 0.0 <= p <= 1.0, got %s" % p)
  175. # for small n, use literal definition
  176. if self.n < 16:
  177. return bitarray((random.random() < p for _ in range(self.n)),
  178. self.endian)
  179. # exploit symmetry to establish: p < 0.5
  180. if p > 0.5:
  181. a = self.random_p(1.0 - p)
  182. a.invert() # use in-place to avoid copying
  183. return a
  184. # for small p, set randomly individual bits
  185. if p < self.SMALL_P:
  186. return self.random_k(random.binomialvariate(self.n, p))
  187. # calculate operator sequence
  188. i = int(p * self.K)
  189. if p * (self.K + 1) > i + 1: # see devel/test_random.py
  190. i += 1
  191. seq = self.op_seq(i)
  192. q = i / self.K
  193. # when n is small compared to number of operations, also use literal
  194. if self.n < 100 and self.nbytes <= len(seq) + 3 * bool(q != p):
  195. return bitarray((random.random() < p for _ in range(self.n)),
  196. self.endian)
  197. # combine random bitarrays using bitwise AND and OR operations
  198. a = self.combine_half(seq)
  199. if q < p:
  200. x = (p - q) / (1.0 - q)
  201. a |= self.random_p(x)
  202. elif q > p:
  203. x = p / q
  204. a &= self.random_p(x)
  205. return a
  206. def gen_primes(__n, endian=None, odd=False):
  207. """gen_primes(n, /, endian=None, odd=False) -> bitarray
  208. Generate a bitarray of length `n` in which active indices are prime numbers.
  209. By default (`odd=False`), active indices correspond to prime numbers directly.
  210. When `odd=True`, only odd prime numbers are represented in the resulting
  211. bitarray `a`, and `a[i]` corresponds to `2*i+1` being prime or not.
  212. """
  213. n = int(__n)
  214. if n < 0:
  215. raise ValueError("bitarray length must be >= 0")
  216. if odd:
  217. a = ones(105, endian) # 105 = 3 * 5 * 7
  218. a[1::3] = 0
  219. a[2::5] = 0
  220. a[3::7] = 0
  221. f = "01110110"
  222. else:
  223. a = ones(210, endian) # 210 = 2 * 3 * 5 * 7
  224. for i in 2, 3, 5, 7:
  225. a[::i] = 0
  226. f = "00110101"
  227. # repeating the array many times is faster than setting the multiples
  228. # of the low primes to 0
  229. a *= (n + len(a) - 1) // len(a)
  230. a[:8] = bitarray(f, endian)
  231. del a[n:]
  232. # perform sieve starting at 11
  233. if odd:
  234. for i in a.search(1, 5, int(math.sqrt(n // 2) + 1.0)): # 11//2 = 5
  235. j = 2 * i + 1
  236. a[(j * j) // 2 :: j] = 0
  237. else:
  238. # i*i is always odd, and even bits are already set to 0: use step 2*i
  239. for i in a.search(1, 11, int(math.sqrt(n) + 1.0)):
  240. a[i * i :: 2 * i] = 0
  241. return a
  242. def sum_indices(__a, mode=1):
  243. """sum_indices(a, /, mode=1) -> int
  244. Return sum of indices of all active bits in bitarray `a`.
  245. Equivalent to `sum(i for i, v in enumerate(a) if v)`.
  246. `mode=2` sums square of indices.
  247. """
  248. if mode not in (1, 2):
  249. raise ValueError("unexpected mode %r" % mode)
  250. # For details see: devel/test_sum_indices.py
  251. n = 1 << 19 # block size 512 Kbits
  252. if len(__a) <= n: # shortcut for single block
  253. return _ssqi(__a, mode)
  254. # Constants
  255. m = n // 8 # block size in bytes
  256. o1 = n * (n - 1) // 2
  257. o2 = o1 * (2 * n - 1) // 3
  258. nblocks = (len(__a) + n - 1) // n
  259. padbits = __a.padbits
  260. sm = 0
  261. for i in range(nblocks):
  262. # use memoryview to avoid copying memory
  263. v = memoryview(__a)[i * m : (i + 1) * m]
  264. block = bitarray(None, __a.endian, buffer=v)
  265. if padbits and i == nblocks - 1:
  266. if block.readonly:
  267. block = bitarray(block)
  268. block[-padbits:] = 0
  269. k = block.count()
  270. if k:
  271. y = n * i
  272. z1 = o1 if k == n else _ssqi(block)
  273. if mode == 1:
  274. sm += k * y + z1
  275. else:
  276. z2 = o2 if k == n else _ssqi(block, 2)
  277. sm += (k * y + 2 * z1) * y + z2
  278. return sm
  279. def pprint(__a, stream=None, group=8, indent=4, width=80):
  280. """pprint(bitarray, /, stream=None, group=8, indent=4, width=80)
  281. Pretty-print bitarray object to `stream`, defaults is `sys.stdout`.
  282. By default, bits are grouped in bytes (8 bits), and 64 bits per line.
  283. Non-bitarray objects are printed using `pprint.pprint()`.
  284. """
  285. if stream is None:
  286. stream = sys.stdout
  287. if not isinstance(__a, bitarray):
  288. import pprint as _pprint
  289. _pprint.pprint(__a, stream=stream, indent=indent, width=width)
  290. return
  291. group = int(group)
  292. if group < 1:
  293. raise ValueError('group must be >= 1')
  294. indent = int(indent)
  295. if indent < 0:
  296. raise ValueError('indent must be >= 0')
  297. width = int(width)
  298. if width <= indent:
  299. raise ValueError('width must be > %d (indent)' % indent)
  300. gpl = (width - indent) // (group + 1) # groups per line
  301. epl = group * gpl # elements per line
  302. if epl == 0:
  303. epl = width - indent - 2
  304. type_name = type(__a).__name__
  305. # here 4 is len("'()'")
  306. multiline = len(type_name) + 4 + len(__a) + len(__a) // group >= width
  307. if multiline:
  308. quotes = "'''"
  309. elif __a:
  310. quotes = "'"
  311. else:
  312. quotes = ""
  313. stream.write("%s(%s" % (type_name, quotes))
  314. for i, b in enumerate(__a):
  315. if multiline and i % epl == 0:
  316. stream.write('\n%s' % (indent * ' '))
  317. if i % group == 0 and i % epl != 0:
  318. stream.write(' ')
  319. stream.write(str(b))
  320. if multiline:
  321. stream.write('\n')
  322. stream.write("%s)\n" % quotes)
  323. stream.flush()
  324. def strip(__a, mode='right'):
  325. """strip(bitarray, /, mode='right') -> bitarray
  326. Return a new bitarray with zeros stripped from left, right or both ends.
  327. Allowed values for mode are the strings: `left`, `right`, `both`
  328. """
  329. if not isinstance(mode, str):
  330. raise TypeError("str expected for mode, got '%s'" %
  331. type(__a).__name__)
  332. if mode not in ('left', 'right', 'both'):
  333. raise ValueError("mode must be 'left', 'right' or 'both', got %r" %
  334. mode)
  335. start = None if mode == 'right' else __a.find(1)
  336. if start == -1:
  337. return __a[:0]
  338. stop = None if mode == 'left' else __a.find(1, right=1) + 1
  339. return __a[start:stop]
  340. def intervals(__a):
  341. """intervals(bitarray, /) -> iterator
  342. Compute all uninterrupted intervals of 1s and 0s, and return an
  343. iterator over tuples `(value, start, stop)`. The intervals are guaranteed
  344. to be in order, and their size is always non-zero (`stop - start > 0`).
  345. """
  346. try:
  347. value = __a[0] # value of current interval
  348. except IndexError:
  349. return
  350. n = len(__a)
  351. stop = 0 # "previous" stop - becomes next start
  352. while stop < n:
  353. start = stop
  354. # assert __a[start] == value
  355. try: # find next occurrence of opposite value
  356. stop = __a.index(not value, start)
  357. except ValueError:
  358. stop = n
  359. yield int(value), start, stop
  360. value = not value # next interval has opposite value
  361. def ba2int(__a, signed=False):
  362. """ba2int(bitarray, /, signed=False) -> int
  363. Convert the given bitarray to an integer.
  364. The bit-endianness of the bitarray is respected.
  365. `signed` indicates whether two's complement is used to represent the integer.
  366. """
  367. if not isinstance(__a, bitarray):
  368. raise TypeError("bitarray expected, got '%s'" % type(__a).__name__)
  369. length = len(__a)
  370. if length == 0:
  371. raise ValueError("non-empty bitarray expected")
  372. if __a.padbits:
  373. pad = zeros(__a.padbits, __a.endian)
  374. __a = __a + pad if __a.endian == "little" else pad + __a
  375. res = int.from_bytes(__a.tobytes(), byteorder=__a.endian)
  376. if signed and res >> length - 1:
  377. res -= 1 << length
  378. return res
  379. def int2ba(__i, length=None, endian=None, signed=False):
  380. """int2ba(int, /, length=None, endian=None, signed=False) -> bitarray
  381. Convert the given integer to a bitarray (with given bit-endianness,
  382. and no leading (big-endian) / trailing (little-endian) zeros), unless
  383. the `length` of the bitarray is provided. An `OverflowError` is raised
  384. if the integer is not representable with the given number of bits.
  385. `signed` determines whether two's complement is used to represent the integer,
  386. and requires `length` to be provided.
  387. """
  388. if not isinstance(__i, int):
  389. raise TypeError("int expected, got '%s'" % type(__i).__name__)
  390. if length is not None:
  391. if not isinstance(length, int):
  392. raise TypeError("int expected for argument 'length'")
  393. if length <= 0:
  394. raise ValueError("length must be > 0")
  395. if signed:
  396. if length is None:
  397. raise TypeError("signed requires argument 'length'")
  398. m = 1 << length - 1
  399. if not (-m <= __i < m):
  400. raise OverflowError("signed integer not in range(%d, %d), "
  401. "got %d" % (-m, m, __i))
  402. if __i < 0:
  403. __i += 1 << length
  404. else: # unsigned
  405. if length and __i >> length:
  406. raise OverflowError("unsigned integer not in range(0, %d), "
  407. "got %d" % (1 << length, __i))
  408. a = bitarray(0, endian)
  409. b = __i.to_bytes(bits2bytes(__i.bit_length()), byteorder=a.endian)
  410. a.frombytes(b)
  411. le = a.endian == 'little'
  412. if length is None:
  413. return strip(a, 'right' if le else 'left') if a else a + '0'
  414. if len(a) > length:
  415. return a[:length] if le else a[-length:]
  416. if len(a) == length:
  417. return a
  418. # len(a) < length, we need padding
  419. pad = zeros(length - len(a), a.endian)
  420. return a + pad if le else pad + a
  421. # ------------------------------ Huffman coding -----------------------------
  422. def _huffman_tree(__freq_map):
  423. """_huffman_tree(dict, /) -> Node
  424. Given a dict mapping symbols to their frequency, construct a Huffman tree
  425. and return its root node.
  426. """
  427. from heapq import heappush, heappop
  428. class Node(object):
  429. """
  430. There are to tyes of Node instances (both have 'freq' attribute):
  431. * leaf node: has 'symbol' attribute
  432. * parent node: has 'child' attribute (tuple with both children)
  433. """
  434. def __lt__(self, other):
  435. # heapq needs to be able to compare the nodes
  436. return self.freq < other.freq
  437. minheap = []
  438. # create all leaf nodes and push them onto the queue
  439. for sym, f in __freq_map.items():
  440. leaf = Node()
  441. leaf.symbol = sym
  442. leaf.freq = f
  443. heappush(minheap, leaf)
  444. # repeat the process until only one node remains
  445. while len(minheap) > 1:
  446. # take the two nodes with lowest frequencies from the queue
  447. # to construct a new parent node and push it onto the queue
  448. parent = Node()
  449. parent.child = heappop(minheap), heappop(minheap)
  450. parent.freq = parent.child[0].freq + parent.child[1].freq
  451. heappush(minheap, parent)
  452. # the single remaining node is the root of the Huffman tree
  453. return minheap[0]
  454. def huffman_code(__freq_map, endian=None):
  455. """huffman_code(dict, /, endian=None) -> dict
  456. Given a frequency map, a dictionary mapping symbols to their frequency,
  457. calculate the Huffman code, i.e. a dict mapping those symbols to
  458. bitarrays (with given bit-endianness). Note that the symbols are not limited
  459. to being strings. Symbols may be any hashable object.
  460. """
  461. if not isinstance(__freq_map, dict):
  462. raise TypeError("dict expected, got '%s'" % type(__freq_map).__name__)
  463. if len(__freq_map) < 2:
  464. if len(__freq_map) == 0:
  465. raise ValueError("cannot create Huffman code with no symbols")
  466. # Only one symbol: Normally if only one symbol is given, the code
  467. # could be represented with zero bits. However here, the code should
  468. # be at least one bit for the .encode() and .decode() methods to work.
  469. # So we represent the symbol by a single code of length one, in
  470. # particular one 0 bit. This is an incomplete code, since if a 1 bit
  471. # is received, it has no meaning and will result in an error.
  472. sym = list(__freq_map)[0]
  473. return {sym: bitarray('0', endian)}
  474. result = {}
  475. def traverse(nd, prefix=bitarray(0, endian)):
  476. try: # leaf
  477. result[nd.symbol] = prefix
  478. except AttributeError: # parent, so traverse each child
  479. traverse(nd.child[0], prefix + '0')
  480. traverse(nd.child[1], prefix + '1')
  481. traverse(_huffman_tree(__freq_map))
  482. return result
  483. def canonical_huffman(__freq_map):
  484. """canonical_huffman(dict, /) -> tuple
  485. Given a frequency map, a dictionary mapping symbols to their frequency,
  486. calculate the canonical Huffman code. Returns a tuple containing:
  487. 0. the canonical Huffman code as a dict mapping symbols to bitarrays
  488. 1. a list containing the number of symbols of each code length
  489. 2. a list of symbols in canonical order
  490. Note: the two lists may be used as input for `canonical_decode()`.
  491. """
  492. if not isinstance(__freq_map, dict):
  493. raise TypeError("dict expected, got '%s'" % type(__freq_map).__name__)
  494. if len(__freq_map) < 2:
  495. if len(__freq_map) == 0:
  496. raise ValueError("cannot create Huffman code with no symbols")
  497. # Only one symbol: see note above in huffman_code()
  498. sym = list(__freq_map)[0]
  499. return {sym: bitarray('0', 'big')}, [0, 1], [sym]
  500. code_length = {} # map symbols to their code length
  501. def traverse(nd, length=0):
  502. # traverse the Huffman tree, but (unlike in huffman_code() above) we
  503. # now just simply record the length for reaching each symbol
  504. try: # leaf
  505. code_length[nd.symbol] = length
  506. except AttributeError: # parent, so traverse each child
  507. traverse(nd.child[0], length + 1)
  508. traverse(nd.child[1], length + 1)
  509. traverse(_huffman_tree(__freq_map))
  510. # We now have a mapping of symbols to their code length, which is all we
  511. # need to construct a list of tuples (symbol, code length) sorted by
  512. # code length:
  513. table = sorted(code_length.items(), key=lambda item: item[1])
  514. maxbits = table[-1][1]
  515. codedict = {}
  516. count = (maxbits + 1) * [0]
  517. code = 0
  518. for i, (sym, length) in enumerate(table):
  519. codedict[sym] = int2ba(code, length, 'big')
  520. count[length] += 1
  521. if i + 1 < len(table):
  522. code += 1
  523. code <<= table[i + 1][1] - length
  524. return codedict, count, [item[0] for item in table]