Skip to content

Commit

Permalink
More assert fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyue Ping Ong committed Sep 5, 2023
1 parent 2a3abb9 commit dc76ff7
Show file tree
Hide file tree
Showing 19 changed files with 101 additions and 101 deletions.
12 changes: 6 additions & 6 deletions tests/test_bisect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
class FuncTestCase(unittest.TestCase):
def test_funcs(self):
l = [0, 1, 2, 3, 4]
self.assertEqual(index(l, 1), 1)
self.assertEqual(find_lt(l, 1), 0)
self.assertEqual(find_gt(l, 1), 2)
self.assertEqual(find_le(l, 1), 1)
self.assertEqual(find_ge(l, 2), 2)
# self.assertEqual(index([0, 1, 1.5, 2], 1.501, atol=0.1), 4)
assert index(l, 1) == 1
assert find_lt(l, 1) == 0
assert find_gt(l, 1) == 2
assert find_le(l, 1) == 1
assert find_ge(l, 2) == 2
# assert index([0, 1, 1.5, 2], 1.501, atol=0.1) == 4


if __name__ == "__main__":
Expand Down
16 changes: 8 additions & 8 deletions tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@ def test_frozen_dict(self):
d = frozendict({"hello": "world"})
self.assertRaises(KeyError, d.__setitem__, "k", "v")
self.assertRaises(KeyError, d.update, {"k": "v"})
self.assertEqual(d["hello"], "world")
assert d["hello"] == "world"

def test_namespace_dict(self):
d = Namespace(foo="bar")
d["hello"] = "world"
self.assertEqual(d["foo"], "bar")
assert d["foo"] == "bar"
self.assertRaises(KeyError, d.__setitem__, "foo", "spam")

def test_attr_dict(self):
d = AttrDict(foo=1, bar=2)
self.assertEqual(d.bar, 2)
self.assertEqual(d["foo"], d.foo)
assert d.bar == 2
assert d["foo"] == d.foo
d.bar = "hello"
self.assertEqual(d["bar"], "hello")
assert d["bar"] == "hello"

def test_frozen_attrdict(self):
d = FrozenAttrDict({"hello": "world", 1: 2})
self.assertEqual(d["hello"], "world")
self.assertEqual(d.hello, "world")
assert d["hello"] == "world"
assert d.hello == "world"
self.assertRaises(KeyError, d.update, {"updating": 2})
with self.assertRaises(KeyError):
d["foo"] = "bar"
Expand All @@ -46,7 +46,7 @@ def test_tree(self):
self.assertIn("b", x["a"])
self.assertNotIn("c", x["a"])
self.assertIn("c", x["a"]["b"])
self.assertEqual(x["a"]["b"]["c"]["d"], 1)
assert x["a"]["b"]["c"]["d"] == 1


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions tests/test_design_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class A:
a1 = A()
a2 = A()

self.assertEqual(id(a1), id(a2))
assert id(a1) == id(a2)


@cached_class
Expand All @@ -36,13 +36,13 @@ def test_cached_class(self):
a1b = A(1)
a2 = A(2)

self.assertEqual(id(a1a), id(a1b))
assert id(a1a) == id(a1b)
self.assertNotEqual(id(a1a), id(a2))

# def test_pickle(self):
# a = A(2)
# o = pickle.dumps(a)
# self.assertEqual(a, pickle.loads(o))
# assert a == pickle.loads(o)


if __name__ == "__main__":
Expand Down
12 changes: 6 additions & 6 deletions tests/test_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def func_a(self):
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger a warning.
self.assertEqual(a().property_b, "b")
assert a().property_b == "b"
# Verify some things
assert issubclass(w[-1].category, FutureWarning)

with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger a warning.
self.assertEqual(a().func_a(), "a")
assert a().func_a() == "a"
# Verify some things
assert issubclass(w[-1].category, FutureWarning)

Expand All @@ -86,7 +86,7 @@ def classmethod_b(self):
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger a warning.
self.assertEqual(A().classmethod_b(), "b")
assert A().classmethod_b() == "b"
# Verify some things
assert issubclass(w[-1].category, FutureWarning)

