Skip to content

Commit

Permalink
frontend: add torch.Tensor.masked_select + fix torch.Tensor.cholesky …
Browse files Browse the repository at this point in the history
…supported dtypes (#28780)

Co-authored-by: Jin Wang <[email protected]>
  • Loading branch information
Daniel4078 and Jin Wang authored Jul 1, 2024
1 parent 12d421a commit 1e0bad6
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

@to_ivy_arrays_and_back
@with_supported_dtypes(
{"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch"
{"2.2 and below": ("float32", "float64", "complex32", "complex64", "complex128")}, "torch"
)
def cholesky(input, *, upper=False, out=None):
return ivy.cholesky(input, upper=upper, out=out)
Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,9 @@ def masked_fill_(self, mask, value):
self.ivy_array = self.masked_fill(mask, value).ivy_array
return self

def masked_select(self, mask):
return torch_frontend.masked_select(self, mask)

@with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch")
def index_add_(self, dim, index, source, *, alpha=1):
self.ivy_array = torch_frontend.index_add(
Expand Down Expand Up @@ -1891,7 +1894,9 @@ def addcdiv_(self, tensor1, tensor2, *, value=1):
).ivy_array
return self

@with_unsupported_dtypes({"2.2 and below": ("bfloat16", "float16")}, "torch")
@with_supported_dtypes(
{"2.2 and below": ("float32", "float64", "complex32", "complex64", "complex128")}, "torch"
)
def cholesky(self, upper=False):
return torch_frontend.cholesky(self, upper=upper)

Expand Down
40 changes: 37 additions & 3 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4730,10 +4730,9 @@ def test_torch_cholesky(
backend_fw,
):
input_dtype, x = dtype_and_x
x = x[0]
x = np.asarray(x[0], dtype=input_dtype[0])
x = np.matmul(np.conjugate(x.T), x) + np.identity(x.shape[0], dtype=input_dtype[0])
# make symmetric positive-definite
x = np.matmul(x.swapaxes(-1, -2), x) + np.identity(x.shape[-1]) * 1e-3

helpers.test_frontend_method(
init_input_dtypes=input_dtype,
backend_to_test=backend_fw,
Expand Down Expand Up @@ -9312,6 +9311,41 @@ def test_torch_masked_fill(
)


# masked_select
@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="torch.tensor",
method_name="masked_select",
x_mask_val=_masked_fill_helper(),
)
def test_torch_masked_select(
x_mask_val,
frontend_method_data,
init_flags,
method_flags,
frontend,
on_device,
backend_fw,
):
dtype, x, mask, _ = x_mask_val
helpers.test_frontend_method(
init_input_dtypes=[dtype],
backend_to_test=backend_fw,
init_all_as_kwargs_np={
"data": x,
},
method_input_dtypes=["bool", dtype],
method_all_as_kwargs_np={
"mask": mask,
},
frontend_method_data=frontend_method_data,
init_flags=init_flags,
method_flags=method_flags,
frontend=frontend,
on_device=on_device,
)


# matmul
@handle_frontend_method(
class_tree=CLASS_TREE,
Expand Down

0 comments on commit 1e0bad6

Please sign in to comment.