diff --git a/target/snitch_cluster/sw/apps/dnn/datagen.py b/target/snitch_cluster/sw/apps/dnn/datagen.py index 5c720f0c9e..04f0e2a1db 100755 --- a/target/snitch_cluster/sw/apps/dnn/datagen.py +++ b/target/snitch_cluster/sw/apps/dnn/datagen.py @@ -73,8 +73,9 @@ def emit_transformer_layer(name='transformer', **kwargs): ifmap = kwargs['ifmap'] weights_q = kwargs['weights_q'] - weights_k = kwargs['weights_k'].T + weights_k = kwargs['weights_k'] weights_v = kwargs['weights_v'] + weights_o = kwargs['weights_o'] query = kwargs['query'] key = kwargs['key'] @@ -104,10 +105,16 @@ def emit_transformer_layer(name='transformer', **kwargs): layer_str += f'\t.dtype = FP{kwargs["prec"]},\n' layer_str += '};\n\n\n' + # Declare the DRAM arrays + layer_str += f'static {dtype} {name}_Q_lin_dram[{S}][{E}] __attribute__((section(".data")));\n\n' + layer_str += f'static {dtype} {name}_K_lin_dram[{S}][{E}] __attribute__((section(".data")));\n\n' + layer_str += f'static {dtype} {name}_V_lin_dram[{S}][{E}] __attribute__((section(".data")));\n\n' + layer_str += f'static {dtype} {name}_O_dram[{S}][{P}] __attribute__((section(".data")));\n\n' layer_str += f'static {dtype} {name}_ifmap_dram[{S}][{E}] = ' + array_to_cstr(ifmap) + ';\n\n' layer_str += f'static {dtype} {name}_weights_q_dram[{E}][{P}] = ' + array_to_cstr(weights_q) + ';\n\n' layer_str += f'static {dtype} {name}_weights_k_dram[{E}][{P}] = ' + array_to_cstr(weights_k) + ';\n\n' layer_str += f'static {dtype} {name}_weights_v_dram[{E}][{P}] = ' + array_to_cstr(weights_v) + ';\n\n' + layer_str += f'static {dtype} {name}_weights_o_dram[{P}][{E}] = ' + array_to_cstr(weights_o) + ';\n\n' layer_str += f'static {dtype} {name}_query_dram[{S}][{P}] = ' + array_to_cstr(query) + ';\n\n' layer_str += f'static {dtype} {name}_key_dram[{S}][{P}] = ' + array_to_cstr(key) + ';\n\n' layer_str += f'static {dtype} {name}_value_dram[{S}][{P}] = ' + array_to_cstr(value) + ';\n\n' @@ -921,16 +928,24 @@ def main(): requires_grad=False, dtype=dtype) weights_v = torch.randn(param['input_dim']['embeddings'], param['input_dim']['positional_embeddings'], requires_grad=False, dtype=dtype) + weights_o = torch.randn(param['input_dim']['positional_embeddings'], param['input_dim']['embeddings'], + requires_grad=False, dtype=dtype) query = transformer(ifmap, weights_q.T, None, None, None, False) key = transformer(ifmap, weights_k.T, None, None, None, False) value = transformer(ifmap, weights_v.T, None, None, None, False) + sm = torch.nn.Softmax(dim=-1) + attention = sm(torch.matmul(query, key.T)) + attention_out = torch.matmul(attention, value) + transformer_out = torch.matmul(attention_out, weights_o) + kwargs = { 'ifmap': ifmap, 'weights_q': weights_q, 'weights_k': weights_k, 'weights_v': weights_v, + 'weights_o': weights_o, 'query': query, 'key': key, 'value': value,