Skip to content

Commit

Permalink
chore: Remove cycle handling flat graph (synnada-ai#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada authored Dec 11, 2024
1 parent 7e41524 commit d4bfad7
Showing 1 changed file with 1 addition and 80 deletions.
81 changes: 1 addition & 80 deletions mithril/framework/physical/flat_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,24 +126,18 @@ def all_keys(self):
| set(self.output_dict.values())
)

def add_value(self, model: PrimitiveModel, keys: dict[str, str]) -> bool:
def add_value(self, model: PrimitiveModel, keys: dict[str, str]):
output_key = keys[PrimitiveModel.output_key]
keys = {
key: self._temp_connection_info.get(value, value)
for key, value in keys.items()
}

# Check if the model output conn is already in the connections
cycle_occured = output_key in self.connections

# Buffer primitives are not added to the graph
if isinstance(model, Buffer):
self.update_output_keys(keys["output"], keys["input"])
self._temp_connection_info[keys["output"]] = keys["input"]

if cycle_occured:
self.handle_cycle(self.connections[keys["input"]], output_key)

if keys["input"] in self.connections:
self._update_connection_keys(self.connections[keys["input"]])

Expand All @@ -153,12 +147,6 @@ def add_value(self, model: PrimitiveModel, keys: dict[str, str]) -> bool:
# Create output connection of the new Node.
out_conn = Connection(node, output_key, [], [], set())

if cycle_occured:
# Model addition order is wrong, therefore a cycle occured.
# Output of this model is created by another model input, remove
# that connection and recreate connection as this model output.
self.handle_cycle(out_conn, output_key)

self.connections[output_key] = out_conn
node.connections[PrimitiveModel.output_key] = out_conn

Expand Down Expand Up @@ -189,31 +177,6 @@ def add_value(self, model: PrimitiveModel, keys: dict[str, str]) -> bool:

self._update_all_source_keys()
self._update_all_target_keys()
return cycle_occured

def handle_cycle(self, new_conn: Connection, key_name: str):
# Loop through all nodes and check if output connection is in the connections
for node in self.nodes.values():
output_conn = self.connections.get(key_name) # TODO: Why this is here?
if output_conn is None:
continue

if output_conn in node.connections.values():
new_conn.connections.add(node.connections[PrimitiveModel.output_key])

# Replace the connection with the new output connection
key = next(
key
for key, value in node.connections.items()
if value == output_conn
)
node.connections[key] = new_conn

# Remove old connection
self._remove_conn(output_conn)

# Update source and target keys
self._update_connection_keys(node.connections["output"])

def collapse_model_keys(self, output_key: str, new_reference_key: str):
# If a model removed, the models that uses the output of the removed model
Expand Down Expand Up @@ -245,48 +208,6 @@ def all_target_keys(self) -> set[str]:
def all_source_keys(self) -> set[str]:
return self._all_source_keys

def _reorder_connections(self):
queue = list(self._input_keys)
visited_keys: list[str] = []

while queue:
key = queue.pop()
if key in visited_keys:
continue

visited_keys.append(key)
# TODO: Cyclic extension bug is solved temporarily
# (see test_cyclic_extension in test_scripts.py)
# find a better solution for this.
new_target_keys = self.get_target_keys(key)
for target_key in new_target_keys:
source_keys = self.get_source_keys(target_key, True)

node = self.connections[target_key].node
local_keys = []
if node is not None:
local_keys = list(node.connections.keys())

if "cache" in local_keys:
source_keys.pop(local_keys.index("cache") - 1)
if set(source_keys).issubset(visited_keys):
queue.append(target_key)

for key in self._input_keys:
visited_keys.remove(key)

nodes: dict[PrimitiveModel, Node] = {}
for key in visited_keys:
model = self.get_model(key)
if model is None:
continue
nodes[model] = self.nodes[model]

# If graph is not completed do not reorder nodes!
if len(nodes) == len(self.nodes):
self.nodes = nodes
self._update_topological_order()

def _update_topological_order(self):
self._topological_order = [
node.connections[PrimitiveModel.output_key].key
Expand Down

0 comments on commit d4bfad7

Please sign in to comment.