Skip to content

Commit

Permalink
Fix path to network template
Browse files Browse the repository at this point in the history
  • Loading branch information
lukamac committed Jan 31, 2024
1 parent ef7dcb7 commit f080507
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
4 changes: 2 additions & 2 deletions dory/Hardware_targets/PULP/Common/C_Parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def mapping_layers_to_C_files(self):

if n_memory_levels > 2 and (node.L3_input != 0 or (node.tiling_dimensions["L3"]["output_dimensions"] != node.tiling_dimensions["L2"]["output_dimensions"]) or (node.tiling_dimensions["L3"]["weights_dimensions"] != node.tiling_dimensions["L2"]["weights_dimensions"])):
tk = Layer2D_writer.print_template_layer_L3(node)
TemplateWriter.write(tk, {os.path.join(self.src_dir, node.prefixed_name + ".c"): os.path.join(self.tmpl_dir, "layer_L3_c_template.c"),
os.path.join(self.inc_dir, node.prefixed_name + ".h"): os.path.join(self.tmpl_dir, "layer_L3_h_template.h")})
TemplateWriter.write(tk, {os.path.join(self.src_dir, node.prefixed_name + ".c"): os.path.join(self.layer_tmpl_dir, "layer_L3_c_template.c"),
os.path.join(self.inc_dir, node.prefixed_name + ".h"): os.path.join(self.layer_tmpl_dir, "layer_L3_h_template.h")})
if node.tiling_dimensions["L3"]["input_dimensions"][1] > node.tiling_dimensions["L2"]["input_dimensions"][1]:
node.tiling_dimensions["L2"]["output_dimensions"][1] = int(np.floor((node.tiling_dimensions["L2"]["input_dimensions"][1] - node.kernel_shape[0] + node.strides[0]) / node.strides[0]))
if node.tiling_dimensions["L3"]["output_dimensions"][1] > node.tiling_dimensions["L2"]["output_dimensions"][1]:
Expand Down
4 changes: 2 additions & 2 deletions dory/Hardware_targets/PULP/GAP9_NE16/C_Parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def mapping_layers_to_C_files(self):
if n_memory_levels > 2 and (node.L3_input != 0 or (node.tiling_dimensions["L3"]["output_dimensions"] != node.tiling_dimensions["L2"]["output_dimensions"]) or (node.tiling_dimensions["L3"]["weights_dimensions"] != node.tiling_dimensions["L2"]["weights_dimensions"])):
#breakpoint()
tk = Layer2D_writer.print_template_layer_L3(node)
TemplateWriter.write(tk, {os.path.join(self.src_dir, node.prefixed_name + ".c"): os.path.join(self.tmpl_dir, "layer_L3_c_template.c"),
os.path.join(self.inc_dir, node.prefixed_name + ".h"): os.path.join(self.tmpl_dir, "layer_L3_h_template.h")})
TemplateWriter.write(tk, {os.path.join(self.src_dir, node.prefixed_name + ".c"): os.path.join(self.layer_tmpl_dir, "layer_L3_c_template.c"),
os.path.join(self.inc_dir, node.prefixed_name + ".h"): os.path.join(self.layer_tmpl_dir, "layer_L3_h_template.h")})
if node.tiling_dimensions["L3"]["input_dimensions"][1] > node.tiling_dimensions["L2"]["input_dimensions"][1]:
node.tiling_dimensions["L2"]["output_dimensions"][1] = (node.tiling_dimensions["L2"]["input_dimensions"][1] - node.kernel_shape[0] + node.strides[0]) // node.strides[0]
if node.tiling_dimensions["L3"]["output_dimensions"][1] > node.tiling_dimensions["L2"]["output_dimensions"][1]:
Expand Down
11 changes: 8 additions & 3 deletions dory/Parsers/Parser_HW_to_C.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def mapping_network_to_C_file(self):
self.config_file,
self.verbose_level,
self.perf_layer,
self.tmpl_dir,
self.app_directory,
self.inc_dir_rel,
self.src_dir_rel)
Expand Down Expand Up @@ -86,8 +87,8 @@ def l2_c_template(self, node, backend_library):
def l2_template_mapping(self, node, backend_library):
tmpl_c = self.l2_c_template(node, backend_library)
return {
os.path.join(self.src_dir, node.name + ".c"): os.path.join(self.tmpl_dir, tmpl_c),
os.path.join(self.inc_dir, node.name + ".h"): os.path.join(self.tmpl_dir, "layer_L2_h_template.h"),
os.path.join(self.src_dir, node.name + ".c"): os.path.join(self.layer_tmpl_dir, tmpl_c),
os.path.join(self.inc_dir, node.name + ".h"): os.path.join(self.layer_tmpl_dir, "layer_L2_h_template.h"),
}

def mapping_layers_to_C_files(self):
Expand Down Expand Up @@ -170,7 +171,11 @@ def get_file_path(self):

@property
def tmpl_dir(self):
return os.path.realpath(os.path.join(self.get_file_path(), 'Templates/layer_templates'))
return os.path.realpath(os.path.join(self.get_file_path(), 'Templates'))

@property
def layer_tmpl_dir(self):
return os.path.realpath(os.path.join(self.tmpl_dir, 'layer_templates'))

@property
def utils_files_dir(self):
Expand Down
7 changes: 4 additions & 3 deletions dory/Utils/Templates_writer/Network_template_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def print_template_network(
config_file,
verbose_level,
perf_layer,
tmpl_dir,
app_directory,
inc_dir_rel,
src_dir_rel
Expand Down Expand Up @@ -86,19 +87,19 @@ def print_template_network(
l += "// %s %s\n" % (k.ljust(30), v)
tk['DORY_HW_graph'] = graph
root = os.path.realpath(os.path.dirname(__file__))
tmpl = Template(filename=os.path.join(root, "../../Hardware_targets", HW_description["name"], "Templates/network_c_template.c"))
tmpl = Template(filename=os.path.join(tmpl_dir, "network_c_template.c"))
s = tmpl.render(verbose_log=l, **tk)
save_string = os.path.join(app_directory, src_dir_rel, prefix + 'network.c')
with open(save_string, "w") as f:
f.write(s)

tmpl = Template(filename=os.path.join(root, "../../Hardware_targets", HW_description["name"], "Templates/network_h_template.h"))
tmpl = Template(filename=os.path.join(tmpl_dir, "network_h_template.h"))
s = tmpl.render(verbose_log=l, **tk)
save_string = os.path.join(app_directory, inc_dir_rel, prefix + 'network.h')
with open(save_string, "w") as f:
f.write(s)

tmpl = Template(filename=os.path.join(root, "../../Hardware_targets", HW_description["name"], "Templates/main_template.c"))
tmpl = Template(filename=os.path.join(tmpl_dir, "main_template.c"))
s = tmpl.render(verbose_log=l, **tk)
save_string = os.path.join(app_directory, src_dir_rel, prefix + 'main.c')
with open(save_string, "w") as f:
Expand Down

0 comments on commit f080507

Please sign in to comment.