Skip to content

Commit

Permalink
dnn: Refactor and verify layernorm
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Oct 28, 2023
1 parent f869b2e commit 3cb1bff
Show file tree
Hide file tree
Showing 54 changed files with 1,614 additions and 612 deletions.
3 changes: 3 additions & 0 deletions .clang-format-ignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@

# Ignore vendored third-party code
./sw/math/*
./target/snitch_cluster/sw/apps/transformer/src/transformer.c
./target/snitch_cluster/sw/apps/transformer/src/data.h
./sw/apps/transformer/src/transformer.h
1 change: 1 addition & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ jobs:
with:
flake8-version: "6.0.0"
max-line-length: "100"
exclude: "target/snitch_cluster/sw/apps/dnn/datagen.py"

######################
# Clang-Format Check #
Expand Down
14 changes: 7 additions & 7 deletions sw/blas/axpy/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import os

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
from data_utils import format_scalar_definition, format_vector_definition, \
format_vector_declaration, format_ifdef_wrapper # noqa: E402
from data_utils import format_scalar_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper # noqa: E402

MIN = -1000
MAX = +1000
Expand Down Expand Up @@ -47,16 +47,16 @@ def main():
a = np.random.uniform(MIN, MAX, 1)
x = np.random.uniform(MIN, MAX, length)
y = np.random.uniform(MIN, MAX, length)
z = np.zeros(length)
g = golden_model(a, x, y)

# Format header file
l_str = format_scalar_definition('const uint32_t', 'l', length)
a_str = format_scalar_definition('const double', 'a', a[0])
x_str = format_vector_definition('double', 'x', x, alignment=BURST_ALIGNMENT, section=section)
y_str = format_vector_definition('double', 'y', y, alignment=BURST_ALIGNMENT, section=section)
z_str = format_vector_declaration('double', 'z', z, alignment=BURST_ALIGNMENT, section=section)
g_str = format_vector_definition('double', 'g', g)
x_str = format_array_definition('double', 'x', x, alignment=BURST_ALIGNMENT, section=section)
y_str = format_array_definition('double', 'y', y, alignment=BURST_ALIGNMENT, section=section)
z_str = format_array_declaration('double', 'z', [length],
alignment=BURST_ALIGNMENT, section=section)
g_str = format_array_definition('double', 'g', g)
g_str = format_ifdef_wrapper('BIST', g_str)
f_str = '\n\n'.join([l_str, a_str, x_str, y_str, z_str, g_str])
f_str += '\n'
Expand Down
16 changes: 8 additions & 8 deletions sw/blas/gemm/data/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
from data_utils import emit_license, format_scalar_definition, \
format_vector_definition, format_ifdef_wrapper # noqa: E402
format_array_definition, format_ifdef_wrapper # noqa: E402


np.random.seed(42)
Expand Down Expand Up @@ -100,18 +100,18 @@ def emit_header(**kwargs):
data_str += [format_scalar_definition('uint32_t', 'BETA', kwargs['beta'])]
data_str += [format_scalar_definition('uint32_t', 'dtype_size', kwargs['prec']//8)]
data_str += [format_scalar_definition('uint32_t', 'expand', kwargs['expand'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten(),
data_str += [format_array_definition(C_TYPES[str(kwargs['prec'])], 'a', a.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'b', b.flatten(),
data_str += [format_array_definition(C_TYPES[str(kwargs['prec'])], 'b', b.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
data_str += [format_vector_definition(C_TYPES[str(kwargs['prec'])], 'c', c.flatten(),
data_str += [format_array_definition(C_TYPES[str(kwargs['prec'])], 'c', c.flatten(),
alignment=BURST_ALIGNMENT, section=kwargs['section'])]
if kwargs['prec'] == 8:
result_def = format_vector_definition(C_TYPES['64'], 'result', result.flatten())
result_def = format_array_definition(C_TYPES['64'], 'result', result.flatten())
else:
result_def = format_vector_definition(C_TYPES[str(kwargs['prec'])],
'result',
result.flatten())
result_def = format_array_definition(C_TYPES[str(kwargs['prec'])],
'result',
result.flatten())
data_str += [format_ifdef_wrapper('BIST', result_def)]
data_str = '\n\n'.join(data_str)

Expand Down
36 changes: 17 additions & 19 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ typedef char v8f8 __attribute__((vector_size(8)));
dump_float(gemm, 8);
dump_uint(index, 9);


void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
uint32_t ldA, uint32_t ta, double* B, uint32_t ldB,
uint32_t tb, double* C, uint32_t ldC, double BETA) {
Expand Down Expand Up @@ -74,24 +73,23 @@ void gemm_fp64_baseline(uint32_t M, uint32_t N, uint32_t K, double* A,
}

/* params:
* M: number of rows of A and C
* N: number of columns of B and C
* K: number of columns of A and rows of B
* A: pointer to matrix A
* ldA: row stride of A
* ta: transpose A
* B: pointer to matrix B
* ldB: row stride of B
* tb: transpose B
* C: pointer to matrix C
* ldC: row stride of C
* ALPHA: scalar alpha
* A is MxK, B is KxN, C is MxN
*/
* M: number of rows of A and C
* N: number of columns of B and C
* K: number of columns of A and rows of B
* A: pointer to matrix A
* ldA: row stride of A
* ta: transpose A
* B: pointer to matrix B
* ldB: row stride of B
* tb: transpose B
* C: pointer to matrix C
* ldC: row stride of C
* ALPHA: scalar alpha
* A is MxK, B is KxN, C is MxN
*/
void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A,
uint32_t ldA, uint32_t ta, float* B, uint32_t ldB,
uint32_t tb, float* C, uint32_t ldC, float ALPHA) {

// float c0, c1, c2, c3 = 0;
float c0 = 0.0f;
float c1 = 0.0f;
Expand All @@ -110,7 +108,7 @@ void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A,
c1 = 0.0f;
c2 = 0.0f;
c3 = 0.0f;
for (uint32_t k = 0; k < K; k+=4) {
for (uint32_t k = 0; k < K; k += 4) {
c0 += A[(k + 0) + m * ldA] * B[(k + 0) * ldB + n];
c1 += A[(k + 1) + m * ldA] * B[(k + 1) * ldB + n];
c2 += A[(k + 2) + m * ldA] * B[(k + 2) * ldB + n];
Expand All @@ -131,7 +129,7 @@ void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A,
c1 = 0.0f;
c2 = 0.0f;
c3 = 0.0f;
for (uint32_t k = 0; k < K; k+=4) {
for (uint32_t k = 0; k < K; k += 4) {
c0 += A[(k + 0) * M * ldA + m * ldA] * B[(k + 0) * ldB + n];
c1 += A[(k + 1) * M * ldA + m * ldA] * B[(k + 1) * ldB + n];
c2 += A[(k + 2) * M * ldA + m * ldA] * B[(k + 2) * ldB + n];
Expand All @@ -152,7 +150,7 @@ void gemm_fp32_baseline(uint32_t M, uint32_t N, uint32_t K, float* A,
c1 = 0.0f;
c2 = 0.0f;
c3 = 0.0f;
for (uint32_t k = 0; k < K; k+=4) {
for (uint32_t k = 0; k < K; k += 4) {
// c0 += A[k + m * ldA] * B[k + n * ldB];
c0 += A[(k + 0) + m * ldA] * B[(k + 0) + n * ldB];
c1 += A[(k + 1) + m * ldA] * B[(k + 1) + n * ldB];
Expand Down
1 change: 1 addition & 0 deletions sw/dnn/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*/data/data.h
136 changes: 136 additions & 0 deletions sw/dnn/batchnorm/data/datagen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#!/usr/bin/env python3
# Copyright 2023 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
#
# Tim Fischer <[email protected]>
# Viviane Potocnik <[email protected]>
# Luca Colagrande <[email protected]>

import argparse
import pathlib
import hjson
import sys
import os
import torch

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
import data_utils # noqa: E402
from data_utils import emit_license, \
format_struct_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper # noqa: E402

torch.manual_seed(42)

# AXI splits bursts crossing 4KB address boundaries. To minimize
# the occurrence of these splits the data should be aligned to 4KB
BURST_ALIGNMENT = 4096

PRECISION_T = {
'64': 'FP64',
'32': 'FP32',
'16': 'FP16',
'8': 'FP8'
}


def golden_model(ifmap):
n, ci, ih, iw = ifmap.shape
bn = torch.nn.BatchNorm2d(ci)
bn.weight.requires_grad = False
bn.bias.requires_grad = False
running_mean = torch.randn_like(bn.running_mean, requires_grad=False)
running_var = torch.rand_like(bn.running_var, requires_grad=False)
gamma = bn.weight / torch.sqrt(running_var + bn.eps)
beta = bn.bias - running_mean * bn.weight / torch.sqrt(running_var + bn.eps)
ofmap = ifmap * gamma.unsqueeze(-1).unsqueeze(-1) + beta.unsqueeze(-1).unsqueeze(-1)
return ofmap, gamma, beta


def emit_header(**kwargs):

in_channels = kwargs['input_dim']['channels']
in_height = kwargs['input_dim']['height']
in_width = kwargs['input_dim']['width']
tile_ci = kwargs['tile_ci']
prec = str(kwargs['prec'])

torch_type = data_utils.floating_point_torch_type(prec)
ctype = data_utils.floating_point_ctype(prec)

ifmap = torch.randn(1, in_channels, in_height, in_width, requires_grad=False, dtype=torch_type)
ofmap, gamma, beta = golden_model(ifmap)

# convert from CHW to HWC format
ifmap = ifmap.permute(0, 2, 3, 1)
ofmap = ofmap.permute(0, 2, 3, 1)

n, ih, iw, ci = ifmap.shape

ifmap_uid = 'ifmap'
ofmap_uid = 'ofmap'
beta_uid = 'beta'
gamma_uid = 'gamma'

layer_cfg = {
'CI': ci,
'IH': ih,
'IW': iw,
'TILE_CI': tile_ci,
'ifmap': ifmap_uid,
'ofmap': ofmap_uid,
'beta': beta_uid,
'gamma': gamma_uid
}

data_str = [emit_license()]
# Array forward declarations
data_str += [format_array_declaration(ctype, ifmap_uid, ifmap.shape)]
data_str += [format_array_declaration(ctype, ofmap_uid, ofmap.shape)]
data_str += [format_array_declaration(ctype, beta_uid, beta.shape)]
data_str += [format_array_declaration(ctype, gamma_uid, gamma.shape)]
# Layer struct
data_str += [format_struct_definition('batchnorm_layer_t', 'layer', layer_cfg)]
# Array definitions
data_str += [format_array_definition(ctype, ifmap_uid, ifmap)]
data_str += [format_array_definition(ctype, beta_uid, beta)]
data_str += [format_array_definition(ctype, gamma_uid, gamma)]
# Golden results for BIST
result_def = format_array_definition(ctype, 'golden', ofmap)
data_str += [format_ifdef_wrapper('BIST', result_def)]
data_str = '\n\n'.join(data_str)

return data_str


def main():

parser = argparse.ArgumentParser(description='Generate data for layernorm kernel')
parser.add_argument(
"-c", "--cfg",
type=pathlib.Path,
required=True,
help='Select param config file kernel'
)
parser.add_argument(
'--section',
type=str,
help='Section to store matrices in')
parser.add_argument(
'output',
type=pathlib.Path,
help='Path of the output header file')
args = parser.parse_args()

# Load param config file
with args.cfg.open() as f:
param = hjson.loads(f.read())
param['section'] = args.section

# Emit header file
with open(args.output, 'w') as f:
f.write(emit_header(**param))


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@
// Solderpad Hardware License, Version 0.51, see LICENSE for details.
// SPDX-License-Identifier: SHL-0.51

// Parameters for a single BatchNorm layer

{
kernel: "BatchNorm"
channels: {
out: 32,
in: 32
}
input_dim: {
channels: 32
height: 8,
width: 8
}
tile_ci: 32
prec: 64
}
Loading

0 comments on commit 3cb1bff

Please sign in to comment.