Skip to content

Commit

Permalink
sw: Replace AtA kernel with syrk
Browse files Browse the repository at this point in the history
  • Loading branch information
colluca committed Aug 20, 2024
1 parent d1ab148 commit 0385632
Show file tree
Hide file tree
Showing 15 changed files with 180 additions and 146 deletions.
6 changes: 3 additions & 3 deletions sw/apps/covariance/src/covariance.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
// Luca Colagrande <[email protected]>

#include "args.h"
#include "blas.h"
#include "snrt.h"
#include "ata.h"

#define DOUBLE_BUFFER 1

Expand Down Expand Up @@ -41,7 +41,7 @@ void covariance_naive(uint32_t m, uint32_t n, double inv_n,
snrt_cluster_hw_barrier();

// Compute covariance matrix
ata_naive(inv_n_m1, m, n, data, datat, cov);
syrk_naive(m, n, inv_n_m1, data, datat, 0, cov);
}

void covariance_baseline(uint32_t m, uint32_t n, double inv_n,
Expand Down Expand Up @@ -74,7 +74,7 @@ void covariance_baseline(uint32_t m, uint32_t n, double inv_n,
snrt_cluster_hw_barrier();

// Compute covariance matrix
ata_baseline(inv_n_m1, m, n, data, datat, cov);
syrk_baseline(m, n, inv_n_m1, data, datat, 0, cov);
}

void covariance_opt(uint32_t m, uint32_t n, double inv_n,
Expand Down
14 changes: 14 additions & 0 deletions sw/blas/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@

#pragma once

// Floating-point multiplications by zero cannot be optimized as in some
// edge cases they do not yield zero:
// - 0f * NaN = NaN
// - 0f * INFINITY == NaN
// Thus in order to optimize it, we need to test for zero. You can use this
// function for free when `multiplier` is a constant.
static inline double multiply_opt(double multiplicand, double multiplier) {
if (multiplier)
return multiplicand * multiplier;
else
return 0;
}

#include "axpy/src/axpy.h"
#include "dot/src/dot.h"
#include "gemm/src/gemm.h"
#include "syrk/src/syrk.h"
13 changes: 0 additions & 13 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,6 @@ typedef float v2f32 __attribute__((vector_size(8)));
typedef __fp16 v4f16 __attribute__((vector_size(8)));
typedef char v8f8 __attribute__((vector_size(8)));

// Floating-point multiplications by zero cannot be optimized as in some
// edge cases they do not yield zero:
// - 0f * NaN = NaN
// - 0f * INFINITY == NaN
// Thus in order to optimize it, we need to test for zero. You can use this
// function for free when `multiplier` is a constant.
static inline double multiply_opt(double multiplicand, double multiplier) {
if (multiplier)
return multiplicand * multiplier;
else
return 0;
}

#include "gemm_fp16.h"
#include "gemm_fp32.h"
#include "gemm_fp64.h"
Expand Down
2 changes: 1 addition & 1 deletion sw/blas/gemm/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <math.h>
#include <stdint.h>

#include "gemm.h"
#include "blas.h"

#include "data.h"
#include "snrt.h"
Expand Down
File renamed without changes.
10 changes: 6 additions & 4 deletions sw/apps/ata/data/params.json → sw/blas/syrk/data/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
// SPDX-License-Identifier: Apache-2.0

{
"m": 16,
"n": 4,
"m_tiles": 2,
"funcptr": "ata_opt"
"m": 8,
"n": 2,
"alpha": 1.5,
"beta": 3.2,
"m_tiles": 1,
"funcptr": "syrk_opt"
}
46 changes: 29 additions & 17 deletions sw/apps/ata/scripts/datagen.py → sw/blas/syrk/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,27 @@

DOUBLE_BUFFER = True

class AtaDataGen(DataGen):
class SyrkDataGen(DataGen):

# Function pointers to alternative implementations
FUNCPTRS = ["ata_naive", "ata_baseline", "ata_opt"]
FUNCPTRS = ["syrk_naive", "syrk_baseline", "syrk_opt"]

def golden_model(self, alpha, A):
return alpha * np.matmul(A, A.transpose())
def golden_model(self, alpha, A, beta, C):
return alpha * np.matmul(A, A.transpose()) + beta * C

def validate(self, **kwargs):
n_cores = 8
assert (kwargs['m'] % kwargs['m_tiles']) == 0, "m must be an integer multiple of m_tiles"
m_frac = kwargs['m'] / kwargs['m_tiles']
assert (m_frac % 8) == 0, "m_frac must be an integer multiple of the number of cores"
assert (m_frac % 4) == 0, "m_frac must be an integer multiple of the unroll factor 4"
assert (m_frac % n_cores) == 0, "m_frac must be an integer multiple of the number of cores"
if kwargs['funcptr'] != "syrk_naive":
assert (m_frac % 4) == 0, "m_frac must be an integer multiple of the unroll factor 4"
assert kwargs['funcptr'] in self.FUNCPTRS, f"Function pointer must be among {self.FUNCPTRS}"

# Calculate total TCDM occupation
a_tile_size = m_frac * kwargs['n'] * 8
b_tile_size = m_frac * m_frac * 8
total_size = 2 * a_tile_size + b_tile_size
c_tile_size = m_frac * m_frac * 8
total_size = 2 * a_tile_size + c_tile_size
if DOUBLE_BUFFER:
total_size *= 2
data_utils.validate_tcdm_footprint(total_size)
Expand All @@ -42,33 +44,43 @@ def emit_header(self, **kwargs):

self.validate(**kwargs)

if 'alpha' in kwargs:
alpha = kwargs['alpha']
else:
alpha = np.random.randint(-200, 100)/100
if 'beta' in kwargs:
beta = kwargs['beta']
else:
beta = np.random.randint(-200, 100)/100

A = np.random.randint(-200, 100, size=(kwargs['m'], kwargs['n']))/100
alpha = np.random.randint(-200, 100)/100
B = self.golden_model(alpha, A)
C_in = np.random.randint(-200, 100, size=(kwargs['m'], kwargs['m']))/100
C_out = self.golden_model(alpha, A, beta, C_in)

A = A.flatten()
B = B.flatten()
C_in = C_in.flatten()

A_uid = 'A'
B_uid = 'B'
C_uid = 'C'

cfg = {
'alpha': alpha,
'm': kwargs['m'],
'n': kwargs['n'],
'alpha': alpha,
'beta': beta,
'a': A_uid,
'b': B_uid,
'c': C_uid,
'm_tiles': kwargs['m_tiles'],
'funcptr': kwargs['funcptr']
}

header += [format_array_definition('double', A_uid, A)]
header += [format_array_declaration('double', B_uid, B.shape)]
header += [format_struct_definition('ata_args_t', 'args', cfg)]
header += [format_array_definition('double', C_uid, C_in)]
header += [format_struct_definition('syrk_args_t', 'args', cfg)]
header = '\n\n'.join(header)

return header


if __name__ == '__main__':
AtaDataGen().main()
SyrkDataGen().main()
20 changes: 13 additions & 7 deletions sw/apps/ata/scripts/verify.py → sw/blas/syrk/scripts/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,24 @@

import numpy as np
import sys
from datagen import AtaDataGen
from datagen import SyrkDataGen

from snitch.util.sim.verif_utils import Verifier


class AtaVerifier(Verifier):
class SyrkVerifier(Verifier):

OUTPUT_UIDS = ['B']
OUTPUT_UIDS = ['C']

def __init__(self):
super().__init__()
self.func_args = {
'alpha': 'd',
'm': 'I',
'n': 'I',
'alpha': 'd',
'beta': 'd',
'A': 'I',
'B': 'I',
'C': 'I',
'm_tiles': 'I',
'funcptr': 'I'
}
Expand All @@ -34,12 +35,17 @@ def get_actual_results(self):

def get_expected_results(self):
A = self.get_input_from_symbol('A', 'double')
C = self.get_input_from_symbol('C', 'double')
A = np.reshape(A, (self.func_args['m'], self.func_args['n']))
return AtaDataGen().golden_model(self.func_args['alpha'], A).flatten()
C = np.reshape(C, (self.func_args['m'], self.func_args['m']))
return SyrkDataGen().golden_model(
self.func_args['alpha'], A,
self.func_args['beta'], C
).flatten()

def check_results(self, *args):
return super().check_results(*args, rtol=1e-10)


if __name__ == "__main__":
sys.exit(AtaVerifier().main())
sys.exit(SyrkVerifier().main())
13 changes: 7 additions & 6 deletions sw/apps/ata/src/args.h → sw/blas/syrk/src/args.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
#pragma once
#include <stdint.h>

typedef void (*ata_fp_t)(double alpha, uint32_t m, uint32_t n, double *a,
double *at, double *b);
typedef void (*syrk_fp_t)(uint32_t m, uint32_t n, double alpha, double *a,
double *at, double beta, double *b);

typedef struct {
double alpha;
uint32_t m;
uint32_t n;
double alpha;
double beta;
double *a;
double *b;
double *c;
uint32_t m_tiles;
ata_fp_t funcptr;
} ata_args_t;
syrk_fp_t funcptr;
} syrk_args_t;
4 changes: 2 additions & 2 deletions sw/apps/ata/src/main.c → sw/blas/syrk/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

#include "snrt.h"

#include "ata.h"
#include "blas.h"
#include "data.h"

int main() {

ata_job(&args);
syrk_job(&args);

return 0;
}
Loading

0 comments on commit 0385632

Please sign in to comment.