From d6fe912e4b4d5a34b8663ab6b3459db8a223b6bd Mon Sep 17 00:00:00 2001 From: Avimitin Date: Tue, 27 Aug 2024 23:29:37 +0800 Subject: [PATCH] [tests] add pytorch.llama Signed-off-by: Avimitin --- nix/pkgs/buddy-mlir.nix | 6 ++ tests/pytorch/tinyllama/build.nix | 101 +++++++++++++++++++++++++++ tests/pytorch/tinyllama/tinyllama.cc | 60 ++++++++++++++++ tests/pytorch/tinyllama/tinyllama.py | 60 ++++++++++++++++ 4 files changed, 227 insertions(+) create mode 100644 tests/pytorch/tinyllama/build.nix create mode 100644 tests/pytorch/tinyllama/tinyllama.cc create mode 100644 tests/pytorch/tinyllama/tinyllama.py diff --git a/nix/pkgs/buddy-mlir.nix b/nix/pkgs/buddy-mlir.nix index b15cb9397..be44bcaae 100644 --- a/nix/pkgs/buddy-mlir.nix +++ b/nix/pkgs/buddy-mlir.nix @@ -57,7 +57,13 @@ let pyenv = python3.withPackages (ps: [ self ps.torch + + # mobilenet ps.torchvision + + # tinyllama + ps.transformers + ps.accelerate ]); }; }; diff --git a/tests/pytorch/tinyllama/build.nix b/tests/pytorch/tinyllama/build.nix new file mode 100644 index 000000000..6b5030b26 --- /dev/null +++ b/tests/pytorch/tinyllama/build.nix @@ -0,0 +1,101 @@ +{ buildBuddyE2ETest, fetchgit }: +let + model = fetchgit { + url = "https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0"; + rev = "fe8a4ea1ffedaf415f4da2f062534de366a451e6"; + fetchLFS = true; + hash = "sha256-vp/aUHKX+NJZZMIk2CgSh2czeGD0HeQGS30p/If2pA0="; + }; +in +buildBuddyE2ETest { + caseName = "tinyllama"; + + passthru.model = model; + + env.LLAMA_MODEL_PATH = "${model}"; + optPhase = '' + python ./tinyllama.py + + echo "Lowering forward.mlir" + buddy-opt forward.mlir -pass-pipeline \ + "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),\ + func.func(tosa-to-tensor),func.func(tosa-to-arith))" \ + | buddy-opt --arith-expand \ + --eliminate-empty-tensors \ + --empty-tensor-to-alloc-tensor \ + --one-shot-bufferize \ + --batchmatmul-optimize \ + --convert-linalg-to-affine-loops \ + --affine-loop-fusion \ + --lower-affine \ + --func-bufferize \ + --arith-bufferize \ + --tensor-bufferize \ + --buffer-deallocation \ + --finalizing-bufferize \ + --convert-vector-to-scf \ + --expand-strided-metadata \ + --convert-vector-to-llvm \ + --memref-expand \ + --arith-expand \ + --convert-arith-to-llvm \ + --finalize-memref-to-llvm \ + --convert-scf-to-cf \ + --llvm-request-c-wrappers \ + --convert-openmp-to-llvm \ + --convert-arith-to-llvm \ + --convert-math-to-llvm \ + --convert-math-to-libm \ + --convert-func-to-llvm \ + --reconcile-unrealized-casts \ + > forward-lowered.mlir + + echo "Lowering subgraphs[0]" + buddy-opt subgraphs0.mlir -pass-pipeline \ + "builtin.module(func.func(tosa-to-linalg-named, tosa-to-arith, tosa-to-linalg, tosa-to-tensor))" \ + | buddy-opt \ + --arith-expand \ + --eliminate-empty-tensors \ + --empty-tensor-to-alloc-tensor \ + --one-shot-bufferize \ + --batchmatmul-optimize \ + --convert-linalg-to-affine-loops \ + --affine-loop-fusion \ + --lower-affine \ + --func-bufferize-dynamic-offset \ + --tensor-bufferize \ + --arith-bufferize \ + --buffer-deallocation \ + --finalizing-bufferize \ + --convert-vector-to-scf \ + --expand-strided-metadata \ + --cse \ + --lower-vector-exp \ + --lower-rvv=rv32 \ + --convert-vector-to-llvm \ + --memref-expand \ + --arith-expand \ + --convert-arith-to-llvm \ + --finalize-memref-to-llvm \ + --convert-scf-to-cf \ + --llvm-request-c-wrappers \ + --convert-openmp-to-llvm \ + --convert-arith-to-llvm \ + --convert-math-to-llvm \ + --convert-math-to-libm \ + --convert-func-to-llvm \ + --reconcile-unrealized-casts \ + > subgraphs0-lowered.mlir + + echo "Compiling memrefCopy library" + $CXX -nostdlib -c ${../lib/MemrefCopy.cc} -o memrefCopy.o + llcArtifacts+=( + memrefCopy.o + ) + + optArtifacts+=( + "forward-lowered.mlir" + "subgraphs0-lowered.mlir" + ) + ''; +} diff --git a/tests/pytorch/tinyllama/tinyllama.cc b/tests/pytorch/tinyllama/tinyllama.cc new file mode 100644 index 000000000..705d16985 --- /dev/null +++ b/tests/pytorch/tinyllama/tinyllama.cc @@ -0,0 +1,60 @@ +//===- GoogleBenchmarkMain.cpp --------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the benchmark for Tiny LLaMA model. +// +//===----------------------------------------------------------------------===// + +#include "memref.hpp" + +constexpr size_t ParamsSize = 110581; +// constexpr size_t ParamsSize = 11058; +constexpr size_t MaxVocabSize = 32000; +constexpr size_t MaxTokenLength = 40; +constexpr size_t HiddenSize = 2048; + +// resultContainer[0] +__attribute((section(".vdata"))) float result0[1 + MaxTokenLength + HiddenSize]; +static constexpr int32_t sizesResult0[3] = {1, MaxTokenLength, HiddenSize}; + +// resultContainer[1] +__attribute(( + section(".vdata"))) float result1[1 + MaxTokenLength + MaxVocabSize]; +static constexpr int32_t sizesResult1[3] = {1, MaxTokenLength, MaxVocabSize}; + +// inputContainer +__attribute((section(".vdata"))) int32_t input[1 + MaxTokenLength]; +static constexpr int32_t sizesInput[2] = {1, MaxTokenLength}; + +// paramsContainer +__attribute((section(".vdata"))) float param[ParamsSize]; +static constexpr int32_t sizesParam[1] = {ParamsSize}; + +extern "C" { +void _mlir_ciface_forward(MemRef *a, MemRef *b, + MemRef *c); +} + +MemRef resultContainer[2] = { + MemRef(result0, 2.0, sizesResult0), + MemRef(result1, 3.0, sizesResult1)}; +MemRef inputContainer(input, 4, sizesInput); +MemRef paramsContainerf32(param, 5.0, sizesParam); + +extern "C" int test() { + _mlir_ciface_forward(resultContainer, ¶msContainerf32, &inputContainer); + return 0; +} diff --git a/tests/pytorch/tinyllama/tinyllama.py b/tests/pytorch/tinyllama/tinyllama.py new file mode 100644 index 000000000..bd60c6567 --- /dev/null +++ b/tests/pytorch/tinyllama/tinyllama.py @@ -0,0 +1,60 @@ +# ===- buddy_tinyllama_import.py ----------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# This is the TinyLlama model AOT importer. +# +# ===--------------------------------------------------------------------------- + +import os +import sys +import torch +from torch._inductor.decomposition import decompositions as inductor_decomp +from transformers import AutoModelForCausalLM, AutoTokenizer + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa +from buddy.compiler.graph import GraphDriver +from buddy.compiler.graph.transform import simply_fuse + +checkpoint = os.environ.get("LLAMA_MODEL_PATH") +if checkpoint is None: + sys.exit("Error: No model path was provided. Please set $LLAMA_MODEL_PATH") +tokenizer = AutoTokenizer.from_pretrained(checkpoint) +model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto") +model.config.use_cache = False + +# Initialize Dynamo Compiler with specific configurations as an importer. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +# Import the model into MLIR module and parameters. +with torch.no_grad(): + data = torch.tensor([[1 for i in range(40)]], dtype=torch.int64) + graphs = dynamo_compiler.importer(model, data) + +assert len(graphs) == 1 +graph = graphs[0] +params = dynamo_compiler.imported_params[graph] +pattern_list = [simply_fuse] +graphs[0].fuse_ops(pattern_list) +driver = GraphDriver(graphs[0]) +driver.subgraphs[0].lower_to_top_level_ir() +with open("subgraphs0.mlir", "w") as module_file: + print(driver.subgraphs[0]._imported_module, file=module_file) +with open("forward.mlir", "w") as module_file: + print(driver.construct_main_graph(True), file=module_file)