From 6b28e820c52570fcfb49994e7474cb75824c90b6 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 5 Jun 2024 10:28:03 -0400 Subject: [PATCH] ruff unignore PLR, only ignore specific PLR0912 PLR0913 PLR0915 --- chgnet/graph/graph.py | 10 +++++++--- chgnet/model/dynamics.py | 2 +- pyproject.toml | 15 +++++++++++++-- site/make_docs.py | 3 ++- tests/test_graph.py | 2 +- 5 files changed, 24 insertions(+), 8 deletions(-) diff --git a/chgnet/graph/graph.py b/chgnet/graph/graph.py index d14291e1..ba619a79 100644 --- a/chgnet/graph/graph.py +++ b/chgnet/graph/graph.py @@ -127,7 +127,9 @@ def __init__(self, nodes: list[Node]) -> None: self.undirected_edges: dict[frozenset[int], list[UndirectedEdge]] = {} self.undirected_edges_list: list[UndirectedEdge] = [] - def add_edge(self, center_index, neighbor_index, image, distance) -> None: + def add_edge( + self, center_index, neighbor_index, image, distance, dist_tol: float = 1e-6 + ) -> None: """Add an directed edge to the graph. Args: @@ -135,6 +137,8 @@ def add_edge(self, center_index, neighbor_index, image, distance) -> None: neighbor_index (int): neighbor node index image (np.array): the periodic cell image the neighbor is from distance (float): distance between center and neighbor. + dist_tol (float): tolerance for distance comparison between edges. + Default = 1e-6 """ # Create directed_edge (DE) index using the length of added DEs directed_edge_index = len(self.directed_edges_list) @@ -173,7 +177,7 @@ def add_edge(self, center_index, neighbor_index, image, distance) -> None: # different image and distance (this is possible consider periodicity) for undirected_edge in self.undirected_edges[tmp]: if ( - abs(undirected_edge.info["distance"] - distance) < 1e-6 + abs(undirected_edge.info["distance"] - distance) < dist_tol and len(undirected_edge.info["directed_edge_index"]) == 1 ): # There is an undirected edge with similar length and only one of @@ -286,7 +290,7 @@ def line_graph_adjacency_list(self, cutoff) -> tuple[list[list[int]], list[int]] # if encountered exception, # it means after Atom_Graph creation, the UDE has only 1 DE associated # This exception is not encountered from the develop team's experience - if len(u_edge.info["directed_edge_index"]) != 2: + if len(u_edge.info["directed_edge_index"]) != 2: # noqa: PLR2004 raise ValueError( "Did not find 2 Directed_edges !!!" f"undirected edge {u_edge} has:" diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 31a3ad33..29c915af 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -262,7 +262,7 @@ def relax( dict[str, Structure | TrajectoryObserver]: A dictionary with 'final_structure' and 'trajectory'. """ - import ase.filters as filters + from ase import filters from ase.filters import Filter valid_filter_names = [ diff --git a/pyproject.toml b/pyproject.toml index c409f492..db68c739 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,9 @@ ignore = [ "ERA001", # found commented out code "ISC001", "NPY002", # TODO replace legacy np.random.seed - "PLR", # pylint refactor + "PLR0912", # too many branches + "PLR0913", # too many args in function def + "PLR0915", # too many statements "PLW2901", # Outer for loop variable overwritten by inner assignment target "PT006", # pytest-parametrize-names-wrong-type "PTH", # prefer Path to os.path @@ -93,7 +95,16 @@ docstring-code-format = true [tool.ruff.lint.per-file-ignores] "site/*" = ["INP001", "S602"] -"tests/*" = ["ANN201", "D100", "D103", "FBT001", "FBT002", "INP001", "S101"] +"tests/*" = [ + "ANN201", + "D100", + "D103", + "FBT001", + "FBT002", + "INP001", + "PLR2004", + "S101", +] # E402 Module level import not at top of file "examples/*" = ["E402", "I002", "INP001", "N816", "S101", "T201"] "chgnet/**/*" = ["T201"] diff --git a/site/make_docs.py b/site/make_docs.py index de9b1d2b..79b12b26 100644 --- a/site/make_docs.py +++ b/site/make_docs.py @@ -33,7 +33,8 @@ # remove all files with less than 20 lines # these correspond to mostly empty __init__.py files - if markdown.count("\n") < 20: + min_line_cnt = 20 + if markdown.count("\n") < min_line_cnt: os.remove(path) continue diff --git a/tests/test_graph.py b/tests/test_graph.py index fa9604e4..da4b3c25 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -100,7 +100,7 @@ def test_directed_edge() -> None: info = {"image": np.zeros(3), "distance": 1} edge = DirectedEdge([0, 1], index=0, info=info) undirected = edge.make_undirected(index=0, info=info) - assert edge == edge + assert edge == edge # noqa: PLR0124 assert edge == undirected assert edge.nodes == [0, 1] assert edge.index == 0