Skip to content

Commit

Permalink
trafo: add DRAM data sections to datagen
Browse files Browse the repository at this point in the history
  • Loading branch information
Viviane Potocnik committed Sep 16, 2023
1 parent f7e07b4 commit 3078f99
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion target/snitch_cluster/sw/apps/dnn/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ def emit_transformer_layer(name='transformer', **kwargs):

ifmap = kwargs['ifmap']
weights_q = kwargs['weights_q']
weights_k = kwargs['weights_k'].T
weights_k = kwargs['weights_k']
weights_v = kwargs['weights_v']
weights_o = kwargs['weights_o']

query = kwargs['query']
key = kwargs['key']
Expand Down Expand Up @@ -104,10 +105,16 @@ def emit_transformer_layer(name='transformer', **kwargs):
layer_str += f'\t.dtype = FP{kwargs["prec"]},\n'
layer_str += '};\n\n\n'

# Declare the DRAM arrays
layer_str += f'static {dtype} {name}_Q_lin_dram[{S}][{E}] __attribute__((section(".data")));\n\n'
layer_str += f'static {dtype} {name}_K_lin_dram[{S}][{E}] __attribute__((section(".data")));\n\n'
layer_str += f'static {dtype} {name}_V_lin_dram[{S}][{E}] __attribute__((section(".data")));\n\n'
layer_str += f'static {dtype} {name}_O_dram[{S}][{P}] __attribute__((section(".data")));\n\n'
layer_str += f'static {dtype} {name}_ifmap_dram[{S}][{E}] = ' + array_to_cstr(ifmap) + ';\n\n'
layer_str += f'static {dtype} {name}_weights_q_dram[{E}][{P}] = ' + array_to_cstr(weights_q) + ';\n\n'
layer_str += f'static {dtype} {name}_weights_k_dram[{E}][{P}] = ' + array_to_cstr(weights_k) + ';\n\n'
layer_str += f'static {dtype} {name}_weights_v_dram[{E}][{P}] = ' + array_to_cstr(weights_v) + ';\n\n'
layer_str += f'static {dtype} {name}_weights_o_dram[{P}][{E}] = ' + array_to_cstr(weights_o) + ';\n\n'
layer_str += f'static {dtype} {name}_query_dram[{S}][{P}] = ' + array_to_cstr(query) + ';\n\n'
layer_str += f'static {dtype} {name}_key_dram[{S}][{P}] = ' + array_to_cstr(key) + ';\n\n'
layer_str += f'static {dtype} {name}_value_dram[{S}][{P}] = ' + array_to_cstr(value) + ';\n\n'
Expand Down Expand Up @@ -921,16 +928,24 @@ def main():
requires_grad=False, dtype=dtype)
weights_v = torch.randn(param['input_dim']['embeddings'], param['input_dim']['positional_embeddings'],
requires_grad=False, dtype=dtype)
weights_o = torch.randn(param['input_dim']['positional_embeddings'], param['input_dim']['embeddings'],
requires_grad=False, dtype=dtype)

query = transformer(ifmap, weights_q.T, None, None, None, False)
key = transformer(ifmap, weights_k.T, None, None, None, False)
value = transformer(ifmap, weights_v.T, None, None, None, False)

sm = torch.nn.Softmax(dim=-1)
attention = sm(torch.matmul(query, key.T))
attention_out = torch.matmul(attention, value)
transformer_out = torch.matmul(attention_out, weights_o)

kwargs = {
'ifmap': ifmap,
'weights_q': weights_q,
'weights_k': weights_k,
'weights_v': weights_v,
'weights_o': weights_o,
'query': query,
'key': key,
'value': value,
Expand Down

0 comments on commit 3078f99

Please sign in to comment.