Skip to content

Commit

Permalink
Add missing methods in PyGraph and PyDiGraph stubs (#967)
Browse files Browse the repository at this point in the history
* Add missing methods in PyGraph and PyDiGraph stubs

* Remove TODOs

* Add new methods

* Add comments from review

---------

Co-authored-by: Matthew Treinish <[email protected]>
  • Loading branch information
IvanIsCoding and mtreinish authored Jan 13, 2024
1 parent 6d82a11 commit f75291d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
28 changes: 26 additions & 2 deletions rustworkx/digraph.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,19 @@

import numpy as np
from .iterators import *
from .graph import PyGraph

from typing import Any, Callable, Generic, TypeVar, Sequence
from typing import Any, Callable, Generic, TypeVar, Sequence, TYPE_CHECKING

if TYPE_CHECKING:
from .graph import PyGraph

__all__ = ["PyDiGraph"]

S = TypeVar("S")
T = TypeVar("T")

class PyDiGraph(Generic[S, T]):
attrs: Any
check_cycle: bool = ...
multigraph: bool = ...
def __init__(
Expand All @@ -44,6 +47,8 @@ class PyDiGraph(Generic[S, T]):
def add_parent(self, child: int, obj: S, edge: T, /) -> int: ...
def adj(self, node: int, /) -> dict[int, T]: ...
def adj_direction(self, node: int, direction: bool, /) -> dict[int, T]: ...
def clear(self) -> None: ...
def clear_edges(self) -> None: ...
def compose(
self,
other: PyDiGraph[S, T],
Expand All @@ -52,11 +57,20 @@ class PyDiGraph(Generic[S, T]):
node_map_func: Callable[[S], int] | None = ...,
edge_map_func: Callable[[T], int] | None = ...,
) -> dict[int, int]: ...
def contract_nodes(
self,
nodes: Sequence[int],
obj: S,
/,
check_cycle: bool | None = ...,
weight_combo_fn: Callable[[T, T], T] | None = ...,
) -> int: ...
def copy(self) -> PyDiGraph[S, T]: ...
def edge_index_map(self) -> EdgeIndexMap[T]: ...
def edge_indices(self) -> EdgeIndices: ...
def edge_list(self) -> EdgeList: ...
def edges(self) -> list[T]: ...
def edge_subgraph(self, edge_list: Sequence[tuple[int, int]], /) -> PyDiGraph[S, T]: ...
def extend_from_edge_list(
self: PyDiGraph[S | None, T | None], edge_list: Sequence[tuple[int, int]], /
) -> None: ...
Expand All @@ -65,6 +79,8 @@ class PyDiGraph(Generic[S, T]):
edge_list: Sequence[tuple[int, int, T]],
/,
) -> None: ...
def filter_edges(self, filter_function: Callable[[T], bool]) -> EdgeIndices: ...
def filter_nodes(self, filter_function: Callable[[S], bool]) -> NodeIndices: ...
def find_adjacent_node_by_edge(self, node: int, predicate: Callable[[T], bool], /) -> S: ...
def find_node_by_weight(
self,
Expand All @@ -74,6 +90,7 @@ class PyDiGraph(Generic[S, T]):
def find_predecessors_by_edge(
self, node: int, filter_fn: Callable[[T], bool], /
) -> list[S]: ...
def find_predecessor_node_by_edge(self, node: int, predicate: Callable[[T], bool], /) -> S: ...
def find_successors_by_edge(self, node: int, filter_fn: Callable[[T], bool], /) -> list[S]: ...
@staticmethod
def from_adjacency_matrix(
Expand All @@ -86,17 +103,24 @@ class PyDiGraph(Generic[S, T]):
def get_all_edge_data(self, node_a: int, node_b: int, /) -> list[T]: ...
def get_edge_data(self, node_a: int, node_b: int, /) -> T: ...
def get_node_data(self, node: int, /) -> S: ...
def get_edge_data_by_index(self, edge_index: int, /) -> T: ...
def get_edge_endpoints_by_index(self, edge_index: int, /) -> tuple[int, int]: ...
def has_edge(self, node_a: int, node_b: int, /) -> bool: ...
def has_parallel_edges(self) -> bool: ...
def in_degree(self, node: int, /) -> int: ...
def in_edges(self, node: int, /) -> WeightedEdgeList[T]: ...
def incident_edge_index_map(self, node: int, /, all_edges: bool = ...) -> EdgeIndexMap: ...
def incident_edges(self, node: int, /, all_edges: bool = ...) -> EdgeIndices: ...
def insert_node_on_in_edges(self, node: int, ref_node: int, /) -> None: ...
def insert_node_on_in_edges_multiple(self, node: int, ref_nodes: Sequence[int], /) -> None: ...
def insert_node_on_out_edges(self, node: int, ref_node: int, /) -> None: ...
def insert_node_on_out_edges_multiple(self, node: int, ref_nodes: Sequence[int], /) -> None: ...
def is_symmetric(self) -> bool: ...
def make_symmetric(self, edge_payload_fn: Callable[[T], T] | None = ...) -> None: ...
def merge_nodes(self, u: int, v: int, /) -> None: ...
def neighbors(self, node: int, /) -> NodeIndices: ...
def node_indexes(self) -> NodeIndices: ...
def node_indices(self) -> NodeIndices: ...
def nodes(self) -> list[S]: ...
def num_edges(self) -> int: ...
def num_nodes(self) -> int: ...
Expand Down
31 changes: 31 additions & 0 deletions rustworkx/graph.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@ from typing import (
Generic,
TypeVar,
Sequence,
TYPE_CHECKING,
)

if TYPE_CHECKING:
from .digraph import PyDiGraph

__all__ = ["PyGraph"]

S = TypeVar("S")
T = TypeVar("T")

class PyGraph(Generic[S, T]):
attrs: Any
multigraph: bool = ...
def __init__(self, /, multigraph: bool = ...) -> None: ...
def add_edge(self, node_a: int, node_b: int, edge: T, /) -> int: ...
Expand All @@ -40,6 +45,8 @@ class PyGraph(Generic[S, T]):
def add_node(self, obj: S, /) -> int: ...
def add_nodes_from(self, obj_list: Sequence[S], /) -> NodeIndices: ...
def adj(self, node: int, /) -> dict[int, T]: ...
def clear(self) -> None: ...
def clear_edges(self) -> None: ...
def compose(
self,
other: PyGraph[S, T],
Expand All @@ -48,12 +55,20 @@ class PyGraph(Generic[S, T]):
node_map_func: Callable[[S], int] | None = ...,
edge_map_func: Callable[[T], int] | None = ...,
) -> dict[int, int]: ...
def contract_nodes(
self,
nodes: Sequence[int],
obj: S,
/,
weight_combo_fn: Callable[[T, T], T] | None = ...,
) -> int: ...
def copy(self) -> PyGraph[S, T]: ...
def degree(self, node: int, /) -> int: ...
def edge_index_map(self) -> EdgeIndexMap[T]: ...
def edge_indices(self) -> EdgeIndices: ...
def edge_list(self) -> EdgeList: ...
def edges(self) -> list[T]: ...
def edge_subgraph(self, edge_list: Sequence[tuple[int, int]], /) -> PyGraph[S, T]: ...
def extend_from_edge_list(
self: PyGraph[S | None, T | None], edge_list: Sequence[tuple[int, int]], /
) -> None: ...
Expand All @@ -62,6 +77,13 @@ class PyGraph(Generic[S, T]):
edge_list: Sequence[tuple[int, int, T]],
/,
) -> None: ...
def filter_edges(self, filter_function: Callable[[T], bool]) -> EdgeIndices: ...
def filter_nodes(self, filter_function: Callable[[S], bool]) -> NodeIndices: ...
def find_node_by_weight(
self,
obj: Callable[[S], bool],
/,
) -> int | None: ...
@staticmethod
def from_adjacency_matrix(
matrix: np.ndarray, /, null_value: float = ...
Expand All @@ -72,13 +94,21 @@ class PyGraph(Generic[S, T]):
) -> PyGraph[int, complex]: ...
def get_all_edge_data(self, node_a: int, node_b: int, /) -> list[T]: ...
def get_edge_data(self, node_a: int, node_b: int, /) -> T: ...
def get_edge_data_by_index(self, edge_index: int, /) -> T: ...
def get_edge_endpoints_by_index(self, edge_index: int, /) -> tuple[int, int]: ...
def get_node_data(self, node: int, /) -> S: ...
def has_edge(self, node_a: int, node_b: int, /) -> bool: ...
def has_parallel_edges(self) -> bool: ...
def in_edges(self, node: int, /) -> WeightedEdgeList[T]: ...
def incident_edge_index_map(self, node: int, /) -> EdgeIndexMap: ...
def incident_edges(self, node: int, /) -> EdgeIndices: ...
def neighbors(self, node: int, /) -> NodeIndices: ...
def node_indexes(self) -> NodeIndices: ...
def node_indices(self) -> NodeIndices: ...
def nodes(self) -> list[S]: ...
def num_edges(self) -> int: ...
def num_nodes(self) -> int: ...
def out_edges(self, node: int, /) -> WeightedEdgeList[T]: ...
@staticmethod
def read_edge_list(
path: str,
Expand Down Expand Up @@ -110,6 +140,7 @@ class PyGraph(Generic[S, T]):
graph_attr: dict[str, str] | None = ...,
filename: str | None = ...,
) -> str | None: ...
def to_directed(self) -> PyDiGraph[S, T]: ...
def update_edge(
self,
source: int,
Expand Down

0 comments on commit f75291d

Please sign in to comment.