From e074c13c1144bc84c1831a25cdbb6bf5b19429d4 Mon Sep 17 00:00:00 2001 From: Viviane Potocnik Date: Mon, 16 Oct 2023 14:55:18 +0200 Subject: [PATCH] transformer: improved data generation for transformer with automatic tiling --- target/snitch_cluster/sw/apps/dnn/datagen.py | 217 ++++++++++++++++-- .../sw/apps/transformer/src/params.hjson | 9 +- 2 files changed, 209 insertions(+), 17 deletions(-) diff --git a/target/snitch_cluster/sw/apps/dnn/datagen.py b/target/snitch_cluster/sw/apps/dnn/datagen.py index 04f0e2a1db..4accc40076 100755 --- a/target/snitch_cluster/sw/apps/dnn/datagen.py +++ b/target/snitch_cluster/sw/apps/dnn/datagen.py @@ -72,6 +72,17 @@ def emit_header_file(file, layer_type: str, **kwargs): def emit_transformer_layer(name='transformer', **kwargs): ifmap = kwargs['ifmap'] + ifmap_ln = kwargs['ifmap_ln'] + # retrieve tiling parameters + S_tile_ln = kwargs['S_tile_ln'] + S_tile_lin1 = kwargs['S_tile_lin1'] + P_tile_lin1 = kwargs['P_tile_lin1'] + Br_tile_fa = kwargs['Br_tile_fa'] + Bc_tile_fa = kwargs['Bc_tile_fa'] + positional_embeddings_fa = kwargs['positional_embeddings_fa'] + q_fa = kwargs['q_fa'] + k_fa = kwargs['k_fa'] + v_fa = kwargs['v_fa'] weights_q = kwargs['weights_q'] weights_k = kwargs['weights_k'] weights_v = kwargs['weights_v'] @@ -81,10 +92,14 @@ def emit_transformer_layer(name='transformer', **kwargs): key = kwargs['key'] value = kwargs['value'] + ofmap = kwargs['ofmap'] + # Get the dimensions: sequence length S, # embedding size E, and position embedding size P S, E = ifmap.shape _, P = weights_q.shape + _, P_fa = q_fa.shape + HP, _ = weights_o.shape ctypes = { '64': 'double', @@ -98,10 +113,19 @@ def emit_transformer_layer(name='transformer', **kwargs): # layer_str = '#include \n' layer_str = '' layer_str += '#include "transformer.h"\n\n' - layer_str += f'transformer_layer_t {name}_l = {{\n' + layer_str += f'transformer_layer_fp{kwargs["prec"]}_t {name}_l = {{\n' layer_str += f'\t.seq_len = {S},\n' + layer_str += f'\t.S_tile_ln = {S_tile_ln},\n' + layer_str += f'\t.S_tile_lin1 = {S_tile_lin1},\n' + layer_str += f'\t.P_tile_lin1 = {P_tile_lin1},\n' + layer_str += f'\t.Br_tile_fa = {Br_tile_fa},\n' + layer_str += f'\t.Bc_tile_fa = {Bc_tile_fa},\n' layer_str += f'\t.embeddings = {E},\n' layer_str += f'\t.positional_embeddings = {P},\n' + layer_str += f'\t.positional_embeddings_fa = {positional_embeddings_fa},\n' + layer_str += f'\t.feedforward_len = {kwargs["feedforward_len"]},\n' + layer_str += f'\t.heads = {int(HP / P)},\n' + layer_str += f'\t.eps = {kwargs["eps"]},\n' layer_str += f'\t.dtype = FP{kwargs["prec"]},\n' layer_str += '};\n\n\n' @@ -110,14 +134,20 @@ def emit_transformer_layer(name='transformer', **kwargs): layer_str += f'static {dtype} {name}_K_lin_dram[{S}][{E}] __attribute__((section(".data")));\n\n' layer_str += f'static {dtype} {name}_V_lin_dram[{S}][{E}] __attribute__((section(".data")));\n\n' layer_str += f'static {dtype} {name}_O_dram[{S}][{P}] __attribute__((section(".data")));\n\n' + layer_str += f'static {dtype} {name}_OH_dram[{S}][{HP}] __attribute__((section(".data")));\n\n' layer_str += f'static {dtype} {name}_ifmap_dram[{S}][{E}] = ' + array_to_cstr(ifmap) + ';\n\n' + layer_str += f'static {dtype} {name}_ifmap_ln_dram[{S}][{E}] = ' + array_to_cstr(ifmap_ln) + ';\n\n' + # layer_str += f'static {dtype} {name}_ofmap_dram[{S}][{HP}] = ' + array_to_cstr(ofmap) + ';\n\n' layer_str += f'static {dtype} {name}_weights_q_dram[{E}][{P}] = ' + array_to_cstr(weights_q) + ';\n\n' layer_str += f'static {dtype} {name}_weights_k_dram[{E}][{P}] = ' + array_to_cstr(weights_k) + ';\n\n' layer_str += f'static {dtype} {name}_weights_v_dram[{E}][{P}] = ' + array_to_cstr(weights_v) + ';\n\n' - layer_str += f'static {dtype} {name}_weights_o_dram[{P}][{E}] = ' + array_to_cstr(weights_o) + ';\n\n' - layer_str += f'static {dtype} {name}_query_dram[{S}][{P}] = ' + array_to_cstr(query) + ';\n\n' - layer_str += f'static {dtype} {name}_key_dram[{S}][{P}] = ' + array_to_cstr(key) + ';\n\n' - layer_str += f'static {dtype} {name}_value_dram[{S}][{P}] = ' + array_to_cstr(value) + ';\n\n' + layer_str += f'static {dtype} {name}_q_fa_dram[{S}][{P_fa}] = ' + array_to_cstr(q_fa) + ';\n\n' + layer_str += f'static {dtype} {name}_k_fa_dram[{P_fa}][{S}] = ' + array_to_cstr(k_fa) + ';\n\n' + layer_str += f'static {dtype} {name}_v_fa_dram[{S}][{P_fa}] = ' + array_to_cstr(v_fa) + ';\n\n' + # layer_str += f'static {dtype} {name}_weights_o_dram[{HP}][{E}] = ' + array_to_cstr(weights_o) + ';\n\n' + # layer_str += f'static {dtype} {name}_query_dram[{S}][{P}] = ' + array_to_cstr(query) + ';\n\n' + # layer_str += f'static {dtype} {name}_key_dram[{S}][{P}] = ' + array_to_cstr(key) + ';\n\n' + # layer_str += f'static {dtype} {name}_value_dram[{S}][{P}] = ' + array_to_cstr(value) + ';\n\n' return layer_str @@ -919,37 +949,194 @@ def main(): emit_header_file(args.output, 'LayerNorm', **kwargs) elif param['kernel'] == 'Transformer': - ifmap = torch.randn(param['input_dim']['seq_len'], param['input_dim']['embeddings'], + seq_len = param['input_dim']['seq_len'] + heads = param['input_dim']['heads'] + embeddings = param['input_dim']['embeddings'] + positional_embeddings = param['input_dim']['positional_embeddings'] + feedforward_len = param['input_dim']['feedforward_len'] + + # check if we want to run a brief test + brief = param['brief'] + num_iters = param['num_iters'] + print("Brief test: ", brief) + print("Number of iterations: ", num_iters) + + # tcdm capacity in bytes + tcdm_size = 125 * 1024 + # data type size in bytes + data_type_size = torch.tensor(1, dtype=dtype).element_size() + print("Data type size: ", data_type_size) + # initialize the best solution parameters + best_dram_accessed_data = float('inf') + dram_accessed_data_list = [] + best_tcdm_storage = 0 + best_s_tile_ln = 0 + for S_tile in range(8, seq_len, 8): + dram_accessed_data = 2 * (seq_len // S_tile) * embeddings * data_type_size + tcdm_storage = (S_tile) * embeddings * data_type_size + if tcdm_storage <= tcdm_size: + if tcdm_storage > best_tcdm_storage or dram_accessed_data < best_dram_accessed_data: + best_dram_accessed_data = dram_accessed_data + best_tcdm_storage = tcdm_storage + best_s_tile_ln = S_tile + + print("LayerNorm Best S_tile: ", best_s_tile_ln) + if(brief == True): + seq_len = num_iters * best_s_tile_ln + embeddings = embeddings // 70 + + print("LayerNorm Sequence length: ", seq_len) + + # Layer 1: LayerNorm layer + ifmap = torch.randn(seq_len, embeddings, requires_grad=False, dtype=dtype) - - weights_q = torch.randn(param['input_dim']['embeddings'], param['input_dim']['positional_embeddings'], + + eps = param['eps'] + + m = nn.LayerNorm(ifmap.size()[1:]) + + # TODO: due to a bug in PyTorch, we need to cast the input to float32 or BFloat16 + ifmap = ifmap.type(torch.float32) + + ifmap_ln = m(ifmap) + + # cast back to the original data type + ifmap_ln = ifmap_ln.to(dtype).detach() + ifmap = ifmap.to(dtype) + + # Layer 2: Linear layer 1 + # TODO: check whether we go for min DRAM accesses or min DRAM accessed data + # we reset the best solution parameters + embeddings = param['input_dim']['embeddings'] + best_dram_accessed_data = float('inf') + best_dram_accesses = float('inf') + best_tcdm_storage = 0 + best_s_tile_lin1 = 0 + best_p_tile_lin1 = 0 + + for S_tile in range(8, seq_len, 8): + for P_tile in range(1, positional_embeddings, 1): + dram_accessed_data = ((seq_len // S_tile) * (S_tile * embeddings) + \ + (seq_len // S_tile) * 3 * (positional_embeddings // P_tile) * embeddings * P_tile \ + + (seq_len // S_tile) * (positional_embeddings // P_tile) * S_tile * P_tile) * data_type_size + dram_accesses = (seq_len // S_tile) + (seq_len // S_tile) * 2 * 3 * (positional_embeddings // P_tile) + tcdm_storage = (S_tile * embeddings + 3 * embeddings * P_tile + 3 * S_tile * P_tile) * data_type_size + if tcdm_storage <= tcdm_size: + if tcdm_storage > best_tcdm_storage or dram_accesses < best_dram_accesses:#or dram_accessed_data < best_dram_accessed_data: + best_dram_accessed_data = dram_accessed_data + best_dram_accesses = dram_accesses + best_tcdm_storage = tcdm_storage + best_s_tile_lin1 = S_tile + best_p_tile_lin1 = P_tile + + print("Layer 1 Best S_tile: ", best_s_tile_lin1) + print("Layer 1 Best P_tile: ", best_p_tile_lin1) + if(brief == True): + seq_len = num_iters * best_s_tile_lin1 + positional_embeddings = num_iters * best_p_tile_lin1 + embeddings = embeddings // 70 + + print("Layer 1 Sequence length: ", seq_len) + print("Layer 1 Positional embeddings: ", positional_embeddings) + print("Layer 1 Embeddings: ", embeddings) + + weights_q = torch.randn(embeddings, positional_embeddings, + requires_grad=False, dtype=dtype) + weights_k = torch.randn(embeddings, positional_embeddings, + requires_grad=False, dtype=dtype) + weights_v = torch.randn(embeddings, positional_embeddings, requires_grad=False, dtype=dtype) - weights_k = torch.randn(param['input_dim']['embeddings'], param['input_dim']['positional_embeddings'], + + # Layer 3: FlashAttention-2 layer + # TODO: check whether we go for min DRAM accesses or min DRAM accessed data + + # we reset the best solution parameters + # TODO: For the full model, we must also reset the sequence length + # seq_len = param['input_dim']['seq_len'] + embeddings = param['input_dim']['embeddings'] + positional_embeddings = param['input_dim']['positional_embeddings'] + best_dram_accessed_data = float('inf') + best_dram_accesses = float('inf') + best_tcdm_storage = 0 + best_br_tile_fa = 0 + best_bc_tile_fa = 0 + + for B_r in range(8, seq_len, 8): + for B_c in range(2, seq_len, 2): + dram_accesses = (seq_len // B_r) + (seq_len // B_r) * 2 * (seq_len // B_c) + (seq_len // B_r) + dram_accessed_data = ((seq_len // B_r) * (B_r * positional_embeddings) + (seq_len // B_r) * 2 * (seq_len // B_c) * (B_c * positional_embeddings) + (seq_len // B_r) * (B_r * B_c)) * data_type_size + tcdm_storage = (B_r * positional_embeddings + 2 * positional_embeddings * B_c + 4 * B_r + 2 * B_r * B_c ) * data_type_size + if tcdm_storage <= tcdm_size: + if tcdm_storage > best_tcdm_storage or dram_accesses < best_dram_accesses: + best_dram_accessed_data = dram_accessed_data + best_dram_accesses = dram_accesses + best_tcdm_storage = tcdm_storage + best_br_tile_fa = B_r + best_bc_tile_fa = B_c + + print("FlashAttention Layer 2 Best B_r: ", best_br_tile_fa) + print("FlashAttention Layer 2 Best P_tile: ", best_bc_tile_fa) + if(brief == True): + # seq_len = num_iters * best_br_tile_fa + seq_len = num_iters * best_s_tile_ln + positional_embeddings = positional_embeddings // 10 + embeddings = embeddings // 70 + + print("FlashAttention Layer 2 Sequence length: ", seq_len) + print("FlashAttention Layer 2 Positional embeddings: ", positional_embeddings) + print("FlashAttention Layer 2 Embeddings: ", embeddings) + + positional_embeddings_fa = positional_embeddings + + # q_fa = torch.randn(seq_len, positional_embeddings_fa, + # requires_grad=False, dtype=dtype) + # k_fa = torch.randn(positional_embeddings_fa, seq_len, + # requires_grad=False, dtype=dtype) + # v_fa = torch.randn(seq_len, positional_embeddings_fa, + # requires_grad=False, dtype=dtype) + # create tensors where the values are within 0 and 1 + q_fa = torch.rand(seq_len, positional_embeddings_fa, requires_grad=False, dtype=dtype) - weights_v = torch.randn(param['input_dim']['embeddings'], param['input_dim']['positional_embeddings'], + k_fa = torch.rand(positional_embeddings_fa, seq_len, requires_grad=False, dtype=dtype) - weights_o = torch.randn(param['input_dim']['positional_embeddings'], param['input_dim']['embeddings'], + v_fa = torch.rand(seq_len, positional_embeddings_fa, + requires_grad=False, dtype=dtype) + + weights_o = torch.randn(param['input_dim']['heads'] * positional_embeddings, embeddings, requires_grad=False, dtype=dtype) query = transformer(ifmap, weights_q.T, None, None, None, False) key = transformer(ifmap, weights_k.T, None, None, None, False) value = transformer(ifmap, weights_v.T, None, None, None, False) - sm = torch.nn.Softmax(dim=-1) - attention = sm(torch.matmul(query, key.T)) - attention_out = torch.matmul(attention, value) - transformer_out = torch.matmul(attention_out, weights_o) + ofmap = torch.randn(param['input_dim']['seq_len'], param['input_dim']['heads'] * positional_embeddings, + requires_grad=False, dtype=dtype) + + feedforward_len = param['input_dim']['feedforward_len'] kwargs = { 'ifmap': ifmap, + 'ifmap_ln': ifmap_ln, + 'S_tile_ln': best_s_tile_ln, + 'S_tile_lin1': best_s_tile_lin1, + 'P_tile_lin1': best_p_tile_lin1, + 'Br_tile_fa': best_br_tile_fa, + 'Bc_tile_fa': best_bc_tile_fa, + 'positional_embeddings_fa': positional_embeddings_fa, 'weights_q': weights_q, 'weights_k': weights_k, 'weights_v': weights_v, + 'q_fa': q_fa, + 'k_fa': k_fa, + 'v_fa': v_fa, 'weights_o': weights_o, + 'ofmap': ofmap, 'query': query, 'key': key, 'value': value, 'prec': param['prec'], + 'eps': eps, + 'feedforward_len': feedforward_len, } emit_header_file(args.output, 'Transformer', **kwargs) diff --git a/target/snitch_cluster/sw/apps/transformer/src/params.hjson b/target/snitch_cluster/sw/apps/transformer/src/params.hjson index 39d36b77df..2c8cf14d4b 100644 --- a/target/snitch_cluster/sw/apps/transformer/src/params.hjson +++ b/target/snitch_cluster/sw/apps/transformer/src/params.hjson @@ -8,9 +8,14 @@ { kernel: "Transformer" input_dim: { - seq_len: 128, + seq_len: 2048, embeddings: 768, positional_embeddings: 64, + feedforward_len: 3072, + heads: 12, } - prec: 32 + eps: 1e-5 + prec: 64 + brief: true + num_iters: 2 } \ No newline at end of file