Skip to content

Commit

Permalink
MNIST lora layer test (#913)
Browse files Browse the repository at this point in the history
Add test to make sure LoRA layer can be implemented and used in a full
training pipeline.

Problem encountered
Hitting `KeyError: 'gradient_lora1.b.consteval_graph.output'` when
implementing LoRA layer with nn.Parameters.
```
class LoraLayer(nn.Module):
    def __init__(self, input_size, output_size, rank=8, alpha=4, dtype=torch.float32):
        super(LoraLayer, self).__init__()
        self.a = nn.Parameter(torch.empty(input_size, rank, dtype=dtype), requires_grad=True)
        self.b = nn.Parameter(torch.zeros(rank, output_size, dtype=dtype), requires_grad=True)
        self.alpha = alpha / rank

        nn.init.normal_(self.a, mean=0, std=1)

    def forward(self, x):
        return self.alpha * (x @ self.a @ self.b)
```

Raised issue: #929
  • Loading branch information
pmarkovicTT authored Dec 18, 2024
1 parent 9fd58b4 commit c00cc94
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
59 changes: 59 additions & 0 deletions forge/test/mlir/mnist/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,62 @@ def test_loss_device():
break

print(f"Test (total) loss: {test_loss}")


@pytest.mark.push
def test_lora():
torch.manual_seed(0)

# Config
num_epochs = 3
batch_size = 64
learning_rate = 0.001

# Load dataset
test_loader, train_loader = load_dataset(batch_size)

framework_model = MNISTLora(bias=False)
framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=learning_rate)

tt_model = forge.compile(framework_model, sample_inputs=[torch.rand(batch_size, 784)], training=True)

loss_fn = CrossEntropyLoss(name="cross_entropy_loss")

loss_inputs = [torch.rand(batch_size, 10).requires_grad_(True), torch.rand(batch_size, 10)]
loss_inputs = to_forge_tensors(loss_inputs)
tt_loss = forge.compile(loss_fn, sample_inputs=loss_inputs, attach_to=tt_model, training=True)

logger.info("Starting training loop... (logger will be disabled)")
logger.disable("")
for epoch_idx in range(num_epochs):
total_loss = 0
for _, (data, target) in enumerate(train_loader):
framework_optimizer.zero_grad()

# Create target tensor and leave on CPU
target = nn.functional.one_hot(target, num_classes=10).float()

# Forward pass (prediction) on device
pred = tt_model(data)[0]
golden_pred = framework_model(data)
assert compare_with_golden(golden_pred, pred, pcc=0.95)

loss = tt_loss(pred, target)
total_loss += loss[0].item()

# Run backward pass on device
tt_loss.backward()

# Adjust weights (on CPU)
framework_optimizer.step()

print(f"epoch: {epoch_idx} loss: {total_loss}")

test_loss = 0
for _, (data, target) in enumerate(test_loader):
pred = tt_model(data)[0]
target = nn.functional.one_hot(target, num_classes=10).float()

test_loss += tt_loss(pred, target)[0]

print(f"Test (total) loss: {test_loss}")
47 changes: 47 additions & 0 deletions forge/test/mlir/mnist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,53 @@ def forward(self, x):
return logits


class LoraLayer(nn.Module):
def __init__(self, input_size, output_size, rank=8, alpha=4, dtype=torch.float32):
super(LoraLayer, self).__init__()
self.a = nn.Linear(in_features=input_size, out_features=rank, bias=False, dtype=dtype)
self.b = nn.Linear(in_features=rank, out_features=output_size, bias=False, dtype=dtype)
self.alpha = alpha / rank

nn.init.kaiming_uniform_(self.a.weight, a=torch.sqrt(torch.tensor([5])).item())
nn.init.zeros_(self.b.weight)

def forward(self, x):
logits = self.a(x)
logits = self.alpha * self.b(logits)

return logits


class MNISTLora(nn.Module):
def __init__(
self, input_size=784, output_size=10, hidden_size=512, bias=True, rank=8, alpha=16, dtype=torch.float32
):
super(MNISTLora, self).__init__()
self.linear1 = nn.Linear(input_size, hidden_size, bias=bias, dtype=dtype)
self.lora1 = LoraLayer(input_size, hidden_size, rank=rank, alpha=alpha, dtype=dtype)
self.relu1 = nn.ReLU()

self.linear2 = nn.Linear(hidden_size, hidden_size, bias=bias, dtype=dtype)
self.lora2 = LoraLayer(hidden_size, hidden_size, rank=rank, alpha=alpha, dtype=dtype)
self.relu2 = nn.ReLU()

self.linear3 = nn.Linear(hidden_size, output_size, bias=bias, dtype=dtype)

self.freeze_linear_layers()

def forward(self, x):
first_layer_logits = self.relu1(self.linear1(x) + self.lora1(x))
second_layer_logits = self.relu2(self.linear2(first_layer_logits) + self.lora2(first_layer_logits))
final_logits = self.linear3(second_layer_logits)

return final_logits

def freeze_linear_layers(self):
for layer in [self.linear1, self.linear2, self.linear3]:
for param in layer.parameters():
param.requires_grad = False


class EarlyStopping:
def __init__(self, patience=3, mode="max"):
assert mode in ["min", "max"]
Expand Down

0 comments on commit c00cc94

Please sign in to comment.