| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914 |
- # Copyright (c) 2019 - 2025, Ilan Schnell; All Rights Reserved
- # bitarray is published under the PSF license.
- #
- # Author: Ilan Schnell
- """
- Tests for bitarray.util module
- """
- import os
- import sys
- import math
- import array
- import base64
- import binascii
- import operator
- import struct
- import shutil
- import tempfile
- import unittest
- from io import StringIO
- from functools import reduce
- from random import (choice, choices, getrandbits, randrange, randint, random,
- sample, seed)
- from string import hexdigits, whitespace
- from collections import Counter
- from bitarray import (bitarray, frozenbitarray, decodetree, bits2bytes,
- _set_default_endian)
- from bitarray.test_bitarray import Util, skipIf, is_pypy, urandom_2, PTRSIZE
- from bitarray.util import (
- zeros, ones, urandom, random_k, random_p, pprint, strip, count_n,
- parity, gen_primes, sum_indices, xor_indices,
- count_and, count_or, count_xor, any_and, subset,
- correspond_all, byteswap, intervals,
- serialize, deserialize, ba2hex, hex2ba, ba2base, base2ba,
- ba2int, int2ba,
- sc_encode, sc_decode, vl_encode, vl_decode,
- _huffman_tree, huffman_code, canonical_huffman, canonical_decode,
- )
- from bitarray.util import _Random, _ssqi # type: ignore
- # --------------------------- zeros() ones() -----------------------------
- class ZerosOnesTests(unittest.TestCase):
- def test_basic(self):
- for _ in range(50):
- default_endian = choice(['little', 'big'])
- _set_default_endian(default_endian)
- a = choice([zeros(0), zeros(0, None), zeros(0, endian=None),
- ones(0), ones(0, None), ones(0, endian=None)])
- self.assertEqual(a, bitarray())
- self.assertEqual(a.endian, default_endian)
- self.assertEqual(type(a), bitarray)
- endian = choice(['little', 'big', None])
- n = randrange(100)
- a = choice([zeros(n, endian), zeros(n, endian=endian)])
- self.assertEqual(a.to01(), n * "0")
- self.assertEqual(a.endian, endian or default_endian)
- b = choice([ones(n, endian), ones(n, endian=endian)])
- self.assertEqual(b.to01(), n * "1")
- self.assertEqual(b.endian, endian or default_endian)
- def test_errors(self):
- for f in zeros, ones:
- self.assertRaises(TypeError, f) # no argument
- self.assertRaises(TypeError, f, '')
- self.assertRaises(TypeError, f, bitarray())
- self.assertRaises(TypeError, f, [])
- self.assertRaises(TypeError, f, 1.0)
- self.assertRaises(ValueError, f, -1)
- # endian not string
- for x in 0, 1, {}, [], False, True:
- self.assertRaises(TypeError, f, 0, x)
- # endian wrong string
- self.assertRaises(ValueError, f, 0, 'foo')
- # ----------------------------- urandom() ---------------------------------
- class URandomTests(unittest.TestCase):
- def test_basic(self):
- for _ in range(20):
- default_endian = choice(['little', 'big'])
- _set_default_endian(default_endian)
- a = choice([urandom(0), urandom(0, endian=None)])
- self.assertEqual(a, bitarray())
- self.assertEqual(a.endian, default_endian)
- endian = choice(['little', 'big', None])
- n = randrange(100)
- a = choice([urandom(n, endian), urandom(n, endian=endian)])
- self.assertEqual(len(a), n)
- self.assertEqual(a.endian, endian or default_endian)
- self.assertEqual(type(a), bitarray)
- def test_errors(self):
- U = urandom
- self.assertRaises(TypeError, U)
- self.assertRaises(TypeError, U, '')
- self.assertRaises(TypeError, U, bitarray())
- self.assertRaises(TypeError, U, [])
- self.assertRaises(TypeError, U, 1.0)
- self.assertRaises(ValueError, U, -1)
- self.assertRaises(TypeError, U, 0, 1)
- self.assertRaises(ValueError, U, 0, 'foo')
- def test_count(self):
- a = urandom(10_000_000)
- # see if population is within expectation
- self.assertTrue(abs(a.count() - 5_000_000) <= 15_811)
- # ---------------------------- random_k() ---------------------------------
- class Random_K_Tests(unittest.TestCase):
- def test_basic(self):
- for _ in range(250):
- default_endian = choice(['little', 'big'])
- _set_default_endian(default_endian)
- endian = choice(['little', 'big', None])
- n = randrange(120)
- k = randint(0, n)
- a = random_k(n, k, endian)
- self.assertTrue(type(a), bitarray)
- self.assertEqual(len(a), n)
- self.assertEqual(a.count(), k)
- self.assertEqual(a.endian, endian or default_endian)
- def test_inputs_and_edge_cases(self):
- R = random_k
- self.assertRaises(TypeError, R)
- self.assertRaises(TypeError, R, 4)
- self.assertRaises(TypeError, R, 1, "0.5")
- self.assertRaises(TypeError, R, 1, p=1)
- self.assertRaises(TypeError, R, 11, 5.5) # see issue #239
- self.assertRaises(ValueError, R, -1, 0)
- for k in -1, 11: # k is not 0 <= k <= n
- self.assertRaises(ValueError, R, 10, k)
- self.assertRaises(ValueError, R, 10, 7, 'foo')
- self.assertRaises(ValueError, R, 10, 7, endian='foo')
- for n in range(20):
- self.assertEqual(R(n, k=0), zeros(n))
- self.assertEqual(R(n, k=n), ones(n))
- def test_count(self):
- for n in range(10): # test explicitly for small n
- for k in range(n + 1):
- a = random_k(n, k)
- self.assertEqual(len(a), n)
- self.assertEqual(a.count(), k)
- for _ in range(100):
- n = randrange(10_000)
- k = randint(0, n)
- a = random_k(n, k)
- self.assertEqual(len(a), n)
- self.assertEqual(a.count(), k)
- def test_active_bits(self):
- # test if all bits are active
- n = 240
- cum = zeros(n)
- for _ in range(1000):
- k = randint(30, 40)
- a = random_k(n, k)
- self.assertEqual(a.count(), k)
- cum |= a
- if cum.all():
- break
- else:
- self.fail()
- # test uses math.comb, added in 3.8
- @skipIf(sys.version_info[:2] < (3, 8))
- def test_combinations(self):
- # for entire range of 0 <= k <= n, validate that random_k()
- # generates all possible combinations
- n = 7
- total = 0
- for k in range(n + 1):
- expected = math.comb(n, k)
- combs = set()
- for _ in range(10_000):
- combs.add(frozenbitarray(random_k(n, k)))
- if len(combs) == expected:
- total += expected
- break
- else:
- self.fail()
- self.assertEqual(total, 2 ** n)
- def collect_code_branches(self):
- # return list of bitarrays from all code branches of random_k()
- res = []
- # test small k (no .combine_half())
- res.append(random_k(300, 10))
- # general cases
- for k in 100, 500, 2_500, 4_000:
- res.append(random_k(5_000, k))
- return res
- def test_seed(self):
- # We ensure that after setting a seed value, random_k() will
- # always return the same random bitarrays. However, we do not ensure
- # that these results will not change in future versions of bitarray.
- _set_default_endian("little")
- a = []
- for val in 654321, 654322, 654321, 654322:
- seed(val)
- a.append(self.collect_code_branches())
- self.assertEqual(a[0], a[2])
- self.assertEqual(a[1], a[3])
- for item0, item1 in zip(a[0], a[1]):
- self.assertNotEqual(item0, item1)
- # initialize seed with current system time again
- seed()
- # ---------------- tests for internal _Random methods -------------------
- def test_op_seq(self):
- r = _Random()
- G = r.op_seq
- K = r.K
- M = r.M
- # special cases
- self.assertRaises(ValueError, G, 0)
- self.assertEqual(G(1), zeros(M - 1))
- self.assertEqual(G(K // 2), bitarray())
- self.assertEqual(G(K - 1), ones(M - 1))
- self.assertRaises(ValueError, G, K)
- # examples
- for p, s in [
- (0.15625, '0100'),
- (0.25, '0'), # 1/2 AND -> 1/4
- (0.375, '10'), # 1/2 OR -> 3/4 AND -> 3/8
- (0.5, ''),
- (0.625, '01'), # 1/2 AND -> 1/4 OR -> 5/8
- (0.6875, '101'),
- (0.75, '1'), # 1/2 OR -> 3/4
- ]:
- seq = G(int(p * K))
- self.assertEqual(seq.to01(), s)
- for i in range(1, K):
- seq = G(i)
- self.assertTrue(0 <= len(s) < M)
- q = 0.5 # a = random_half()
- for k in seq:
- # k=0: AND k=1: OR
- if k:
- q += 0.5 * (1.0 - q) # a |= random_half()
- else:
- q *= 0.5 # a &= random_half()
- self.assertEqual(q, i / K)
- def test_combine_half(self):
- r = _Random(1_000_000)
- for seq, mean in [
- ([], 500_000), # .random_half() itself
- ([0], 250_000), # AND
- ([1], 750_000), # OR
- ([1, 0], 375_000), # OR followed by AND
- ]:
- a = r.combine_half(seq)
- self.assertTrue(abs(a.count() - mean) < 5_000)
- # ---------------------------- random_p() ---------------------------------
- HAVE_BINOMIALVARIATE = sys.version_info[:2] >= (3, 12)
- @skipIf(HAVE_BINOMIALVARIATE)
- class Random_P_Not_Implemented(unittest.TestCase):
- def test_not_implemented(self):
- self.assertRaises(NotImplementedError, random_p, 100, 0.25)
- @skipIf(not HAVE_BINOMIALVARIATE)
- class Random_P_Tests(unittest.TestCase):
- def test_basic(self):
- for _ in range(250):
- default_endian = choice(['little', 'big'])
- _set_default_endian(default_endian)
- endian = choice(['little', 'big', None])
- n = randrange(120)
- p = choice([0.0, 0.0001, 0.2, 0.5, 0.9, 1.0])
- a = random_p(n, p, endian)
- self.assertTrue(type(a), bitarray)
- self.assertEqual(len(a), n)
- self.assertEqual(a.endian, endian or default_endian)
- def test_inputs_and_edge_cases(self):
- R = random_p
- self.assertRaises(TypeError, R)
- self.assertRaises(TypeError, R, 0.25)
- self.assertRaises(TypeError, R, 1, "0.5")
- self.assertRaises(ValueError, R, -1)
- self.assertRaises(ValueError, R, 1, -0.5)
- self.assertRaises(ValueError, R, 1, p=1.5)
- self.assertRaises(ValueError, R, 1, 0.15, 'foo')
- self.assertRaises(ValueError, R, 10, 0.5, endian='foo')
- self.assertEqual(R(0), bitarray())
- for n in range(20):
- self.assertEqual(R(n, 0), zeros(n))
- self.assertEqual(len(R(n, 0.5)), n)
- self.assertEqual(R(n, p=1), ones(n))
- def test_default(self):
- a = random_p(10_000_000) # p defaults to 0.5
- # see if population is within expectation
- self.assertTrue(abs(a.count() - 5_000_000) <= 15_811)
- def test_count(self):
- for _ in range(500):
- n = choice([randrange(4, 120), randrange(100, 1000)])
- p = choice([0.0001, 0.001, 0.01, 0.1, 0.25, 0.5, 0.9])
- sigma = math.sqrt(n * p * (1.0 - p))
- a = random_p(n, p)
- self.assertEqual(len(a), n)
- self.assertTrue(abs(a.count() - n * p) < max(4, 10 * sigma))
- def collect_code_branches(self):
- # return list of bitarrays from all code branches of random_p()
- res = []
- # for default p=0.5, random_p uses getrandbits
- res.append(random_p(32))
- # test small p
- res.append(random_p(5_000, 0.002))
- # small n (note that p=0.4 will call the "literal definition" case)
- res.append(random_p(15, 0.4))
- # general cases
- for p in 0.1, 0.2, 0.375, 0.4999, 0.7:
- res.append(random_p(150, p))
- return res
- def test_seed(self):
- # We ensure that after setting a seed value, random_p() will always
- # return the same random bitarrays. However, we do not ensure that
- # these results will not change in future versions of bitarray.
- _set_default_endian("little")
- a = []
- for val in 123456, 123457, 123456, 123457:
- seed(val)
- a.append(self.collect_code_branches())
- self.assertEqual(a[0], a[2])
- self.assertEqual(a[1], a[3])
- for item0, item1 in zip(a[0], a[1]):
- self.assertNotEqual(item0, item1)
- # initialize seed with current system time again
- seed()
- def test_small_p_limit(self):
- # For understanding how the algorithm works, see ./doc/random_p.rst
- # Also, see VerificationTests in devel/test_random.py
- r = _Random()
- limit = 1.0 / (r.K + 1) # lower limit for p
- self.assertTrue(r.SMALL_P > limit)
- # ---------------------------- gen_primes() -------------------------------
- class PrimeTests(unittest.TestCase):
- primes = [
- 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61,
- 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137,
- 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211,
- 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283,
- 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379,
- 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461,
- ]
- def test_errors(self):
- P = gen_primes
- self.assertRaises(TypeError, P, 3, 1)
- self.assertRaises(ValueError, P, "1.0")
- self.assertRaises(ValueError, P, -1)
- self.assertRaises(TypeError, P, 8, 4)
- self.assertRaises(TypeError, P, 8, foo="big")
- self.assertRaises(ValueError, P, 8, "foo")
- self.assertRaises(ValueError, P, 8, endian="foo")
- def test_explitcit(self):
- for n in range(230):
- default_endian = choice(['little', 'big'])
- _set_default_endian(default_endian)
- endian = choice(["little", "big", None])
- odd = getrandbits(1)
- a = gen_primes(n, endian, odd)
- self.assertEqual(len(a), n)
- self.assertEqual(a.endian, endian or default_endian)
- if odd:
- lst = [2] + [2 * i + 1 for i in a.search(1)]
- else:
- lst = [i for i in a.search(1)]
- self.assertEqual(lst, self.primes[:len(lst)])
- def test_cmp(self):
- N = 10_000
- c = ones(N)
- c[:2] = 0
- for i in range(int(math.sqrt(N) + 1.0)):
- if c[i]:
- c[i * i :: i] = 0
- self.assertEqual(list(c.search(1, 0, 462)), self.primes)
- for _ in range(20):
- n = randrange(N)
- endian = choice(["little", "big"])
- a = gen_primes(n, endian=endian)
- self.assertEqual(a, c[:n])
- self.assertEqual(a.endian, endian)
- b = gen_primes(n // 2, endian, odd=True)
- self.assertEqual(b, a[1::2])
- self.assertEqual(b, c[1:n:2])
- for _ in range(20):
- i = randrange(10, 100)
- x = randint(-1, 1)
- n = i * i + x
- self.assertEqual(gen_primes(n), c[:n])
- self.assertEqual(gen_primes(n // 2, odd=1), c[1:n:2])
- self.assertEqual(gen_primes(N), c)
- self.assertEqual(gen_primes(N // 2, odd=1), c[1::2])
- def test_count(self):
- for n, count, sum_p, sum_sqr_p in [
- ( 10, 4, 17, 87),
- ( 100, 25, 1_060, 65_796),
- ( 1_000, 168, 76_127, 49_345_379),
- (10_000, 1229, 5_736_396, 37_546_387_960),
- ]:
- a = gen_primes(n)
- self.assertEqual(len(a), n)
- self.assertEqual(a.count(), count)
- self.assertEqual(sum_indices(a), sum_p)
- self.assertEqual(sum_indices(a, 2), sum_sqr_p)
- b = gen_primes(n // 2, odd=1)
- self.assertEqual(len(b), n // 2)
- self.assertEqual(b.count() + 1, count) # +1 because of prime 2
- self.assertEqual(b, a[1::2])
- # ----------------------------- pprint() ----------------------------------
- class PPrintTests(unittest.TestCase):
- @staticmethod
- def get_code_string(a):
- f = StringIO()
- pprint(a, stream=f)
- return f.getvalue()
- def round_trip(self, a):
- b = eval(self.get_code_string(a))
- self.assertEqual(b, a)
- self.assertEqual(type(b), type(a))
- def test_bitarray(self):
- a = bitarray('110')
- self.assertEqual(self.get_code_string(a), "bitarray('110')\n")
- self.round_trip(a)
- def test_frozenbitarray(self):
- a = frozenbitarray('01')
- self.assertEqual(self.get_code_string(a), "frozenbitarray('01')\n")
- self.round_trip(a)
- def test_formatting(self):
- a = bitarray(200)
- for width in range(40, 130, 10):
- for n in range(1, 10):
- f = StringIO()
- pprint(a, stream=f, group=n, width=width)
- r = f.getvalue()
- self.assertEqual(eval(r), a)
- s = r.strip("bitary(')\n")
- for group in s.split()[:-1]:
- self.assertEqual(len(group), n)
- for line in s.split('\n'):
- self.assertTrue(len(line) < width)
- def test_fallback(self):
- for a in None, 'asd', [1, 2], bitarray(), frozenbitarray('1'):
- self.round_trip(a)
- def test_subclass(self):
- class Foo(bitarray):
- pass
- a = Foo()
- code = self.get_code_string(a)
- self.assertEqual(code, "Foo()\n")
- b = eval(code)
- self.assertEqual(b, a)
- self.assertEqual(type(b), type(a))
- def test_random(self):
- for n in range(150):
- self.round_trip(urandom(n))
- def test_file(self):
- tmpdir = tempfile.mkdtemp()
- tmpfile = os.path.join(tmpdir, 'testfile')
- a = urandom_2(1000)
- try:
- with open(tmpfile, 'w') as fo:
- pprint(a, fo)
- with open(tmpfile, 'r') as fi:
- b = eval(fi.read())
- self.assertEqual(a, b)
- finally:
- shutil.rmtree(tmpdir)
- # ----------------------------- strip() -----------------------------------
- class StripTests(unittest.TestCase, Util):
- def test_simple(self):
- self.assertRaises(TypeError, strip, '0110')
- self.assertRaises(TypeError, strip, bitarray(), 123)
- self.assertRaises(ValueError, strip, bitarray(), 'up')
- for default_endian in 'big', 'little':
- _set_default_endian(default_endian)
- a = bitarray('00010110000')
- self.assertEQUAL(strip(a), bitarray('0001011'))
- self.assertEQUAL(strip(a, 'left'), bitarray('10110000'))
- self.assertEQUAL(strip(a, 'both'), bitarray('1011'))
- b = frozenbitarray('00010110000')
- c = strip(b, 'both')
- self.assertEqual(c, bitarray('1011'))
- self.assertEqual(type(c), frozenbitarray)
- def test_zeros_ones(self):
- for _ in range(50):
- n = randrange(10)
- mode = choice(['left', 'right', 'both'])
- a = zeros(n)
- c = strip(a, mode)
- self.assertEqual(type(c), bitarray)
- self.assertEqual(len(c), 0)
- self.assertEqual(a, zeros(n))
- b = frozenbitarray(a)
- c = strip(b, mode)
- self.assertEqual(type(c), frozenbitarray)
- self.assertEqual(len(c), 0)
- a.setall(1)
- c = strip(a, mode)
- self.assertEqual(c, ones(n))
- def test_random(self):
- for a in self.randombitarrays():
- b = a.copy()
- f = frozenbitarray(a)
- s = a.to01()
- for mode, res in [
- ('left', bitarray(s.lstrip('0'), a.endian)),
- ('right', bitarray(s.rstrip('0'), a.endian)),
- ('both', bitarray(s.strip('0'), a.endian)),
- ]:
- c = strip(a, mode)
- self.assertEQUAL(c, res)
- self.assertEqual(type(c), bitarray)
- self.assertEQUAL(a, b)
- c = strip(f, mode)
- self.assertEQUAL(c, res)
- self.assertEqual(type(c), frozenbitarray)
- self.assertEQUAL(f, b)
- def test_one_set(self):
- for _ in range(10):
- n = randint(1, 10000)
- a = bitarray(n)
- a.setall(0)
- a[randrange(n)] = 1
- self.assertEqual(strip(a, 'both'), bitarray('1'))
- self.assertEqual(len(a), n)
- # ----------------------------- count_n() ---------------------------------
- class CountN_Tests(unittest.TestCase, Util):
- @staticmethod
- def count_n(a, n):
- "return lowest index i for which a[:i].count() == n"
- i, j = n, a.count(1, 0, n)
- while j < n:
- j += a[i]
- i += 1
- return i
- def check_result(self, a, n, i, v=1):
- self.assertEqual(a.count(v, 0, i), n)
- if i == 0:
- self.assertEqual(n, 0)
- else:
- self.assertEqual(a[i - 1], v)
- def test_empty(self):
- a = bitarray()
- self.assertEqual(count_n(a, 0), 0)
- self.assertEqual(count_n(a, 0, 0), 0)
- self.assertEqual(count_n(a, 0, 1), 0)
- self.assertRaises(ValueError, count_n, a, 1)
- self.assertRaises(TypeError, count_n, '', 0)
- self.assertRaises(TypeError, count_n, a, 7.0)
- self.assertRaises(ValueError, count_n, a, 0, 2)
- self.assertRaisesMessage(ValueError, "n = 1 larger than bitarray "
- "length 0", count_n, a, 1)
- def test_simple(self):
- a = bitarray('111110111110111110111110011110111110111110111000')
- b = a.copy()
- self.assertEqual(len(a), 48)
- self.assertEqual(a.count(), 37)
- self.assertEqual(a.count(0), 11)
- self.assertEqual(count_n(a, 0), 0)
- self.assertEqual(count_n(a, 0, 0), 0)
- self.assertEqual(count_n(a, 2, 0), 12)
- self.assertEqual(count_n(a, 10, 0), 47)
- self.assertEqual(count_n(a, 20), 23)
- self.assertEqual(count_n(a, 20, 1), 23)
- self.assertEqual(count_n(a, 37), 45)
- # n < 0
- self.assertRaisesMessage(ValueError, "non-negative integer expected",
- count_n, a, -1)
- # n > len(a)
- self.assertRaisesMessage(ValueError, "n = 49 larger than bitarray "
- "length 48", count_n, a, 49)
- # n > a.count(0)
- self.assertRaisesMessage(ValueError, "n = 12 exceeds total count "
- "(a.count(0) = 11)", count_n, a, 12, 0)
- # n > a.count(1)
- self.assertRaisesMessage(ValueError, "n = 38 exceeds total count "
- "(a.count(1) = 37)", count_n, a, 38, 1)
- for v in 0, 1:
- for n in range(a.count(v) + 1):
- i = count_n(a, n, v)
- self.check_result(a, n, i, v)
- self.assertEqual(a[:i].count(v), n)
- self.assertEqual(i, self.count_n(a if v else ~a, n))
- self.assertEQUAL(a, b)
- def test_frozenbitarray(self):
- a = frozenbitarray('001111101111101111101111100111100')
- self.assertEqual(len(a), 33)
- self.assertEqual(a.count(), 24)
- self.assertEqual(count_n(a, 0), 0)
- self.assertEqual(count_n(a, 10), 13)
- self.assertEqual(count_n(a, 24), 31)
- self.assertRaises(ValueError, count_n, a, -1) # n < 0
- self.assertRaises(ValueError, count_n, a, 25) # n > a.count()
- self.assertRaises(ValueError, count_n, a, 34) # n > len(a)
- for n in range(25):
- self.check_result(a, n, count_n(a, n))
- def test_ones(self):
- n = randint(1, 100_000)
- a = ones(n)
- self.assertEqual(count_n(a, n), n)
- self.assertRaises(ValueError, count_n, a, 1, 0)
- self.assertRaises(ValueError, count_n, a, n + 1)
- for _ in range(20):
- i = randint(0, n)
- self.assertEqual(count_n(a, i), i)
- def test_one_set(self):
- n = randint(1, 100_000)
- a = zeros(n)
- self.assertEqual(count_n(a, 0), 0)
- self.assertRaises(ValueError, count_n, a, 1)
- for _ in range(20):
- a.setall(0)
- i = randrange(n)
- a[i] = 1
- self.assertEqual(count_n(a, 1), i + 1)
- self.assertRaises(ValueError, count_n, a, 2)
- def test_last(self):
- for N in range(1, 1000):
- a = zeros(N)
- a[-1] = 1
- self.assertEqual(count_n(a, 1), N)
- if N == 1:
- msg = "n = 2 larger than bitarray length 1"
- else:
- msg = "n = 2 exceeds total count (a.count(1) = 1)"
- self.assertRaisesMessage(ValueError, msg, count_n, a, 2)
- def test_primes(self):
- a = gen_primes(10_000)
- # there are 1229 primes below 10,000
- self.assertEqual(a.count(), 1229)
- for n, p in [( 10, 29), # the 10th prime number is 29
- ( 100, 541), # the 100th prime number is 541
- (1000, 7919)]: # the 1000th prime number is 7919
- self.assertEqual(count_n(a, n) - 1, p)
- def test_large(self):
- for _ in range(100):
- N = randint(100_000, 250_000)
- a = bitarray(N)
- v = getrandbits(1)
- a.setall(not v)
- for _ in range(randrange(100)):
- a[randrange(N)] = v
- tc = a.count(v) # total count
- i = count_n(a, tc, v)
- self.check_result(a, tc, i, v)
- n = tc + 1
- self.assertRaisesMessage(ValueError, "n = %d exceeds total count "
- "(a.count(%d) = %d)" % (n, v, tc),
- count_n, a, n, v)
- for _ in range(20):
- n = randint(0, tc)
- i = count_n(a, n, v)
- self.check_result(a, n, i, v)
- # ---------------------------------------------------------------------------
- class BitwiseCountTests(unittest.TestCase, Util):
- def test_count_byte(self):
- for i in range(256):
- a = bitarray(bytearray([i]))
- cnt = a.count()
- self.assertEqual(count_and(a, zeros(8)), 0)
- self.assertEqual(count_and(a, ones(8)), cnt)
- self.assertEqual(count_and(a, a), cnt)
- self.assertEqual(count_or(a, zeros(8)), cnt)
- self.assertEqual(count_or(a, ones(8)), 8)
- self.assertEqual(count_or(a, a), cnt)
- self.assertEqual(count_xor(a, zeros(8)), cnt)
- self.assertEqual(count_xor(a, ones(8)), 8 - cnt)
- self.assertEqual(count_xor(a, a), 0)
- def test_1(self):
- a = bitarray('001111')
- aa = a.copy()
- b = bitarray('010011')
- bb = b.copy()
- self.assertEqual(count_and(a, b), 2)
- self.assertEqual(count_or(a, b), 5)
- self.assertEqual(count_xor(a, b), 3)
- for f in count_and, count_or, count_xor:
- # not two arguments
- self.assertRaises(TypeError, f)
- self.assertRaises(TypeError, f, a)
- self.assertRaises(TypeError, f, a, b, 3)
- # wrong argument types
- self.assertRaises(TypeError, f, a, '')
- self.assertRaises(TypeError, f, '1', b)
- self.assertRaises(TypeError, f, a, 4)
- self.assertEQUAL(a, aa)
- self.assertEQUAL(b, bb)
- b.append(1)
- for f in count_and, count_or, count_xor:
- self.assertRaises(ValueError, f, a, b)
- self.assertRaises(ValueError, f,
- bitarray('110', 'big'),
- bitarray('101', 'little'))
- def test_frozen(self):
- a = frozenbitarray('001111')
- b = frozenbitarray('010011')
- self.assertEqual(count_and(a, b), 2)
- self.assertEqual(count_or(a, b), 5)
- self.assertEqual(count_xor(a, b), 3)
- def test_random(self):
- for _ in range(100):
- n = randrange(1000)
- a = urandom_2(n)
- b = urandom(n, a.endian)
- self.assertEqual(count_and(a, b), (a & b).count())
- self.assertEqual(count_or(a, b), (a | b).count())
- self.assertEqual(count_xor(a, b), (a ^ b).count())
- def test_misc(self):
- for a in self.randombitarrays():
- n = len(a)
- b = urandom(n, a.endian)
- # any and
- self.assertEqual(any(a & b), count_and(a, b) > 0)
- self.assertEqual(any_and(a, b), any(a & b))
- # any or
- self.assertEqual(any(a | b), count_or(a, b) > 0)
- self.assertEqual(any(a | b), any(a) or any(b))
- # any xor
- self.assertEqual(any(a ^ b), count_xor(a, b) > 0)
- self.assertEqual(any(a ^ b), a != b)
- # all and
- self.assertEqual(all(a & b), count_and(a, b) == n)
- self.assertEqual(all(a & b), all(a) and all(b))
- # all or
- self.assertEqual(all(a | b), count_or(a, b) == n)
- # all xor
- self.assertEqual(all(a ^ b), count_xor(a, b) == n)
- self.assertEqual(all(a ^ b), a == ~b)
- # --------------------------- any_and() -----------------------------------
- class BitwiseAnyTests(unittest.TestCase, Util):
- def test_basic(self):
- a = frozenbitarray('0101')
- b = bitarray('0111')
- self.assertTrue(any_and(a, b))
- self.assertRaises(TypeError, any_and)
- self.assertRaises(TypeError, any_and, a, 4)
- b.append(1)
- self.assertRaises(ValueError, any_and, a, b)
- self.assertRaises(ValueError, any_and,
- bitarray('01', 'little'),
- bitarray('11', 'big'))
- def test_overlap(self):
- n = 100
- for _ in range(500):
- i1 = randint(0, n)
- j1 = randint(i1, n)
- r1 = range(i1, j1)
- i2 = randint(0, n)
- j2 = randint(i2, n)
- r2 = range(i2, j2)
- # test if ranges r1 and r2 overlap
- res1 = bool(r1) and bool(r2) and (i2 in r1 or i1 in r2)
- res2 = bool(set(r1) & set(r2))
- self.assertEqual(res1, res2)
- a1, a2 = bitarray(n), bitarray(n)
- a1[i1:j1] = a2[i2:j2] = 1
- self.assertEqual(any_and(a1, a2), res1)
- def test_common(self):
- n = 100
- for _ in range(500):
- s1 = self.random_slice(n)
- s2 = self.random_slice(n)
- r1 = range(n)[s1]
- r2 = range(n)[s2]
- # test if ranges r1 and r2 have common items
- a1, a2 = bitarray(n), bitarray(n)
- a1[s1] = a2[s2] = 1
- self.assertEqual(any_and(a1, a2), bool(set(r1) & set(r2)))
- def check(self, a, b):
- r = any_and(a, b)
- self.assertEqual(type(r), bool)
- self.assertEqual(r, any_and(b, a)) # symmetry
- self.assertEqual(r, any(a & b))
- self.assertEqual(r, (a & b).any())
- self.assertEqual(r, count_and(a, b) > 0)
- def test_explitcit(self):
- for a, b , res in [
- ('', '', False),
- ('0', '1', False),
- ('0', '0', False),
- ('1', '1', True),
- ('00011', '11100', False),
- ('00001011 1', '01000100 1', True)]:
- a = bitarray(a)
- b = bitarray(b)
- self.assertTrue(any_and(a, b) is res)
- self.check(a, b)
- def test_random(self):
- for a in self.randombitarrays():
- n = len(a)
- b = urandom(n, a.endian)
- self.check(a, b)
- def test_one(self):
- for n in range(1, 300):
- a = zeros(n)
- b = urandom(n)
- i = randrange(n)
- a[i] = 1
- self.assertEqual(b[i], any_and(a, b))
- # ---------------------------- subset() -----------------------------------
- class SubsetTests(unittest.TestCase, Util):
- def test_basic(self):
- a = frozenbitarray('0101')
- b = bitarray('0111')
- self.assertTrue(subset(a, b))
- self.assertFalse(subset(b, a))
- self.assertRaises(TypeError, subset)
- self.assertRaises(TypeError, subset, a, '')
- self.assertRaises(TypeError, subset, '1', b)
- self.assertRaises(TypeError, subset, a, 4)
- b.append(1)
- self.assertRaises(ValueError, subset, a, b)
- self.assertRaises(ValueError, subset,
- bitarray('01', 'little'),
- bitarray('11', 'big'))
- def check(self, a, b, res):
- r = subset(a, b)
- self.assertEqual(type(r), bool)
- self.assertEqual(r, res)
- self.assertEqual(a | b == b, res)
- self.assertEqual(a & b == a, res)
- def test_True(self):
- for a, b in [('', ''), ('0', '1'), ('0', '0'), ('1', '1'),
- ('000', '111'), ('0101', '0111'),
- ('000010111', '010011111')]:
- self.check(bitarray(a), bitarray(b), True)
- def test_False(self):
- for a, b in [('1', '0'), ('1101', '0111'),
- ('0000101111', '0100111011')]:
- self.check(bitarray(a), bitarray(b), False)
- def test_random(self):
- for a in self.randombitarrays(start=1):
- b = a.copy()
- # we set one random bit in b to 1, so a is always a subset of b
- b[randrange(len(a))] = 1
- self.check(a, b, True)
- # but b is only a subset when they are equal
- self.check(b, a, a == b)
- # we set all bits in a, which ensures that b is a subset of a
- a.setall(1)
- self.check(b, a, True)
- # ------------------------- correspond_all() ------------------------------
- class CorrespondAllTests(unittest.TestCase):
- def test_basic(self):
- a = frozenbitarray('0101')
- b = bitarray('0111')
- self.assertTrue(correspond_all(a, b), (1, 1, 1, 1))
- self.assertRaises(TypeError, correspond_all)
- b.append(1)
- self.assertRaises(ValueError, correspond_all, a, b)
- self.assertRaises(ValueError, correspond_all,
- bitarray('01', 'little'),
- bitarray('11', 'big'))
- def test_explitcit(self):
- for a, b, res in [
- ('', '', (0, 0, 0, 0)),
- ('0000011111',
- '0000100111', (4, 1, 2, 3)),
- ]:
- self.assertEqual(correspond_all(bitarray(a), bitarray(b)), res)
- def test_random(self):
- for _ in range(100):
- n = randrange(3000)
- a = urandom_2(n)
- b = urandom(n, a.endian)
- res = correspond_all(a, b)
- self.assertEqual(res[0], count_and(~a, ~b))
- self.assertEqual(res[1], count_and(~a, b))
- self.assertEqual(res[2], count_and(a, ~b))
- self.assertEqual(res[3], count_and(a, b))
- self.assertEqual(res[0], n - count_or(a, b))
- self.assertEqual(res[1] + res[2], count_xor(a, b))
- self.assertEqual(sum(res), n)
- # ----------------------------- byteswap() --------------------------------
- @skipIf(is_pypy)
- class ByteSwapTests(unittest.TestCase):
- def test_basic_bytearray(self):
- a = bytearray(b"ABCD")
- byteswap(a, 2)
- self.assertEqual(a, bytearray(b"BADC"))
- byteswap(a)
- self.assertEqual(a, bytearray(b"CDAB"))
- a = bytearray(b"ABCDEF")
- byteswap(a, 3)
- self.assertEqual(a, bytearray(b"CBAFED"))
- byteswap(a, 1)
- self.assertEqual(a, bytearray(b"CBAFED"))
- def test_basic_bitarray(self):
- a = bitarray("11110000 01010101")
- byteswap(a)
- self.assertEqual(a, bitarray("01010101 11110000"))
- a = bitarray("01111000 1001")
- b = a.copy()
- a.tobytes() # clear padbits
- byteswap(a)
- self.assertEqual(a, bitarray("10010000 0111"))
- byteswap(a)
- self.assertEqual(a, b)
- def test_basic_array(self):
- r = os.urandom(64)
- for typecode in array.typecodes:
- # type code 'u' is deprecated and will be removed in Python 3.16
- if typecode == 'u':
- continue
- a = array.array(typecode, r)
- self.assertEqual(len(a) * a.itemsize, 64)
- a.byteswap()
- byteswap(a, a.itemsize)
- self.assertEqual(a.tobytes(), r)
- def test_empty(self):
- a = bytearray()
- byteswap(a)
- self.assertEqual(a, bytearray())
- for n in range(10):
- byteswap(a, n)
- self.assertEqual(a, bytearray())
- def test_one_byte(self):
- a = bytearray(b'\xab')
- byteswap(a)
- self.assertEqual(a, bytearray(b'\xab'))
- for n in range(2):
- byteswap(a, n)
- self.assertEqual(a, bytearray(b'\xab'))
- def test_errors(self):
- # buffer not writable
- for a in b"AB", frozenbitarray(16):
- self.assertRaises(BufferError, byteswap, a)
- a = bytearray(b"ABCD")
- b = bitarray(32)
- for n in -1, 3, 5, 6:
- # byte size not multiple of n
- self.assertRaises(ValueError, byteswap, a, n)
- self.assertRaises(ValueError, byteswap, b, n)
- def test_range(self):
- for n in range(20):
- for m in range(20):
- r = os.urandom(m * n)
- a = bytearray(r)
- byteswap(a, n)
- lst = []
- for i in range(m):
- x = r[i * n:i * n + n]
- lst.extend(x[::-1])
- self.assertEqual(a, bytearray(lst))
- def test_reverse_bytearray(self):
- for n in range(100):
- r = os.urandom(n)
- a = bytearray(r)
- byteswap(a)
- self.assertEqual(a, bytearray(r[::-1]))
- def test_reverse_bitarray(self):
- for n in range(100):
- a = urandom(8 * n)
- b = a.copy()
- byteswap(a)
- a.bytereverse()
- self.assertEqual(a, b[::-1])
- # ------------------------------ parity() ---------------------------------
- class ParityTests(unittest.TestCase):
- def test_explitcit(self):
- for s, res in [('', 0), ('1', 1), ('0010011', 1), ('10100110', 0)]:
- self.assertTrue(parity(bitarray(s)) is res)
- self.assertTrue(parity(frozenbitarray(s)) is res)
- def test_zeros_ones(self):
- for n in range(2000):
- self.assertEqual(parity(zeros(n)), 0)
- self.assertEqual(parity(ones(n)), n % 2)
- def test_random(self):
- endian = choice(["little", "big"])
- a = bitarray(endian=endian)
- par = 0
- for i in range(2000):
- self.assertEqual(parity(a), par)
- self.assertEqual(par, a.count() % 2)
- self.assertEqual(a.endian, endian)
- self.assertEqual(len(a), i)
- v = getrandbits(1)
- a.append(v)
- par ^= v
- def test_wrong_args(self):
- self.assertRaises(TypeError, parity, '')
- self.assertRaises(TypeError, parity, 1)
- self.assertRaises(TypeError, parity)
- self.assertRaises(TypeError, parity, bitarray("110"), 1)
- # ---------------------------- sum_indices() ------------------------------
- class SumIndicesUtil(unittest.TestCase):
- def check_explicit(self, S):
- for s, r1, r2 in [
- ("", 0, 0), ("0", 0, 0), ("1", 0, 0), ("11", 1, 1),
- ("011", 3, 5), ("001", 2, 4), ("0001100", 7, 25),
- ("00001111", 22, 126), ("01100111 1101", 49, 381),
- ]:
- for a in [bitarray(s, choice(['little', 'big'])),
- frozenbitarray(s, choice(['little', 'big']))]:
- self.assertEqual(S(a, 1), r1)
- self.assertEqual(S(a, 2), r2)
- self.assertEqual(a, bitarray(s))
- def check_wrong_args(self, S):
- self.assertRaises(TypeError, S, '')
- self.assertRaises(TypeError, S, 1.0)
- self.assertRaises(TypeError, S)
- for mode in -1, 0, 3, 4:
- self.assertRaises(ValueError, S, bitarray("110"), mode)
- def check_urandom(self, S, n):
- a = urandom_2(n)
- self.assertEqual(S(a, 1), sum(i for i, v in enumerate(a) if v))
- self.assertEqual(S(a, 2), sum(i * i for i, v in enumerate(a) if v))
- def check_sparse(self, S, n, k, mode=1, freeze=False, inv=False):
- a = zeros(n, choice(['little', 'big']))
- self.assertEqual(S(a, mode), 0)
- self.assertFalse(a.any())
- indices = sample(range(n), k)
- a[indices] = 1
- res = sum(indices) if mode == 1 else sum(i * i for i in indices)
- if inv:
- a.invert()
- sum_ones = 3 if mode == 1 else 2 * n - 1
- sum_ones *= n * (n - 1)
- sum_ones //= 6
- res = sum_ones - res
- if freeze:
- a = frozenbitarray(a)
- c = a.copy()
- self.assertEqual(a.count(), n - k if inv else k)
- self.assertEqual(S(a, mode), res)
- self.assertEqual(a, c)
- class SSQI_Tests(SumIndicesUtil):
- # Additional tests for _ssqi() in: devel/test_sum_indices.py
- def test_explicit(self):
- self.check_explicit(_ssqi)
- def test_wrong_args(self):
- self.check_wrong_args(_ssqi)
- def test_small(self):
- a = bitarray()
- sm1 = sm2 = 0
- for i in range(100):
- v = getrandbits(1)
- a.append(v)
- if v:
- sm1 += i
- sm2 += i * i
- self.assertEqual(_ssqi(a, 1), sm1)
- self.assertEqual(_ssqi(a, 2), sm2)
- def test_urandom(self):
- self.check_urandom(_ssqi, 10_037)
- def test_sparse(self):
- for _ in range(5):
- mode = randint(1, 2)
- freeze = getrandbits(1)
- inv = getrandbits(1)
- self.check_sparse(_ssqi, n=1_000_003, k=400,
- mode=mode, freeze=freeze, inv=inv)
- class SumIndicesTests(SumIndicesUtil):
- # Additional tests in: devel/test_sum_indices.py
- def test_explicit(self):
- self.check_explicit(sum_indices)
- a = gen_primes(100)
- self.assertEqual(sum_indices(a, mode=1), 1_060)
- self.assertEqual(sum_indices(a, mode=2), 65_796)
- def test_wrong_args(self):
- self.check_wrong_args(sum_indices)
- def test_ones(self):
- for mode in 1, 2:
- self.check_sparse(sum_indices, n=1_600_037, k=0,
- mode=mode, freeze=True, inv=True)
- def test_sparse(self):
- for _ in range(20):
- n = choice([500_029, 600_011]) # below and above block size
- k = randrange(1_000)
- mode = randint(1, 2)
- freeze = getrandbits(1)
- inv = getrandbits(1)
- self.check_sparse(sum_indices, n, k, mode, freeze, inv)
- # ---------------------------------------------------------------------------
- class XoredIndicesTests(unittest.TestCase, Util):
- def test_explicit(self):
- for s, r in [("", 0), ("0", 0), ("1", 0), ("11", 1),
- ("011", 3), ("001", 2), ("0001100", 7),
- ("01100111 1101", 13)]:
- for a in [bitarray(s, self.random_endian()),
- frozenbitarray(s, self.random_endian())]:
- self.assertEqual(xor_indices(a), r)
- def test_wrong_args(self):
- X = xor_indices
- self.assertRaises(TypeError, X, '')
- self.assertRaises(TypeError, X, 1)
- self.assertRaises(TypeError, X)
- self.assertRaises(TypeError, X, bitarray("110"), 1)
- def test_ones(self):
- # OEIS A003815
- lst = [0, 1, 3, 0, 4, 1, 7, 0, 8, 1, 11, 0, 12, 1, 15, 0, 16, 1, 19]
- self.assertEqual([xor_indices(ones(i)) for i in range(1, 20)], lst)
- a = bitarray()
- x = 0
- for i in range(1000):
- a.append(1)
- x ^= i
- self.assertEqual(xor_indices(a), x)
- if i < 19:
- self.assertEqual(lst[i], x)
- def test_primes(self):
- # OEIS A126084
- lst = [0, 2, 1, 4, 3, 8, 5, 20, 7, 16, 13, 18, 55, 30, 53, 26, 47]
- primes = gen_primes(1000)
- x = 0
- for i, p in enumerate(primes.search(1)):
- self.assertEqual(xor_indices(primes[:p]), x)
- if i < 17:
- self.assertEqual(lst[i], x)
- x ^= p
- def test_large_random(self):
- n = 10_037
- for a in [urandom_2(n), frozenbitarray(urandom_2(n))]:
- res = reduce(operator.xor, (i for i, v in enumerate(a) if v))
- b = a.copy()
- self.assertEqual(xor_indices(a), res)
- self.assertEqual(a, b)
- def test_random(self):
- for a in self.randombitarrays():
- c = 0
- for i, v in enumerate(a):
- c ^= i * v
- self.assertEqual(xor_indices(a), c)
- def test_flips(self):
- a = bitarray(128)
- c = 0
- for _ in range(1000):
- self.assertEqual(xor_indices(a), c)
- i = randrange(len(a))
- a.invert(i)
- c ^= i
- def test_error_correct(self):
- parity_bits = [1, 2, 4, 8, 16, 32, 64, 128] # parity bit positions
- a = urandom(256)
- a[parity_bits] = 0
- c = xor_indices(a)
- # set parity bits such that block is well prepared
- a[parity_bits] = int2ba(c, length=8, endian="little")
- for i in range(0, 256):
- self.assertEqual(xor_indices(a), 0) # ensure well prepared
- a.invert(i)
- self.assertEqual(xor_indices(a), i) # index of the flipped bit!
- a.invert(i)
- # ------------------ intervals of uninterrupted runs --------------------
- def runs(a):
- "return number of uninterrupted intervals of 1s and 0s"
- n = len(a)
- if n < 2:
- return n
- return 1 + count_xor(a[:-1], a[1:])
- class IntervalsTests(unittest.TestCase, Util):
- def test_explicit(self):
- for s, lst in [
- ('', []),
- ('0', [(0, 0, 1)]),
- ('1', [(1, 0, 1)]),
- ('00111100 0000011',
- [(0, 0, 2), (1, 2, 6), (0, 6, 13), (1, 13, 15)]),
- ]:
- a = bitarray(s)
- self.assertEqual(list(intervals(a)), lst)
- self.assertEqual(runs(a), len(lst))
- def test_uniform(self):
- for n in range(1, 100):
- for v in 0, 1:
- a = n * bitarray([v], self.random_endian())
- self.assertEqual(list(intervals(a)), [(v, 0, n)])
- self.assertEqual(runs(a), 1)
- def test_random(self):
- for a in self.randombitarrays():
- n = len(a)
- b = urandom(n)
- for value, start, stop in intervals(a):
- self.assertFalse(isinstance(value, bool))
- self.assertTrue(0 <= start < stop <= n)
- b[start:stop] = value
- self.assertEqual(a, b)
- def test_list_runs(self):
- for a in self.randombitarrays():
- # list of length of runs of alternating bits
- alt_runs = [stop - start for _, start, stop in intervals(a)]
- self.assertEqual(len(alt_runs), runs(a))
- b = bitarray()
- v = a[0] if a else None # value of first run
- for length in alt_runs:
- self.assertTrue(length > 0)
- b.extend(length * bitarray([v]))
- v = not v
- self.assertEqual(a, b)
- # -------------------------- ba2hex() hex2ba() ---------------------------
- class HexlifyTests(unittest.TestCase, Util):
- def test_explicit(self):
- data = [ # little big
- ('', '', ''),
- ('1000', '1', '8'),
- ('0101 0110', 'a6', '56'),
- ('0100 1001 1101', '29b', '49d'),
- ('0000 1100 1110 1111', '037f', '0cef'),
- ]
- for bs, hex_le, hex_be in data:
- a_be = bitarray(bs, 'big')
- a_le = bitarray(bs, 'little')
- self.assertEQUAL(hex2ba(hex_be, 'big'), a_be)
- self.assertEQUAL(hex2ba(hex_le, 'little'), a_le)
- self.assertEqual(ba2hex(a_be), hex_be)
- self.assertEqual(ba2hex(a_le), hex_le)
- def test_ba2hex_group(self):
- a = bitarray('1000 0000 0101 1111', 'little')
- self.assertEqual(ba2hex(a), "10af")
- self.assertEqual(ba2hex(a, 0), "10af")
- self.assertEqual(ba2hex(a, 1, ""), "10af")
- self.assertEqual(ba2hex(a, 1), "1 0 a f")
- self.assertEqual(ba2hex(a, group=2), "10 af")
- self.assertEqual(ba2hex(a, 2, "-"), "10-af")
- self.assertEqual(ba2hex(a, group=3, sep="_"), "10a_f")
- self.assertEqual(ba2hex(a, 3, sep=", "), "10a, f")
- def test_ba2hex_errors(self):
- self.assertRaises(TypeError, ba2hex)
- self.assertRaises(TypeError, ba2hex, None)
- self.assertRaises(TypeError, ba2hex, '101')
- # length not multiple of 4
- self.assertRaises(ValueError, ba2hex, bitarray('10'))
- a = bitarray('1000 0000 0101 1111', 'little')
- self.assertRaises(ValueError, ba2hex, a, -1)
- self.assertRaises(ValueError, ba2hex, a, group=-1)
- # sep not str
- self.assertRaises(TypeError, ba2hex, a, 1, b" ")
- # embedded null character in sep
- self.assertRaises(ValueError, ba2hex, a, 2, " \0")
- def test_hex2ba_whitespace(self):
- _set_default_endian('big')
- self.assertEqual(hex2ba("F1 FA %s f3 c0" % whitespace),
- bitarray("11110001 11111010 11110011 11000000"))
- self.assertEQUAL(hex2ba(b' a F ', 'big'),
- bitarray('1010 1111', 'big'))
- self.assertEQUAL(hex2ba(860 * " " + '0 1D' + 590 * " ", 'little'),
- bitarray('0000 1000 1011', 'little'))
- def test_hex2ba_errors(self):
- self.assertRaises(TypeError, hex2ba, 0)
- self.assertRaises(TypeError, hex2ba, "F", 1)
- self.assertRaises(ValueError, hex2ba, "F", "foo")
- for s in '01a7g89', '0\u20ac', '0 \0', b'\x00':
- self.assertRaises(ValueError, hex2ba, s)
- for s in 'g', 'ag', 'aag' 'aaaga', 'ag':
- msg = "invalid digit found for base16, got 'g' (0x67)"
- self.assertRaisesMessage(ValueError, msg, hex2ba, s, 'big')
- def test_hex2ba_types(self):
- for c in 'e', 'E', b'e', b'E', bytearray(b'e'), bytearray(b'E'):
- a = hex2ba(c, "big")
- self.assertEqual(a.to01(), '1110')
- self.assertEqual(a.endian, 'big')
- self.assertEqual(type(a), bitarray)
- def test_random(self):
- for _ in range(100):
- default_endian = self.random_endian()
- _set_default_endian(default_endian)
- endian = choice(["little", "big", None])
- a = urandom_2(4 * randrange(100), endian)
- s = ba2hex(a, group=randrange(10), sep=choice(whitespace))
- b = hex2ba(s, endian)
- self.assertEqual(b.endian, endian or default_endian)
- self.assertEqual(a, b)
- self.check_obj(b)
- def test_hexdigits(self):
- a = hex2ba(hexdigits)
- self.assertEqual(len(a), 4 * len(hexdigits))
- self.assertEqual(type(a), bitarray)
- self.check_obj(a)
- t = ba2hex(a)
- self.assertEqual(t, hexdigits.lower())
- self.assertEqual(type(t), str)
- self.assertEQUAL(a, hex2ba(t))
- def test_binascii(self):
- a = urandom(80, 'big')
- s = binascii.hexlify(a.tobytes()).decode()
- self.assertEqual(ba2hex(a), s)
- b = bitarray(binascii.unhexlify(s), endian='big')
- self.assertEQUAL(hex2ba(s, 'big'), b)
- # -------------------------- ba2base() base2ba() -------------------------
- class BaseTests(unittest.TestCase, Util):
- def test_explicit(self):
- data = [ # n little big
- ('', 2, '', ''),
- ('1 0 1', 2, '101', '101'),
- ('11 01 00', 4, '320', '310'),
- ('111 001', 8, '74', '71'),
- ('1111 0001', 16, 'f8', 'f1'),
- ('11111 00001', 32, '7Q', '7B'),
- ('111111 000001', 64, '/g', '/B'),
- ]
- for bs, n, s_le, s_be in data:
- a_le = bitarray(bs, 'little')
- a_be = bitarray(bs, 'big')
- self.assertEQUAL(base2ba(n, s_le, 'little'), a_le)
- self.assertEQUAL(base2ba(n, s_be, 'big'), a_be)
- self.assertEqual(ba2base(n, a_le), s_le)
- self.assertEqual(ba2base(n, a_be), s_be)
- def test_base2ba_types(self):
- for c in '7', b'7', bytearray(b'7'):
- a = base2ba(32, c)
- self.assertEqual(a.to01(), '11111')
- self.assertEqual(type(a), bitarray)
- def test_base2ba_whitespace(self):
- self.assertEqual(base2ba(8, bytearray(b"17 0"), "little"),
- bitarray("100 111 000"))
- self.assertEqual(base2ba(32, "7 A"), bitarray("11111 00000"))
- self.assertEqual(base2ba(64, b"A /"), bitarray("000000 111111"))
- for n in 2, 4, 8, 16, 32, 64:
- a = base2ba(n, whitespace)
- self.assertEqual(a, bitarray())
- a = urandom(60)
- c = list(ba2base(n, a))
- for _ in range(randrange(80)):
- c.insert(randint(0, len(c)), choice(whitespace))
- s = ''.join(c)
- self.assertEqual(base2ba(n, s), a)
- def test_ba2base_group(self):
- a = bitarray("001 011 100 111", "little")
- self.assertEqual(ba2base(8, a, 3), "461 7")
- self.assertEqual(ba2base(8, a, group=2), "46 17")
- self.assertEqual(ba2base(8, a, sep="_", group=2), "46_17")
- self.assertEqual(ba2base(8, a, 2, sep="."), "46.17")
- for n, s, group, sep, res in [
- (2, '10100', 2, '-', '10-10-0'),
- (4, '10 11 00 01', 1, "_", "2_3_0_1"),
- (8, "101 100 011 101 001 010", 3, " ", "543 512"),
- (8, "101 100 011 101 001 010", 3, "", "543512"),
- (16, '1011 0001 1101 1010 1111', 4, "+", "b1da+f"),
- (32, "10110 00111 01101 01111", 2, ", ", "WH, NP"),
- (64, "101100 011101 101011 111110 101110", 2, ".", "sd.r+.u"),
- ]:
- a = bitarray(s, "big")
- s = ba2base(n, a, group, sep)
- self.assertEqual(type(s), str)
- self.assertEqual(s, res)
- def test_empty(self):
- for n in 2, 4, 8, 16, 32, 64:
- a = base2ba(n, '')
- self.assertEqual(a, bitarray())
- self.assertEqual(ba2base(n, a), '')
- def test_invalid_characters(self):
- for n, s in ((2, '2'), (4, '4'), (8, '8'), (16, 'g'), (32, '8'),
- (32, '1'), (32, 'a'), (64, '-'), (64, '_')):
- msg = ("invalid digit found for base%d, "
- "got '%s' (0x%02x)" % (n, s, ord(s)))
- self.assertRaisesMessage(ValueError, msg, base2ba, n, s)
- for n in 2, 4, 8, 16, 32, 64:
- for s in '_', '@', '[', '\u20ac', '\0', b'\0', b'\x80', b'\xff':
- self.assertRaises(ValueError, base2ba, n, s)
- msg = "invalid digit found for base%d, got '{' (0x7b)" % n
- self.assertRaisesMessage(ValueError, msg, base2ba, n, '{')
- def test_invalid_args(self):
- a = bitarray()
- self.assertRaises(TypeError, ba2base, None, a)
- self.assertRaises(TypeError, base2ba, None, '')
- self.assertRaises(TypeError, ba2base, 16.0, a)
- self.assertRaises(TypeError, base2ba, 16.0, '')
- self.assertRaises(TypeError, ba2base, 32, None)
- self.assertRaises(TypeError, base2ba, 32, None)
- for values, msg in [
- ([-1023, -16, -1, 0, 3, 5, 31, 48, 63, 129, 511, 4123],
- "base must be a power of 2"),
- ([1, 128, 256, 512, 1024, 2048, 4096, 8192],
- "base must be 2, 4, 8, 16, 32 or 64")]:
- for i in values:
- self.assertRaisesMessage(ValueError, msg, ba2base, i, a)
- self.assertRaisesMessage(ValueError, msg, base2ba, i, '')
- a = bitarray(29)
- for m in range(2, 7):
- msg = "bitarray length 29 not multiple of %d" % m
- self.assertRaisesMessage(ValueError, msg, ba2base, 1 << m, a)
- def test_hexadecimal(self):
- a = base2ba(16, 'F61', 'big')
- self.assertEqual(a, bitarray('1111 0110 0001'))
- self.assertEqual(ba2base(16, a), 'f61')
- for n in range(50):
- s = ''.join(choices(hexdigits, k=n))
- endian = self.random_endian()
- a = base2ba(16, s, endian)
- self.assertEQUAL(a, hex2ba(s, endian))
- self.assertEqual(ba2base(16, a), ba2hex(a))
- def test_base32(self):
- msg = os.urandom(randint(10, 100) * 5)
- s = base64.b32encode(msg).decode()
- a = base2ba(32, s, 'big')
- self.assertEqual(a.tobytes(), msg)
- self.assertEqual(ba2base(32, a), s)
- self.assertEqual(base64.b32decode(s), msg)
- def test_base64(self):
- msg = os.urandom(randint(10, 100) * 3)
- s = base64.standard_b64encode(msg).decode()
- a = base2ba(64, s, 'big')
- self.assertEqual(a.tobytes(), msg)
- self.assertEqual(ba2base(64, a), s)
- self.assertEqual(base64.standard_b64decode(s), msg)
- def test_primes(self):
- primes = gen_primes(60, odd=True)
- base_2 = primes.to01()
- for n, endian, rep in [
- ( 2, "little", base_2),
- ( 2, "big", base_2),
- ( 4, "little", "232132030132012122122010132110"),
- ( 4, "big", "131231030231021211211020231220"),
- ( 8, "little", "65554155441515405550"),
- ( 8, "big", "35551455114545105550"),
- (16, "little", "e6bc4b46a921d61"),
- (16, "big", "76d32d265948b68"),
- (32, "little", "O3SJLSJTSI3C"),
- (32, "big", "O3JS2JSZJC3I"),
- (64, "little", "utMtkppEtF"),
- (64, "big", "dtMtJllIto"),
- ]:
- a = bitarray(primes, endian)
- s = ba2base(n, a)
- self.assertEqual(type(s), str)
- self.assertEqual(s, rep)
- b = base2ba(n, rep, endian)
- self.assertEqual(b, a)
- self.assertEqual(type(b), bitarray)
- self.assertEqual(b.endian, endian)
- alphabets = [
- # m n alphabet
- (1, 2, '01'),
- (2, 4, '0123'),
- (3, 8, '01234567'),
- (4, 16, '0123456789abcdef'),
- (4, 16, '0123456789ABCDEF'),
- (5, 32, 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567'),
- (6, 64, 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdef'
- 'ghijklmnopqrstuvwxyz0123456789+/'),
- ]
- def test_alphabets(self):
- for m, n, alphabet in self.alphabets:
- self.assertEqual(1 << m, n)
- self.assertEqual(len(alphabet), n)
- for i, c in enumerate(alphabet):
- endian = self.random_endian()
- self.assertEqual(ba2int(base2ba(n, c, endian)), i)
- if m == 4 and c in "ABCDEF":
- c = chr(ord(c) + 32)
- self.assertEqual(ba2base(n, int2ba(i, m, endian)), c)
- def test_not_alphabets(self):
- for m, n, alphabet in self.alphabets:
- for i in range(256):
- c = chr(i)
- if c in alphabet or c.isspace():
- continue
- if n == 16 and c in hexdigits:
- continue
- self.assertRaises(ValueError, base2ba, n, c)
- def test_random(self):
- for _ in range(100):
- m = randint(1, 6)
- a = urandom_2(m * randrange(100))
- n = 1 << m
- s = ba2base(n, a, group=randrange(10), sep=randrange(5) * " ")
- if m == 4 and getrandbits(1):
- s = s.upper()
- if getrandbits(1):
- s = s.encode()
- b = base2ba(n, s, a.endian)
- self.assertEQUAL(a, b)
- self.check_obj(b)
- # --------------------------- sparse compression ----------------------------
- class SC_Tests(unittest.TestCase, Util):
- def test_explicit(self):
- for b, bits, endian in [
- (b'\x00\0', '', 'little'),
- (b'\x01\x03\x01\x03\0', '110', 'little'),
- (b'\x01\x07\x01\x40\0', '0000001', 'little'),
- (b'\x11\x07\x01\x02\0', '0000001', 'big'),
- (b'\x01\x10\x02\xf0\x0f\0', '00001111 11110000', 'little'),
- (b'\x11\x10\xa1\x0c\0', '00000000 00001000', 'big'),
- (b'\x11\x09\xa1\x08\0', '00000000 1', 'big'),
- (b'\x01g\xa4abde\0', 97 * '0' + '110110', 'little'),
- ]:
- a = bitarray(bits, endian)
- self.assertEqual(sc_encode(a), b)
- self.assertEQUAL(sc_decode(b), a)
- def test_encode_types(self):
- for a in bitarray('1', 'big'), frozenbitarray('1', 'big'):
- b = sc_encode(a)
- self.assertEqual(type(b), bytes)
- self.assertEqual(b, b'\x11\x01\x01\x80\0')
- for a in None, [], 0, 123, b'', b'\x00', 3.14:
- self.assertRaises(TypeError, sc_encode, a)
- def test_decode_types(self):
- blob = b'\x11\x03\x01\x20\0'
- for b in blob, bytearray(blob), list(blob), array.array('B', blob):
- a = sc_decode(b)
- self.assertEqual(type(a), bitarray)
- self.assertEqual(a.endian, 'big')
- self.assertEqual(a.to01(), '001')
- a = [17, 3, 1, 32, 0]
- self.assertEqual(sc_decode(a), bitarray("001"))
- for x in 256, -1:
- a[-1] = x
- self.assertRaises(ValueError, sc_decode, a)
- self.assertRaises(TypeError, sc_decode, [0x02, None])
- for x in None, 3, 3.2, Ellipsis, 'foo':
- self.assertRaises(TypeError, sc_decode, x)
- def test_decode_header_nbits(self):
- for b, n in [
- (b'\x00\0', 0),
- (b'\x01\x00\0', 0),
- (b'\x01\x01\0', 1),
- (b'\x02\x00\x00\0', 0),
- (b'\x02\x00\x01\0', 256),
- (b'\x03\x00\x00\x00\0', 0),
- (b'\x03\x00\x00\x01\0', 65536),
- ]:
- a = sc_decode(b)
- self.assertEqual(len(a), n)
- self.assertFalse(a.any())
- def test_decode_untouch(self):
- stream = iter(b'\x01\x03\x01\x03\0XYZ')
- self.assertEqual(sc_decode(stream), bitarray('110'))
- self.assertEqual(next(stream), ord('X'))
- stream = iter([0x11, 0x05, 0x01, 0xff, 0, None, 'foo'])
- self.assertEqual(sc_decode(stream), bitarray('11111'))
- self.assertTrue(next(stream) is None)
- self.assertEqual(next(stream), 'foo')
- def test_decode_header_errors(self):
- # invalid header
- for c in 0x20, 0x21, 0x40, 0x80, 0xc0, 0xf0, 0xff:
- self.assertRaisesMessage(ValueError,
- "invalid header: 0x%02x" % c,
- sc_decode, [c])
- # invalid block head
- for c in 0xc0, 0xc1, 0xc5, 0xff:
- self.assertRaisesMessage(ValueError,
- "invalid block head: 0x%02x" % c,
- sc_decode, [0x01, 0x10, c])
- def test_decode_header_overflow(self):
- self.assertRaisesMessage(
- OverflowError,
- "sizeof(Py_ssize_t) = %d: cannot read 9 bytes" % PTRSIZE,
- sc_decode, b'\x09' + 9 * b'\x00')
- self.assertRaisesMessage(
- ValueError,
- "read %d bytes got negative value: -1" % PTRSIZE,
- sc_decode, [PTRSIZE] + PTRSIZE * [0xff])
- if PTRSIZE == 4:
- self.assertRaisesMessage(
- OverflowError,
- "sizeof(Py_ssize_t) = 4: cannot read 5 bytes",
- sc_decode, b'\x05' + 5 * b'\x00')
- self.assertRaisesMessage(
- ValueError,
- "read 4 bytes got negative value: -2147483648",
- sc_decode, b'\x04\x00\x00\x00\x80')
- def test_decode_errors(self):
- # too many raw bytes
- self.assertRaisesMessage(
- ValueError, "decode error (raw): 0 + 2 > 1",
- sc_decode, b"\x01\x05\x02\xff\xff\0")
- self.assertRaisesMessage(
- ValueError, "decode error (raw): 32 + 3 > 34",
- sc_decode, b"\x02\x0f\x01\xa0\x03\xff\xff\xff\0")
- # sparse index too high
- self.assertRaisesMessage(
- ValueError, "decode error (n=1): 128 >= 128",
- sc_decode, b"\x01\x80\xa1\x80\0")
- self.assertRaisesMessage(
- ValueError, "decode error (n=2): 512 >= 512",
- sc_decode, b"\x02\x00\x02\xc2\x01\x00\x02\0")
- self.assertRaisesMessage(
- ValueError, "decode error (n=3): 32768 >= 32768",
- sc_decode, b"\x02\x00\x80\xc3\x01\x00\x80\x00\0")
- msg = {4: "read 4 bytes got negative value: -2147483648",
- 8: "decode error (n=4): 2147483648 >= 16"}
- self.assertRaisesMessage(
- ValueError, msg[PTRSIZE],
- sc_decode, b"\x01\x10\xc4\x01\x00\x00\x00\x80\0")
- msg = {4: "read 4 bytes got negative value: -1",
- 8: "decode error (n=4): 4294967295 >= 16"}
- self.assertRaisesMessage(
- ValueError, msg[PTRSIZE],
- sc_decode, b"\x01\x10\xc4\x01\xff\xff\xff\xff\0")
- def test_decode_end_of_stream(self):
- for stream in [b'', b'\x00', b'\x01', b'\x02\x77',
- b'\x01\x04\x01', b'\x01\x04\xa1', b'\x01\x04\xa0']:
- self.assertRaises(StopIteration, sc_decode, stream)
- def test_decode_ambiguity(self):
- for b in [
- # raw:
- b'\x11\x03\x01\x20\0', # this is what sc_encode gives us
- b'\x11\x03\x01\x3f\0', # but we can set the pad bits to 1
- # sparse:
- b'\x11\x03\xa1\x02\0', # block type 1
- b'\x11\x03\xc2\x01\x02\x00\0', # block type 2
- b'\x11\x03\xc3\x01\x02\x00\x00\0', # block type 3
- b'\x11\x03\xc4\x01\x02\x00\x00\x00\0', # block type 4
- ]:
- a = sc_decode(b)
- self.assertEqual(a.to01(), '001')
- def test_block_type0(self):
- for k in range(0x01, 0xa0):
- nbytes = k if k <= 32 else 32 * (k - 31)
- nbits = 8 * nbytes
- a = ones(nbits, "little")
- b = bytearray([0x01, nbits] if nbits < 256 else
- [0x02, nbits % 256, nbits // 256])
- b.append(k)
- b.extend(a.tobytes())
- b.append(0) # stop byte
- self.assertEqual(sc_decode(b), a)
- self.assertEqual(sc_encode(a), b)
- def test_block_type1(self):
- a = bitarray(256, 'little')
- for n in range(1, 32):
- a[getrandbits(8)] = 1
- b = bytearray([0x02, 0x00, 0x01, 0xa0 + a.count()])
- b.extend(list(a.search(1))) # sorted indices with no duplicates
- b.append(0) # stop byte
- self.assertEqual(sc_decode(b), a)
- self.assertEqual(sc_encode(a), b)
- def test_block_type2(self):
- a = bitarray(65536, 'little')
- for n in range(1, 256):
- a[getrandbits(16)] = 1
- b = bytearray([0x03, 0x00, 0x00, 0x01, 0xc2, a.count()])
- for i in a.search(1):
- b.extend(struct.pack("<H", i))
- b.append(0) # stop byte
- self.assertEqual(sc_decode(b), a)
- if n < 250:
- # We cannot compare for the highest populations, as for
- # such high values sc_encode() may find better compression
- # with type 1 blocks.
- self.assertEqual(sc_encode(a), b)
- else:
- self.assertTrue(len(sc_encode(a)) <= len(b))
- def test_block_type3(self):
- a = bitarray(16_777_216, 'little')
- a[choices(range(1 << 24), k=255)] = 1
- b = bytearray([0x04, 0x00, 0x00, 0x00, 0x01, 0xc3, a.count()])
- for i in a.search(1):
- b.extend(struct.pack("<I", i)[:3])
- b.append(0) # stop byte
- self.assertEqual(sc_decode(b), a)
- self.assertEqual(sc_encode(a), b)
- def test_block_type4(self):
- a = bitarray(1 << 26, 'little')
- # To understand why we cannot have a population larger than 5 for
- # an array size 4 times the size of a type 3 block, take a look
- # at the cost comparison in sc_encode_block(). (2 + 6 >= 2 * 4)
- indices = sorted(set(choices(range(len(a)), k=5)))
- a[indices] = 1
- b = bytearray(b'\x04\x00\x00\x00\x04\xc4')
- b.append(len(indices))
- for i in indices:
- b.extend(struct.pack("<I", i))
- b.append(0) # stop byte
- self.assertEqual(sc_decode(b), a)
- self.assertEqual(sc_encode(a), b)
- def test_decode_random_bytes(self):
- # ensure random input doesn't crash the decoder
- for _ in range(100):
- n = randrange(20)
- b = b'\x02\x00\x04' + os.urandom(n)
- try:
- a = sc_decode(b)
- except (StopIteration, ValueError):
- continue
- self.assertEqual(len(a), 1024)
- self.assertEqual(a.endian, 'little')
- def check_blob_length(self, a, m):
- blob = sc_encode(a)
- self.assertEqual(len(blob), m)
- self.assertEqual(sc_decode(blob), a)
- def test_encode_zeros(self):
- for i in range(26):
- n = 1 << i
- a = zeros(n)
- m = 2 # head byte and stop byte
- m += bits2bytes(n.bit_length()) # size of n in bytes
- self.check_blob_length(a, m)
- a[0] = 1
- m += 2 # block head byte and one index byte
- m += 2 * bool(i > 9) # count byte and second index byte
- m += bool(i > 16) # third index byte
- m += bool(i > 24) # fourth index byte
- self.check_blob_length(a, m)
- def test_encode_ones(self):
- for _ in range(10):
- nbits = randrange(100_000)
- a = ones(nbits)
- m = 2 # head byte and stop byte
- m += bits2bytes(nbits.bit_length()) # size bytes
- nbytes = bits2bytes(nbits)
- m += nbytes # actual raw bytes
- # number of head bytes, all of block type 0:
- m += bool(nbytes % 32) # number in 0x01 .. 0x1f
- m += (nbytes // 32 + 127) // 128 # number in 0x20 .. 0xbf
- self.check_blob_length(a, m)
- def round_trip(self, a):
- c = a.copy()
- i = iter(sc_encode(a))
- b = sc_decode(i)
- self.assertTrue(a == b == c)
- self.assertTrue(a.endian == b.endian == c.endian)
- self.assertEqual(list(i), [])
- def test_random(self):
- for _ in range(10):
- n = randrange(100_000)
- endian = self.random_endian()
- a = ones(n, endian)
- while a.count():
- a &= urandom(n, endian)
- self.round_trip(a)
- # ---------------------------------------------------------------------------
- class VLFTests(unittest.TestCase, Util):
- def test_explicit(self):
- for blob, s in [
- (b'\x40', ''),
- (b'\x30', '0'),
- (b'\x38', '1'),
- (b'\x00', '0000'),
- (b'\x01', '0001'),
- (b'\xd3\x20', '001101'),
- (b'\xe0\x40', '0000 1'),
- (b'\x90\x02', '0000 000001'),
- (b'\xb5\xa7\x18', '0101 0100111 0011'),
- (b'\x95\xb7\x1c', '0101 0110111 001110'),
- ]:
- default_endian = self.random_endian()
- _set_default_endian(default_endian)
- a = bitarray(s)
- self.assertEqual(vl_encode(a), blob)
- c = vl_decode(blob)
- self.assertEqual(c, a)
- self.assertEqual(c.endian, default_endian)
- for endian in 'big', 'little', None:
- a = bitarray(s, endian)
- c = vl_encode(a)
- self.assertEqual(type(c), bytes)
- self.assertEqual(c, blob)
- c = vl_decode(blob, endian)
- self.assertEqual(c, a)
- self.assertEqual(c.endian, endian or default_endian)
- def test_encode_types(self):
- s = "0011 01"
- for a in bitarray(s), frozenbitarray(s):
- b = vl_encode(a)
- self.assertEqual(type(b), bytes)
- self.assertEqual(b, b'\xd3\x20')
- for a in None, [], 0, 123, b'', b'\x00', 3.14:
- self.assertRaises(TypeError, vl_encode, a)
- def test_decode_types(self):
- blob = b'\xd3\x20'
- for s in (blob, iter(blob), memoryview(blob), iter([0xd3, 0x20]),
- bytearray(blob)):
- a = vl_decode(s, endian=self.random_endian())
- self.assertEqual(type(a), bitarray)
- self.assertEqual(a, bitarray('0011 01'))
- # these objects are not iterable
- for arg in None, 0, 1, 0.0:
- self.assertRaises(TypeError, vl_decode, arg)
- # these items cannot be interpreted as ints
- for item in None, 2.34, Ellipsis, 'foo':
- self.assertRaises(TypeError, vl_decode, iter([0x95, item]))
- def test_decode_args(self):
- # item not integer
- self.assertRaises(TypeError, vl_decode, iter([b'\x40']))
- self.assertRaises(TypeError, vl_decode, b'\x40', 'big', 3)
- self.assertRaises(ValueError, vl_decode, b'\x40', 'foo')
- def test_decode_trailing(self):
- for s, bits in [(b'\x40ABC', ''),
- (b'\xe0\x40A', '00001')]:
- stream = iter(s)
- self.assertEqual(vl_decode(stream), bitarray(bits))
- self.assertEqual(next(stream), 65)
- def test_decode_ambiguity(self):
- for s in b'\x40', b'\x4f', b'\x45':
- self.assertEqual(vl_decode(s), bitarray())
- for s in b'\x1e', b'\x1f':
- self.assertEqual(vl_decode(s), bitarray('111'))
- def test_decode_stream(self):
- stream = iter(b'\x40\x30\x38\x40\x2c\xe0\x40\xd3\x20')
- for bits in '', '0', '1', '', '11', '0000 1', '0011 01':
- self.assertEqual(vl_decode(stream), bitarray(bits))
- arrays = [urandom(randrange(30)) for _ in range(1000)]
- stream = iter(b''.join(vl_encode(a) for a in arrays))
- for a in arrays:
- self.assertEqual(vl_decode(stream), a)
- def test_decode_errors(self):
- # decode empty bytes
- self.assertRaises(StopIteration, vl_decode, b'')
- # invalid head byte
- for s in [
- b'\x70', b'\xf0', # padding = 7
- b'\x50', b'\x60', b'\x70', # no second byte, but padding > 4
- ]:
- self.assertRaisesMessage(ValueError,
- "invalid head byte: 0x%02x" % s[0],
- vl_decode, s)
- # high bit set, but no terminating byte
- for s in b'\x80', b'\x80\x80':
- self.assertRaises(StopIteration, vl_decode, s)
- # decode list with out of range items
- for i in -1, 256:
- self.assertRaises(ValueError, vl_decode, [i])
- # wrong type
- self.assertRaises(TypeError, vl_decode, [None])
- def test_decode_invalid_stream(self):
- N = 100
- s = iter(N * (3 * [0x80] + ['XX']) + ['end.'])
- for _ in range(N):
- a = None
- try:
- a = vl_decode(s)
- except TypeError:
- pass
- self.assertTrue(a is None)
- self.assertEqual(next(s), 'end.')
- def test_explicit_zeros(self):
- for n in range(100):
- a = zeros(4 + n * 7)
- s = n * b'\x80' + b'\x00'
- self.assertEqual(vl_encode(a), s)
- self.assertEqual(vl_decode(s), a)
- def round_trip(self, a):
- c = a.copy()
- s = vl_encode(a)
- b = vl_decode(s)
- self.check_obj(b)
- self.assertTrue(a == b == c)
- LEN_PAD_BITS = 3
- self.assertEqual(len(s), (len(a) + LEN_PAD_BITS + 6) // 7)
- head = s[0]
- padding = (head & 0x70) >> 4
- self.assertEqual(len(a) + padding, 7 * len(s) - LEN_PAD_BITS)
- def test_large(self):
- for _ in range(10):
- a = urandom(randrange(100_000))
- self.round_trip(a)
- def test_random(self):
- for a in self.randombitarrays():
- self.round_trip(a)
- # ---------------------------------------------------------------------------
- class IntegerizationTests(unittest.TestCase, Util):
- def test_ba2int(self):
- self.assertEqual(ba2int(bitarray('0')), 0)
- self.assertEqual(ba2int(bitarray('1')), 1)
- self.assertEqual(ba2int(bitarray('00101', 'big')), 5)
- self.assertEqual(ba2int(bitarray('00101', 'little')), 20)
- self.assertEqual(ba2int(frozenbitarray('11')), 3)
- self.assertRaises(ValueError, ba2int, bitarray())
- self.assertRaises(ValueError, ba2int, frozenbitarray())
- self.assertRaises(TypeError, ba2int, '101')
- a = bitarray('111')
- b = a.copy()
- self.assertEqual(ba2int(a), 7)
- # ensure original object wasn't altered
- self.assertEQUAL(a, b)
- def test_ba2int_frozen(self):
- for a in self.randombitarrays(start=1):
- b = frozenbitarray(a)
- self.assertEqual(ba2int(b), ba2int(a))
- self.assertEQUAL(a, b)
- def test_ba2int_random(self):
- for a in self.randombitarrays(start=1):
- b = bitarray(a, 'big')
- self.assertEqual(a, b)
- self.assertEqual(ba2int(b), int(b.to01(), 2))
- def test_ba2int_bytes(self):
- for n in range(1, 50):
- a = urandom_2(8 * n)
- c = bytearray(a.tobytes())
- i = 0
- for x in (c if a.endian == 'big' else reversed(c)):
- i <<= 8
- i |= x
- self.assertEqual(ba2int(a), i)
- def test_int2ba(self):
- self.assertEqual(int2ba(0), bitarray('0'))
- self.assertEqual(int2ba(1), bitarray('1'))
- self.assertEqual(int2ba(5), bitarray('101'))
- self.assertEQUAL(int2ba(6, endian='big'), bitarray('110', 'big'))
- self.assertEQUAL(int2ba(6, endian='little'),
- bitarray('011', 'little'))
- self.assertRaises(TypeError, int2ba, 1.0)
- self.assertRaises(TypeError, int2ba, 1, 3.0)
- self.assertRaises(ValueError, int2ba, 1, 0)
- self.assertRaises(TypeError, int2ba, 1, 10, 123)
- self.assertRaises(ValueError, int2ba, 1, 10, 'asd')
- # signed integer requires length
- self.assertRaises(TypeError, int2ba, 100, signed=True)
- def test_signed(self):
- for s, i in [
- ('0', 0),
- ('1', -1),
- ('00', 0),
- ('10', 1),
- ('01', -2),
- ('11', -1),
- ('000', 0),
- ('100', 1),
- ('010', 2),
- ('110', 3),
- ('001', -4),
- ('101', -3),
- ('011', -2),
- ('111', -1),
- ('00000', 0),
- ('11110', 15),
- ('00001', -16),
- ('11111', -1),
- ('00000000 0', 0),
- ('11111111 0', 255),
- ('00000000 1', -256),
- ('11111111 1', -1),
- ]:
- self.assertEqual(ba2int(bitarray(s, 'little'), signed=1), i)
- self.assertEqual(ba2int(bitarray(s[::-1], 'big'), signed=1), i)
- len_s = len(bitarray(s))
- self.assertEQUAL(int2ba(i, len_s, 'little', signed=1),
- bitarray(s, 'little'))
- self.assertEQUAL(int2ba(i, len_s, 'big', signed=1),
- bitarray(s[::-1], 'big'))
- def test_zero(self):
- for endian in "little", "big":
- a = int2ba(0, endian=endian)
- self.assertEQUAL(a, bitarray('0', endian=endian))
- for n in range(1, 100):
- a = int2ba(0, length=n, endian=endian, signed=True)
- b = bitarray(n * '0', endian)
- self.assertEQUAL(a, b)
- for signed in 0, 1:
- self.assertEqual(ba2int(b, signed=signed), 0)
- def test_negative_one(self):
- for endian in "little", "big":
- for n in range(1, 100):
- a = int2ba(-1, length=n, endian=endian, signed=True)
- b = bitarray(n * '1', endian)
- self.assertEQUAL(a, b)
- self.assertEqual(ba2int(b, signed=True), -1)
- def test_int2ba_overflow(self):
- self.assertRaises(OverflowError, int2ba, -1)
- self.assertRaises(OverflowError, int2ba, -1, 4)
- self.assertRaises(OverflowError, int2ba, 128, 7)
- self.assertRaises(OverflowError, int2ba, 64, 7, signed=1)
- self.assertRaises(OverflowError, int2ba, -65, 7, signed=1)
- for n in range(1, 20):
- self.assertRaises(OverflowError, int2ba, 1 << n, n)
- self.assertRaises(OverflowError, int2ba, 1 << (n - 1), n,
- signed=1)
- self.assertRaises(OverflowError, int2ba, -(1 << (n - 1)) - 1, n,
- signed=1)
- def test_int2ba_length(self):
- self.assertRaises(TypeError, int2ba, 0, 1.0)
- self.assertRaises(ValueError, int2ba, 0, 0)
- self.assertEqual(int2ba(5, length=6, endian='big'),
- bitarray('000101'))
- for n in range(1, 100):
- ab = int2ba(1, n, 'big')
- al = int2ba(1, n, 'little')
- self.assertEqual(ab.endian, 'big')
- self.assertEqual(al.endian, 'little')
- self.assertEqual(len(ab), n),
- self.assertEqual(len(al), n)
- self.assertEqual(ab, bitarray((n - 1) * '0') + bitarray('1'))
- self.assertEqual(al, bitarray('1') + bitarray((n - 1) * '0'))
- ab = int2ba(0, n, 'big')
- al = int2ba(0, n, 'little')
- self.assertEqual(len(ab), n)
- self.assertEqual(len(al), n)
- self.assertEqual(ab, bitarray(n * '0', 'big'))
- self.assertEqual(al, bitarray(n * '0', 'little'))
- self.assertEqual(int2ba(2 ** n - 1), bitarray(n * '1'))
- self.assertEqual(int2ba(2 ** n - 1, endian='little'),
- bitarray(n * '1'))
- def test_explicit(self):
- _set_default_endian('big')
- for i, sa in [( 0, '0'), (1, '1'),
- ( 2, '10'), (3, '11'),
- (25, '11001'), (265, '100001001'),
- (3691038, '1110000101001000011110')]:
- ab = bitarray(sa, 'big')
- al = bitarray(sa[::-1], 'little')
- self.assertEQUAL(int2ba(i), ab)
- self.assertEQUAL(int2ba(i, endian='big'), ab)
- self.assertEQUAL(int2ba(i, endian='little'), al)
- self.assertEqual(ba2int(ab), ba2int(al), i)
- def check_round_trip(self, i):
- for endian in 'big', 'little':
- a = int2ba(i, endian=endian)
- self.check_obj(a)
- self.assertEqual(a.endian, endian)
- self.assertTrue(len(a) > 0)
- # ensure we have no leading zeros
- if a.endian == 'big':
- self.assertTrue(len(a) == 1 or a.index(1) == 0)
- self.assertEqual(ba2int(a), i)
- if i > 0:
- self.assertEqual(i.bit_length(), len(a))
- # add a few trailing / leading zeros to bitarray
- if endian == 'big':
- a = zeros(randrange(4), endian) + a
- else:
- a = a + zeros(randrange(4), endian)
- self.assertEqual(a.endian, endian)
- self.assertEqual(ba2int(a), i)
- def test_many(self):
- for _ in range(20):
- self.check_round_trip(randrange(10 ** randint(3, 300)))
- @staticmethod
- def twos_complement(i, num_bits):
- # https://en.wikipedia.org/wiki/Two%27s_complement
- mask = 2 ** (num_bits - 1)
- return -(i & mask) + (i & ~mask)
- def test_random_signed(self):
- for a in self.randombitarrays(start=1):
- i = ba2int(a, signed=True)
- b = int2ba(i, len(a), a.endian, signed=True)
- self.assertEQUAL(a, b)
- j = ba2int(a, signed=False) # unsigned
- if i >= 0:
- self.assertEqual(i, j)
- self.assertEqual(i, self.twos_complement(j, len(a)))
- # ---------------------------------------------------------------------------
- class MixedTests(unittest.TestCase, Util):
- def test_bin(self):
- for _ in range(20):
- i = randrange(1000)
- s = bin(i)
- self.assertEqual(s[:2], '0b')
- a = bitarray(s[2:], 'big')
- self.assertEqual(ba2int(a), i)
- t = a.to01()
- self.assertEqual(t, s[2:])
- self.assertEqual(int(t, 2), i)
- def test_oct(self):
- for _ in range(20):
- i = randrange(1000)
- s = oct(i)
- self.assertEqual(s[:2], '0o')
- a = base2ba(8, s[2:], 'big')
- self.assertEqual(ba2int(a), i)
- t = ba2base(8, a)
- self.assertEqual(t, s[2:])
- self.assertEqual(int(t, 8), i)
- def test_hex(self):
- for _ in range(20):
- i = randrange(1000)
- s = hex(i)
- self.assertEqual(s[:2], '0x')
- a = hex2ba(s[2:], 'big')
- self.assertEqual(ba2int(a), i)
- t = ba2hex(a)
- self.assertEqual(t, s[2:])
- self.assertEqual(int(t, 16), i)
- def test_bitwise(self):
- for a in self.randombitarrays(start=1):
- b = urandom(len(a), a.endian)
- aa = a.copy()
- bb = b.copy()
- i = ba2int(a)
- j = ba2int(b)
- self.assertEqual(ba2int(a & b), i & j)
- self.assertEqual(ba2int(a | b), i | j)
- self.assertEqual(ba2int(a ^ b), i ^ j)
- n = randint(0, len(a))
- if a.endian == 'big':
- self.assertEqual(ba2int(a >> n), i >> n)
- c = zeros(len(a), 'big') + a
- self.assertEqual(ba2int(c << n), i << n)
- self.assertEQUAL(a, aa)
- self.assertEQUAL(b, bb)
- def test_bitwise_inplace(self):
- for a in self.randombitarrays(start=1):
- b = urandom(len(a), a.endian)
- bb = b.copy()
- i = ba2int(a)
- j = ba2int(b)
- c = a.copy()
- c &= b
- self.assertEqual(ba2int(c), i & j)
- c = a.copy()
- c |= b
- self.assertEqual(ba2int(c), i | j)
- c = a.copy()
- c ^= b
- self.assertEqual(ba2int(c), i ^ j)
- self.assertEQUAL(b, bb)
- n = randint(0, len(a))
- if a.endian == 'big':
- c = a.copy()
- c >>= n
- self.assertEqual(ba2int(c), i >> n)
- c = zeros(len(a), 'big') + a
- c <<= n
- self.assertEqual(ba2int(c), i << n)
- # ---------------------- serialize() deserialize() -----------------------
- class SerializationTests(unittest.TestCase, Util):
- def test_explicit(self):
- for blob, endian, bits in [
- (b'\x00', 'little', ''),
- (b'\x07\x01', 'little', '1'),
- (b'\x17\x80', 'big', '1'),
- (b'\x13\xf8', 'big', '11111'),
- (b'\x00\x0f', 'little', '11110000'),
- (b'\x10\xf0', 'big', '11110000'),
- (b'\x12\x87\xd8', 'big', '10000111 110110')
- ]:
- a = bitarray(bits, endian)
- s = serialize(a)
- self.assertEqual(blob, s)
- self.assertEqual(type(s), bytes)
- b = deserialize(blob)
- self.assertEqual(b, a)
- self.assertEqual(b.endian, endian)
- self.assertEqual(type(b), bitarray)
- def test_serialize_args(self):
- for x in '0', 0, 1, b'\x00', 0.0, [0, 1], bytearray([0]):
- self.assertRaises(TypeError, serialize, x)
- # no arguments
- self.assertRaises(TypeError, serialize)
- # too many arguments
- self.assertRaises(TypeError, serialize, bitarray(), 1)
- for a in bitarray('0111', 'big'), frozenbitarray('0111', 'big'):
- self.assertEqual(serialize(a), b'\x14\x70')
- def test_deserialize_args(self):
- for x in 0, 1, False, True, None, '', '01', 0.0, [0, 1]:
- self.assertRaises(TypeError, deserialize, x)
- # no arguments
- self.assertRaises(TypeError, deserialize)
- # too many arguments
- self.assertRaises(TypeError, deserialize, b'\x00', 1)
- blob = b'\x03\x06'
- x = bitarray(blob)
- for s in blob, bytearray(blob), memoryview(blob), x:
- a = deserialize(s)
- self.assertEqual(a.to01(), '01100')
- self.assertEqual(a.endian, 'little')
- def test_invalid_bytes(self):
- self.assertRaises(ValueError, deserialize, b'')
- def check_msg(b):
- msg = "invalid header byte: 0x%02x" % b[0]
- self.assertRaisesMessage(ValueError, msg, deserialize, b)
- for i in range(256):
- b = bytearray([i])
- if i == 0 or i == 16:
- self.assertEqual(deserialize(b), bitarray())
- else:
- self.assertRaises(ValueError, deserialize, b)
- check_msg(b)
- b.append(0)
- if i < 32 and i % 16 < 8:
- self.assertEqual(deserialize(b), zeros(8 - i % 8))
- else:
- self.assertRaises(ValueError, deserialize, b)
- check_msg(b)
- def test_padbits_ignored(self):
- for blob, endian in [
- (b'\x07\x01', 'little'),
- (b'\x07\x03', 'little'),
- (b'\x07\xff', 'little'),
- (b'\x17\x80', 'big'),
- (b'\x17\xc0', 'big'),
- (b'\x17\xff', 'big'),
- ]:
- a = deserialize(blob)
- self.assertEqual(a.to01(), '1')
- self.assertEqual(a.endian, endian)
- def test_random(self):
- for a in self.randombitarrays():
- b = serialize(a)
- c = deserialize(b)
- self.assertEqual(a, c)
- self.assertEqual(a.endian, c.endian)
- self.check_obj(c)
- # ---------------------------------------------------------------------------
- class HuffmanTreeTests(unittest.TestCase): # tests for _huffman_tree()
- def test_empty(self):
- freq = {}
- self.assertRaises(IndexError, _huffman_tree, freq)
- def test_one_symbol(self):
- freq = {"A": 1}
- tree = _huffman_tree(freq)
- self.assertEqual(tree.symbol, "A")
- self.assertEqual(tree.freq, 1)
- self.assertRaises(AttributeError, getattr, tree, 'child')
- def test_two_symbols(self):
- freq = {"A": 1, "B": 1}
- tree = _huffman_tree(freq)
- self.assertRaises(AttributeError, getattr, tree, 'symbol')
- self.assertEqual(tree.freq, 2)
- self.assertEqual(tree.child[0].symbol, "A")
- self.assertEqual(tree.child[0].freq, 1)
- self.assertEqual(tree.child[1].symbol, "B")
- self.assertEqual(tree.child[1].freq, 1)
- class HuffmanTests(unittest.TestCase):
- def test_simple(self):
- freq = {0: 10, 'as': 2, None: 1.6}
- code = huffman_code(freq)
- self.assertEqual(len(code), 3)
- self.assertEqual(len(code[0]), 1)
- self.assertEqual(len(code['as']), 2)
- self.assertEqual(len(code[None]), 2)
- def test_endianness(self):
- freq = {'A': 10, 'B': 2, 'C': 5}
- for endian in 'big', 'little':
- code = huffman_code(freq, endian)
- self.assertEqual(len(code), 3)
- for v in code.values():
- self.assertEqual(v.endian, endian)
- def test_wrong_arg(self):
- self.assertRaises(TypeError, huffman_code, [('a', 1)])
- self.assertRaises(TypeError, huffman_code, 123)
- self.assertRaises(TypeError, huffman_code, None)
- # cannot compare 'a' with 1
- self.assertRaises(TypeError, huffman_code, {'A': 'a', 'B': 1})
- # frequency map cannot be empty
- self.assertRaises(ValueError, huffman_code, {})
- def test_one_symbol(self):
- cnt = {'a': 1}
- code = huffman_code(cnt)
- self.assertEqual(code, {'a': bitarray('0')})
- for n in range(4):
- msg = n * ['a']
- a = bitarray()
- a.encode(code, msg)
- self.assertEqual(a.to01(), n * '0')
- self.assertEqual(list(a.decode(code)), msg)
- a.append(1)
- self.assertRaises(ValueError, list, a.decode(code))
- def check_tree(self, code):
- n = len(code)
- tree = decodetree(code)
- self.assertEqual(tree.todict(), code)
- # ensure tree has 2n-1 nodes (n symbol nodes and n-1 internal nodes)
- self.assertEqual(tree.nodes(), 2 * n - 1)
- # a proper Huffman tree is complete
- self.assertTrue(tree.complete())
- def test_balanced(self):
- n = 6
- freq = {}
- for i in range(1 << n):
- freq[i] = 1
- code = huffman_code(freq)
- self.assertEqual(len(code), 1 << n)
- self.assertTrue(all(len(v) == n for v in code.values()))
- self.check_tree(code)
- def test_unbalanced(self):
- n = 27
- freq = {}
- for i in range(n):
- freq[i] = 1 << i
- code = huffman_code(freq)
- self.assertEqual(len(code), n)
- for i in range(n):
- self.assertEqual(len(code[i]), n - max(1, i))
- self.check_tree(code)
- def test_counter(self):
- message = 'the quick brown fox jumps over the lazy dog.'
- code = huffman_code(Counter(message))
- a = bitarray()
- a.encode(code, message)
- self.assertEqual(''.join(a.decode(code)), message)
- self.check_tree(code)
- def test_random_list(self):
- plain = choices(range(100), k=500)
- code = huffman_code(Counter(plain))
- a = bitarray()
- a.encode(code, plain)
- self.assertEqual(list(a.decode(code)), plain)
- self.check_tree(code)
- def test_random_freq(self):
- for n in 2, 3, 4, randint(5, 200):
- # create Huffman code for n symbols
- code = huffman_code({i: random() for i in range(n)})
- self.check_tree(code)
- # ---------------------------------------------------------------------------
- class CanonicalHuffmanTests(unittest.TestCase, Util):
- def test_basic(self):
- plain = bytearray(b'the quick brown fox jumps over the lazy dog.')
- chc, count, symbol = canonical_huffman(Counter(plain))
- self.assertEqual(type(chc), dict)
- self.assertEqual(type(count), list)
- self.assertEqual(type(symbol), list)
- a = bitarray()
- a.encode(chc, plain)
- self.assertEqual(bytearray(a.decode(chc)), plain)
- self.assertEqual(bytearray(canonical_decode(a, count, symbol)), plain)
- def test_example(self):
- cnt = {'a': 5, 'b': 3, 'c': 1, 'd': 1, 'r': 2}
- codedict, count, symbol = canonical_huffman(cnt)
- self.assertEqual(codedict, {'a': bitarray('0'),
- 'b': bitarray('10'),
- 'c': bitarray('1110'),
- 'd': bitarray('1111'),
- 'r': bitarray('110')})
- self.assertEqual(count, [0, 1, 1, 1, 2])
- self.assertEqual(symbol, ['a', 'b', 'r', 'c', 'd'])
- a = bitarray('01011001110011110101100')
- msg = "abracadabra"
- self.assertEqual(''.join(a.decode(codedict)), msg)
- self.assertEqual(''.join(canonical_decode(a, count, symbol)), msg)
- def test_canonical_huffman_errors(self):
- self.assertRaises(TypeError, canonical_huffman, [])
- # frequency map cannot be empty
- self.assertRaises(ValueError, canonical_huffman, {})
- self.assertRaises(TypeError, canonical_huffman)
- cnt = huffman_code(Counter('aabc'))
- self.assertRaises(TypeError, canonical_huffman, cnt, 'a')
- def test_one_symbol(self):
- cnt = {'a': 1}
- chc, count, symbol = canonical_huffman(cnt)
- self.assertEqual(chc, {'a': bitarray('0')})
- self.assertEqual(count, [0, 1])
- self.assertEqual(symbol, ['a'])
- for n in range(4):
- msg = n * ['a']
- a = bitarray()
- a.encode(chc, msg)
- self.assertEqual(a.to01(), n * '0')
- self.assertEqual(list(canonical_decode(a, count, symbol)), msg)
- a.append(1)
- self.assertRaises(ValueError, list,
- canonical_decode(a, count, symbol))
- def test_canonical_decode_errors(self):
- a = bitarray('1101')
- s = ['a']
- # bitarray not of bitarray type
- self.assertRaises(TypeError, canonical_decode, '11', [0, 1], s)
- # count not sequence
- self.assertRaises(TypeError, canonical_decode, a, {0, 1}, s)
- # count element not an int
- self.assertRaises(TypeError, canonical_decode, a, [0, 1.0], s)
- # count element overflow
- self.assertRaises(OverflowError, canonical_decode, a, [0, 1 << 65], s)
- # symbol not sequence
- self.assertRaises(TypeError, canonical_decode, a, [0, 1], 43)
- symbol = ['a', 'b', 'c', 'd']
- # sum(count) != len(symbol)
- self.assertRaisesMessage(ValueError,
- "sum(count) = 3, but len(symbol) = 4",
- canonical_decode, a, [0, 1, 2], symbol)
- # count list too long
- self.assertRaisesMessage(ValueError,
- "len(count) cannot be larger than 32",
- canonical_decode, a, 33 * [0], symbol)
- def test_canonical_decode_count_range(self):
- a = bitarray()
- for i in range(1, 32):
- count = 32 * [0]
- # negative count
- count[i] = -1
- self.assertRaisesMessage(ValueError,
- "count[%d] not in [0..%d], got -1" % (i, 1 << i),
- canonical_decode, a, count, [])
- maxbits = 1 << i
- count[i] = maxbits
- if i == 31 and PTRSIZE == 4:
- self.assertRaises(OverflowError,
- canonical_decode, a, count, [])
- continue
- self.assertRaisesMessage(ValueError,
- "sum(count) = %d, but len(symbol) = 0" % maxbits,
- canonical_decode, a, count, [])
- count[i] = maxbits + 1
- self.assertRaisesMessage(ValueError,
- "count[%d] not in [0..%d], got %d" % (i, maxbits, count[i]),
- canonical_decode, a, count, [])
- iter = canonical_decode(a, 32 * [0], [])
- self.assertEqual(list(iter), [])
- def test_canonical_decode_simple(self):
- # symbols can be anything, they do not even have to be hashable here
- cnt = [0, 0, 4]
- s = ['A', 42, [1.2-3.7j, 4j], {'B': 6}]
- a = bitarray('00 01 10 11')
- # count can be a list
- self.assertEqual(list(canonical_decode(a, cnt, s)), s)
- # count can also be a tuple (any sequence object in fact)
- self.assertEqual(list(canonical_decode(a, (0, 0, 4), s)), s)
- self.assertEqual(list(canonical_decode(7 * a, cnt, s)), 7 * s)
- # the count list may have extra 0's at the end (but not too many)
- count = [0, 0, 4, 0, 0, 0, 0, 0]
- self.assertEqual(list(canonical_decode(a, count, s)), s)
- # the element count[0] is unused
- self.assertEqual(list(canonical_decode(a, [-47, 0, 4], s)), s)
- # in fact it can be anything, as it is entirely ignored
- self.assertEqual(list(canonical_decode(a, [None, 0, 4], s)), s)
- # the symbol argument can be any sequence object
- s = [65, 66, 67, 98]
- self.assertEqual(list(canonical_decode(a, cnt, s)), s)
- self.assertEqual(list(canonical_decode(a, cnt, bytearray(s))), s)
- self.assertEqual(list(canonical_decode(a, cnt, tuple(s))), s)
- self.assertEqual(list(canonical_decode(a, cnt, bytes(s))), s)
- # Implementation Note:
- # The symbol can even be an iterable. This was done because we
- # want to use PySequence_Fast in order to convert sequence
- # objects (like bytes and bytearray) to a list. This is faster
- # as all objects are now elements in an array of pointers (as
- # opposed to having the object's __getitem__ method called on
- # every iteration).
- self.assertEqual(list(canonical_decode(a, cnt, iter(s))), s)
- def test_canonical_decode_empty(self):
- a = bitarray()
- # count and symbol are empty, ok because sum([]) == len([])
- self.assertEqual(list(canonical_decode(a, [], [])), [])
- a.append(0)
- self.assertRaisesMessage(ValueError, "reached end of bitarray",
- list, canonical_decode(a, [], []))
- a = bitarray(31 * '0')
- self.assertRaisesMessage(ValueError, "ran out of codes",
- list, canonical_decode(a, [], []))
- def test_canonical_decode_one_symbol(self):
- symbols = ['A']
- count = [0, 1]
- a = bitarray('000')
- self.assertEqual(list(canonical_decode(a, count, symbols)),
- 3 * symbols)
- a.append(1)
- a.extend(bitarray(10 * '0'))
- iterator = canonical_decode(a, count, symbols)
- self.assertRaisesMessage(ValueError, "reached end of bitarray",
- list, iterator)
- a.extend(bitarray(20 * '0'))
- iterator = canonical_decode(a, count, symbols)
- self.assertRaisesMessage(ValueError, "ran out of codes",
- list, iterator)
- def test_canonical_decode_large(self):
- with open(__file__, 'rb') as f:
- msg = bytearray(f.read())
- self.assertTrue(len(msg) > 50000)
- codedict, count, symbol = canonical_huffman(Counter(msg))
- a = bitarray()
- a.encode(codedict, msg)
- self.assertEqual(bytearray(canonical_decode(a, count, symbol)), msg)
- self.check_code(codedict, count, symbol)
- def test_canonical_decode_symbol_change(self):
- msg = bytearray(b"Hello World!")
- codedict, count, symbol = canonical_huffman(Counter(msg))
- self.check_code(codedict, count, symbol)
- a = bitarray()
- a.encode(codedict, 10 * msg)
- it = canonical_decode(a, count, symbol)
- def decode_one_msg():
- return bytearray(next(it) for _ in range(len(msg)))
- self.assertEqual(decode_one_msg(), msg)
- symbol[symbol.index(ord("l"))] = ord("k")
- self.assertEqual(decode_one_msg(), bytearray(b"Hekko Workd!"))
- del symbol[:]
- self.assertRaises(IndexError, decode_one_msg)
- def ensure_sorted(self, chc, symbol):
- # ensure codes are sorted
- for i in range(len(symbol) - 1):
- a = chc[symbol[i]]
- b = chc[symbol[i + 1]]
- self.assertTrue(ba2int(a) < ba2int(b))
- def ensure_consecutive(self, chc, count, symbol):
- start = 0
- for nbits, cnt in enumerate(count):
- for i in range(start, start + cnt - 1):
- # ensure two consecutive codes (with same bit length) have
- # consecutive integer values
- a = chc[symbol[i]]
- b = chc[symbol[i + 1]]
- self.assertTrue(len(a) == len(b) == nbits)
- self.assertEqual(ba2int(a) + 1, ba2int(b))
- start += cnt
- def ensure_count(self, chc, count):
- # ensure count list corresponds to length counts from codedict
- maxbits = len(count) - 1
- self.assertEqual(maxbits, max(len(a) for a in chc.values()))
- my_count = (maxbits + 1) * [0]
- for a in chc.values():
- self.assertEqual(a.endian, 'big')
- my_count[len(a)] += 1
- self.assertEqual(my_count, count)
- def ensure_complete(self, count):
- # ensure code is complete and not oversubscribed
- len_c = len(count)
- x = sum(count[i] << (len_c - i) for i in range(1, len_c))
- self.assertEqual(x, 1 << len_c)
- def ensure_complete_2(self, chc):
- # ensure code is complete
- dt = decodetree(chc)
- self.assertTrue(dt.complete())
- def ensure_round_trip(self, chc, count, symbol):
- # create a short test message, encode and decode
- msg = choices(symbol, k=10)
- a = bitarray()
- a.encode(chc, msg)
- it = canonical_decode(a, count, symbol)
- # the iterator holds a reference to the bitarray and symbol list
- del a, count, symbol
- self.assertEqual(type(it).__name__, 'canonical_decodeiter')
- self.assertEqual(list(it), msg)
- def check_code(self, chc, count, symbol):
- self.assertTrue(len(chc) == len(symbol) == sum(count))
- self.assertEqual(count[0], 0) # no codes have length 0
- self.assertTrue(set(chc) == set(symbol))
- # the code of the last symbol has all 1 bits
- self.assertTrue(chc[symbol[-1]].all())
- # the code of the first symbol starts with bit 0
- self.assertFalse(chc[symbol[0]][0])
- self.ensure_sorted(chc, symbol)
- self.ensure_consecutive(chc, count, symbol)
- self.ensure_count(chc, count)
- self.ensure_complete(count)
- self.ensure_complete_2(chc)
- self.ensure_round_trip(chc, count, symbol)
- def test_simple_counter(self):
- plain = bytearray(b'the quick brown fox jumps over the lazy dog.')
- cnt = Counter(plain)
- self.check_code(*canonical_huffman(cnt))
- def test_no_comp(self):
- freq = {None: 1, "A": 1} # None and "A" are not comparable
- self.check_code(*canonical_huffman(freq))
- def test_balanced(self):
- n = 7
- freq = {}
- for i in range(1 << n):
- freq[i] = 1
- code, count, sym = canonical_huffman(freq)
- self.assertEqual(len(code), 1 << n)
- self.assertTrue(all(len(v) == n for v in code.values()))
- self.check_code(code, count, sym)
- def test_unbalanced(self):
- n = 32
- freq = {}
- for i in range(n):
- freq[i] = 1 << i
- code = canonical_huffman(freq)[0]
- self.assertEqual(len(code), n)
- for i in range(n):
- self.assertEqual(len(code[i]), n - max(1, i))
- self.check_code(*canonical_huffman(freq))
- def test_random_freq(self):
- for n in 2, 3, 4, randint(5, 200):
- freq = {i: random() for i in range(n)}
- self.check_code(*canonical_huffman(freq))
- # ---------------------------------------------------------------------------
- if __name__ == '__main__':
- unittest.main()
|