Skip to content

Commit

Permalink
Update py_import macros for the ability to unpack additional wheels…
Browse files Browse the repository at this point in the history
… in the same folder as the main wheel.

Usage example: provide NVIDIA wheel dependencies for ML wheels that have rpaths pointing to NVIDIA folders. When a user executes `pip install tensorflow[and_cuda]`, NVIDIA wheels are installed together with Tensorflow wheel. To reproduce this behavior in hermetic Python approach, we need to define `py_import` as follows (provided NVIDIA dependencies are defined in `requirements.in` and requirements lock files):

        py_import(
            name = "tf_py_import",
            wheel = ":wheel",
            deps = [
                "@pypi_absl_py//:pkg",
                "@pypi_astunparse//:pkg",
                "@pypi_flatbuffers//:pkg",
                "@pypi_gast//:pkg",
                "@pypi_ml_dtypes//:pkg",
                "@pypi_numpy//:pkg",
                "@pypi_opt_einsum//:pkg",
                "@pypi_packaging//:pkg",
                "@pypi_protobuf//:pkg",
                "@pypi_requests//:pkg",
                "@pypi_termcolor//:pkg",
                "@pypi_typing_extensions//:pkg",
                "@pypi_wrapt//:pkg",
            ],
            wheel_deps = [
                "@pypi_nvidia_cublas_cu12//:whl",
                "@pypi_nvidia_cuda_cupti_cu12//:whl",
                "@pypi_nvidia_cuda_nvcc_cu12//:whl",
                "@pypi_nvidia_cuda_nvrtc_cu12//:whl",
                "@pypi_nvidia_cuda_runtime_cu12//:whl",
                "@pypi_nvidia_cudnn_cu12//:whl",
                "@pypi_nvidia_cufft_cu12//:whl",
                "@pypi_nvidia_curand_cu12//:whl",
                "@pypi_nvidia_cusolver_cu12//:whl",
                "@pypi_nvidia_cusparse_cu12//:whl",
                "@pypi_nvidia_nccl_cu12//:whl",
                "@pypi_nvidia_nvjitlink_cu12//:whl",
            ],
        )

PiperOrigin-RevId: 702375747
  • Loading branch information
Google-ML-Automation committed Dec 12, 2024
1 parent ebc8a58 commit ffdad50
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
licenses(["restricted"]) # NVIDIA proprietary license
load(
"@xla//xla/tsl/platform/default:cuda_build_defs.bzl",
"cuda_rpath_flags",
)

