-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add simple matmul on gemm accelerator #39
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool! Some minor comments.
I mostly don't like how we use custom variables and libraries from another repo without clearly referencing their location, but I don't know what is a better solution
kernels/simple_matmul/gendata.py
Outdated
C_golden = np.matmul(A.astype(np.dtype("int32")), B.astype(np.dtype("int32"))) | ||
C = np.zeros(C_golden.shape, np.dtype("int32")) | ||
|
||
assert A.shape[1] == B.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assert must come before the np.matmul computation. If the shapes do not correspond, the matmul would fail anyway, making the assert useless
kernels/simple_matmul/gendata.py
Outdated
|
||
# C = A.B | ||
A = np.random.randint(low_bound, high_bound, size=A_size, dtype=np.dtype("int8")) | ||
# A = np.ones(A_size, dtype=np.dtype("int8")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove comment
kernels/simple_matmul/gendata.py
Outdated
A = np.random.randint(low_bound, high_bound, size=A_size, dtype=np.dtype("int8")) | ||
# A = np.ones(A_size, dtype=np.dtype("int8")) | ||
B = np.random.randint(low_bound, high_bound, size=B_size, dtype=np.dtype("int8")) | ||
# B = np.ones(B_size, dtype=np.dtype("int8")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove comment
runtime/include/memref.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The MLIR docs show quite a nice way to have only one memref struct specification for all types, in C++
template<typename T, size_t N>
struct MemRefDescriptor {
T *allocated;
T *aligned;
intptr_t offset;
intptr_t sizes[N];
intptr_t strides[N];
};
do you know if something similar is possible in C? if not, I guess this is fine for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is possible, but it looks a bit complicated https://isocpp.org/wiki/faq/mixing-c-and-cpp , let's take this for a future PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good!
kernels/simple_matmul/main.c
Outdated
#include "snax-gemm-lib.h" | ||
#include "snax-gemm-params.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me it feels kind of weird to include header files from another repository without specifying where they come from 😕 . I guess it also doesn't make sense to have them duplicate... Is there a way to share this in a better way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not, maybe make it clear with a comment where these files come from
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed! Clear like this?
kernels/simple_matmul/main.c
Outdated
memrefA.stride[0] = sizeof(int8_t); | ||
memrefA.stride[1] = sizeof(int8_t); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we were to use a standard layout, this would also not be correct, but rather:
memrefA.stride[0] = sizeof(int8_t); | |
memrefA.stride[1] = sizeof(int8_t); | |
memrefA.stride[0] = sizeof(int8_t); | |
memrefA.stride[1] = sizeof(int8_t) * M_size; |
Maybe just set them to 0 to make it very clear we are not using them.
kernels/simple_matmul/main.c
Outdated
memrefA.aligned_data = memrefA.data; | ||
memrefA.shape[0] = M_size; | ||
memrefA.shape[1] = K_size; | ||
// These are not considered correctly right now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment not very clear.
// These are not considered correctly right now | |
// Strides are not used due to the tiled-block layout. | |
// Instead we use the variables strideInnermostA, ldA and strideA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!!
kernels/simple_mult/main.c
Outdated
@@ -55,6 +55,7 @@ int main() { | |||
|
|||
int nerr = 0; | |||
for (int i = 0; i < N; i++) { | |||
printf("result: %d golden: %d\n", memrefD.aligned_data[i], G[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To delete then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will make a seperate PR for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait no
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's gone now haha
Update: This PR just mainly puts everything in place (data layout and movement wise) to execute stuff on the gemm accelerator.
Right now it mostly bypasses MLIR stuff, because still missing:
WIP on adding a simple quantized matmul to this repository
Needs work:
figure out what parameters in set_batch_gemm can be extracted from the operation/memref itself and which ones we can assume to be hardcoded for now?most are still hardcoded because we are not working on 4d memrefs yet.Add support for or remove entirely the call to setup the CSRs for the accelerator.accelerator not considered in this PR