Skip to content

Commit

Permalink
Rewrite syr (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamirDroubi authored Feb 16, 2024
1 parent 9352184 commit 34ccb21
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 166 deletions.
147 changes: 10 additions & 137 deletions src/level2/syr.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,23 @@
from __future__ import annotations

from exo import *
from exo.libs.memories import DRAM_STATIC
from exo.platforms.x86 import *
from exo.syntax import *
from exo.stdlib.scheduling import *

import exo_blas_config as C

from composed_schedules import *
from blaslib import *
from codegen_helpers import *


### EXO_LOC ALGORITHM START ###
@proc
def syr_row_major_Upper(n: size, alpha: R, x: [R][n], A: [R][n, n]):
def syr_rm_u(n: size, alpha: R, x: [R][n], A: [R][n, n]):
assert stride(A, 1) == 1

for i in seq(0, n):
for j in seq(0, n - i):
A[i, i + j] += alpha * x[i] * x[i + j]
for j in seq(0, i + 1):
A[n - i - 1, n - j - 1] += (alpha * x[n - i - 1]) * x[n - j - 1]


@proc
def syr_row_major_Lower(n: size, alpha: R, x: [R][n], A: [R][n, n]):
def syr_rm_l(n: size, alpha: R, x: [R][n], A: [R][n, n]):
assert stride(A, 1) == 1

for i in seq(0, n):
Expand All @@ -34,133 +29,11 @@ def syr_row_major_Lower(n: size, alpha: R, x: [R][n], A: [R][n, n]):


### EXO_LOC SCHEDULE START ###
def specialize_syr(syr, precision):
prefix = "s" if precision == "f32" else "d"
name = syr.name()
name = name.replace("", "")
specialized = rename(syr, "exo_" + prefix + name)

args = ["alpha", "x", "A"]

for arg in args:
specialized = set_precision(specialized, arg, precision)

return simplify(specialized)


def schedule_interleave_syr_row_major_stride_1(
syr, VEC_W, INTERLEAVE_FACTOR, memory, instructions, precision
):
stride_1 = specialize_syr(syr, precision)
stride_1 = rename(stride_1, stride_1.name() + "_stride_1")
stride_1 = stride_1.add_assertion("stride(x, 0) == 1")

j_loop = stride_1.find_loop("j")
stride_1 = scalar_to_simd(stride_1, j_loop, VEC_W, memory, precision)
stride_1 = hoist_from_loop(stride_1, j_loop)
stride_1 = replace_all_stmts(stride_1, instructions)
return simplify(stride_1)


#################################################
# Kernel Parameters
#################################################

#################################################
# Generate specialized kernels for f32 precision
#################################################

exo_ssyr_row_major_Upper_stride_any = specialize_syr(syr_row_major_Upper, "f32")
exo_ssyr_row_major_Upper_stride_any = rename(
exo_ssyr_row_major_Upper_stride_any,
exo_ssyr_row_major_Upper_stride_any.name() + "_stride_any",
variants_generator(optimize_level_2)(
syr_rm_u, "i", 4, 2, round_up=False, globals=globals()
)
exo_ssyr_row_major_Lower_stride_any = specialize_syr(syr_row_major_Lower, "f32")
exo_ssyr_row_major_Lower_stride_any = rename(
exo_ssyr_row_major_Lower_stride_any,
exo_ssyr_row_major_Lower_stride_any.name() + "_stride_any",
variants_generator(optimize_level_2)(
syr_rm_l, "i", 4, 2, round_up=False, globals=globals()
)
f32_instructions = [
C.Machine.load_instr_f32,
C.Machine.store_instr_f32,
C.Machine.mul_instr_f32,
C.Machine.fmadd_reduce_instr_f32,
C.Machine.broadcast_instr_f32,
C.Machine.broadcast_scalar_instr_f32,
]

exo_ssyr_row_major_Upper_stride_1 = schedule_interleave_syr_row_major_stride_1(
syr_row_major_Upper,
C.Machine.f32_vec_width,
1,
C.Machine.mem_type,
f32_instructions,
"f32",
)
exo_ssyr_row_major_Lower_stride_1 = schedule_interleave_syr_row_major_stride_1(
syr_row_major_Lower,
C.Machine.f32_vec_width,
1,
C.Machine.mem_type,
f32_instructions,
"f32",
)

#################################################
# Generate specialized kernels for f64 precision
#################################################

exo_dsyr_row_major_Upper_stride_any = specialize_syr(syr_row_major_Upper, "f64")
exo_dsyr_row_major_Upper_stride_any = rename(
exo_dsyr_row_major_Upper_stride_any,
exo_dsyr_row_major_Upper_stride_any.name() + "_stride_any",
)
exo_dsyr_row_major_Lower_stride_any = specialize_syr(syr_row_major_Lower, "f64")
exo_dsyr_row_major_Lower_stride_any = rename(
exo_dsyr_row_major_Lower_stride_any,
exo_dsyr_row_major_Lower_stride_any.name() + "_stride_any",
)

f64_instructions = [
C.Machine.load_instr_f64,
C.Machine.store_instr_f64,
C.Machine.mul_instr_f64,
C.Machine.fmadd_reduce_instr_f64,
C.Machine.broadcast_instr_f64,
C.Machine.broadcast_scalar_instr_f64,
]

exo_dsyr_row_major_Upper_stride_1 = schedule_interleave_syr_row_major_stride_1(
syr_row_major_Upper,
C.Machine.f32_vec_width // 2,
1,
C.Machine.mem_type,
f64_instructions,
"f64",
)
exo_dsyr_row_major_Lower_stride_1 = schedule_interleave_syr_row_major_stride_1(
syr_row_major_Lower,
C.Machine.f32_vec_width // 2,
1,
C.Machine.mem_type,
f64_instructions,
"f64",
)
### EXO_LOC SCHEDULE END ###

entry_points = [
exo_ssyr_row_major_Upper_stride_any,
exo_ssyr_row_major_Upper_stride_1,
exo_dsyr_row_major_Upper_stride_any,
exo_dsyr_row_major_Upper_stride_1,
exo_ssyr_row_major_Lower_stride_any,
exo_ssyr_row_major_Lower_stride_1,
exo_dsyr_row_major_Lower_stride_any,
exo_dsyr_row_major_Lower_stride_1,
]

if __name__ == "__main__":
for p in entry_points:
print(p)

__all__ = [p.name() for p in entry_points]
2 changes: 1 addition & 1 deletion test/codegen/reference/sha256/avx2.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"exo_swap": "805f2a357a4709c9ad3c548d7c1225d2d630b453269190f1d6a66b765c22c5bf",
"exo_symm": "3676ce1cbd96dbd378b277e26bea9514afa44dfad54e06e22d6434cf93c9b629",
"exo_symv": "d84d487822d9839dd6caf34c77acc3461644facaa938a78648b3d78c1bbc48aa",
"exo_syr": "22b4c2eaa19929159563ddaf2b5d6b4e6ee0967d58e039689e3371f05792dc9a",
"exo_syr": "ec81e1b55e18367b5a941d2f4fbff18a303a432efca0752b8e540a11c2d2eb7f",
"exo_syr2": "8bd9def6d6aba8be6a763a5970721c048b6e5a102c5fe1434fc740ef36636450",
"exo_syrk": "1d751306796e0f091f6049dd28e48ec9fc7ecfd231a3a0072c96e8e7452a09ad",
"exo_tbmv": "dc55c55ecc566cf14a3e59d28884dc900c5476f3294b5b1010f47a0c6fe2b170",
Expand Down
4 changes: 2 additions & 2 deletions test/level2/dsyr/bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ static void BM_cblas_dsyr(benchmark::State &state) {
? CBLAS_ORDER::CblasRowMajor
: CBLAS_ORDER::CblasColMajor;
const enum CBLAS_UPLO Uplo =
state.range(2) == 0 ? CBLAS_UPLO::CblasUpper : CBLAS_UPLO::CblasLower;
state.range(2) == 1 ? CBLAS_UPLO::CblasUpper : CBLAS_UPLO::CblasLower;
const double alpha = state.range(3);
const int lda = N + state.range(4);
const int incX = state.range(5);
Expand All @@ -33,7 +33,7 @@ static void BM_exo_dsyr(benchmark::State &state) {
? CBLAS_ORDER::CblasRowMajor
: CBLAS_ORDER::CblasColMajor;
const enum CBLAS_UPLO Uplo =
state.range(2) == 0 ? CBLAS_UPLO::CblasUpper : CBLAS_UPLO::CblasLower;
state.range(2) == 1 ? CBLAS_UPLO::CblasUpper : CBLAS_UPLO::CblasLower;
const double alpha = state.range(3);
const int lda = N + state.range(4);
const int incX = state.range(5);
Expand Down
24 changes: 12 additions & 12 deletions test/level2/dsyr/exo_dsyr.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,29 @@ void exo_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
}
if (Uplo == CBLAS_UPLO::CblasUpper) {
if (incX == 1) {
exo_dsyr_row_major_Upper_stride_1(
nullptr, N, &alpha, exo_win_1f64c{.data = X, .strides = {incX}},
exo_win_2f64{.data = A, .strides = {lda, 1}});
exo_dsyr_rm_u_stride_1(nullptr, N, &alpha,
exo_win_1f64c{.data = X, .strides = {incX}},
exo_win_2f64{.data = A, .strides = {lda, 1}});
} else {
if (incX < 0) {
X = X + (1 - N) * incX;
}
exo_dsyr_row_major_Upper_stride_any(
nullptr, N, &alpha, exo_win_1f64c{.data = X, .strides = {incX}},
exo_win_2f64{.data = A, .strides = {lda, 1}});
exo_dsyr_rm_u_stride_any(nullptr, N, &alpha,
exo_win_1f64c{.data = X, .strides = {incX}},
exo_win_2f64{.data = A, .strides = {lda, 1}});
}
} else {
if (incX == 1) {
exo_dsyr_row_major_Lower_stride_1(
nullptr, N, &alpha, exo_win_1f64c{.data = X, .strides = {incX}},
exo_win_2f64{.data = A, .strides = {lda, 1}});
exo_dsyr_rm_l_stride_1(nullptr, N, &alpha,
exo_win_1f64c{.data = X, .strides = {incX}},
exo_win_2f64{.data = A, .strides = {lda, 1}});
} else {
if (incX < 0) {
X = X + (1 - N) * incX;
}
exo_dsyr_row_major_Lower_stride_any(
nullptr, N, &alpha, exo_win_1f64c{.data = X, .strides = {incX}},
exo_win_2f64{.data = A, .strides = {lda, 1}});
exo_dsyr_rm_l_stride_any(nullptr, N, &alpha,
exo_win_1f64c{.data = X, .strides = {incX}},
exo_win_2f64{.data = A, .strides = {lda, 1}});
}
}
}
4 changes: 2 additions & 2 deletions test/level2/ssyr/bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ static void BM_cblas_ssyr(benchmark::State &state) {
? CBLAS_ORDER::CblasRowMajor
: CBLAS_ORDER::CblasColMajor;
const enum CBLAS_UPLO Uplo =
state.range(2) == 0 ? CBLAS_UPLO::CblasUpper : CBLAS_UPLO::CblasLower;
state.range(2) == 1 ? CBLAS_UPLO::CblasUpper : CBLAS_UPLO::CblasLower;
const float alpha = state.range(3);
const int lda = N + state.range(4);
const int incX = state.range(5);
Expand All @@ -33,7 +33,7 @@ static void BM_exo_ssyr(benchmark::State &state) {
? CBLAS_ORDER::CblasRowMajor
: CBLAS_ORDER::CblasColMajor;
const enum CBLAS_UPLO Uplo =
state.range(2) == 0 ? CBLAS_UPLO::CblasUpper : CBLAS_UPLO::CblasLower;
state.range(2) == 1 ? CBLAS_UPLO::CblasUpper : CBLAS_UPLO::CblasLower;
const float alpha = state.range(3);
const int lda = N + state.range(4);
const int incX = state.range(5);
Expand Down
24 changes: 12 additions & 12 deletions test/level2/ssyr/exo_ssyr.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,29 @@ void exo_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
}
if (Uplo == CBLAS_UPLO::CblasUpper) {
if (incX == 1) {
exo_ssyr_row_major_Upper_stride_1(
nullptr, N, &alpha, exo_win_1f32c{.data = X, .strides = {incX}},
exo_win_2f32{.data = A, .strides = {lda, 1}});
exo_ssyr_rm_u_stride_1(nullptr, N, &alpha,
exo_win_1f32c{.data = X, .strides = {incX}},
exo_win_2f32{.data = A, .strides = {lda, 1}});
} else {
if (incX < 0) {
X = X + (1 - N) * incX;
}
exo_ssyr_row_major_Upper_stride_any(
nullptr, N, &alpha, exo_win_1f32c{.data = X, .strides = {incX}},
exo_win_2f32{.data = A, .strides = {lda, 1}});
exo_ssyr_rm_u_stride_any(nullptr, N, &alpha,
exo_win_1f32c{.data = X, .strides = {incX}},
exo_win_2f32{.data = A, .strides = {lda, 1}});
}
} else {
if (incX == 1) {
exo_ssyr_row_major_Lower_stride_1(
nullptr, N, &alpha, exo_win_1f32c{.data = X, .strides = {incX}},
exo_win_2f32{.data = A, .strides = {lda, 1}});
exo_ssyr_rm_l_stride_1(nullptr, N, &alpha,
exo_win_1f32c{.data = X, .strides = {incX}},
exo_win_2f32{.data = A, .strides = {lda, 1}});
} else {
if (incX < 0) {
X = X + (1 - N) * incX;
}
exo_ssyr_row_major_Lower_stride_any(
nullptr, N, &alpha, exo_win_1f32c{.data = X, .strides = {incX}},
exo_win_2f32{.data = A, .strides = {lda, 1}});
exo_ssyr_rm_l_stride_any(nullptr, N, &alpha,
exo_win_1f32c{.data = X, .strides = {incX}},
exo_win_2f32{.data = A, .strides = {lda, 1}});
}
}
}

0 comments on commit 34ccb21

Please sign in to comment.