Skip to content

Commit

Permalink
move function placement
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 10, 2024
1 parent 3b0b49f commit 2c035b3
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,6 @@ def get_target_nodes(graph: GraphModule, targets: Set[Module]) -> Set[Node]:
)


def check_assumption(graph: Graph) -> bool:
for node in graph.nodes:
for user in node.users:
if node not in user.all_input_nodes:
return False

for input_node in node.all_input_nodes:
if node not in input_node.users:
return False

if len(node.users) != len(set(node.users)) or len(node.all_input_nodes) != len(
set(node.all_input_nodes)
):
return False

return True


def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List[Node]]:
assert check_assumption(graph.graph)
target_nodes = get_target_nodes(graph, targets)
Expand Down Expand Up @@ -279,3 +261,21 @@ def trace_consumed_names(subgraphs: List[Dict[str, Any]]):
break
else:
assert False


def check_assumption(graph: Graph) -> bool:
for node in graph.nodes:
for user in node.users:
if node not in user.all_input_nodes:
return False

for input_node in node.all_input_nodes:
if node not in input_node.users:
return False

if len(node.users) != len(set(node.users)) or len(node.all_input_nodes) != len(
set(node.all_input_nodes)
):
return False

return True

0 comments on commit 2c035b3

Please sign in to comment.