Skip to content

Commit

Permalink
transformer: improved data generation for transformer with automatic …
Browse files Browse the repository at this point in the history
…tiling
  • Loading branch information
Viviane Potocnik committed Oct 16, 2023
1 parent f13ba2f commit e074c13
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 17 deletions.
217 changes: 202 additions & 15 deletions target/snitch_cluster/sw/apps/dnn/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ def emit_header_file(file, layer_type: str, **kwargs):
def emit_transformer_layer(name='transformer', **kwargs):

ifmap = kwargs['ifmap']
ifmap_ln = kwargs['ifmap_ln']
# retrieve tiling parameters
S_tile_ln = kwargs['S_tile_ln']
S_tile_lin1 = kwargs['S_tile_lin1']
P_tile_lin1 = kwargs['P_tile_lin1']
Br_tile_fa = kwargs['Br_tile_fa']
Bc_tile_fa = kwargs['Bc_tile_fa']
positional_embeddings_fa = kwargs['positional_embeddings_fa']
q_fa = kwargs['q_fa']
k_fa = kwargs['k_fa']
v_fa = kwargs['v_fa']
weights_q = kwargs['weights_q']
weights_k = kwargs['weights_k']
weights_v = kwargs['weights_v']
Expand All @@ -81,10 +92,14 @@ def emit_transformer_layer(name='transformer', **kwargs):
key = kwargs['key']
value = kwargs['value']

ofmap = kwargs['ofmap']

# Get the dimensions: sequence length S,
# embedding size E, and position embedding size P
S, E = ifmap.shape
_, P = weights_q.shape
_, P_fa = q_fa.shape
HP, _ = weights_o.shape