Expand All @@ -107,7 +107,7 @@ def classmethod_b(self):
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger a warning.
self.assertEqual(A().classmethod_b(), "b")
assert A().classmethod_b() == "b"
# Verify some things
assert issubclass(w[-1].category, DeprecationWarning)

Expand All @@ -127,10 +127,10 @@ def use_fictitious_mod():
def use_unittest():
return "success"

self.assertEqual(use_unittest(), "success")
assert use_unittest() == "success"

def test_get_ncpus(self):
self.assertEqual(get_ncpus(), multiprocessing.cpu_count())
assert get_ncpus() == multiprocessing.cpu_count()

def test_install_except_hook(self):
install_excepthook()
Expand Down
Binary file modified tests/test_files/3000_lines.txt.gz
Binary file not shown.
4 changes: 2 additions & 2 deletions tests/test_fractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

class FuncTestCase(unittest.TestCase):
def test_gcd(self):
self.assertEqual(gcd(7, 14, 63), 7)
assert gcd(7, 14, 63) == 7

def test_lcm(self):
self.assertEqual(lcm(2, 3, 4), 12)
assert lcm(2, 3, 4) == 12

def test_gcd_float(self):
vs = [6.2, 12.4, 15.5 + 5e-9]
Expand Down
50 changes: 25 additions & 25 deletions tests/test_functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ def cached_func(a, b, c=3):
return a + b + c

# call a few times to get some stats
self.assertEqual(cached_func(1, 2, c=4), 7)
self.assertEqual(cached_func(3, 2), 8)
self.assertEqual(cached_func(3, 2), 8)
self.assertEqual(cached_func(1, 2, c=4), 7)
self.assertEqual(cached_func(4, 2), 9)
self.assertEqual(cached_func(4, 2), 9)
self.assertEqual(cached_func(3, 2), 8)
self.assertEqual(cached_func(1, 2), 6)

self.assertEqual(cached_func.cache_info().hits, 3)
self.assertEqual(cached_func.cache_info().misses, 5)
assert cached_func(1, 2, c=4) == 7
assert cached_func(3, 2) == 8
assert cached_func(3, 2) == 8
assert cached_func(1, 2, c=4) == 7
assert cached_func(4, 2) == 9
assert cached_func(4, 2) == 9
assert cached_func(3, 2) == 8
assert cached_func(1, 2) == 6

assert cached_func.cache_info().hits == 3
assert cached_func.cache_info().misses == 5

def test_class_method(self):
class TestClass:
Expand All @@ -42,14 +42,14 @@ def cached_func(self, x):
a = TestClass()
b = TestClass()

self.assertEqual(a.cached_func(1), 1)
self.assertEqual(b.cached_func(2), 2)
self.assertEqual(b.cached_func(3), 3)
self.assertEqual(a.cached_func(3), 3)
self.assertEqual(a.cached_func(1), 1)
assert a.cached_func(1) == 1
assert b.cached_func(2) == 2
assert b.cached_func(3) == 3
assert a.cached_func(3) == 3
assert a.cached_func(1) == 1

self.assertEqual(a.cached_func.cache_info().hits, 1)
self.assertEqual(a.cached_func.cache_info().misses, 4)
assert a.cached_func.cache_info().hits == 1
assert a.cached_func.cache_info().misses == 4

class TestClass2:
@lru_cache(None)
Expand All @@ -59,14 +59,14 @@ def cached_func(self, x):
a = TestClass2()
b = TestClass2()

self.assertEqual(a.cached_func(1), 1)
self.assertEqual(b.cached_func(2), 2)
self.assertEqual(b.cached_func(3), 3)
self.assertEqual(a.cached_func(3), 3)
self.assertEqual(a.cached_func(1), 1)
assert a.cached_func(1) == 1
assert b.cached_func(2) == 2
assert b.cached_func(3) == 3
assert a.cached_func(3) == 3
assert a.cached_func(1) == 1

self.assertEqual(a.cached_func.cache_info().hits, 1)
self.assertEqual(a.cached_func.cache_info().misses, 4)
assert a.cached_func.cache_info().hits == 1
assert a.cached_func.cache_info().misses == 4


