Skip to content

Commit

Permalink
util/sim: Extend struct definition to include array initializers
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Nov 12, 2023
1 parent 4f674e0 commit 074e76b
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions util/sim/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ def floating_point_ctype(precision):
def flatten(array):
if isinstance(array, np.ndarray):
return array.flatten()
if isinstance(array, torch.Tensor):
elif isinstance(array, torch.Tensor):
return array.numpy().flatten()
elif isinstance(array, list):
return np.array(array).flatten()


def variable_attributes(alignment=None, section=None):
Expand Down Expand Up @@ -78,7 +80,19 @@ def format_array_declaration(dtype, uid, shape, alignment=None, section=None):
def format_array_definition(dtype, uid, array, alignment=None, section=None):
# Definition starts with the declaration stripped off of the terminating semicolon
s = format_array_declaration(dtype, uid, array.shape, alignment, section)[:-1]
s += ' = {\n'
s += ' = '
s += format_array_initializer(dtype, array)
s += ';'
return s


def format_scalar_definition(dtype, uid, scalar):
s = f'{alias_dtype(dtype)} {uid} = {scalar};'
return s


def format_array_initializer(dtype, array):
s = '{\n'
# Flatten array
if dtype == '__fp8':
array = zip(flatten(array['sign']),
Expand All @@ -95,18 +109,18 @@ def format_array_definition(dtype, uid, array, alignment=None, section=None):
else:
el_str = f'{el}'
s += f'\t{el_str},\n'
s += '};'
return s


def format_scalar_definition(dtype, uid, scalar):
s = f'{alias_dtype(dtype)} {uid} = {scalar};'
s += '}'
return s


def format_struct_definition(dtype, uid, map):
def format_value(value):
if isinstance(value, list):
return format_array_initializer(str, value)
else:
return str(value)
s = f'{alias_dtype(dtype)} {uid} = {{\n'
s += ',\n'.join([f'\t.{key} = {value}' for (key, value) in map.items()])
s += ',\n'.join([f'\t.{key} = {format_value(value)}' for (key, value) in map.items()])
s += '\n};'
return s

Expand Down

0 comments on commit 074e76b

Please sign in to comment.