diff --git a/nix/overlay.nix b/nix/overlay.nix index 9306cffd7f..e4e226e491 100644 --- a/nix/overlay.nix +++ b/nix/overlay.nix @@ -19,7 +19,14 @@ rec { dramsim3 = final.callPackage ./pkgs/dramsim3.nix { }; libspike = final.callPackage ./pkgs/libspike.nix { }; libspike_interfaces = final.callPackage ../difftest/spike_interfaces { }; - buddy-mlir = final.callPackage ./pkgs/buddy-mlir.nix { }; + + # DynamoCompiler doesn't support python 3.12+ yet + buddy-mlir = final.callPackage ./pkgs/buddy-mlir.nix { python3 = final.python311; }; + buddy-mlir-pyenv = final.buddy-mlir.pythonModule.withPackages (ps: [ + final.buddy-mlir + ps.torch + ]); + fetchMillDeps = final.callPackage ./pkgs/mill-builder.nix { }; circt-full = final.callPackage ./pkgs/circt-full.nix { }; rvv-codegen = final.callPackage ./pkgs/rvv-codegen.nix { }; diff --git a/nix/pkgs/buddy-mlir.nix b/nix/pkgs/buddy-mlir.nix index f2f899cea7..4ec3a4ff92 100644 --- a/nix/pkgs/buddy-mlir.nix +++ b/nix/pkgs/buddy-mlir.nix @@ -3,13 +3,14 @@ , llvmPackages_17 , fetchFromGitHub , fetchpatch +, python3 , callPackage }: let stdenv = llvmPackages_17.stdenv; bintools = llvmPackages_17.bintools; - buddy-llvm = callPackage ./buddy-llvm.nix { inherit stdenv; }; + buddy-llvm = callPackage ./buddy-llvm.nix { inherit stdenv python3; }; in stdenv.mkDerivation { pname = "buddy-mlir"; @@ -37,4 +38,20 @@ stdenv.mkDerivation { # No need to do check, and it also takes too much time to finish. doCheck = false; + + # Here we concatenate the LLVM and Buddy python module into one directory for easier import + postFixup = '' + mkdir -p $out/lib/python${python3.pythonVersion}/site-packages + cp -vr $out/python_packages/buddy $out/lib/python${python3.pythonVersion}/site-packages/ + cp -vr ${buddy-llvm}/python_packages/mlir_core/mlir $out/lib/python${python3.pythonVersion}/site-packages/ + ''; + + passthru = { + llvm = buddy-llvm; + + # Below three fields are black magic that allow site-packages automatically imported with nixpkgs hooks + pythonModule = python3; + pythonPath = [ ]; + requiredPythonModules = [ ]; + }; } diff --git a/tests/pytorch/demo/demo.py b/tests/pytorch/demo/demo.py new file mode 100644 index 0000000000..f189e5ebbd --- /dev/null +++ b/tests/pytorch/demo/demo.py @@ -0,0 +1,31 @@ +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + +# Define the target function or model. +def foo(x, y): + return x * y + x + +# Define the input data. +float32_in1 = torch.randn(10).to(torch.float32) +float32_in2 = torch.randn(10).to(torch.float32) +int32_in1 = torch.randint(0, 10, (10,)).to(torch.int32) +int32_in2 = torch.randint(0, 10, (10,)).to(torch.int32) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +# Pass the function and input data to the dynamo compiler's importer, the +# importer will first build a graph. Then, lower the graph to top-level IR. +# (tosa, linalg, etc.). Finally, accepts the generated module and weight parameters. +graphs = dynamo_compiler.importer(foo, *(float32_in1, float32_in2)) +graph = graphs[0] +graph.lower_to_top_level_ir() + +print(graph._imported_module)