test_curried.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import cytoolz
  2. import cytoolz.curried
  3. from cytoolz.curried import (take, first, second, sorted, merge_with, reduce,
  4. merge, operator as cop)
  5. from collections import defaultdict
  6. from importlib import import_module
  7. from operator import add
  8. def test_take():
  9. assert list(take(2)([1, 2, 3])) == [1, 2]
  10. def test_first():
  11. assert first is cytoolz.itertoolz.first
  12. def test_merge():
  13. assert merge(factory=lambda: defaultdict(int))({1: 1}) == {1: 1}
  14. assert merge({1: 1}) == {1: 1}
  15. assert merge({1: 1}, factory=lambda: defaultdict(int)) == {1: 1}
  16. def test_merge_with():
  17. assert merge_with(sum)({1: 1}, {1: 2}) == {1: 3}
  18. def test_merge_with_list():
  19. assert merge_with(sum, [{'a': 1}, {'a': 2}]) == {'a': 3}
  20. def test_sorted():
  21. assert sorted(key=second)([(1, 2), (2, 1)]) == [(2, 1), (1, 2)]
  22. def test_reduce():
  23. assert reduce(add)((1, 2, 3)) == 6
  24. def test_module_name():
  25. assert cytoolz.curried.__name__ == 'cytoolz.curried'
  26. def should_curry(func):
  27. if not callable(func) or isinstance(func, cytoolz.curry):
  28. return False
  29. nargs = cytoolz.functoolz.num_required_args(func)
  30. if nargs is None or nargs > 1:
  31. return True
  32. return nargs == 1 and cytoolz.functoolz.has_keywords(func)
  33. def test_curried_operator():
  34. import operator
  35. for k, v in vars(cop).items():
  36. if not callable(v):
  37. continue
  38. if not isinstance(v, cytoolz.curry):
  39. try:
  40. # Make sure it is unary
  41. v(1)
  42. except TypeError:
  43. try:
  44. v('x')
  45. except TypeError:
  46. pass
  47. else:
  48. continue
  49. raise AssertionError(
  50. 'cytoolz.curried.operator.%s is not curried!' % k,
  51. )
  52. assert should_curry(getattr(operator, k)) == isinstance(v, cytoolz.curry), k
  53. # Make sure this isn't totally empty.
  54. assert len(set(vars(cop)) & {'add', 'sub', 'mul'}) == 3
  55. def test_curried_namespace():
  56. exceptions = import_module('cytoolz.curried.exceptions')
  57. namespace = {}
  58. def curry_namespace(ns):
  59. return {
  60. name: cytoolz.curry(f) if should_curry(f) else f
  61. for name, f in ns.items() if '__' not in name
  62. }
  63. from_cytoolz = curry_namespace(vars(cytoolz))
  64. from_exceptions = curry_namespace(vars(exceptions))
  65. namespace.update(cytoolz.merge(from_cytoolz, from_exceptions))
  66. namespace = cytoolz.valfilter(callable, namespace)
  67. curried_namespace = cytoolz.valfilter(callable, cytoolz.curried.__dict__)
  68. if namespace != curried_namespace:
  69. missing = set(namespace) - set(curried_namespace)
  70. if missing:
  71. raise AssertionError('There are missing functions in cytoolz.curried:\n %s'
  72. % ' \n'.join(sorted(missing)))
  73. extra = set(curried_namespace) - set(namespace)
  74. if extra:
  75. raise AssertionError('There are extra functions in cytoolz.curried:\n %s'
  76. % ' \n'.join(sorted(extra)))
  77. unequal = cytoolz.merge_with(list, namespace, curried_namespace)
  78. unequal = cytoolz.valfilter(lambda x: x[0] != x[1], unequal)
  79. messages = []
  80. for name, (orig_func, auto_func) in sorted(unequal.items()):
  81. if name in from_exceptions:
  82. messages.append('%s should come from cytoolz.curried.exceptions' % name)
  83. elif should_curry(getattr(cytoolz, name)):
  84. messages.append('%s should be curried from cytoolz' % name)
  85. else:
  86. messages.append('%s should come from cytoolz and NOT be curried' % name)
  87. raise AssertionError('\n'.join(messages))