Skip to content

Commit

Permalink
Add support for is_spill annotations for inputs
Browse files Browse the repository at this point in the history
Previously, SLOTHY would fail when a global input was spilled via
an instruction annotated with `// @slothy:is_spill`.

This commit fixes this by ad-hoc creating a VirtualInstruction node
for the input in this case -- as is done during ordinary DFG parsing
already.
  • Loading branch information
hanno-becker authored and dop-amin committed Jan 3, 2025
1 parent 60b8cb8 commit 8c6fca2
Showing 1 changed file with 33 additions and 33 deletions.
66 changes: 33 additions & 33 deletions slothy/core/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,14 +874,41 @@ def _add_node_from_candidates(self, candidates, sourceline):
# Add the single valid candidate parsing to the CFG
self._add_node(valid_candidates[0])

def _find_source_single(self,ty,name):
self.logger.debug("Finding source of register %s of type %s", name, ty)

# Check if the inputs have been produced by the data flow graph
if name not in self.reg_state:
# If not, treat them as a global input
self.logger.debug("-> %s is a global input", name)
# Create a virtual instruction producing the output add that first
# Since the virtual instruction does not have any inputs, there is
# no risk of infinite recursion here
self._add_node(VirtualInputInstruction(name, ty))
# Fall through

# At this point, the source _must_ be produced by an instruction in the graph
assert name in self.reg_state

# Return a reference to the node producing the input
origin = self.reg_state[name]
self.logger.debug(f"-> {name} has been produced by {origin}")

if origin.get_type() != ty:
warnstr = f"Type mismatch: Output {name} of {type(origin.src.inst).__name__} has "\
f"type {origin.get_type()} but {type(s).__name__} expects it to have type {ty}"
self.logger.debug(warnstr)
raise DataFlowGraphException(warnstr)

return self.reg_state[name]

def _process_restore_instruction(self, reg, loc):
assert loc in self.spilled_reg_state.keys()
self.reg_state[reg] = self.spilled_reg_state.pop(loc)

def _process_spill_instruction(self, reg, loc):
def _process_spill_instruction(self, reg, loc, ty):
assert loc not in self.spilled_reg_state.keys()
assert reg in self.reg_state.keys()
self.spilled_reg_state[loc] = self.reg_state.pop(reg)
self.spilled_reg_state[loc] = self._find_source_single(ty, reg)

def _add_node(self, s):
"""Add a node to the data flow graph
Expand All @@ -901,7 +928,8 @@ def _add_node(self, s):
self.logger.debug("Handling spill instruction: %s", s)
reg = s.args_in[0]
loc = s.args_out[0]
self._process_spill_instruction(reg, loc)
ty = s.arg_types_in[0]
self._process_spill_instruction(reg, loc, ty)
return
if self.config._absorb_spills is True and \
s.source_line.tags.get("is_restore", False) is True:
Expand All @@ -918,36 +946,8 @@ def _add_node(self, s):
elif isinstance(s, VirtualOutputInstruction):
self.logger.debug("Adding virtual instruction for output %s", s.orig_reg)

def find_source_single(ty,name):
self.logger.debug("Finding source of register %s of type %s", name, ty)

# Check if the inputs have been produced by the data flow graph
if name not in self.reg_state:
# If not, treat them as a global input
self.logger.debug("-> %s is a global input", name)
# Create a virtual instruction producing the output add that first
# Since the virtual instruction does not have any inputs, there is
# no risk of infinite recursion here
self._add_node(VirtualInputInstruction(name, ty))
# Fall through

# At this point, the source _must_ be produced by an instruction in the graph
assert name in self.reg_state

# Return a reference to the node producing the input
origin = self.reg_state[name]
self.logger.debug(f"-> {name} has been produced by {origin}")

if origin.get_type() != ty:
warnstr = f"Type mismatch: Output {name} of {type(origin.src.inst).__name__} has "\
f"type {origin.get_type()} but {type(s).__name__} expects it to have type {ty}"
self.logger.debug(warnstr)
raise DataFlowGraphException(warnstr)

return self.reg_state[name]

def find_sources(types,names):
return [ find_source_single(t,n) for t,n in zip(types,names) ]
return [ self._find_source_single(t,n) for t,n in zip(types,names) ]

# Lookup computation nodes for inputs
src_in = find_sources(s.arg_types_in, s.args_in)
Expand Down

0 comments on commit 8c6fca2

Please sign in to comment.