diff --git a/tests/test_bisect.py b/tests/test_bisect.py index c196e8fa8..898e28f89 100644 --- a/tests/test_bisect.py +++ b/tests/test_bisect.py @@ -3,7 +3,7 @@ from monty.bisect import find_ge, find_gt, find_le, find_lt, index -class FuncTestCase(unittest.TestCase): +class TestFunc: def test_funcs(self): l = [0, 1, 2, 3, 4] assert index(l, 1) == 1 diff --git a/tests/test_collections.py b/tests/test_collections.py index 6bbe94350..c1768083d 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -1,23 +1,25 @@ import os -import unittest + +import pytest from monty.collections import AttrDict, FrozenAttrDict, Namespace, frozendict, tree test_dir = os.path.join(os.path.dirname(__file__), "test_files") -class FrozenDictTest(unittest.TestCase): +class TestFrozenDict: def test_frozen_dict(self): d = frozendict({"hello": "world"}) - self.assertRaises(KeyError, d.__setitem__, "k", "v") - self.assertRaises(KeyError, d.update, {"k": "v"}) + with pytest.raises(KeyError): + d["k"] == "v" assert d["hello"] == "world" def test_namespace_dict(self): d = Namespace(foo="bar") d["hello"] = "world" assert d["foo"] == "bar" - self.assertRaises(KeyError, d.__setitem__, "foo", "spam") + with pytest.raises(KeyError): + d.update({"foo": "spam"}) def test_attr_dict(self): d = AttrDict(foo=1, bar=2) @@ -30,16 +32,18 @@ def test_frozen_attrdict(self): d = FrozenAttrDict({"hello": "world", 1: 2}) assert d["hello"] == "world" assert d.hello == "world" - self.assertRaises(KeyError, d.update, {"updating": 2}) - with self.assertRaises(KeyError): + with pytest.raises(KeyError): + d["updating"] == 2 + + with pytest.raises(KeyError): d["foo"] = "bar" - with self.assertRaises(KeyError): + with pytest.raises(KeyError): d.foo = "bar" - with self.assertRaises(KeyError): + with pytest.raises(KeyError): d.hello = "new" -class TreeTest(unittest.TestCase): +class TestTree: def test_tree(self): x = tree() x["a"]["b"]["c"]["d"] = 1 @@ -47,7 +51,3 @@ def test_tree(self): assert "c" not in x["a"] assert "c" in x["a"]["b"] assert x["a"]["b"]["c"]["d"] == 1 - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_design_patterns.py b/tests/test_design_patterns.py index 3a8b62080..d6ded8160 100644 --- a/tests/test_design_patterns.py +++ b/tests/test_design_patterns.py @@ -3,7 +3,7 @@ from monty.design_patterns import cached_class, singleton -class SingletonTest(unittest.TestCase): +class TestSingleton: def test_singleton(self): @singleton class A: @@ -30,7 +30,7 @@ def __getnewargs__(self): return (self.val,) -class CachedClassTest(unittest.TestCase): +class TestCachedClass: def test_cached_class(self): a1a = A(1) a1b = A(1) diff --git a/tests/test_dev.py b/tests/test_dev.py index 3c2f2f603..2309e4576 100644 --- a/tests/test_dev.py +++ b/tests/test_dev.py @@ -2,6 +2,8 @@ import unittest import warnings +import pytest + from monty.dev import deprecated, get_ncpus, install_excepthook, requires @@ -16,7 +18,7 @@ def prop(self): pass -class DecoratorTest(unittest.TestCase): +class TestDecorator: def test_deprecated(self): def func_a(): pass @@ -121,7 +123,8 @@ def test_requires(self): def use_fictitious_mod(): print("success") - self.assertRaises(RuntimeError, use_fictitious_mod) + with pytest.raises(RuntimeError): + use_fictitious_mod() @requires(unittest is not None, "scipy is not present.") def use_unittest(): diff --git a/tests/test_files/3000_lines.txt.gz b/tests/test_files/3000_lines.txt.gz index 4a61dbc11..b0551deb7 100644 Binary files a/tests/test_files/3000_lines.txt.gz and b/tests/test_files/3000_lines.txt.gz differ diff --git a/tests/test_fnmatch.py b/tests/test_fnmatch.py index b274ecc99..1a5fdb735 100644 --- a/tests/test_fnmatch.py +++ b/tests/test_fnmatch.py @@ -3,7 +3,7 @@ from monty.fnmatch import WildCard -class FuncTest(unittest.TestCase): +class TestFunc: def test_match(self): wc = WildCard("*.pdf") assert wc.match("A.pdf") diff --git a/tests/test_fractions.py b/tests/test_fractions.py index efc789951..3c4be41a2 100644 --- a/tests/test_fractions.py +++ b/tests/test_fractions.py @@ -1,9 +1,11 @@ import unittest +import pytest + from monty.fractions import gcd, gcd_float, lcm -class FuncTestCase(unittest.TestCase): +class TestFunc: def test_gcd(self): assert gcd(7, 14, 63) == 7 @@ -12,7 +14,7 @@ def test_lcm(self): def test_gcd_float(self): vs = [6.2, 12.4, 15.5 + 5e-9] - self.assertAlmostEqual(gcd_float(vs, 1e-8), 3.1) + assert gcd_float(vs, 1e-8) == pytest.approx(3.1) if __name__ == "__main__": diff --git a/tests/test_inspect.py b/tests/test_inspect.py index 72ea43343..2d2173926 100644 --- a/tests/test_inspect.py +++ b/tests/test_inspect.py @@ -19,7 +19,7 @@ class LittleCatD(LittleCatB): pass -class InspectTest(unittest.TestCase): +class TestInspect: def test_func(self): # Not a real test. Need something better. assert find_top_pyfile() diff --git a/tests/test_io.py b/tests/test_io.py index 4b2ac201b..6f07008e0 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,6 +1,8 @@ import os import unittest +import pytest + try: from pathlib import Path except ImportError: @@ -17,7 +19,7 @@ test_dir = os.path.join(os.path.dirname(__file__), "test_files") -class ReverseReadlineTest(unittest.TestCase): +class TestReverseReadline: NUMLINES = 3000 def test_reverse_readline(self): @@ -28,11 +30,9 @@ def test_reverse_readline(self): """ with open(os.path.join(test_dir, "3000_lines.txt")) as f: for idx, line in enumerate(reverse_readline(f)): - self.assertEqual( - int(line), - self.NUMLINES - idx, - "read_backwards read {} whereas it should " "have read {}".format(int(line), self.NUMLINES - idx), - ) + assert int(line) == self.NUMLINES - idx, "read_backwards read {} whereas it should "( + "have read {" "}" + ).format(int(line), self.NUMLINES - idx) def test_reverse_readline_fake_big(self): """ @@ -40,11 +40,9 @@ def test_reverse_readline_fake_big(self): """ with open(os.path.join(test_dir, "3000_lines.txt")) as f: for idx, line in enumerate(reverse_readline(f, max_mem=0)): - self.assertEqual( - int(line), - self.NUMLINES - idx, - "read_backwards read {} whereas it should " "have read {}".format(int(line), self.NUMLINES - idx), - ) + assert int(line) == self.NUMLINES - idx, "read_backwards read {} whereas it should "( + "have read {" "}" + ).format(int(line), self.NUMLINES - idx) def test_reverse_readline_bz2(self): """ @@ -68,7 +66,7 @@ def test_empty_file(self): raise ValueError("an empty file is being read!") -class ReverseReadfileTest(unittest.TestCase): +class TestReverseReadfile: NUMLINES = 3000 def test_reverse_readfile(self): @@ -79,11 +77,7 @@ def test_reverse_readfile(self): """ fname = os.path.join(test_dir, "3000_lines.txt") for idx, line in enumerate(reverse_readfile(fname)): - self.assertEqual( - int(line), - self.NUMLINES - idx, - "read_backwards read {} whereas it should " "have read {}".format(int(line), self.NUMLINES - idx), - ) + assert int(line) == self.NUMLINES - idx def test_reverse_readfile_gz(self): """ @@ -93,11 +87,7 @@ def test_reverse_readfile_gz(self): """ fname = os.path.join(test_dir, "3000_lines.txt.gz") for idx, line in enumerate(reverse_readfile(fname)): - self.assertEqual( - int(line), - self.NUMLINES - idx, - "read_backwards read {} whereas it should " "have read {}".format(int(line), self.NUMLINES - idx), - ) + assert int(line) == self.NUMLINES - idx def test_reverse_readfile_bz2(self): """ @@ -107,11 +97,7 @@ def test_reverse_readfile_bz2(self): """ fname = os.path.join(test_dir, "3000_lines.txt.bz2") for idx, line in enumerate(reverse_readfile(fname)): - self.assertEqual( - int(line), - self.NUMLINES - idx, - "read_backwards read {} whereas it should " "have read {}".format(int(line), self.NUMLINES - idx), - ) + assert int(line) == self.NUMLINES - idx def test_empty_file(self): """ @@ -122,7 +108,7 @@ def test_empty_file(self): raise ValueError("an empty file is being read!") -class ZopenTest(unittest.TestCase): +class TestZopen: def test_zopen(self): with zopen(os.path.join(test_dir, "myfile_gz.gz"), mode="rt") as f: assert f.read() == "HelloWorld.\n\n" @@ -145,18 +131,18 @@ def test_Path_objects(self): assert f.read() == "HelloWorld.\n\n" -class FileLockTest(unittest.TestCase): - def setUp(self): +class TestFileLock: + def setup_method(self): self.file_name = "__lock__" self.lock = FileLock(self.file_name, timeout=1) self.lock.acquire() def test_raise(self): - with self.assertRaises(FileLockException): + with pytest.raises(FileLockException): new_lock = FileLock(self.file_name, timeout=1) new_lock.acquire() - def tearDown(self): + def teardown_method(self): self.lock.release() diff --git a/tests/test_json.py b/tests/test_json.py index 8159ac4ec..135ae2b2f 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -10,6 +10,7 @@ import numpy as np import pandas as pd +import pytest import torch from bson.objectid import ObjectId @@ -129,8 +130,8 @@ class NestedDataClass: points: list[Point] -class MSONableTest(unittest.TestCase): - def setUp(self): +class TestMSONable: + def setup_method(self): self.good_cls = GoodMSONClass class BadMSONClass(MSONable): @@ -161,17 +162,19 @@ def __init__(self, a, b): def test_to_from_dict(self): obj = self.good_cls("Hello", "World", "Python") d = obj.as_dict() - self.assertIsNotNone(d) + assert d is not None self.good_cls.from_dict(d) jsonstr = obj.to_json() d = json.loads(jsonstr) assert d["@class"], "GoodMSONClass" obj = self.bad_cls("Hello", "World") d = obj.as_dict() - self.assertIsNotNone(d) - self.assertRaises(TypeError, self.bad_cls.from_dict, d) + assert d is not None + with pytest.raises(TypeError): + self.bad_cls.from_dict(d) obj = self.bad_cls2("Hello", "World") - self.assertRaises(NotImplementedError, obj.as_dict) + with pytest.raises(NotImplementedError): + obj.as_dict() obj = self.auto_mson(2, 3) d = obj.as_dict() self.auto_mson.from_dict(d) @@ -278,7 +281,7 @@ def test_enum_serialization(self): assert f["123"] == 1 -class JsonTest(unittest.TestCase): +class TestJson: def test_as_from_dict(self): obj = GoodMSONClass(1, 2, 3, hello="world") s = json.dumps(obj, cls=MontyEncoder) @@ -350,7 +353,8 @@ def test_nan(self): def test_numpy(self): x = np.array([1, 2, 3], dtype="int64") - self.assertRaises(TypeError, json.dumps, x) + with pytest.raises(TypeError): + json.dumps(x) djson = json.dumps(x, cls=MontyEncoder) d = json.loads(djson) assert d["@class"] == "array" @@ -360,10 +364,12 @@ def test_numpy(self): x = json.loads(djson, cls=MontyDecoder) assert isinstance(x, np.ndarray) x = np.min([1, 2, 3]) > 2 - self.assertRaises(TypeError, json.dumps, x) + with pytest.raises(TypeError): + json.dumps(x) x = np.array([1 + 1j, 2 + 1j, 3 + 1j], dtype="complex64") - self.assertRaises(TypeError, json.dumps, x) + with pytest.raises(TypeError): + json.dumps(x) djson = json.dumps(x, cls=MontyEncoder) d = json.loads(djson) assert d["@class"] == "array" @@ -375,7 +381,8 @@ def test_numpy(self): assert x.dtype == "complex64" x = np.array([[1 + 1j, 2 + 1j], [3 + 1j, 4 + 1j]], dtype="complex64") - self.assertRaises(TypeError, json.dumps, x) + with pytest.raises(TypeError): + json.dumps(x) djson = json.dumps(x, cls=MontyEncoder) d = json.loads(djson) assert d["@class"] == "array" @@ -398,15 +405,12 @@ def test_numpy(self): assert d["np_a"]["a"][0]["b"]["@module"] == "numpy" assert d["np_a"]["a"][0]["b"]["@class"] == "array" - self.assertEqual( - d["np_a"]["a"][0]["b"]["data"], - [[[1.0, 2.0], [3.0, 4.0]], [[1.0, 1.0], [1.0, 1.0]]], - ) + assert d["np_a"]["a"][0]["b"]["data"] == [[[1.0, 2.0], [3.0, 4.0]], [[1.0, 1.0], [1.0, 1.0]]] assert d["np_a"]["a"][0]["b"]["dtype"] == "complex64" obj = ClassContainingNumpyArray.from_dict(d) - self.assertIsInstance(obj, ClassContainingNumpyArray) - self.assertIsInstance(obj.np_a["a"][0]["b"], np.ndarray) + assert isinstance(obj, ClassContainingNumpyArray) + assert isinstance(obj.np_a["a"][0]["b"], np.ndarray) assert obj.np_a["a"][0]["b"][0][1] == 2 + 1j def test_pandas(self): @@ -418,8 +422,8 @@ def test_pandas(self): assert d["df"]["@class"] == "DataFrame" obj = ClassContainingDataFrame.from_dict(d) - self.assertIsInstance(obj, ClassContainingDataFrame) - self.assertIsInstance(obj.df, pd.DataFrame) + assert isinstance(obj, ClassContainingDataFrame) + assert isinstance(obj.df, pd.DataFrame) assert list(obj.df.a), [1 == 1] cls = ClassContainingSeries(s=pd.Series({"a": [1, 2, 3], "b": [4, 5, 6]})) @@ -430,8 +434,8 @@ def test_pandas(self): assert d["s"]["@class"] == "Series" obj = ClassContainingSeries.from_dict(d) - self.assertIsInstance(obj, ClassContainingSeries) - self.assertIsInstance(obj.s, pd.Series) + assert isinstance(obj, ClassContainingSeries) + assert isinstance(obj.s, pd.Series) assert list(obj.s.a), [1, 2 == 3] cls = ClassContainingSeries(s={"df": [pd.Series({"a": [1, 2, 3], "b": [4, 5, 6]})]}) @@ -442,8 +446,8 @@ def test_pandas(self): assert d["s"]["df"][0]["@class"] == "Series" obj = ClassContainingSeries.from_dict(d) - self.assertIsInstance(obj, ClassContainingSeries) - self.assertIsInstance(obj.s["df"][0], pd.Series) + assert isinstance(obj, ClassContainingSeries) + assert isinstance(obj.s["df"][0], pd.Series) assert list(obj.s["df"][0].a), [1, 2 == 3] def test_callable(self): @@ -468,7 +472,8 @@ def test_callable(self): MethodSerializationClass, Enum, ]: - self.assertRaises(TypeError, json.dumps, function) + with pytest.raises(TypeError): + json.dumps(function) djson = json.dumps(function, cls=MontyEncoder) d = json.loads(djson) assert "@callable" in d @@ -478,7 +483,8 @@ def test_callable(self): # test method bound to instance for function in [instance.method]: - self.assertRaises(TypeError, json.dumps, function) + with pytest.raises(TypeError): + json.dumps(function) djson = json.dumps(function, cls=MontyEncoder) d = json.loads(djson) assert "@callable" in d @@ -493,7 +499,8 @@ def test_callable(self): # test method bound to object that is not serializable for function in [MethodNonSerializationClass(1).method]: - self.assertRaises(TypeError, json.dumps, function, cls=MontyEncoder) + with pytest.raises(TypeError): + json.dumps(function, cls=MontyEncoder) # test that callable MSONable objects still get serialized as the objects # rather than as a callable @@ -502,7 +509,8 @@ def test_callable(self): def test_objectid(self): oid = ObjectId("562e8301218dcbbc3d7d91ce") - self.assertRaises(TypeError, json.dumps, oid) + with pytest.raises(TypeError): + json.dumps(oid) djson = json.dumps(oid, cls=MontyEncoder) x = json.loads(djson, cls=MontyDecoder) assert isinstance(x, ObjectId) @@ -511,13 +519,14 @@ def test_jsanitize(self): # clean_json should have no effect on None types. d = {"hello": 1, "world": None} clean = jsanitize(d) - self.assertIsNone(clean["world"]) + assert clean["world"] is None assert json.loads(json.dumps(d)) == json.loads(json.dumps(clean)) d = {"hello": GoodMSONClass(1, 2, 3)} - self.assertRaises(TypeError, json.dumps, d) + with pytest.raises(TypeError): + json.dumps(d) clean = jsanitize(d) - self.assertIsInstance(clean["hello"], str) + assert isinstance(clean["hello"], str) clean_strict = jsanitize(d, strict=True) assert clean_strict["hello"]["a"] == 1 assert clean_strict["hello"]["b"] == 2 @@ -527,9 +536,9 @@ def test_jsanitize(self): d = {"dt": datetime.datetime.now()} clean = jsanitize(d) - self.assertIsInstance(clean["dt"], str) + assert isinstance(clean["dt"], str) clean = jsanitize(d, allow_bson=True) - self.assertIsInstance(clean["dt"], datetime.datetime) + assert isinstance(clean["dt"], datetime.datetime) d = { "a": ["b", np.array([1, 2, 3])], @@ -537,13 +546,13 @@ def test_jsanitize(self): } clean = jsanitize(d) assert clean["a"], ["b", [1, 2 == 3]] - self.assertIsInstance(clean["b"], str) + assert isinstance(clean["b"], str) rnd_bin = bytes(np.random.rand(10)) d = {"a": bytes(rnd_bin)} clean = jsanitize(d, allow_bson=True) assert clean["a"] == bytes(rnd_bin) - self.assertIsInstance(clean["a"], bytes) + assert isinstance(clean["a"], bytes) p = pathlib.Path("/home/user/") clean = jsanitize(p, strict=True) @@ -592,7 +601,8 @@ def test_jsanitize(self): assert isinstance(clean["f"], str) # test that strict checking gives an error - self.assertRaises(AttributeError, jsanitize, d, strict=True) + with pytest.raises(AttributeError): + jsanitize(d, strict=True) # test that callable MSONable objects still get serialized as the objects # rather than as a callable @@ -629,10 +639,7 @@ def test_redirect(self): def test_redirect_settings_file(self): data = _load_redirect(os.path.join(test_dir, "test_settings.yaml")) - self.assertEqual( - data, - {"old_module": {"old_class": {"@class": "new_class", "@module": "new_module"}}}, - ) + assert data == {"old_module": {"old_class": {"@class": "new_class", "@module": "new_module"}}} def test_pydantic_integrations(self): from pydantic import BaseModel @@ -691,8 +698,8 @@ def test_dataclass(self): c2 = Coordinates.from_dict(d) assert d["points"][0]["x"] == 1 assert d["points"][1]["y"] == 4 - self.assertIsInstance(c2, Coordinates) - self.assertIsInstance(c2.points[0], Point) + assert isinstance(c2, Coordinates) + assert isinstance(c2.points[0], Point) s = MontyEncoder().encode(Point(1, 2)) p = MontyDecoder().decode(s) @@ -702,7 +709,7 @@ def test_dataclass(self): ndc = NestedDataClass([Point(1, 2), Point(3, 4)]) str_ = json.dumps(ndc, cls=MontyEncoder) ndc2 = json.loads(str_, cls=MontyDecoder) - self.assertIsInstance(ndc2, NestedDataClass) + assert isinstance(ndc2, NestedDataClass) if __name__ == "__main__": diff --git a/tests/test_logging.py b/tests/test_logging.py index c32b9a457..d3819b4f3 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -10,7 +10,7 @@ def add(a, b): return a + b -class FuncTest(unittest.TestCase): +class TestFunc: def test_logged(self): s = StringIO() logging.basicConfig(level=logging.DEBUG, stream=s) diff --git a/tests/test_math.py b/tests/test_math.py index 1824babeb..3c5a2c3c8 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -3,7 +3,7 @@ from monty.math import nCr, nPr -class FuncTest(unittest.TestCase): +class TestFunc: def test_nCr(self): assert nCr(4, 2) == 6 diff --git a/tests/test_operator.py b/tests/test_operator.py index a186c3d2c..014d41ea6 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -3,7 +3,7 @@ from monty.operator import operator_from_str -class OperatorTestCase(unittest.TestCase): +class TestOperator: def test_something(self): assert operator_from_str("==")(1, 1) and operator_from_str("+")(1, 1) == 2 diff --git a/tests/test_os.py b/tests/test_os.py index 37698a626..5928297a6 100644 --- a/tests/test_os.py +++ b/tests/test_os.py @@ -1,13 +1,15 @@ import os import unittest +import pytest + from monty.os import cd, makedirs_p from monty.os.path import find_exts, zpath test_dir = os.path.join(os.path.dirname(__file__), "test_files") -class PathTest(unittest.TestCase): +class TestPath: def test_zpath(self): fullzpath = zpath(os.path.join(test_dir, "myfile_gz")) assert os.path.join(test_dir, "myfile_gz.gz") == fullzpath @@ -19,7 +21,7 @@ def test_find_exts(self): assert len(find_exts(os.path.dirname(__file__), "bz2", include_dirs="test_files")) == 2 -class CdTest(unittest.TestCase): +class TestCd: def test_cd(self): with cd(test_dir): assert os.path.exists("empty_file.txt") @@ -31,17 +33,18 @@ def test_cd_exception(self): assert not os.path.exists("empty_file.txt") -class Makedirs_pTest(unittest.TestCase): - def setUp(self): +class TestMakedirs_p: + def setup_method(self): self.test_dir_path = os.path.join(test_dir, "test_dir") def test_makedirs_p(self): makedirs_p(self.test_dir_path) assert os.path.exists(self.test_dir_path) makedirs_p(self.test_dir_path) - self.assertRaises(OSError, makedirs_p, os.path.join(test_dir, "myfile_txt")) + with pytest.raises(OSError): + makedirs_p(os.path.join(test_dir, "myfile_txt")) - def tearDown(self): + def teardown_method(self): os.rmdir(self.test_dir_path) diff --git a/tests/test_pprint.py b/tests/test_pprint.py index 4f2afff76..3e9febdcf 100644 --- a/tests/test_pprint.py +++ b/tests/test_pprint.py @@ -3,13 +3,13 @@ from monty.pprint import draw_tree, pprint_table -class PprintTableTest(unittest.TestCase): +class TestPprintTable: def test_print(self): table = [["one", "two"], ["1", "2"]] pprint_table(table) -class DrawTreeTest(unittest.TestCase): +class TestDrawTree: def test_draw_tree(self): class Node: def __init__(self, name, children): diff --git a/tests/test_re.py b/tests/test_re.py index 4611b08ab..3d038cf73 100644 --- a/tests/test_re.py +++ b/tests/test_re.py @@ -6,7 +6,7 @@ test_dir = os.path.join(os.path.dirname(__file__), "test_files") -class RegrepTest(unittest.TestCase): +class TestRegrep: def test_regrep(self): """ We are making sure a file containing line numbers is read in reverse diff --git a/tests/test_serialization.py b/tests/test_serialization.py index ab8f7b83c..709f87c77 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -3,6 +3,8 @@ import os import unittest +import pytest + try: import msgpack except ImportError: @@ -12,7 +14,7 @@ from monty.tempfile import ScratchDir -class SerialTest(unittest.TestCase): +class TestSerial: @classmethod def tearDownClass(cls): # Cleans up test files if a test fails @@ -36,11 +38,7 @@ def test_dumpfn_loadfn(self): fn = f"monte_test.{ext}" dumpfn(d, fn) d2 = loadfn(fn) - self.assertEqual( - d, - d2, - msg=f"Test file with extension {ext} did not parse correctly", - ) + assert d == d2, f"Test file with extension {ext} did not parse correctly" os.remove(fn) # Test custom kwarg configuration @@ -55,15 +53,15 @@ def test_dumpfn_loadfn(self): # Check if fmt override works. dumpfn(d, "monte_test.json", fmt="yaml") - with self.assertRaises(json.decoder.JSONDecodeError): - d2 = loadfn("monte_test.json") + with pytest.raises(json.decoder.JSONDecodeError): + loadfn("monte_test.json") d2 = loadfn("monte_test.json", fmt="yaml") assert d == d2 os.remove("monte_test.json") - with self.assertRaises(TypeError): + with pytest.raises(TypeError): dumpfn(d, "monte_test.txt", fmt="garbage") - with self.assertRaises(TypeError): + with pytest.raises(TypeError): loadfn("monte_test.txt", fmt="garbage") @unittest.skipIf(msgpack is None, "msgpack-python not installed.") diff --git a/tests/test_shutil.py b/tests/test_shutil.py index 4695f0a79..43820e069 100644 --- a/tests/test_shutil.py +++ b/tests/test_shutil.py @@ -6,6 +6,8 @@ from gzip import GzipFile from pathlib import Path +import pytest + from monty.shutil import ( compress_dir, compress_file, @@ -19,8 +21,8 @@ test_dir = os.path.join(os.path.dirname(__file__), "test_files") -class CopyRTest(unittest.TestCase): - def setUp(self): +class TestCopyR: + def setup_method(self): os.mkdir(os.path.join(test_dir, "cpr_src")) with open(os.path.join(test_dir, "cpr_src", "test"), "w") as f: f.write("what") @@ -50,13 +52,13 @@ def test_pathlib(self): assert (test_path / "cpr_dst" / "test").exists() assert (test_path / "cpr_dst" / "sub" / "testr").exists() - def tearDown(self): + def teardown_method(self): shutil.rmtree(os.path.join(test_dir, "cpr_src")) shutil.rmtree(os.path.join(test_dir, "cpr_dst")) -class CompressFileDirTest(unittest.TestCase): - def setUp(self): +class TestCompressFileDir: + def setup_method(self): with open(os.path.join(test_dir, "tempfile"), "w") as f: f.write("hello world") @@ -72,19 +74,20 @@ def test_compress_and_decompress_file(self): with open(fname) as f: txt = f.read() assert txt == "hello world" - self.assertRaises(ValueError, compress_file, "whatever", "badformat") + with pytest.raises(ValueError): + compress_file("whatever", "badformat") # test decompress non-existent/non-compressed file - self.assertIsNone(decompress_file("non-existent")) - self.assertIsNone(decompress_file("non-existent.gz")) - self.assertIsNone(decompress_file("non-existent.bz2")) + assert decompress_file("non-existent") is None + assert decompress_file("non-existent.gz") is None + assert decompress_file("non-existent.bz2") is None - def tearDown(self): + def teardown_method(self): os.remove(os.path.join(test_dir, "tempfile")) -class GzipDirTest(unittest.TestCase): - def setUp(self): +class TestGzipDir: + def setup_method(self): os.mkdir(os.path.join(test_dir, "gzip_dir")) with open(os.path.join(test_dir, "gzip_dir", "tempfile"), "w") as f: f.write("what") @@ -101,7 +104,7 @@ def test_gzip(self): with GzipFile(f"{full_f}.gz") as g: assert g.readline().decode("utf-8") == "what" - self.assertAlmostEqual(os.path.getmtime(f"{full_f}.gz"), self.mtime, 4) + assert os.path.getmtime(f"{full_f}.gz") == pytest.approx(self.mtime, 4) def test_handle_sub_dirs(self): sub_dir = os.path.join(test_dir, "gzip_dir", "sub_dir") @@ -118,11 +121,11 @@ def test_handle_sub_dirs(self): with GzipFile(f"{sub_file}.gz") as g: assert g.readline().decode("utf-8") == "anotherwhat" - def tearDown(self): + def teardown_method(self): shutil.rmtree(os.path.join(test_dir, "gzip_dir")) -class RemoveTest(unittest.TestCase): +class TestRemove: @unittest.skipIf(platform.system() == "Windows", "Skip on windows") def test_remove_file(self): tempdir = tempfile.mkdtemp(dir=test_dir) diff --git a/tests/test_string.py b/tests/test_string.py index 841c44973..cff05a7dc 100644 --- a/tests/test_string.py +++ b/tests/test_string.py @@ -9,7 +9,7 @@ from monty.string import remove_non_ascii, unicode2str -class FuncTest(unittest.TestCase): +class TestFunc: def test_remove_non_ascii(self): s = "".join(chr(random.randint(0, 127)) for i in range(10)) s += "".join(chr(random.randint(128, 150)) for i in range(10)) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 888284b05..c7a682fb6 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -3,7 +3,7 @@ from monty.subprocess import Command -class CommandTest(unittest.TestCase): +class TestCommand: def test_command(self): """Test Command class""" sleep05 = Command("sleep 0.5") diff --git a/tests/test_tempfile.py b/tests/test_tempfile.py index cac6e1815..4fe078a2a 100644 --- a/tests/test_tempfile.py +++ b/tests/test_tempfile.py @@ -7,8 +7,8 @@ test_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_files") -class ScratchDirTest(unittest.TestCase): - def setUp(self): +class TestScratchDir: + def setup_method(self): self.cwd = os.getcwd() os.chdir(test_dir) self.scratch_root = os.path.join(test_dir, "..", "..", "tempscratch") @@ -140,7 +140,7 @@ def test_bad_root(self): with ScratchDir("bad_groot") as d: assert d == test_dir - def tearDown(self): + def teardown_method(self): os.chdir(self.cwd) shutil.rmtree(self.scratch_root) diff --git a/tests/test_termcolor.py b/tests/test_termcolor.py index 9b9b7daec..a57f0c55c 100644 --- a/tests/test_termcolor.py +++ b/tests/test_termcolor.py @@ -15,7 +15,7 @@ ) -class FuncTest(unittest.TestCase): +class TestFunc: def test_remove_non_ascii(self): enable(True) print("Current terminal type: %s" % os.getenv("TERM"))