class TestCase(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_func(self):
assert caller_name()

def test_all_subclasses(self):
self.assertEqual(all_subclasses(LittleCatA), [LittleCatB, LittleCatD])
assert all_subclasses(LittleCatA) == [LittleCatB, LittleCatD]


if __name__ == "__main__":
Expand Down
14 changes: 7 additions & 7 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,24 +125,24 @@ def test_empty_file(self):
class ZopenTest(unittest.TestCase):
def test_zopen(self):
with zopen(os.path.join(test_dir, "myfile_gz.gz"), mode="rt") as f:
self.assertEqual(f.read(), "HelloWorld.\n\n")
assert f.read() == "HelloWorld.\n\n"
with zopen(os.path.join(test_dir, "myfile_bz2.bz2"), mode="rt") as f:
self.assertEqual(f.read(), "HelloWorld.\n\n")
assert f.read() == "HelloWorld.\n\n"
with zopen(os.path.join(test_dir, "myfile_bz2.bz2"), "rt") as f:
self.assertEqual(f.read(), "HelloWorld.\n\n")
assert f.read() == "HelloWorld.\n\n"
with zopen(os.path.join(test_dir, "myfile_xz.xz"), "rt") as f:
self.assertEqual(f.read(), "HelloWorld.\n\n")
assert f.read() == "HelloWorld.\n\n"
with zopen(os.path.join(test_dir, "myfile_lzma.lzma"), "rt") as f:
self.assertEqual(f.read(), "HelloWorld.\n\n")
assert f.read() == "HelloWorld.\n\n"
with zopen(os.path.join(test_dir, "myfile"), mode="rt") as f:
self.assertEqual(f.read(), "HelloWorld.\n\n")
assert f.read() == "HelloWorld.\n\n"

@unittest.skipIf(Path is None, "Not Py3k")
def test_Path_objects(self):
p = Path(test_dir) / "myfile_gz.gz"

with zopen(p, mode="rt") as f:
self.assertEqual(f.read(), "HelloWorld.\n\n")
assert f.read() == "HelloWorld.\n\n"


class FileLockTest(unittest.TestCase):
Expand Down
26 changes: 13 additions & 13 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ def test_unsafe_hash(self):
a_list[0].unsafe_hash().hexdigest(),
"ea44de0e2ef627be582282c02c48e94de0d58ec6",
)
self.assertEqual(obj.unsafe_hash().hexdigest(), "44204c8da394e878f7562c9aa2e37c2177f28b81")
assert obj.unsafe_hash().hexdigest() == "44204c8da394e878f7562c9aa2e37c2177f28b81"

def test_version(self):
obj = self.good_cls("Hello", "World", "Python")
d = obj.as_dict()
self.assertEqual(d["@version"], tests_version)
assert d["@version"] == tests_version

def test_nested_to_from_dict(self):
GMC = GoodMSONClass
Expand Down Expand Up @@ -243,35 +243,35 @@ def test_nested_to_from_dict(self):
obj2 = GoodNestedMSONClass.from_dict(obj_dict)
assert [obj2.a_list[ii] == aa for ii, aa in enumerate(obj.a_list)]
assert [obj2.b_dict[kk] == val for kk, val in obj.b_dict.items()]
self.assertEqual(len(obj.a_list), len(obj2.a_list))
self.assertEqual(len(obj.b_dict), len(obj2.b_dict))
assert len(obj.a_list) == len(obj2.a_list)
assert len(obj.b_dict) == len(obj2.b_dict)
s = json.dumps(obj_dict)
obj3 = json.loads(s, cls=MontyDecoder)
assert [obj2.a_list[ii] == aa for ii, aa in enumerate(obj3.a_list)]
assert [obj2.b_dict[kk] == val for kk, val in obj3.b_dict.items()]
self.assertEqual(len(obj3.a_list), len(obj2.a_list))
self.assertEqual(len(obj3.b_dict), len(obj2.b_dict))
assert len(obj3.a_list) == len(obj2.a_list)
assert len(obj3.b_dict) == len(obj2.b_dict)
s = json.dumps(obj, cls=MontyEncoder)
obj4 = json.loads(s, cls=MontyDecoder)
assert [obj4.a_list[ii] == aa for ii, aa in enumerate(obj.a_list)]
assert [obj4.b_dict[kk] == val for kk, val in obj.b_dict.items()]
self.assertEqual(len(obj.a_list), len(obj4.a_list))
self.assertEqual(len(obj.b_dict), len(obj4.b_dict))
assert len(obj.a_list) == len(obj4.a_list)
assert len(obj.b_dict) == len(obj4.b_dict)

