From dc76ff7a4cc51160f0e4f74bf622a5f43c6f8e90 Mon Sep 17 00:00:00 2001 From: Shyue Ping Ong Date: Tue, 5 Sep 2023 07:10:27 -0700 Subject: [PATCH] More assert fixes. --- tests/test_bisect.py | 12 +++---- tests/test_collections.py | 16 ++++----- tests/test_design_patterns.py | 6 ++-- tests/test_dev.py | 12 +++---- tests/test_files/3000_lines.txt.gz | Bin 6496 -> 6496 bytes tests/test_fractions.py | 4 +-- tests/test_functools.py | 50 ++++++++++++++--------------- tests/test_inspect.py | 2 +- tests/test_io.py | 14 ++++---- tests/test_json.py | 26 +++++++-------- tests/test_math.py | 4 +-- tests/test_multiprocessing.py | 14 ++++---- tests/test_os.py | 4 +-- tests/test_re.py | 10 +++--- tests/test_serialization.py | 10 +++--- tests/test_shutil.py | 8 ++--- tests/test_string.py | 6 ++-- tests/test_subprocess.py | 2 +- tests/test_tempfile.py | 2 +- 19 files changed, 101 insertions(+), 101 deletions(-) diff --git a/tests/test_bisect.py b/tests/test_bisect.py index 3c0425fe9..c196e8fa8 100644 --- a/tests/test_bisect.py +++ b/tests/test_bisect.py @@ -6,12 +6,12 @@ class FuncTestCase(unittest.TestCase): def test_funcs(self): l = [0, 1, 2, 3, 4] - self.assertEqual(index(l, 1), 1) - self.assertEqual(find_lt(l, 1), 0) - self.assertEqual(find_gt(l, 1), 2) - self.assertEqual(find_le(l, 1), 1) - self.assertEqual(find_ge(l, 2), 2) - # self.assertEqual(index([0, 1, 1.5, 2], 1.501, atol=0.1), 4) + assert index(l, 1) == 1 + assert find_lt(l, 1) == 0 + assert find_gt(l, 1) == 2 + assert find_le(l, 1) == 1 + assert find_ge(l, 2) == 2 + # assert index([0, 1, 1.5, 2], 1.501, atol=0.1) == 4 if __name__ == "__main__": diff --git a/tests/test_collections.py b/tests/test_collections.py index 87de9e7e1..9ced27055 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -11,25 +11,25 @@ def test_frozen_dict(self): d = frozendict({"hello": "world"}) self.assertRaises(KeyError, d.__setitem__, "k", "v") self.assertRaises(KeyError, d.update, {"k": "v"}) - self.assertEqual(d["hello"], "world") + assert d["hello"] == "world" def test_namespace_dict(self): d = Namespace(foo="bar") d["hello"] = "world" - self.assertEqual(d["foo"], "bar") + assert d["foo"] == "bar" self.assertRaises(KeyError, d.__setitem__, "foo", "spam") def test_attr_dict(self): d = AttrDict(foo=1, bar=2) - self.assertEqual(d.bar, 2) - self.assertEqual(d["foo"], d.foo) + assert d.bar == 2 + assert d["foo"] == d.foo d.bar = "hello" - self.assertEqual(d["bar"], "hello") + assert d["bar"] == "hello" def test_frozen_attrdict(self): d = FrozenAttrDict({"hello": "world", 1: 2}) - self.assertEqual(d["hello"], "world") - self.assertEqual(d.hello, "world") + assert d["hello"] == "world" + assert d.hello == "world" self.assertRaises(KeyError, d.update, {"updating": 2}) with self.assertRaises(KeyError): d["foo"] = "bar" @@ -46,7 +46,7 @@ def test_tree(self): self.assertIn("b", x["a"]) self.assertNotIn("c", x["a"]) self.assertIn("c", x["a"]["b"]) - self.assertEqual(x["a"]["b"]["c"]["d"], 1) + assert x["a"]["b"]["c"]["d"] == 1 if __name__ == "__main__": diff --git a/tests/test_design_patterns.py b/tests/test_design_patterns.py index 2f7f93bf9..ea0ff6586 100644 --- a/tests/test_design_patterns.py +++ b/tests/test_design_patterns.py @@ -12,7 +12,7 @@ class A: a1 = A() a2 = A() - self.assertEqual(id(a1), id(a2)) + assert id(a1) == id(a2) @cached_class @@ -36,13 +36,13 @@ def test_cached_class(self): a1b = A(1) a2 = A(2) - self.assertEqual(id(a1a), id(a1b)) + assert id(a1a) == id(a1b) self.assertNotEqual(id(a1a), id(a2)) # def test_pickle(self): # a = A(2) # o = pickle.dumps(a) - # self.assertEqual(a, pickle.loads(o)) + # assert a == pickle.loads(o) if __name__ == "__main__": diff --git a/tests/test_dev.py b/tests/test_dev.py index 5901025fa..32dd88d87 100644 --- a/tests/test_dev.py +++ b/tests/test_dev.py @@ -56,7 +56,7 @@ def func_a(self): # Cause all warnings to always be triggered. warnings.simplefilter("always") # Trigger a warning. - self.assertEqual(a().property_b, "b") + assert a().property_b == "b" # Verify some things assert issubclass(w[-1].category, FutureWarning) @@ -64,7 +64,7 @@ def func_a(self): # Cause all warnings to always be triggered. warnings.simplefilter("always") # Trigger a warning. - self.assertEqual(a().func_a(), "a") + assert a().func_a() == "a" # Verify some things assert issubclass(w[-1].category, FutureWarning) @@ -86,7 +86,7 @@ def classmethod_b(self): # Cause all warnings to always be triggered. warnings.simplefilter("always") # Trigger a warning. - self.assertEqual(A().classmethod_b(), "b") + assert A().classmethod_b() == "b" # Verify some things assert issubclass(w[-1].category, FutureWarning) @@ -107,7 +107,7 @@ def classmethod_b(self): # Cause all warnings to always be triggered. warnings.simplefilter("always") # Trigger a warning. - self.assertEqual(A().classmethod_b(), "b") + assert A().classmethod_b() == "b" # Verify some things assert issubclass(w[-1].category, DeprecationWarning) @@ -127,10 +127,10 @@ def use_fictitious_mod(): def use_unittest(): return "success" - self.assertEqual(use_unittest(), "success") + assert use_unittest() == "success" def test_get_ncpus(self): - self.assertEqual(get_ncpus(), multiprocessing.cpu_count()) + assert get_ncpus() == multiprocessing.cpu_count() def test_install_except_hook(self): install_excepthook() diff --git a/tests/test_files/3000_lines.txt.gz b/tests/test_files/3000_lines.txt.gz index 9d183fdd210bf62c0c2561fc7c6f4f7f899d4228..770307de7df71769fbf033acd7d36194ca302932 100644 GIT binary patch delta 15 WcmaE0^uUNszMF&NwAn_sC`kY<=>-q~ delta 15 WcmaE0^uUNszMF%C*K{LWlq3Kta0F2R diff --git a/tests/test_fractions.py b/tests/test_fractions.py index 2f03a75c8..efc789951 100644 --- a/tests/test_fractions.py +++ b/tests/test_fractions.py @@ -5,10 +5,10 @@ class FuncTestCase(unittest.TestCase): def test_gcd(self): - self.assertEqual(gcd(7, 14, 63), 7) + assert gcd(7, 14, 63) == 7 def test_lcm(self): - self.assertEqual(lcm(2, 3, 4), 12) + assert lcm(2, 3, 4) == 12 def test_gcd_float(self): vs = [6.2, 12.4, 15.5 + 5e-9] diff --git a/tests/test_functools.py b/tests/test_functools.py index e3a756e86..30499255a 100644 --- a/tests/test_functools.py +++ b/tests/test_functools.py @@ -21,17 +21,17 @@ def cached_func(a, b, c=3): return a + b + c # call a few times to get some stats - self.assertEqual(cached_func(1, 2, c=4), 7) - self.assertEqual(cached_func(3, 2), 8) - self.assertEqual(cached_func(3, 2), 8) - self.assertEqual(cached_func(1, 2, c=4), 7) - self.assertEqual(cached_func(4, 2), 9) - self.assertEqual(cached_func(4, 2), 9) - self.assertEqual(cached_func(3, 2), 8) - self.assertEqual(cached_func(1, 2), 6) - - self.assertEqual(cached_func.cache_info().hits, 3) - self.assertEqual(cached_func.cache_info().misses, 5) + assert cached_func(1, 2, c=4) == 7 + assert cached_func(3, 2) == 8 + assert cached_func(3, 2) == 8 + assert cached_func(1, 2, c=4) == 7 + assert cached_func(4, 2) == 9 + assert cached_func(4, 2) == 9 + assert cached_func(3, 2) == 8 + assert cached_func(1, 2) == 6 + + assert cached_func.cache_info().hits == 3 + assert cached_func.cache_info().misses == 5 def test_class_method(self): class TestClass: @@ -42,14 +42,14 @@ def cached_func(self, x): a = TestClass() b = TestClass() - self.assertEqual(a.cached_func(1), 1) - self.assertEqual(b.cached_func(2), 2) - self.assertEqual(b.cached_func(3), 3) - self.assertEqual(a.cached_func(3), 3) - self.assertEqual(a.cached_func(1), 1) + assert a.cached_func(1) == 1 + assert b.cached_func(2) == 2 + assert b.cached_func(3) == 3 + assert a.cached_func(3) == 3 + assert a.cached_func(1) == 1 - self.assertEqual(a.cached_func.cache_info().hits, 1) - self.assertEqual(a.cached_func.cache_info().misses, 4) + assert a.cached_func.cache_info().hits == 1 + assert a.cached_func.cache_info().misses == 4 class TestClass2: @lru_cache(None) @@ -59,14 +59,14 @@ def cached_func(self, x): a = TestClass2() b = TestClass2() - self.assertEqual(a.cached_func(1), 1) - self.assertEqual(b.cached_func(2), 2) - self.assertEqual(b.cached_func(3), 3) - self.assertEqual(a.cached_func(3), 3) - self.assertEqual(a.cached_func(1), 1) + assert a.cached_func(1) == 1 + assert b.cached_func(2) == 2 + assert b.cached_func(3) == 3 + assert a.cached_func(3) == 3 + assert a.cached_func(1) == 1 - self.assertEqual(a.cached_func.cache_info().hits, 1) - self.assertEqual(a.cached_func.cache_info().misses, 4) + assert a.cached_func.cache_info().hits == 1 + assert a.cached_func.cache_info().misses == 4 class TestCase(unittest.TestCase): diff --git a/tests/test_inspect.py b/tests/test_inspect.py index 299d6461e..72ea43343 100644 --- a/tests/test_inspect.py +++ b/tests/test_inspect.py @@ -26,7 +26,7 @@ def test_func(self): assert caller_name() def test_all_subclasses(self): - self.assertEqual(all_subclasses(LittleCatA), [LittleCatB, LittleCatD]) + assert all_subclasses(LittleCatA) == [LittleCatB, LittleCatD] if __name__ == "__main__": diff --git a/tests/test_io.py b/tests/test_io.py index 973388a56..6871bb0fb 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -125,24 +125,24 @@ def test_empty_file(self): class ZopenTest(unittest.TestCase): def test_zopen(self): with zopen(os.path.join(test_dir, "myfile_gz.gz"), mode="rt") as f: - self.assertEqual(f.read(), "HelloWorld.\n\n") + assert f.read() == "HelloWorld.\n\n" with zopen(os.path.join(test_dir, "myfile_bz2.bz2"), mode="rt") as f: - self.assertEqual(f.read(), "HelloWorld.\n\n") + assert f.read() == "HelloWorld.\n\n" with zopen(os.path.join(test_dir, "myfile_bz2.bz2"), "rt") as f: - self.assertEqual(f.read(), "HelloWorld.\n\n") + assert f.read() == "HelloWorld.\n\n" with zopen(os.path.join(test_dir, "myfile_xz.xz"), "rt") as f: - self.assertEqual(f.read(), "HelloWorld.\n\n") + assert f.read() == "HelloWorld.\n\n" with zopen(os.path.join(test_dir, "myfile_lzma.lzma"), "rt") as f: - self.assertEqual(f.read(), "HelloWorld.\n\n") + assert f.read() == "HelloWorld.\n\n" with zopen(os.path.join(test_dir, "myfile"), mode="rt") as f: - self.assertEqual(f.read(), "HelloWorld.\n\n") + assert f.read() == "HelloWorld.\n\n" @unittest.skipIf(Path is None, "Not Py3k") def test_Path_objects(self): p = Path(test_dir) / "myfile_gz.gz" with zopen(p, mode="rt") as f: - self.assertEqual(f.read(), "HelloWorld.\n\n") + assert f.read() == "HelloWorld.\n\n" class FileLockTest(unittest.TestCase): diff --git a/tests/test_json.py b/tests/test_json.py index 9d5c9d6c9..1ff93b8cf 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -206,12 +206,12 @@ def test_unsafe_hash(self): a_list[0].unsafe_hash().hexdigest(), "ea44de0e2ef627be582282c02c48e94de0d58ec6", ) - self.assertEqual(obj.unsafe_hash().hexdigest(), "44204c8da394e878f7562c9aa2e37c2177f28b81") + assert obj.unsafe_hash().hexdigest() == "44204c8da394e878f7562c9aa2e37c2177f28b81" def test_version(self): obj = self.good_cls("Hello", "World", "Python") d = obj.as_dict() - self.assertEqual(d["@version"], tests_version) + assert d["@version"] == tests_version def test_nested_to_from_dict(self): GMC = GoodMSONClass @@ -243,35 +243,35 @@ def test_nested_to_from_dict(self): obj2 = GoodNestedMSONClass.from_dict(obj_dict) assert [obj2.a_list[ii] == aa for ii, aa in enumerate(obj.a_list)] assert [obj2.b_dict[kk] == val for kk, val in obj.b_dict.items()] - self.assertEqual(len(obj.a_list), len(obj2.a_list)) - self.assertEqual(len(obj.b_dict), len(obj2.b_dict)) + assert len(obj.a_list) == len(obj2.a_list) + assert len(obj.b_dict) == len(obj2.b_dict) s = json.dumps(obj_dict) obj3 = json.loads(s, cls=MontyDecoder) assert [obj2.a_list[ii] == aa for ii, aa in enumerate(obj3.a_list)] assert [obj2.b_dict[kk] == val for kk, val in obj3.b_dict.items()] - self.assertEqual(len(obj3.a_list), len(obj2.a_list)) - self.assertEqual(len(obj3.b_dict), len(obj2.b_dict)) + assert len(obj3.a_list) == len(obj2.a_list) + assert len(obj3.b_dict) == len(obj2.b_dict) s = json.dumps(obj, cls=MontyEncoder) obj4 = json.loads(s, cls=MontyDecoder) assert [obj4.a_list[ii] == aa for ii, aa in enumerate(obj.a_list)] assert [obj4.b_dict[kk] == val for kk, val in obj.b_dict.items()] - self.assertEqual(len(obj.a_list), len(obj4.a_list)) - self.assertEqual(len(obj.b_dict), len(obj4.b_dict)) + assert len(obj.a_list) == len(obj4.a_list) + assert len(obj.b_dict) == len(obj4.b_dict) def test_enum_serialization(self): e = EnumTest.a d = e.as_dict() e_new = EnumTest.from_dict(d) - self.assertEqual(e_new.name, e.name) - self.assertEqual(e_new.value, e.value) + assert e_new.name == e.name + assert e_new.value == e.value d = {"123": EnumTest.a} f = jsanitize(d) - self.assertEqual(f["123"], "EnumTest.a") + assert f["123"] == "EnumTest.a" f = jsanitize(d, strict=True) - self.assertEqual(f["123"]["@module"], "tests.test_json") - self.assertEqual(f["123"]["@class"], "EnumTest") + assert f["123"]["@module"] == "tests.test_json" + assert f["123"]["@class"] == "EnumTest" self.assertEqual(f["123"]["value"], 1) f = jsanitize(d, strict=True, enum_values=True) diff --git a/tests/test_math.py b/tests/test_math.py index d5c93d864..1824babeb 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -5,10 +5,10 @@ class FuncTest(unittest.TestCase): def test_nCr(self): - self.assertEqual(nCr(4, 2), 6) + assert nCr(4, 2) == 6 def test_deprecated_property(self): - self.assertEqual(nPr(4, 2), 12) + assert nPr(4, 2) == 12 if __name__ == "__main__": diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index f4c76a166..f2b8817ca 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -7,14 +7,14 @@ class FuncCase(unittest.TestCase): def test_imap_tqdm(self): results = imap_tqdm(4, sqrt, range(10000)) - self.assertEqual(len(results), 10000) - self.assertEqual(results[0], 0) - self.assertEqual(results[400], 20) - self.assertEqual(results[9999], 99.99499987499375) + assert len(results) == 10000 + assert results[0] == 0 + assert results[400] == 20 + assert results[9999] == 99.99499987499375 results = imap_tqdm(4, sqrt, (i**2 for i in range(10000))) - self.assertEqual(len(results), 10000) - self.assertEqual(results[0], 0) - self.assertEqual(results[400], 400) + assert len(results) == 10000 + assert results[0] == 0 + assert results[400] == 400 if __name__ == "__main__": diff --git a/tests/test_os.py b/tests/test_os.py index 0b1c4971b..dfc377e61 100644 --- a/tests/test_os.py +++ b/tests/test_os.py @@ -10,11 +10,11 @@ class PathTest(unittest.TestCase): def test_zpath(self): fullzpath = zpath(os.path.join(test_dir, "myfile_gz")) - self.assertEqual(os.path.join(test_dir, "myfile_gz.gz"), fullzpath) + assert os.path.join(test_dir, "myfile_gz.gz") == fullzpath def test_find_exts(self): assert len(find_exts(os.path.dirname(__file__), "py")) >= 18 - self.assertEqual(len(find_exts(os.path.dirname(__file__), "bz2")), 2) + assert len(find_exts(os.path.dirname(__file__), "bz2")) == 2 self.assertEqual( len(find_exts(os.path.dirname(__file__), "bz2", exclude_dirs="test_files")), 0, diff --git a/tests/test_re.py b/tests/test_re.py index de4fae07a..4611b08ab 100644 --- a/tests/test_re.py +++ b/tests/test_re.py @@ -15,9 +15,9 @@ def test_regrep(self): """ fname = os.path.join(test_dir, "3000_lines.txt") matches = regrep(fname, {"1": r"1(\d+)", "3": r"3(\d+)"}, postprocess=int) - self.assertEqual(len(matches["1"]), 1380) - self.assertEqual(len(matches["3"]), 571) - self.assertEqual(matches["1"][0][0][0], 0) + assert len(matches["1"]) == 1380 + assert len(matches["3"]) == 571 + assert matches["1"][0][0][0] == 0 matches = regrep( fname, @@ -26,8 +26,8 @@ def test_regrep(self): terminate_on_match=True, postprocess=int, ) - self.assertEqual(len(matches["1"]), 1) - self.assertEqual(len(matches["3"]), 11) + assert len(matches["1"]) == 1 + assert len(matches["3"]) == 11 if __name__ == "__main__": diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 05b4532cf..ab8f7b83c 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -46,11 +46,11 @@ def test_dumpfn_loadfn(self): # Test custom kwarg configuration dumpfn(d, "monte_test.json", indent=4) d2 = loadfn("monte_test.json") - self.assertEqual(d, d2) + assert d == d2 os.remove("monte_test.json") dumpfn(d, "monte_test.yaml") d2 = loadfn("monte_test.yaml") - self.assertEqual(d, d2) + assert d == d2 os.remove("monte_test.yaml") # Check if fmt override works. @@ -58,7 +58,7 @@ def test_dumpfn_loadfn(self): with self.assertRaises(json.decoder.JSONDecodeError): d2 = loadfn("monte_test.json") d2 = loadfn("monte_test.json", fmt="yaml") - self.assertEqual(d, d2) + assert d == d2 os.remove("monte_test.json") with self.assertRaises(TypeError): @@ -73,7 +73,7 @@ def test_mpk(self): # Test automatic format detection dumpfn(d, "monte_test.mpk") d2 = loadfn("monte_test.mpk") - self.assertEqual(d, {k: v for k, v in d2.items()}) + assert d, {k: v for k, v in d2.items()} os.remove("monte_test.mpk") # Test to ensure basename is respected, and not directory @@ -84,7 +84,7 @@ def test_mpk(self): dumpfn({"test": 1}, fname) with open("test_file.json") as f: reloaded = json.loads(f.read()) - self.assertEqual(reloaded["test"], 1) + assert reloaded["test"] == 1 if __name__ == "__main__": diff --git a/tests/test_shutil.py b/tests/test_shutil.py index 2a6b1aae4..4695f0a79 100644 --- a/tests/test_shutil.py +++ b/tests/test_shutil.py @@ -42,7 +42,7 @@ def test_recursive_copy_and_compress(self): assert os.path.exists(os.path.join(test_dir, "cpr_src", "sub", "testr")) with open(os.path.join(test_dir, "cpr_src", "test")) as f: txt = f.read() - self.assertEqual(txt, "what") + assert txt == "what" def test_pathlib(self): test_path = Path(test_dir) @@ -71,7 +71,7 @@ def test_compress_and_decompress_file(self): assert not os.path.exists(fname + "." + fmt) with open(fname) as f: txt = f.read() - self.assertEqual(txt, "hello world") + assert txt == "hello world" self.assertRaises(ValueError, compress_file, "whatever", "badformat") # test decompress non-existent/non-compressed file @@ -99,7 +99,7 @@ def test_gzip(self): assert not os.path.exists(full_f) with GzipFile(f"{full_f}.gz") as g: - self.assertEqual(g.readline().decode("utf-8"), "what") + assert g.readline().decode("utf-8") == "what" self.assertAlmostEqual(os.path.getmtime(f"{full_f}.gz"), self.mtime, 4) @@ -116,7 +116,7 @@ def test_handle_sub_dirs(self): assert not os.path.exists(sub_file) with GzipFile(f"{sub_file}.gz") as g: - self.assertEqual(g.readline().decode("utf-8"), "anotherwhat") + assert g.readline().decode("utf-8") == "anotherwhat" def tearDown(self): shutil.rmtree(os.path.join(test_dir, "gzip_dir")) diff --git a/tests/test_string.py b/tests/test_string.py index 4f6a755a5..841c44973 100644 --- a/tests/test_string.py +++ b/tests/test_string.py @@ -14,13 +14,13 @@ def test_remove_non_ascii(self): s = "".join(chr(random.randint(0, 127)) for i in range(10)) s += "".join(chr(random.randint(128, 150)) for i in range(10)) clean = remove_non_ascii(s) - self.assertEqual(len(clean), 10) + assert len(clean) == 10 def test_unicode2str(self): if sys.version_info.major < 3: - self.assertEqual(type(unicode2str("a")), str) + assert type(unicode2str("a")) == str else: - self.assertEqual(type(unicode2str("a")), str) + assert type(unicode2str("a")) == str if __name__ == "__main__": diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index c005ac7cf..2bcd09fcd 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -10,7 +10,7 @@ def test_command(self): sleep05.run(timeout=1) print(sleep05) - self.assertEqual(sleep05.retcode, 0) + assert sleep05.retcode == 0 assert not sleep05.killed sleep05.run(timeout=0.1) diff --git a/tests/test_tempfile.py b/tests/test_tempfile.py index 0f2e549cd..9f006c640 100644 --- a/tests/test_tempfile.py +++ b/tests/test_tempfile.py @@ -138,7 +138,7 @@ def test_symlink(self): def test_bad_root(self): with ScratchDir("bad_groot") as d: - self.assertEqual(d, test_dir) + assert d == test_dir def tearDown(self): os.chdir(self.cwd)