Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing a SparseTensor object to an optimzer #379

Open
alireza78h opened this issue Jul 31, 2024 · 1 comment
Open

Passing a SparseTensor object to an optimzer #379

alireza78h opened this issue Jul 31, 2024 · 1 comment

Comments

@alireza78h
Copy link

alireza78h commented Jul 31, 2024

So I'm just playing around with pytorch-sparse with some toy examples to get a sense of it as there isn't a complete documentation for that.
Suppose that I have a sparse matrix A, and given two vector x and b. I know Ax=b.
Given b and x and the exact location of nnz in A, I want to find A through SGD, here is the pseudocode I wrote for that and seems it works.

import numpy as np
from torch_sparse import SparseTensor
import torch
n = 10
lr = 0.01
row = np.random.choice(n, n, replace=True)
col = np.random.choice(n, n, replace=True)
data = np.random.choice(np.arange(1, n + 1), n, replace=True)
A_ground_truth = SparseTensor(row = torch.Tensor(row).to(torch.long), col = torch.Tensor(col).to(torch.long), value = torch.Tensor(data).to(torch.float), sparse_sizes=(n, n))
A_approx = SparseTensor(row = torch.Tensor(row).to(torch.long), col = torch.Tensor(col).to(torch.long), sparse_sizes=(n, n)).requires_grad_(True) # This is the matrix A that we want to learn
x = torch.randn(n,1) # This is the vector x that we have, we know Ax = b
b = A_ground_truth.matmul(x) # This is the vector b that we have, we know Ax = b
for i in range(10000):
    output = A_approx.matmul(x)
    l = torch.norm(b - output)
    l.backward()
    row,col,value = A_approx.coo()
    print(f"Epoch: {i+1}, Loss: {l.item():.3f}")
    newvalue = value - lr*value.grad
    A_approx = SparseTensor(row = torch.Tensor(row).to(torch.long), col = torch.Tensor(col).to(torch.long), value=newvalue.detach(), sparse_sizes=(n, n)).requires_grad_(True)

However, I have some questions regard that:

  1. As you can see, in each epoch, I create A_approx again from the updated values, are there any solution to update A_approx values on the fly?
  2. What if I want to use an optimizer in torch.optim like torch.optim.ADAM? How can I pass the val in A_approx to the optimizer? Should I pass val or I can pass the whole SparseTensor object? Could you please provide me an snippet for that?
@KukumavMozolo
Copy link

Hi there, regarding 2 i think you can just call l.backward and then optimizer.step and A_approx should be updated accordingly

row, col, val = A_approx.coo()
optimizer = torch.optim.AdamW([val], lr=0.001)
for i in range(10000):
    output = A_approx.matmul(x)
    l = torch.norm(b - output)
    l.backward()
    optimizer.step()
    print(f"Epoch: {i+1}, Loss: {l.item():.3f}")

there might be other ways though

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants