-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Lower select and embedding bw ops (#836)
Lower select and embedding bw ops * SelectOp - Lower to select op directly to ttir::SelectOp. - Don't decompose select in post autograd. - Create select op with named attrs so lowering to mlir works. * EmbeddingOp - embedding op creates specific embedding_bw op in autograd pass which lowers directly to ttir::EmbeddingBackwardOp. * Llama backward - Added simple UT to test autograd pass and lowering to TTMLIR - Currently doesn't work because of dtype constraints (opened issue in tt-mlir) Misc: few changes related to `named_attrs` property of ForgeOp.
- Loading branch information
1 parent
a08b563
commit b82c130
Showing
9 changed files
with
175 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
import torch | ||
from ..common import to_torch_operands | ||
from forge._C import DataFormat | ||
from forge._C.graph import RuntimeTensorTransform, RuntimeTensorTransformType | ||
from ....forgeglobal import TILE_DIM | ||
|
||
|
||
def eval(type, attr, ops): | ||
assert type == "embedding_bw" | ||
assert len(ops) == 2 | ||
t_ops = to_torch_operands(*ops) | ||
input = t_ops[0] | ||
weight = t_ops[1] | ||
grad = t_ops[1] | ||
|
||
result = torch.zeros(weight.shape) | ||
for i, idx in enumerate(input): | ||
result[idx] = grad[i] | ||
return result | ||
|
||
|
||
def shape(type, attr, tensor_shapes): | ||
assert type == "embedding_bw" | ||
return tensor_shapes[1], [] | ||
|
||
|
||
def lower(type, attr, lc, ops, outputs): | ||
assert False, "embedding_bw should not be lowered" | ||
|
||
|
||
def decompose(type, attr, dc, inputs): | ||
assert False, "embedding_bw should not be decomposed" | ||
|
||
|
||
def backward(type, attr, ac, operand, inputs, output, grad): | ||
assert False, "embedding_bw should not be backwarded" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
import pytest | ||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer | ||
|
||
import forge | ||
from test.mlir.llama.utils.utils import load_model | ||
|
||
|
||
# TODO(tt-mlir issue #1503): This test is failing because the embedding op doesn't work with FP32. | ||
# It should be fixed in the tt-mlir compiler soon. | ||
@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b"]) | ||
@pytest.mark.xfail() | ||
def test_llama_backward(model_path): | ||
# Load Model and Tokenizer | ||
framework_model, tokenizer = load_model(model_path) | ||
|
||
prompt = "Q: What is the largest animal?\nA:" | ||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | ||
|
||
loss_fn = torch.nn.CrossEntropyLoss() | ||
framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=1e-3) | ||
|
||
# Compile the model with loss and optimizer, this will invoke an autograd pass which produces bwd graph. | ||
compiled_model = forge.compile(framework_model, input_ids, loss=loss_fn, optimizer=framework_optimizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters