Skip to content

Commit

Permalink
almost
Browse files Browse the repository at this point in the history
  • Loading branch information
harisbal committed Nov 10, 2023
1 parent 933d0ac commit 7df22ce
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 29 deletions.
92 changes: 67 additions & 25 deletions networkx/algorithms/simple_paths.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from heapq import heappop, heappush
from itertools import count
from itertools import count, product

import networkx as nx
from networkx.algorithms.shortest_paths.weighted import _weight_function
Expand Down Expand Up @@ -255,8 +255,11 @@ def all_simple_paths(G, source, target, cutoff=None):
return _empty_generator()
if cutoff is None:
cutoff = len(G) - 1
if cutoff < 1:
return _empty_generator()

if isinstance(cutoff, int | float):
if cutoff < 1:
return _empty_generator()

if G.is_multigraph():
return _all_simple_paths_multigraph(G, source, targets, cutoff)
else:
Expand All @@ -267,23 +270,39 @@ def _empty_generator():
yield from ()


def _is_path_under_cutoff(G, path, cutoff):
def _path_under_cutoff(G, path, cutoff):
if isinstance(cutoff, int | float):
if len(path) <= cutoff:
return True
cutoff = {None: cutoff}

for w in cutoff:
cost = 0
cutoffw = cutoff[w]

if w is None:
if len(path) - 1 > cutoffw:
return False
else:
return False
cost = nx.path_weight(G, path, w)
if cost >= cutoffw:
return False

return True


def _edge_path_under_cutoff(G, edge_path, cutoff):
if isinstance(cutoff, int | float):
cutoff = {None: cutoff}

for w in cutoff:
cost = 0
cutoffw = cutoff[w]
for u, v in pairwise(path):
if G.is_multigraph():
cost += min(k.get(w, 1) for k in G[u][v].values())
else:
cost += G[u][v].get(w, 1)

if cost > cutoffw:
if w is None:
if len(edge_path) > cutoffw:
return False
else:
cost = nx.edge_path_weight(G, edge_path, w)
if cost >= cutoffw:
return False

return True
Expand All @@ -298,19 +317,23 @@ def _all_simple_paths_graph(G, source, targets, cutoff):
if child is None:
stack.pop()
visited.popitem()
elif _is_path_under_cutoff(G, visited, cutoff):
elif _path_under_cutoff(G, visited, cutoff):
if child in visited:
continue
if child in targets:
yield list(visited) + [child]
path = list(visited) + [child]
if _path_under_cutoff(G, path, cutoff):
yield path
visited[child] = True
if targets - set(visited.keys()): # expand stack until find all targets
stack.append(iter(G[child]))
else:
visited.popitem() # maybe other ways to child
else: # len(visited) == cutoff:
for target in (targets & (set(children) | {child})) - set(visited.keys()):
yield list(visited) + [target]
path = list(visited) + [target]
if _path_under_cutoff(G, path, cutoff):
yield path
stack.pop()
visited.popitem()

Expand All @@ -324,7 +347,7 @@ def _all_simple_paths_multigraph(G, source, targets, cutoff):
if child is None:
stack.pop()
visited.popitem()
elif len(visited) < cutoff:
elif _path_under_cutoff(G, list(visited) + [child], cutoff):
if child in visited:
continue
if child in targets:
Expand All @@ -334,11 +357,16 @@ def _all_simple_paths_multigraph(G, source, targets, cutoff):
stack.append((v for u, v in G.edges(child)))
else:
visited.popitem()
else: # len(visited) == cutoff:
else: # len(visited) >= cutoff:
for target in targets - set(visited.keys()):
count = ([child] + list(children)).count(target)
for i in range(count):
yield list(visited) + [target]
path = list(visited) + [target]
edges = [
[(u, v, k) for k in G[u][v]] for u, v in nx.utils.pairwise(path)
]
for p in product(*edges):
if _edge_path_under_cutoff(G, p, cutoff):
yield [u for u, v, k in p] + [p[-1][1]]

stack.pop()
visited.popitem()

Expand Down Expand Up @@ -439,8 +467,18 @@ def all_simple_edge_paths(G, source, target, cutoff=None):