exports_files([
"version.txt",
Expand All @@ -21,12 +25,14 @@ cc_library(
name = "cublas",
visibility = ["//visibility:public"],
%{comment}deps = [":cublas_shared_library"],
%{comment}linkopts = cuda_rpath_flags("nvidia/cublas/lib"),
)

cc_library(
name = "cublasLt",
visibility = ["//visibility:public"],
%{comment}deps = [":cublasLt_shared_library"],
%{comment}linkopts = cuda_rpath_flags("nvidia/cublas/lib"),
)

cc_library(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
licenses(["restricted"]) # NVIDIA proprietary license
load(
"@xla//xla/tsl/platform/default:cuda_build_defs.bzl",
"cuda_rpath_flags",
)

exports_files([
"version.txt",
Expand Down Expand Up @@ -36,6 +40,7 @@ cc_library(
%{comment}}) + [
%{comment}":cudart_shared_library",
%{comment}],
%{comment}linkopts = cuda_rpath_flags("nvidia/cuda_runtime/lib"),
visibility = ["//visibility:public"],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
licenses(["restricted"]) # NVIDIA proprietary license
load(
"@xla//xla/tsl/platform/default:cuda_build_defs.bzl",
"cuda_rpath_flags",
)

exports_files([
"version.txt",
Expand Down Expand Up @@ -58,6 +62,7 @@ cc_library(
%{comment}"@cuda_nvrtc//:nvrtc",
%{comment}":cudnn_main",
%{comment}],
%{comment}linkopts = cuda_rpath_flags("nvidia/cudnn/lib"),
visibility = ["//visibility:public"],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
licenses(["restricted"]) # NVIDIA proprietary license
load(
"@xla//xla/tsl/platform/default:cuda_build_defs.bzl",
"cuda_rpath_flags",
)

exports_files([
"version.txt",
Expand Down Expand Up @@ -65,6 +69,7 @@ cc_library(
%{comment}"@cuda_nvrtc//:nvrtc",
%{comment}":cudnn_main",
%{comment}],
%{comment}linkopts = cuda_rpath_flags("nvidia/cudnn/lib"),
visibility = ["//visibility:public"],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
licenses(["restricted"]) # NVIDIA proprietary license
load(
"@xla//xla/tsl/platform/default:cuda_build_defs.bzl",
"cuda_rpath_flags",
)

exports_files([
"version.txt",
Expand All @@ -13,6 +17,7 @@ cc_import(
cc_library(
name = "cufft",
%{comment}deps = [":cufft_shared_library"],
%{comment}linkopts = cuda_rpath_flags("nvidia/cufft/lib"),
visibility = ["//visibility:public"],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
licenses(["restricted"]) # NVIDIA proprietary license
load("@local_config_cuda//cuda:build_defs.bzl", "if_version_equal_or_greater_than")
load(
"@xla//xla/tsl/platform/default:cuda_build_defs.bzl",
"cuda_rpath_flags",
)

exports_files([
"version.txt",
])
Expand All @@ -13,6 +18,7 @@ cc_import(
cc_library(
name = "cupti",
%{comment}deps = [":cupti_shared_library"],
%{comment}linkopts = cuda_rpath_flags("nvidia/cuda_cupti/lib"),
visibility = ["//visibility:public"],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
licenses(["restricted"]) # NVIDIA proprietary license
load(
"@xla//xla/tsl/platform/default:cuda_build_defs.bzl",
"cuda_rpath_flags",
)

exports_files([
"version.txt",
Expand All @@ -13,6 +17,7 @@ cc_import(
cc_library(
name = "curand",
%{comment}deps = [":curand_shared_library"],
%{comment}linkopts = cuda_rpath_flags("nvidia/curand/lib"),
visibility = ["//visibility:public"],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
licenses(["restricted"]) # NVIDIA proprietary license
load(
"@xla//xla/tsl/platform/default:cuda_build_defs.bzl",
"cuda_rpath_flags",
)

exports_files([
"version.txt",
Expand All @@ -19,6 +23,7 @@ cc_import(
cc_library(
name = "cusolver",
%{comment}deps = [":cusolver_shared_library"],
%{comment}linkopts = cuda_rpath_flags("nvidia/cusolver/lib"),
visibility = ["//visibility:public"],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
licenses(["restricted"]) # NVIDIA proprietary license
load(
"@xla//xla/tsl/platform/default:cuda_build_defs.bzl",
"cuda_rpath_flags",
)

exports_files([
"version.txt",
Expand All @@ -14,6 +18,7 @@ cc_import(
cc_library(
name = "cusparse",
%{comment}deps = [":cusparse_shared_library"],
%{comment}linkopts = cuda_rpath_flags("nvidia/cusparse/lib"),
visibility = ["//visibility:public"],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
licenses(["restricted"]) # NVIDIA proprietary license
load(
"@xla//xla/tsl/platform/default:cuda_build_defs.bzl",
"cuda_rpath_flags",
)

exports_files([
"version.txt",
Expand All @@ -13,6 +17,7 @@ cc_import(
cc_library(
name = "nvjitlink",
%{comment}deps = [":nvjitlink_shared_library"],
%{comment}linkopts = cuda_rpath_flags("nvidia/nvjitlink/lib"),
visibility = ["//visibility:public"],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
licenses(["restricted"]) # NVIDIA proprietary license
load(
"@xla//xla/tsl/platform/default:cuda_build_defs.bzl",
"cuda_rpath_flags",
)

%{multiline_comment}
cc_import(
name = "nvrtc_main",
Expand All @@ -16,5 +21,6 @@ cc_library(
%{comment}":nvrtc_main",
%{comment}":nvrtc_builtins",
%{comment}],
%{comment}linkopts = cuda_rpath_flags("nvidia/cuda_nvrtc/lib"),
visibility = ["//visibility:public"],
)
5 changes: 5 additions & 0 deletions third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
licenses(["restricted"]) # NVIDIA proprietary license
load(
"@xla//xla/tsl/platform/default:cuda_build_defs.bzl",
"cuda_rpath_flags",
)

exports_files([
"version.txt",
Expand All @@ -14,6 +18,7 @@ cc_import(
cc_library(
name = "nccl",
%{comment}deps = [":nccl_shared_library"],
%{comment}linkopts = cuda_rpath_flags("nvidia/nccl/lib"),
visibility = ["//visibility:public"],
)

Expand Down
32 changes: 16 additions & 16 deletions third_party/tsl/third_party/py/py_import.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,27 @@

def _unpacked_wheel_impl(ctx):
output_dir = ctx.actions.declare_directory(ctx.label.name)
libs = []
for dep in ctx.attr.cc_deps:
linker_inputs = dep[CcInfo].linking_context.linker_inputs.to_list()
for linker_input in linker_inputs:
if linker_input.libraries and linker_input.libraries[0].dynamic_library:
lib = linker_input.libraries[0].dynamic_library
libs.append(lib)
wheel = None
for w in ctx.files.wheel_rule_outputs:
if w.basename.endswith(".whl"):
wheel = w
break
script = """
{zipper} x {wheel} -d {output}
for lib in {libs}; do
cp $lib {output}/tensorflow
for wheel_dep in {wheel_deps}; do
{zipper} x $wheel_dep -d {output}
done
""".format(
zipper = ctx.executable.zipper.path,
wheel = wheel.path,
output = output_dir.path,
libs = " ".join(["'%s'" % lib.path for lib in libs]),
wheel_deps = " ".join([
"'%s'" % wheel_dep.path
for wheel_dep in ctx.files.wheel_deps
]),
)
ctx.actions.run_shell(
inputs = ctx.files.wheel_rule_outputs + libs,
inputs = ctx.files.wheel_rule_outputs + ctx.files.wheel_deps,
command = script,
outputs = [output_dir],
tools = [ctx.executable.zipper],
Expand All @@ -45,16 +41,20 @@ _unpacked_wheel = rule(
cfg = "exec",
executable = True,
),
"cc_deps": attr.label_list(providers = [CcInfo]),
"wheel_deps": attr.label_list(allow_files = True),
},
)

def py_import(name, wheel, deps = [], cc_deps = []):
def py_import(
name,
wheel,
deps = [],
wheel_deps = []):
unpacked_wheel_name = name + "_unpacked_wheel"
_unpacked_wheel(
name = unpacked_wheel_name,
wheel_rule_outputs = wheel,
cc_deps = cc_deps,
wheel_deps = wheel_deps,
)
native.py_library(
name = name,
Expand All @@ -68,6 +68,6 @@ def py_import(name, wheel, deps = [], cc_deps = []):
Args:
wheel: wheel file to unpack.
deps: dependencies of the py_library.
cc_deps: dependencies that will be copied in the folder
with the unpacked wheel content.
wheel_deps: additional wheels to unpack. These wheels will be unpacked in the
same folder as the wheel.
""" # buildifier: disable=no-effect

0 comments on commit ffdad50

Please sign in to comment.