Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch frontend svds functions #28829

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ivy/functional/backends/tensorflow/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,6 @@ def svd(
) -> Union[Union[tf.Tensor, tf.Variable], Tuple[Union[tf.Tensor, tf.Variable], ...]]:
if compute_uv:
results = namedtuple("svd", "U S Vh")

batch_shape = tf.shape(x)[:-2]
num_batch_dims = len(batch_shape)
transpose_dims = list(range(num_batch_dims)) + [
Expand Down
7 changes: 2 additions & 5 deletions ivy/functional/backends/torch/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,15 +415,12 @@ def svd(
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
if compute_uv:
results = namedtuple("svd", "U S Vh")

U, D, VT = torch.linalg.svd(x, full_matrices=full_matrices)
return results(U, D, VT)
else:
results = namedtuple("svd", "S")
svd = torch.linalg.svd(x, full_matrices=full_matrices)
# torch.linalg.svd returns a tuple with U, S, and Vh
D = svd[1]
return results(D)
s = torch.linalg.svdvals(x)
return results(s)


@with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, backend_version)
Expand Down
31 changes: 27 additions & 4 deletions ivy/functional/frontends/torch/blas_and_lapack_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,35 @@ def slogdet(A, *, out=None):


@to_ivy_arrays_and_back
@with_supported_dtypes(
Daniel4078 marked this conversation as resolved.
Show resolved Hide resolved
{
"2.2 and below": (
"float64",
"float32",
"half",
"complex32",
"complex64",
"complex128",
)
},
"torch",
)
def svd(input, some=True, compute_uv=True, *, out=None):
# TODO: add compute_uv
if some:
ret = ivy.svd(input, full_matrices=False)
retu = ivy.svd(input, full_matrices=not some, compute_uv=compute_uv)
results = namedtuple("svd", "U S V")
Daniel4078 marked this conversation as resolved.
Show resolved Hide resolved
if compute_uv:
ret = results(retu[0], retu[1], ivy.adjoint(retu[2]))
else:
ret = ivy.svd(input, full_matrices=True)
shape = list(input.shape)
shape1 = shape
shape2 = shape
shape1[-2] = shape[-1]
shape2[-1] = shape[-2]
ret = results(
ivy.zeros(shape1, device=input.device, dtype=input.dtype),
ivy.astype(retu[0], input.dtype),
ivy.zeros(shape2, device=input.device, dtype=input.dtype),
)
if ivy.exists(out):
return ivy.inplace_update(out, ret)
return ret
Expand Down
25 changes: 22 additions & 3 deletions ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,30 @@ def solve_ex(A, B, *, left=True, check_errors=False, out=None):

@to_ivy_arrays_and_back
@with_supported_dtypes(
{"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch"
{
"2.2 and below": (
"float64",
"float32",
"half",
"complex32",
"complex64",
"complex128",
)
},
"torch",
)
def svd(A, /, *, full_matrices=True, driver=None, out=None):
# TODO: add handling for driver and out
return ivy.svd(A, compute_uv=True, full_matrices=full_matrices)
# TODO: add handling for driver
USVh = ivy.svd(A, compute_uv=True, full_matrices=full_matrices)
if ivy.is_complex_dtype(A.dtype):
d = ivy.complex64
else:
d = ivy.float32
nt = namedtuple("svd", "U S Vh")
ret = nt(ivy.astype(USVh.U, d), ivy.astype(USVh.S, d), ivy.astype(USVh.Vh, d))
if ivy.exists(out):
return ivy.inplace_update(out, ret)
return ret


@to_ivy_arrays_and_back
Expand Down
14 changes: 13 additions & 1 deletion ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2109,7 +2109,19 @@ def adjoint(self):
def conj(self):
return torch_frontend.conj(self)

@with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch")
@with_supported_dtypes(
{
"2.2 and below": (
"float64",
"float32",
"half",
"complex32",
"complex64",
"complex128",
)
},
"torch",
)
def svd(self, some=True, compute_uv=True, *, out=None):
return torch_frontend.svd(self, some=some, compute_uv=compute_uv, out=out)

Expand Down
5 changes: 1 addition & 4 deletions ivy/functional/ivy/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2130,15 +2130,12 @@ def svd(
If ``True`` then left and right singular vectors will be computed and returned
in ``U`` and ``Vh``, respectively. Otherwise, only the singular values will be
computed, which can be significantly faster.
.. note::
with backend set as torch, svd with still compute left and right singular
vectors irrespective of the value of compute_uv, however Ivy will still
only return the singular values.

Returns
-------
.. note::
once complex numbers are supported, each square matrix must be Hermitian.
In addition, the return will be a namedtuple ``(S)`` when compute_uv is ``False``.

ret
a namedtuple ``(U, S, Vh)`` whose
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -848,37 +848,75 @@ def test_torch_qr(
@handle_frontend_test(
fn_tree="torch.svd",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float", index=1),
min_num_dims=3,
max_num_dims=5,
min_dim_size=2,
max_dim_size=5,
available_dtypes=helpers.get_dtypes("valid"),
min_value=0,
max_value=10,
shape=helpers.ints(min_value=2, max_value=5).map(lambda x: (x, x)),
),
some=st.booleans(),
compute=st.booleans(),
compute_uv=st.booleans(),
)
def test_torch_svd(
dtype_and_x,
some,
compute,
compute_uv,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
):
dtype, x = dtype_and_x
helpers.test_frontend_function(
input_dtypes=dtype,
input_dtype, x = dtype_and_x
x = np.asarray(x[0], dtype=input_dtype[0])
# make symmetric positive definite
x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3
ret, frontend_ret = helpers.test_frontend_function(
input_dtypes=input_dtype,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
input=x[0],
test_values=False,
input=x,
some=some,
compute_uv=compute,
)
compute_uv=compute_uv,
)
if backend_fw == "torch":
frontend_ret = [x.detach() for x in frontend_ret]
ret = [x.detach() for x in ret]
ret = [np.asarray(x) for x in ret]
frontend_ret = [
np.asarray(x.resolve_conj()) for x in frontend_ret
]
u, s, v = ret
frontend_u, frontend_s, frontend_v = frontend_ret
if not compute_uv:
helpers.assert_all_close(
ret_np=frontend_s,
ret_from_gt_np=s,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)
elif not some:
helpers.assert_all_close(
ret_np=frontend_u @ np.diag(frontend_s) @ frontend_v.T,
ret_from_gt_np=u @ np.diag(s) @ v.T,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)
else:
helpers.assert_all_close(
ret_np=frontend_u[..., : frontend_s.shape[0]]
@ np.diag(frontend_s)
@ frontend_v.T,
ret_from_gt_np=u[..., : s.shape[0]] @ np.diag(s) @ v.T,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)


@handle_frontend_test(
Expand Down
46 changes: 31 additions & 15 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,12 @@ def test_torch_solve_ex(
# svd
@handle_frontend_test(
fn_tree="torch.linalg.svd",
dtype_and_x=_get_dtype_and_matrix(square=True),
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
min_value=0,
max_value=10,
shape=helpers.ints(min_value=2, max_value=5).map(lambda x: (x, x)),
),
full_matrices=st.booleans(),
)
def test_torch_svd(
Expand All @@ -1257,7 +1262,7 @@ def test_torch_svd(
):
dtype, x = dtype_and_x
x = np.asarray(x[0], dtype=dtype[0])
# make symmetric positive definite beforehand
# make symmetric positive definite
x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3
ret, frontend_ret = helpers.test_frontend_function(
input_dtypes=dtype,
Expand All @@ -1267,25 +1272,36 @@ def test_torch_svd(
fn_tree=fn_tree,
on_device=on_device,
test_values=False,
atol=1e-03,
rtol=1e-05,
A=x,
full_matrices=full_matrices,
)
ret = [ivy.to_numpy(x) for x in ret]
if backend_fw == "torch":
frontend_ret = [x.detach() for x in frontend_ret]
ret = [x.detach() for x in ret]
ret = [np.asarray(x) for x in ret]
frontend_ret = [np.asarray(x) for x in frontend_ret]

u, s, vh = ret
frontend_u, frontend_s, frontend_vh = frontend_ret

assert_all_close(
ret_np=u @ np.diag(s) @ vh,
ret_from_gt_np=frontend_u @ np.diag(frontend_s) @ frontend_vh,
rtol=1e-2,
atol=1e-2,
ground_truth_backend=frontend,
backend=backend_fw,
)
if full_matrices:
helpers.assert_all_close(
ret_np=(
frontend_u[..., : frontend_s.shape[0]]
@ np.diag(frontend_s)
@ frontend_vh
)
Daniel4078 marked this conversation as resolved.
Show resolved Hide resolved
ret_from_gt_np=u[..., : s.shape[0]] @ np.diag(s) @ vh,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)
else:
helpers.assert_all_close(
ret_np=(frontend_u @ np.diag(frontend_s) @ frontend_vh).astype(d),
Daniel4078 marked this conversation as resolved.
Show resolved Hide resolved
ret_from_gt_np=u @ np.diag(s) @ vh,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)


# svdvals
Expand Down
44 changes: 25 additions & 19 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13095,13 +13095,13 @@ def test_torch_sum(
on_device=on_device,
)


# svd
@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="torch.tensor",
method_name="svd",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
available_dtypes=helpers.get_dtypes("valid"),
min_value=0,
max_value=10,
shape=helpers.ints(min_value=2, max_value=5).map(lambda x: (x, x)),
Expand All @@ -13122,7 +13122,8 @@ def test_torch_svd(
):
input_dtype, x = dtype_and_x
x = np.asarray(x[0], dtype=input_dtype[0])

# make symmetric positive-definite
x = np.matmul(x.T, x) + np.identity(x.shape[0]) * 1e-3
ret, frontend_ret = helpers.test_frontend_method(
init_input_dtypes=input_dtype,
init_all_as_kwargs_np={
Expand All @@ -13141,28 +13142,33 @@ def test_torch_svd(
on_device=on_device,
test_values=False,
)
with helpers.update_backend(backend_fw) as ivy_backend:
ret = [ivy_backend.to_numpy(x) for x in ret]
frontend_ret = [np.asarray(x) for x in frontend_ret]

u, s, vh = ret
frontend_u, frontend_s, frontend_vh = frontend_ret

if compute_uv:
ret = [np.asarray(x) for x in ret]
frontend_ret = [np.asarray(x.resolve_conj()) for x in frontend_ret]
u, s, v = ret
frontend_u, frontend_s, frontend_v = frontend_ret
if not compute_uv:
helpers.assert_all_close(
ret_np=frontend_u @ np.diag(frontend_s) @ frontend_vh.T,
ret_from_gt_np=u @ np.diag(s) @ vh,
rtol=1e-2,
atol=1e-2,
ret_np=frontend_s
Daniel4078 marked this conversation as resolved.
Show resolved Hide resolved
ret_from_gt_np=s,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)
elif not some:
helpers.assert_all_close(
ret_np=frontend_u @ np.diag(frontend_s) @ frontend_v.T,
ret_from_gt_np=u @ np.diag(s) @ v.T,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)
else:
helpers.assert_all_close(
ret_np=frontend_s,
ret_from_gt_np=s,
rtol=1e-2,
atol=1e-2,
ret_np=frontend_u[..., : frontend_s.shape[0]]
@ np.diag(frontend_s)
@ frontend_v.T,
ret_from_gt_np=u[..., : s.shape[0]] @ np.diag(s) @ v.T,
atol=1e-04,
backend=backend_fw,
ground_truth_backend=frontend,
)
Expand Down
Loading