def _all_simple_edge_paths_multigraph(G, source, targets, cutoff):
if not cutoff or cutoff < 1:
if not cutoff:
return []

if isinstance(cutoff, int | float):
if cutoff < 1:
return []
elif isinstance(cutoff, dict):
if cutoff.get(None, 1) < 1:
return []
else:
raise TypeError("cutoff should either be int or dict")

visited = [source]
stack = [iter(G.edges(source, keys=True))]

Expand All @@ -450,16 +488,20 @@ def _all_simple_edge_paths_multigraph(G, source, targets, cutoff):
if child is None:
stack.pop()
visited.pop()
elif len(visited) < cutoff:
elif _path_under_cutoff(G, visited, cutoff):
if child[1] in targets:
yield visited[1:] + [child]
path = visited[1:] + [child]
if _path_under_cutoff(G, path, cutoff):
yield path
elif child[1] not in [v[0] for v in visited[1:]]:
visited.append(child)
stack.append(iter(G.edges(child[1], keys=True)))
else: # len(visited) == cutoff:
for u, v, k in [child] + list(children):
if v in targets:
yield visited[1:] + [(u, v, k)]
path = visited[1:] + [(u, v, k)]
if _path_under_cutoff(G, path, cutoff):
yield path
stack.pop()
visited.pop()

Expand Down
14 changes: 10 additions & 4 deletions networkx/algorithms/tests/test_simple_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,16 @@ def test_all_simple_paths_weighted_multigraph_with_multiple_cutoffs():

paths = list(nx.all_simple_paths(G, 0, 4, cutoff={None: 3, c: 20}))

assert len(paths) == 4
assert paths[0] == [0, 1, 4]
assert paths[-1] == [0, 4]
assert max(len(p) for p in paths) == 3
assert {tuple(p) for p in paths} == {
(0, 1, 2, 4),
(0, 1, 3, 4),
(0, 1, 4),
(0, 2, 1, 4),
(0, 2, 4),
(0, 3, 1, 4),
(0, 3, 4),
(0, 4),
}


def test_all_simple_paths_directed():
Expand Down
41 changes: 41 additions & 0 deletions networkx/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"nodes_with_selfloops",
"number_of_selfloops",
"path_weight",
"edge_path_weight",
"is_path",
]

Expand Down Expand Up @@ -1311,3 +1312,43 @@ def path_weight(G, path, weight):
else:
cost += G[node][nbr][weight]
return cost


def edge_path_weight(G, edge_path, weight):
"""Returns total cost associated with specified path and weight
Parameters
----------
G : graph
A NetworkX graph.
path: list
A list of edges which defines the path to traverse
weight: string
A string indicating which edge attribute to use for path cost
Returns
-------
cost: int or float
An integer or a float representing the total cost with respect to the
specified weight of the specified path
Raises
------
NetworkXNoPath
If the specified edge does not exist.
"""
multigraph = G.is_multigraph()
cost = 0

path = [edge_path[0][0], edge_path[-1][1]]
if not nx.is_path(G, path):
raise nx.NetworkXNoPath("path does not exist")

if multigraph:
for u, v, k in edge_path:
cost += G[u][v][k][weight]
else:
cost += G[u][v][weight]
return cost
23 changes: 23 additions & 0 deletions tmp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import networkx as nx

n = 5
c = "distance"
G = nx.complete_graph(n, create_using=nx.MultiGraph)
distances = list(range(1, (n**2) + 1))
edges = G.edges(keys=True)

d = {e: {c: dist} for e, dist in zip(edges, distances)}
nx.set_edge_attributes(G, d)

paths = list(nx.all_simple_paths(G, 0, 4, cutoff={None: 3, c: 20}))

assert {tuple(p) for p in paths} == {
(0, 1, 2, 4),
(0, 1, 3, 4),
(0, 1, 4),
(0, 2, 1, 4),
(0, 2, 4),
(0, 3, 1, 4),
(0, 3, 4),
(0, 4),
}

0 comments on commit 7df22ce

Please sign in to comment.