Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Oct 29, 2024
1 parent cb32580 commit acb4088
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 27 deletions.
1 change: 0 additions & 1 deletion compiler/transforms/convert_kernel_to_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class LowerLinalgBody(RewritePattern):

@op_type_rewrite_pattern
def match_and_rewrite(self, linalg_op: linalg.Generic, rewriter: PatternRewriter):

# find the kernel op in linalg body
if not isinstance(kernel_op := linalg_op.body.block.first_op, Parsable):
return
Expand Down
1 change: 0 additions & 1 deletion compiler/transforms/frontend/preprocess_mlperf_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from compiler.dialects import snax
from compiler.transforms.alloc_to_global import AllocToGlobal
from compiler.transforms.convert_tosa_to_kernel import RescaleClampPattern
from compiler.transforms.test.insert_debugs import InsertDebugStatements


class InsertStaticFunctionCall(RewritePattern):
Expand Down
1 change: 1 addition & 0 deletions compiler/transforms/guarded_linalg_to_memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def match_and_rewrite(self, op: linalg.Generic, rewriter: PatternRewriter) -> No
op.library_call = StringAttr(op.library_call.data[: -len("_stream")])
ConvertGenericOpPattern().match_and_rewrite(op, rewriter)


class GuardedYieldOpPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.YieldOp, rewriter: PatternRewriter):
Expand Down
5 changes: 3 additions & 2 deletions compiler/transforms/set_memory_space.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from xdsl.context import MLContext
from xdsl.dialects import builtin, func, linalg, memref
from xdsl.ir import Operation, SSAValue
from xdsl.ir import SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
Expand All @@ -26,7 +26,8 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
# Function must have memref arguments with an undefined memory space
if not any(
[
isinstance(x, builtin.MemRefType) and isinstance(x.memory_space, builtin.NoneAttr)
isinstance(x, builtin.MemRefType)
and isinstance(x.memory_space, builtin.NoneAttr)
for x in [*op.function_type.inputs, *op.function_type.outputs]
]
):
Expand Down
46 changes: 23 additions & 23 deletions kernels/mlperf_tiny_ad01/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,29 @@
* */
#include <snrt.h>


void _mlir_ciface_run_network(TwoDMemrefI8_t *output, TwoDMemrefI8_t *input);

void _mlir_ciface_debug_kernel_qmac(int32_t _ptr_a, int32_t _ptr_b, int32_t _ptr_c, int32_t when) {
void _mlir_ciface_debug_kernel_qmac(int32_t _ptr_a, int32_t _ptr_b,
int32_t _ptr_c, int32_t when) {
// gemm
int8_t *ptr_a, *ptr_b;
int32_t *ptr_c;
ptr_a = (int8_t*) _ptr_a;
ptr_b = (int8_t*) _ptr_b;
ptr_c = (int32_t*) _ptr_c;
ptr_a = (int8_t *)_ptr_a;
ptr_b = (int8_t *)_ptr_b;
ptr_c = (int32_t *)_ptr_c;

int thisc = snrt_cluster_core_idx();

if (thisc == 0) {
printf("Debugging GeMM at t = %d with A at %p, B at %p, C at %p\n", when, ptr_a, ptr_b, ptr_c);
printf("Debugging GeMM at t = %d with A at %p, B at %p, C at %p\n", when,
ptr_a, ptr_b, ptr_c);

for (int i = 0; i < 5; i++) {
printf("i%d -> A=%d, B=%d, C=%d\n", i, ptr_a[i], ptr_b[i], ptr_c[i]);
}

}

for(uint8_t i = 0; i < 20; i++) {
for (uint8_t i = 0; i < 20; i++) {
if (thisc == i) {
printf("Core %d present.\n", thisc);
if (snrt_is_dm_core()) {
Expand All @@ -45,26 +45,27 @@ void _mlir_ciface_debug_kernel_qmac(int32_t _ptr_a, int32_t _ptr_b, int32_t _ptr
}
snrt_cluster_hw_barrier();
}

}

void _mlir_ciface_debug_kernel_add(int32_t _ptr_a, int32_t _ptr_b, int32_t _ptr_c, int32_t when) {
void _mlir_ciface_debug_kernel_add(int32_t _ptr_a, int32_t _ptr_b,
int32_t _ptr_c, int32_t when) {
// bias addition
int32_t *ptr_a, *ptr_b, *ptr_c;
ptr_a = (int32_t*) _ptr_a;
ptr_b = (int32_t*) _ptr_b;
ptr_c = (int32_t*) _ptr_c;
ptr_a = (int32_t *)_ptr_a;
ptr_b = (int32_t *)_ptr_b;
ptr_c = (int32_t *)_ptr_c;

int thisc = snrt_cluster_core_idx();
if (thisc == 0) {
printf("Debugging bias at t = %d with A at %p, B at %p, C at %p\n", when, ptr_a, ptr_b, ptr_c);
printf("Debugging bias at t = %d with A at %p, B at %p, C at %p\n", when,
ptr_a, ptr_b, ptr_c);

for (int i = 0; i < 5; i++) {
printf("i%d -> A=%d, B=%d, C=%d\n", i, ptr_a[i], ptr_b[i], ptr_c[i]);
}
}

for(uint8_t i = 0; i < 20; i++) {
for (uint8_t i = 0; i < 20; i++) {
if (thisc == i) {
printf("Core %d present.\n", thisc);
if (snrt_is_dm_core()) {
Expand All @@ -73,26 +74,26 @@ void _mlir_ciface_debug_kernel_add(int32_t _ptr_a, int32_t _ptr_b, int32_t _ptr_
}
snrt_cluster_hw_barrier();
}

}

void _mlir_ciface_debug_kernel_rescale(int32_t _ptr_a, int32_t _ptr_b, int32_t _ptr_c, int32_t when) {
void _mlir_ciface_debug_kernel_rescale(int32_t _ptr_a, int32_t _ptr_b,
int32_t _ptr_c, int32_t when) {
// simd rescale
int32_t *ptr_a;
int8_t *ptr_c;
ptr_a = (int32_t*) _ptr_a;
ptr_c = (int8_t*) _ptr_c;

ptr_a = (int32_t *)_ptr_a;
ptr_c = (int8_t *)_ptr_c;

int thisc = snrt_cluster_core_idx();
if (thisc == 0) {
printf("Debugging SIMD at t = %d with A at %p, C at %p\n", when, ptr_a, ptr_c);
printf("Debugging SIMD at t = %d with A at %p, C at %p\n", when, ptr_a,
ptr_c);

for (int i = 0; i < 5; i++) {
printf("i%d -> A=%d, C=%d\n", i, ptr_a[i], ptr_c[i]);
}
}
for(uint8_t i = 0; i < 20; i++) {
for (uint8_t i = 0; i < 20; i++) {
if (thisc == i) {
printf("Core %d present.\n", thisc);
if (snrt_is_dm_core()) {
Expand All @@ -101,7 +102,6 @@ void _mlir_ciface_debug_kernel_rescale(int32_t _ptr_a, int32_t _ptr_b, int32_t _
}
snrt_cluster_hw_barrier();
}

}

int main() {
Expand Down

0 comments on commit acb4088

Please sign in to comment.