ctypes = {
'64': 'double',
Expand All @@ -98,10 +113,19 @@ def emit_transformer_layer(name='transformer', **kwargs):
# layer_str = '#include <stdint.h>\n'
layer_str = ''
layer_str += '#include "transformer.h"\n\n'
layer_str += f'transformer_layer_t {name}_l = {{\n'
layer_str += f'transformer_layer_fp{kwargs["prec"]}_t {name}_l = {{\n'
layer_str += f'\t.seq_len = {S},\n'
layer_str += f'\t.S_tile_ln = {S_tile_ln},\n'
layer_str += f'\t.S_tile_lin1 = {S_tile_lin1},\n'
layer_str += f'\t.P_tile_lin1 = {P_tile_lin1},\n'
layer_str += f'\t.Br_tile_fa = {Br_tile_fa},\n'
layer_str += f'\t.Bc_tile_fa = {Bc_tile_fa},\n'
layer_str += f'\t.embeddings = {E},\n'
layer_str += f'\t.positional_embeddings = {P},\n'
layer_str += f'\t.positional_embeddings_fa = {positional_embeddings_fa},\n'
layer_str += f'\t.feedforward_len = {kwargs["feedforward_len"]},\n'
layer_str += f'\t.heads = {int(HP / P)},\n'
layer_str += f'\t.eps = {kwargs["eps"]},\n'
layer_str += f'\t.dtype = FP{kwargs["prec"]},\n'
layer_str += '};\n\n\n'

Expand All @@ -110,14 +134,20 @@ def emit_transformer_layer(name='transformer', **kwargs):
layer_str += f'static {dtype} {name}_K_lin_dram[{S}][{E}] __attribute__((section(".data")));\n\n'
layer_str += f'static {dtype} {name}_V_lin_dram[{S}][{E}] __attribute__((section(".data")));\n\n'
layer_str += f'static {dtype} {name}_O_dram[{S}][{P}] __attribute__((section(".data")));\n\n'
layer_str += f'static {dtype} {name}_OH_dram[{S}][{HP}] __attribute__((section(".data")));\n\n'
layer_str += f'static {dtype} {name}_ifmap_dram[{S}][{E}] = ' + array_to_cstr(ifmap) + ';\n\n'
layer_str += f'static {dtype} {name}_ifmap_ln_dram[{S}][{E}] = ' + array_to_cstr(ifmap_ln) + ';\n\n'
# layer_str += f'static {dtype} {name}_ofmap_dram[{S}][{HP}] = ' + array_to_cstr(ofmap) + ';\n\n'
layer_str += f'static {dtype} {name}_weights_q_dram[{E}][{P}] = ' + array_to_cstr(weights_q) + ';\n\n'
layer_str += f'static {dtype} {name}_weights_k_dram[{E}][{P}] = ' + array_to_cstr(weights_k) + ';\n\n'
layer_str += f'static {dtype} {name}_weights_v_dram[{E}][{P}] = ' + array_to_cstr(weights_v) + ';\n\n'
layer_str += f'static {dtype} {name}_weights_o_dram[{P}][{E}] = ' + array_to_cstr(weights_o) + ';\n\n'
layer_str += f'static {dtype} {name}_query_dram[{S}][{P}] = ' + array_to_cstr(query) + ';\n\n'
layer_str += f'static {dtype} {name}_key_dram[{S}][{P}] = ' + array_to_cstr(key) + ';\n\n'
layer_str += f'static {dtype} {name}_value_dram[{S}][{P}] = ' + array_to_cstr(value) + ';\n\n'
layer_str += f'static {dtype} {name}_q_fa_dram[{S}][{P_fa}] = ' + array_to_cstr(q_fa) + ';\n\n'
layer_str += f'static {dtype} {name}_k_fa_dram[{P_fa}][{S}] = ' + array_to_cstr(k_fa) + ';\n\n'
layer_str += f'static {dtype} {name}_v_fa_dram[{S}][{P_fa}] = ' + array_to_cstr(v_fa) + ';\n\n'
# layer_str += f'static {dtype} {name}_weights_o_dram[{HP}][{E}] = ' + array_to_cstr(weights_o) + ';\n\n'
# layer_str += f'static {dtype} {name}_query_dram[{S}][{P}] = ' + array_to_cstr(query) + ';\n\n'
# layer_str += f'static {dtype} {name}_key_dram[{S}][{P}] = ' + array_to_cstr(key) + ';\n\n'
# layer_str += f'static {dtype} {name}_value_dram[{S}][{P}] = ' + array_to_cstr(value) + ';\n\n'

return layer_str

Expand Down Expand Up @@ -919,37 +949,194 @@ def main():
emit_header_file(args.output, 'LayerNorm', **kwargs)

elif param['kernel'] == 'Transformer':
ifmap = torch.randn(param['input_dim']['seq_len'], param['input_dim']['embeddings'],
seq_len = param['input_dim']['seq_len']
heads = param['input_dim']['heads']
embeddings = param['input_dim']['embeddings']
positional_embeddings = param['input_dim']['positional_embeddings']
feedforward_len = param['input_dim']['feedforward_len']

# check if we want to run a brief test
brief = param['brief']
num_iters = param['num_iters']
print("Brief test: ", brief)
print("Number of iterations: ", num_iters)

# tcdm capacity in bytes
tcdm_size = 125 * 1024
# data type size in bytes
data_type_size = torch.tensor(1, dtype=dtype).element_size()
print("Data type size: ", data_type_size)
# initialize the best solution parameters
best_dram_accessed_data = float('inf')
dram_accessed_data_list = []
best_tcdm_storage = 0
best_s_tile_ln = 0
for S_tile in range(8, seq_len, 8):
dram_accessed_data = 2 * (seq_len // S_tile) * embeddings * data_type_size
tcdm_storage = (S_tile) * embeddings * data_type_size
if tcdm_storage <= tcdm_size:
if tcdm_storage > best_tcdm_storage or dram_accessed_data < best_dram_accessed_data:
best_dram_accessed_data = dram_accessed_data
best_tcdm_storage = tcdm_storage
best_s_tile_ln = S_tile

print("LayerNorm Best S_tile: ", best_s_tile_ln)
if(brief == True):
seq_len = num_iters * best_s_tile_ln
embeddings = embeddings // 70

print("LayerNorm Sequence length: ", seq_len)

# Layer 1: LayerNorm layer
ifmap = torch.randn(seq_len, embeddings,
requires_grad=False, dtype=dtype)

weights_q = torch.randn(param['input_dim']['embeddings'], param['input_dim']['positional_embeddings'],

eps = param['eps']

m = nn.LayerNorm(ifmap.size()[1:])

# TODO: due to a bug in PyTorch, we need to cast the input to float32 or BFloat16
ifmap = ifmap.type(torch.float32)

ifmap_ln = m(ifmap)

# cast back to the original data type
ifmap_ln = ifmap_ln.to(dtype).detach()
ifmap = ifmap.to(dtype)

# Layer 2: Linear layer 1
# TODO: check whether we go for min DRAM accesses or min DRAM accessed data
# we reset the best solution parameters
embeddings = param['input_dim']['embeddings']
best_dram_accessed_data = float('inf')
best_dram_accesses = float('inf')
best_tcdm_storage = 0
best_s_tile_lin1 = 0
best_p_tile_lin1 = 0

for S_tile in range(8, seq_len, 8):
for P_tile in range(1, positional_embeddings, 1):
dram_accessed_data = ((seq_len // S_tile) * (S_tile * embeddings) + \
(seq_len // S_tile) * 3 * (positional_embeddings // P_tile) * embeddings * P_tile \
+ (seq_len // S_tile) * (positional_embeddings // P_tile) * S_tile * P_tile) * data_type_size
dram_accesses = (seq_len // S_tile) + (seq_len // S_tile) * 2 * 3 * (positional_embeddings // P_tile)
tcdm_storage = (S_tile * embeddings + 3 * embeddings * P_tile + 3 * S_tile * P_tile) * data_type_size
if tcdm_storage <= tcdm_size:
if tcdm_storage > best_tcdm_storage or dram_accesses < best_dram_accesses:#or dram_accessed_data < best_dram_accessed_data:
best_dram_accessed_data = dram_accessed_data
best_dram_accesses = dram_accesses
best_tcdm_storage = tcdm_storage
best_s_tile_lin1 = S_tile
best_p_tile_lin1 = P_tile

print("Layer 1 Best S_tile: ", best_s_tile_lin1)
print("Layer 1 Best P_tile: ", best_p_tile_lin1)
if(brief == True):
seq_len = num_iters * best_s_tile_lin1
positional_embeddings = num_iters * best_p_tile_lin1
embeddings = embeddings // 70

print("Layer 1 Sequence length: ", seq_len)
print("Layer 1 Positional embeddings: ", positional_embeddings)
print("Layer 1 Embeddings: ", embeddings)

weights_q = torch.randn(embeddings, positional_embeddings,
requires_grad=False, dtype=dtype)
weights_k = torch.randn(embeddings, positional_embeddings,
requires_grad=False, dtype=dtype)
weights_v = torch.randn(embeddings, positional_embeddings,
requires_grad=False, dtype=dtype)
weights_k = torch.randn(param['input_dim']['embeddings'], param['input_dim']['positional_embeddings'],

# Layer 3: FlashAttention-2 layer
# TODO: check whether we go for min DRAM accesses or min DRAM accessed data

# we reset the best solution parameters
# TODO: For the full model, we must also reset the sequence length
# seq_len = param['input_dim']['seq_len']
embeddings = param['input_dim']['embeddings']
positional_embeddings = param['input_dim']['positional_embeddings']
best_dram_accessed_data = float('inf')
best_dram_accesses = float('inf')
best_tcdm_storage = 0
best_br_tile_fa = 0
best_bc_tile_fa = 0

for B_r in range(8, seq_len, 8):
for B_c in range(2, seq_len, 2):
dram_accesses = (seq_len // B_r) + (seq_len // B_r) * 2 * (seq_len // B_c) + (seq_len // B_r)
dram_accessed_data = ((seq_len // B_r) * (B_r * positional_embeddings) + (seq_len // B_r) * 2 * (seq_len // B_c) * (B_c * positional_embeddings) + (seq_len // B_r) * (B_r * B_c)) * data_type_size
tcdm_storage = (B_r * positional_embeddings + 2 * positional_embeddings * B_c + 4 * B_r + 2 * B_r * B_c ) * data_type_size
if tcdm_storage <= tcdm_size:
if tcdm_storage > best_tcdm_storage or dram_accesses < best_dram_accesses:
best_dram_accessed_data = dram_accessed_data
best_dram_accesses = dram_accesses
best_tcdm_storage = tcdm_storage
best_br_tile_fa = B_r
best_bc_tile_fa = B_c

print("FlashAttention Layer 2 Best B_r: ", best_br_tile_fa)
print("FlashAttention Layer 2 Best P_tile: ", best_bc_tile_fa)
if(brief == True):
# seq_len = num_iters * best_br_tile_fa
seq_len = num_iters * best_s_tile_ln
positional_embeddings = positional_embeddings // 10
embeddings = embeddings // 70

print("FlashAttention Layer 2 Sequence length: ", seq_len)
print("FlashAttention Layer 2 Positional embeddings: ", positional_embeddings)
print("FlashAttention Layer 2 Embeddings: ", embeddings)

positional_embeddings_fa = positional_embeddings

# q_fa = torch.randn(seq_len, positional_embeddings_fa,
# requires_grad=False, dtype=dtype)
# k_fa = torch.randn(positional_embeddings_fa, seq_len,
# requires_grad=False, dtype=dtype)
# v_fa = torch.randn(seq_len, positional_embeddings_fa,
# requires_grad=False, dtype=dtype)
# create tensors where the values are within 0 and 1
q_fa = torch.rand(seq_len, positional_embeddings_fa,
requires_grad=False, dtype=dtype)
weights_v = torch.randn(param['input_dim']['embeddings'], param['input_dim']['positional_embeddings'],
k_fa = torch.rand(positional_embeddings_fa, seq_len,
requires_grad=False, dtype=dtype)
weights_o = torch.randn(param['input_dim']['positional_embeddings'], param['input_dim']['embeddings'],
v_fa = torch.rand(seq_len, positional_embeddings_fa,
requires_grad=False, dtype=dtype)

weights_o = torch.randn(param['input_dim']['heads'] * positional_embeddings, embeddings,
requires_grad=False, dtype=dtype)

query = transformer(ifmap, weights_q.T, None, None, None, False)
key = transformer(ifmap, weights_k.T, None, None, None, False)
value = transformer(ifmap, weights_v.T, None, None, None, False)

sm = torch.nn.Softmax(dim=-1)
attention = sm(torch.matmul(query, key.T))
attention_out = torch.matmul(attention, value)
transformer_out = torch.matmul(attention_out, weights_o)
ofmap = torch.randn(param['input_dim']['seq_len'], param['input_dim']['heads'] * positional_embeddings,
requires_grad=False, dtype=dtype)

feedforward_len = param['input_dim']['feedforward_len']

kwargs = {
'ifmap': ifmap,
'ifmap_ln': ifmap_ln,
'S_tile_ln': best_s_tile_ln,
'S_tile_lin1': best_s_tile_lin1,
'P_tile_lin1': best_p_tile_lin1,
'Br_tile_fa': best_br_tile_fa,
'Bc_tile_fa': best_bc_tile_fa,
'positional_embeddings_fa': positional_embeddings_fa,
'weights_q': weights_q,
'weights_k': weights_k,
'weights_v': weights_v,
'q_fa': q_fa,
'k_fa': k_fa,
'v_fa': v_fa,
'weights_o': weights_o,
'ofmap': ofmap,
'query': query,
'key': key,
'value': value,
'prec': param['prec'],
'eps': eps,
'feedforward_len': feedforward_len,
}

emit_header_file(args.output, 'Transformer', **kwargs)
Expand Down
9 changes: 7 additions & 2 deletions target/snitch_cluster/sw/apps/transformer/src/params.hjson
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@
{
kernel: "Transformer"
input_dim: {
seq_len: 128,
seq_len: 2048,
embeddings: 768,
positional_embeddings: 64,
feedforward_len: 3072,
heads: 12,
}
prec: 32
eps: 1e-5
prec: 64
brief: true
num_iters: 2
}

0 comments on commit e074c13

Please sign in to comment.