def test_enum_serialization(self):
e = EnumTest.a
d = e.as_dict()
e_new = EnumTest.from_dict(d)
self.assertEqual(e_new.name, e.name)
self.assertEqual(e_new.value, e.value)
assert e_new.name == e.name
assert e_new.value == e.value

d = {"123": EnumTest.a}
f = jsanitize(d)
self.assertEqual(f["123"], "EnumTest.a")
assert f["123"] == "EnumTest.a"

f = jsanitize(d, strict=True)
self.assertEqual(f["123"]["@module"], "tests.test_json")
self.assertEqual(f["123"]["@class"], "EnumTest")
assert f["123"]["@module"] == "tests.test_json"
assert f["123"]["@class"] == "EnumTest"
self.assertEqual(f["123"]["value"], 1)

f = jsanitize(d, strict=True, enum_values=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

class FuncTest(unittest.TestCase):
def test_nCr(self):
self.assertEqual(nCr(4, 2), 6)
assert nCr(4, 2) == 6

def test_deprecated_property(self):
self.assertEqual(nPr(4, 2), 12)
assert nPr(4, 2) == 12


if __name__ == "__main__":
Expand Down
14 changes: 7 additions & 7 deletions tests/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
class FuncCase(unittest.TestCase):
def test_imap_tqdm(self):
results = imap_tqdm(4, sqrt, range(10000))
self.assertEqual(len(results), 10000)
self.assertEqual(results[0], 0)
self.assertEqual(results[400], 20)
self.assertEqual(results[9999], 99.99499987499375)
assert len(results) == 10000
assert results[0] == 0
assert results[400] == 20
assert results[9999] == 99.99499987499375
results = imap_tqdm(4, sqrt, (i**2 for i in range(10000)))
self.assertEqual(len(results), 10000)
self.assertEqual(results[0], 0)
self.assertEqual(results[400], 400)
assert len(results) == 10000
assert results[0] == 0
assert results[400] == 400


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/test_os.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
class PathTest(unittest.TestCase):
def test_zpath(self):
fullzpath = zpath(os.path.join(test_dir, "myfile_gz"))
self.assertEqual(os.path.join(test_dir, "myfile_gz.gz"), fullzpath)
assert os.path.join(test_dir, "myfile_gz.gz") == fullzpath

def test_find_exts(self):
assert len(find_exts(os.path.dirname(__file__), "py")) >= 18
self.assertEqual(len(find_exts(os.path.dirname(__file__), "bz2")), 2)
assert len(find_exts(os.path.dirname(__file__), "bz2")) == 2
self.assertEqual(
len(find_exts(os.path.dirname(__file__), "bz2", exclude_dirs="test_files")),
0,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def test_regrep(self):
"""
fname = os.path.join(test_dir, "3000_lines.txt")
matches = regrep(fname, {"1": r"1(\d+)", "3": r"3(\d+)"}, postprocess=int)
self.assertEqual(len(matches["1"]), 1380)
self.assertEqual(len(matches["3"]), 571)
self.assertEqual(matches["1"][0][0][0], 0)
assert len(matches["1"]) == 1380
assert len(matches["3"]) == 571
assert matches["1"][0][0][0] == 0

matches = regrep(
fname,
Expand All @@ -26,8 +26,8 @@ def test_regrep(self):
terminate_on_match=True,
postprocess=int,
)
self.assertEqual(len(matches["1"]), 1)
self.assertEqual(len(matches["3"]), 11)
assert len(matches["1"]) == 1
assert len(matches["3"]) == 11


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit dc76ff7

Please sign in to comment.