Skip to content

Commit

Permalink
[XLA:CPU] Add shape method python binding to Literal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705130592
  • Loading branch information
WillFroom authored and Google-ML-Automation committed Dec 13, 2024
1 parent e6acbe6 commit 720f85d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
9 changes: 4 additions & 5 deletions xla/backends/cpu/testlib/elemental_kernel_emitter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from xla.backends.cpu.testlib import kernel_runner
from xla.codegen.testlib import kernel_runner as kernel_runner_base
from xla.python import xla_extension

HloOpcode = kernel_runner_base.HloOpcode
create_literal = kernel_runner_base.create_literal_from_np
Expand Down Expand Up @@ -132,11 +131,11 @@ def test_elemental_kernel_emitter(
np.ndarray(shape, dtype=expected_output.dtype)
)

# TODO(willfroom): Add support to get the shape directly from the Literal.
input_shape = xla_extension.Shape.array_shape(dtype, shape)
output_shape = xla_extension.Shape.array_shape(expected_output.dtype, shape)
emitter = kernel_runner.ElementalKernelEmitter(
op.name, op, [input_shape] * num_inputs, output_shape
op.name,
op,
[input.shape() for input in input_literals],
output_literal.shape(),
)

runner = kernel_runner.KernelRunner.create(emitter.emit_kernel_spec())
Expand Down
3 changes: 2 additions & 1 deletion xla/python/xla_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,8 @@ void BuildXlaCompilerSubmodule(nb::module_& m) {
nb::cast(obj));
},
nb::arg("dtype").none() = nb::none(),
nb::arg("copy").none() = nb::none());
nb::arg("copy").none() = nb::none())
.def("shape", &Literal::shape);

nb::class_<XlaComputation>(m, "XlaComputation")
.def("__init__",
Expand Down
1 change: 1 addition & 0 deletions xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ class Literal:
def __array__(
self, dtype: Optional[np.dtype] = None, copy: Optional[bool] = None
) -> np.ndarray: ...
def shape(self) -> Shape: ...

class XlaComputation:
def __init__(self, serialized_hlo_module_proto: bytes) -> None: ...
Expand Down

0 comments on commit 720f85d

Please sign in to comment.