Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can I run an lstm model using aten operations? #7454

Open
ChristophKarlHeck opened this issue Dec 28, 2024 · 3 comments
Open

How can I run an lstm model using aten operations? #7454

ChristophKarlHeck opened this issue Dec 28, 2024 · 3 comments
Labels
module: exir Issues related to Export IR module: kernels Issues related to kernel libraries, e.g. portable kernels and optimized kernels

Comments

@ChristophKarlHeck
Copy link

ChristophKarlHeck commented Dec 28, 2024

📚 The doc issue

Hi guys,
I want to run an LSTM Model on my cortex-m4 architecture nondelegated. This is the model:

import torch
from torch import nn
from torch.export import export, export_for_training, ExportedProgram
from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager
import executorch.exir as exir
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl

# Define the LSTM Model
class LSTMModel(pl.LightningModule):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        return lstm_out

    def training_step(self, batch, batch_idx):
        x, y = batch
        lstm_out = self(x)  # Forward pass
        # Align the loss computation with the target shape
        lstm_last_step = lstm_out[:, -1, :]  # Take the output of the last time step
        loss = nn.functional.mse_loss(lstm_last_step, y)  # Compute MSE loss
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.01)

# Dummy Data for Training
x_data = torch.randn(10, 5, 1)  # Input: [batch_size, seq_length, input_size]
y_data = torch.randn(10, 3)  # Target: [batch_size, hidden_size] for the last time step
dataset = TensorDataset(x_data, y_data)
train_loader = DataLoader(dataset, batch_size=3)

# Model Training
model = LSTMModel(input_size=1, hidden_size=3)
trainer = pl.Trainer(max_epochs=5, logger=False)
trainer.fit(model, train_loader)

# Forward-Only Export to TorchScript
model.eval()
example_input = torch.randn(1, 5, 1)  # Example input
pre_autograd_aten_dialect = export_for_training(
    LSTMModel(input_size=1, hidden_size=3),
    (example_input,)
).module()

aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, (example_input,))
print(aten_dialect)
edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)

executorch_program: exir.ExecutorchProgramManager = edge_program.to_executorch(
    ExecutorchBackendConfig(
        passes=[],  # User-defined passes
    )
)

with open("model.pte", "wb") as file:
    file.write(executorch_program.buffer)

print("LSTM-only model saved as model.pte")

The problem is that I cannot find the LSTM Aten operation in https://github.com/pytorch/executorch/blob/main/kernels/aten/functions.yaml
What am I doing wrong here?
I appreciate your help in advance :)
Cheers,
Chris

Suggest a potential alternative/fix

It would be great to have a document description that shows how to obtain the operations used in the aten representation.

@cccclai cccclai added module: kernels Issues related to kernel libraries, e.g. portable kernels and optimized kernels module: exir Issues related to Export IR labels Dec 30, 2024
@cccclai
Copy link
Contributor

cccclai commented Dec 30, 2024

I remember there are some websites to look up the aten operators. @larryliu0820 @manuelcandales and @SS-JIA might know the best

@kimishpatel
Copy link
Contributor

Are you able to export and run the program? LSTM aten op might be getting decomposed hence it is not listed in kernels/aten

@kimishpatel
Copy link
Contributor

Also there is a simple lstm model listed here, https://github.com/pytorch/executorch/tree/main/examples/models/lstm, which can be used for export + run by following https://github.com/pytorch/executorch/tree/main/examples/portable.

If that succeeds its a good indication that LSTM model will run although performance will likely be not great. But I would like to see if you can try it out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: exir Issues related to Export IR module: kernels Issues related to kernel libraries, e.g. portable kernels and optimized kernels
Projects
None yet
Development

No branches or pull requests

3 participants