Skip to content

Commit

Permalink
Lower select and embedding bw ops (#836)
Browse files Browse the repository at this point in the history
Lower select and embedding bw ops

* SelectOp
- Lower to select op directly to ttir::SelectOp.
- Don't decompose select in post autograd.
- Create select op with named attrs so lowering to mlir works.

* EmbeddingOp
- embedding op creates specific embedding_bw op in autograd pass which
  lowers directly to ttir::EmbeddingBackwardOp.

* Llama backward
- Added simple UT to test autograd pass and lowering to TTMLIR
- Currently doesn't work because of dtype constraints (opened
  issue in tt-mlir)

Misc: few changes related to `named_attrs` property of ForgeOp.
  • Loading branch information
rpavlovicTT authored Dec 16, 2024
1 parent a08b563 commit b82c130
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 29 deletions.
3 changes: 3 additions & 0 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,8 @@ class MLIRGenerator
lowering_handler_map["conv2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::Conv2dOp>;
lowering_handler_map["cosine"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::CosOp>;
lowering_handler_map["embedding"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::EmbeddingOp>;
lowering_handler_map["embedding_bw"] =
&MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::EmbeddingBackwardOp>;
lowering_handler_map["equal"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::EqualOp>;
lowering_handler_map["exp"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ExpOp>;
lowering_handler_map["gelu"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::GeluOp>;
Expand All @@ -580,6 +582,7 @@ class MLIRGenerator
lowering_handler_map["relu"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReluOp>;
lowering_handler_map["remainder"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::RemainderOp>;
lowering_handler_map["reshape"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReshapeOp>;
lowering_handler_map["select"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SelectOp>;
lowering_handler_map["sigmoid"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SigmoidOp>;
lowering_handler_map["sine"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SinOp>;
lowering_handler_map["softmax"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SoftmaxOp>;
Expand Down
1 change: 1 addition & 0 deletions forge/forge/op/eval/forge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"sparse_matmul": "matmul",
"depthwise": "depthwise",
"embedding": "embedding",
"embedding_bw": "embedding_bw",
"ethernet_datacopy": EthernetDatacopy,
"transpose": TransposeTM,
"adv_index": "tm",
Expand Down
32 changes: 27 additions & 5 deletions forge/forge/op/eval/forge/eltwise_nary.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def eval(type, attr, ops):
res = 0
for i in range(len(t_ops)):
res += torch.nn.functional.pad(
t_ops[i], (shifts[2 * i], -shifts[2 * i], shifts[2 * i + 1], -shifts[2 * i + 1])
t_ops[i],
(shifts[2 * i], -shifts[2 * i], shifts[2 * i + 1], -shifts[2 * i + 1]),
)

# To forge shape
Expand Down Expand Up @@ -108,7 +109,13 @@ def get_eltwise_shape_and_broadcast():
assert (
ops[op_index][dim_index] == 1
), f"Eltwise nary ops must have same shape or operand must be 1 wide to broadcast: {ops}"
broadcast.append((op_index, dim_index - len(output_shape), output_shape[dim_index]))
broadcast.append(
(
op_index,
dim_index - len(output_shape),
output_shape[dim_index],
)
)

return tuple(output_shape), broadcast

Expand Down Expand Up @@ -212,11 +219,21 @@ def backward(op_type, attr, ac, operand, inputs, output, grad):
dim_offset = grad.shape[axis]

index_offset = 0
for (i, input_) in enumerate(inputs):
for i, input_ in enumerate(inputs):
if operand is not i:
index_offset += input_.shape[axis]
continue
return ac.op("select", (grad,), (axis, index_offset, input_.shape[axis], dim_offset))
return ac.op(
"select",
(grad,),
(axis, index_offset, input_.shape[axis], dim_offset),
named_attrs={
"dim": axis,
"begin": index_offset,
"length": input_.shape[axis],
"stride": dim_offset,
},
)

elif op_type == "interleave":
axis = attr[0]
Expand All @@ -235,7 +252,12 @@ def backward(op_type, attr, ac, operand, inputs, output, grad):
result = ac.op(
"select",
(result,),
(-1, operand * align_up_tile(grad.shape[-1]), align_up_tile(grad.shape[-1]), result.shape[-1]),
(
-1,
operand * align_up_tile(grad.shape[-1]),
align_up_tile(grad.shape[-1]),
result.shape[-1],
),
)
if grad.shape[-1] % TILE_DIM != 0:
result = ac.op("narrow", (result,), (-1, 0, grad.shape[-1], result.shape[-1]))
Expand Down
3 changes: 1 addition & 2 deletions forge/forge/op/eval/forge/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,4 @@ def decompose(type, attr, dc, inputs):

def backward(type, attr, ac, operand, inputs, output, grad):
assert type == "embedding"
assert len(ops) == 2
raise NotImplementedError("embedding backwards not implemented")
return ac.op("embedding_bw", [inputs[0], inputs[1], grad])
39 changes: 39 additions & 0 deletions forge/forge/op/eval/forge/embedding_bw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0
import torch
from ..common import to_torch_operands
from forge._C import DataFormat
from forge._C.graph import RuntimeTensorTransform, RuntimeTensorTransformType
from ....forgeglobal import TILE_DIM


def eval(type, attr, ops):
assert type == "embedding_bw"
assert len(ops) == 2
t_ops = to_torch_operands(*ops)
input = t_ops[0]
weight = t_ops[1]
grad = t_ops[1]

result = torch.zeros(weight.shape)
for i, idx in enumerate(input):
result[idx] = grad[i]
return result


def shape(type, attr, tensor_shapes):
assert type == "embedding_bw"
return tensor_shapes[1], []


def lower(type, attr, lc, ops, outputs):
assert False, "embedding_bw should not be lowered"


def decompose(type, attr, dc, inputs):
assert False, "embedding_bw should not be decomposed"


def backward(type, attr, ac, operand, inputs, output, grad):
assert False, "embedding_bw should not be backwarded"
15 changes: 8 additions & 7 deletions forge/forge/op/eval/forge/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ def backward(type, attr, ac, operand, inputs, output, grad):

elif type == "reshape":
shape = inputs[0].shape
return ac.op(type, (grad,), attributes=(shape))
return ac.op(type, (grad,), attributes=(shape), named_attrs={"shape": shape})

elif type == "conv2d_depthwise_weights":
return ac.op("conv2d_depthwise_weights_bw", (grad,), attributes=attr)
Expand Down Expand Up @@ -881,7 +881,12 @@ def backward(type, attr, ac, operand, inputs, output, grad):
break

# pass the gradient for selected part
grad_slice = ac.op("select", (grad,), (dim, grad_offset, length, current_size))
grad_slice = ac.op(
"select",
(grad,),
(dim, grad_offset, length, current_size),
named_attrs={"dim": dim, "begin": grad_offset, "length": length, "stride": current_size},
)
if grad_return is None:
grad_return = grad_slice
else:
Expand Down Expand Up @@ -954,7 +959,7 @@ def backward(type, attr, ac, operand, inputs, output, grad):
dim = attr[0]
if grad.shape.len() == 4: # Cannot unsqueeze beyond 4D
return ac.op(Nop.create(), (grad,))
return ac.op("unsqueeze", (grad,), attributes=(dim, grad.shape.len()))
return ac.op("unsqueeze", (grad,), attributes=(dim, grad.shape.len()), named_attrs={"dim": dim})

elif type == "broadcast":
assert len(attr) == 3
Expand Down Expand Up @@ -2009,10 +2014,6 @@ def decompose_post_optimize(type, attr, dc, inputs):


def decompose_post_autograd(type, attr, dc, inputs):
if type == "select":
decompose_select(attr, dc, inputs)
return

if type == "reshape":
assert len(inputs) == 1
input_shape = inputs[0].shape.as_list()
Expand Down
59 changes: 44 additions & 15 deletions forge/forge/op/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@


def Transpose(name: str, operandA: Tensor, dim0: int, dim1: int) -> Tensor:

"""
Tranpose X and Y (i.e. rows and columns) dimensions.
Expand Down Expand Up @@ -163,7 +162,13 @@ def AdvIndex(
return op("adv_index", name, operandA, operandB, attrs=(dim,)).get_tensor()


def Select(name: str, operandA: Tensor, dim: int, index: Union[int, Tuple[int, int]], stride: int = 0) -> Tensor:
def Select(
name: str,
operandA: Tensor,
dim: int,
index: Union[int, Tuple[int, int]],
stride: int = 0,
) -> Tensor:
"""
TM
Expand All @@ -190,7 +195,7 @@ def Select(name: str, operandA: Tensor, dim: int, index: Union[int, Tuple[int, i
Tensor
Forge tensor
"""
dims = len(operandA.shape.dims)
dims = len(operandA.shape)
if dim < 0:
dim += dims

Expand All @@ -200,22 +205,31 @@ def Select(name: str, operandA: Tensor, dim: int, index: Union[int, Tuple[int, i
index = (index, 1)

if stride == 0:
stride = operandA.shape.get_pytorch_shape()[dim]
stride = operandA.shape[dim]

start, length = index
assert (
start < operandA.shape.get_pytorch_shape()[dim]
), f"start = {start} should be < operandA.shape.get_pytorch_shape()[{dim}] = {operandA.shape.get_pytorch_shape()[dim]}"
assert (start + length) <= operandA.shape.get_pytorch_shape()[
assert start < operandA.shape[dim], f"start = {start} should be < operandA.shape[{dim}] = {operandA.shape[dim]}"
assert (start + length) <= operandA.shape[
dim
], f"(start = {start} + length = {length}) should be <= operandA.shape.get_pytorch_shape()[{dim}] = {operandA.shape.get_pytorch_shape()[dim]}"
], f"(start = {start} + length = {length}) should be <= operandA.shape[{dim}] = {operandA.shape[dim]}"
assert (
stride <= operandA.shape.get_pytorch_shape()[dim]
), f"stride = {stride} should be <= operandA.shape.get_pytorch_shape()[{dim}] = {operandA.shape.get_pytorch_shape()[dim]}"
stride <= operandA.shape[dim]
), f"stride = {stride} should be <= operandA.shape[{dim}] = {operandA.shape[dim]}"
assert (start + length) <= stride, f"(start = {start} + length = {length}) should be <= stride = {stride}"
assert (start + length) > 0, f"(start = {start} + length = {length}) should be > 0"

return op("select", name, operandA, attrs=(dim, index[0], index[1], stride)).get_tensor()
return op(
"select",
name,
operandA,
attrs=(dim, index[0], index[1], stride),
**{
"dim": dim,
"begin": index[0],
"length": index[1],
"stride": stride,
},
).get_tensor()


def Pad(
Expand Down Expand Up @@ -388,7 +402,14 @@ def RepeatInterleave(name: str, operandA: Tensor, repeats: int, dim: int) -> Ten
Tensor
Forge tensor
"""
return op("repeat_interleave", name, operandA, attrs=(repeats, dim), repeats=repeats, dim=dim).get_tensor()
return op(
"repeat_interleave",
name,
operandA,
attrs=(repeats, dim),
repeats=repeats,
dim=dim,
).get_tensor()


def Unsqueeze(name: str, operandA: Tensor, dim: int) -> Tensor:
Expand Down Expand Up @@ -513,7 +534,12 @@ def ForgePad(name: str, operandA: Tensor, paddings: Tuple[int, int], value: floa
return op("forge_pad", name, operandA, attrs=(paddings[0], paddings[1], value)).get_tensor()


def ForgeUnpad(name: str, operandA: Tensor, original_length: Tuple[int, ...], paddings: Tuple[int, int]) -> Tensor:
def ForgeUnpad(
name: str,
operandA: Tensor,
original_length: Tuple[int, ...],
paddings: Tuple[int, int],
) -> Tensor:
"""
Unpad operation that removes arbitrary number of tiles by any dimension.
Expand All @@ -532,5 +558,8 @@ def ForgeUnpad(name: str, operandA: Tensor, original_length: Tuple[int, ...], pa
Tuple of paddings for R and C dimensions
"""
return op(
"forge_unpad", name, operandA, attrs=(paddings[0], paddings[1], original_length[0], original_length[1])
"forge_unpad",
name,
operandA,
attrs=(paddings[0], paddings[1], original_length[0], original_length[1]),
).get_tensor()
28 changes: 28 additions & 0 deletions forge/test/mlir/llama/test_llama_backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer

import forge
from test.mlir.llama.utils.utils import load_model


# TODO(tt-mlir issue #1503): This test is failing because the embedding op doesn't work with FP32.
# It should be fixed in the tt-mlir compiler soon.
@pytest.mark.parametrize("model_path", ["openlm-research/open_llama_3b"])
@pytest.mark.xfail()
def test_llama_backward(model_path):
# Load Model and Tokenizer
framework_model, tokenizer = load_model(model_path)

prompt = "Q: What is the largest animal?\nA:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

loss_fn = torch.nn.CrossEntropyLoss()
framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=1e-3)

# Compile the model with loss and optimizer, this will invoke an autograd pass which produces bwd graph.
compiled_model = forge.compile(framework_model, input_ids, loss=loss_fn, optimizer=framework_optimizer)
24 changes: 24 additions & 0 deletions forge/test/mlir/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1979,3 +1979,27 @@ def forward(self, a):
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize("shape", [(1, 32, 64, 64), (32, 64, 64), (64, 64)])
@pytest.mark.parametrize("dim", [-1, -2])
@pytest.mark.parametrize("begin", [0, 16])
@pytest.mark.parametrize("length", [4, 16])
@pytest.mark.parametrize("stride", [16, 32])
def test_select(shape, dim, begin, length, stride):
if stride <= begin + length:
pytest.skip("Skipping since stride <= begin + length")

class Select(forge.ForgeModule):
def __init__(self):
super().__init__("Select")

def forward(self, x):
x = forge.op.Select("select_op", x, dim, [begin, length], stride)
return x

inputs = to_forge_tensors([torch.rand(*shape)])
framework_model = Select()
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model)

0 comments on commit b82c130

Please sign in to comment.