Skip to content

Commit

Permalink
Fixed reduce_sum lowering not working properly in backward graph (#408)
Browse files Browse the repository at this point in the history
reduce_sum when emited in creation of backward graph did not have named parametar keep_dim.
Added named parametars to the AutogradContext.op(). Passed keep_dim for reduce_sum for broadcast and element wise backward operations.
Added test_batch_size for inference and training.
Fix for #354

* Added named_args to the graphlib::OpType so that reduce_sum can be passed keep_dim parametar

* added minimal test for which it does not work

* Parametrized test_batch_size

* Added test case for batch_size

* Fixed missing asserts, added flag for training

* moved test_batch_size to the mlir/test_features.py and broke it into two
  • Loading branch information
ndrakulicTT authored Oct 18, 2024
1 parent 971ff1f commit 38fadee
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 14 deletions.
2 changes: 1 addition & 1 deletion forge/csrc/autograd/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ void autograd_engine::create_backward_graph(const grad_map &requires_grad_map)

NodeContext src = last_out;
last_out = create_op(
OpType("reduce_sum", {dim}),
OpType("reduce_sum", {dim}, {}, {{"keep_dim", true}}),
{src},
node,
edge.consumer_input_port_id,
Expand Down
13 changes: 8 additions & 5 deletions forge/csrc/autograd/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ void AutogradModule(py::module &m_autograd)
[](tt::autograd::autograd_context &self,
std::variant<std::string, py::object> const &type,
std::vector<tt::autograd::NodeContext> operands,
std::vector<graphlib::OpType::Attr> attributes)
std::vector<graphlib::OpType::Attr> attributes,
ForgeOpAttrs named_attrs = {})
{
graphlib::OpType op_type = std::holds_alternative<std::string>(type)
? graphlib::OpType(std::get<std::string>(type), attributes)
: std::get<py::object>(type).attr("op_type").cast<graphlib::OpType>();
graphlib::OpType op_type =
std::holds_alternative<std::string>(type)
? graphlib::OpType(std::get<std::string>(type), attributes, {}, named_attrs)
: std::get<py::object>(type).attr("op_type").cast<graphlib::OpType>();

if (std::holds_alternative<std::string>(type))
TT_LOG_ASSERT(
Expand All @@ -61,7 +63,8 @@ void AutogradModule(py::module &m_autograd)
},
py::arg("type"),
py::arg("operands"),
py::arg("attributes") = std::vector<graphlib::OpType::Attr>())
py::arg("attributes") = std::vector<graphlib::OpType::Attr>(),
py::arg("named_attrs") = ForgeOpAttrs())
.def(
"create_optimizer_op",
[](tt::autograd::autograd_context &self,
Expand Down
6 changes: 3 additions & 3 deletions forge/forge/op/eval/forge/eltwise_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,14 @@ def backward(op_type, attr, ac, operand, inputs, output, grad):
for i in range(len(shapes[operand])):
if shapes[operand][i] < grad_shape[i]:
# Negative indexing for reduce axis
grad = ac.op("reduce_sum", (grad,), (i - grad_shape_len,))
grad = ac.op("reduce_sum", (grad,), (i - grad_shape_len,), {"keep_dim": True})
return ac.op(Nop.create(), (grad,)) # pass gradient through

elif op_type == "subtract":
if inputs[operand].shape != grad.shape:
for i in range(len(shapes[operand])):
if shapes[operand][i] < grad.shape[i]:
grad = ac.op("reduce_sum", (grad,), (i,))
grad = ac.op("reduce_sum", (grad,), (i,), {"keep_dim": True})
if operand == 0:
return ac.op(Nop.create(), (grad,))
else:
Expand All @@ -243,7 +243,7 @@ def backward(op_type, attr, ac, operand, inputs, output, grad):
if inputs[operand].shape != grad.shape:
for i in range(len(shapes[operand])):
if shapes[operand][i] < grad_shape[i]:
op_grad = ac.op("reduce_sum", (op_grad,), (i - grad_shape_len,))
op_grad = ac.op("reduce_sum", (op_grad,), (i - grad_shape_len,), {"keep_dim": True})
return op_grad

elif op_type == "maximum":
Expand Down
4 changes: 2 additions & 2 deletions forge/forge/op/eval/forge/tm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,15 +1018,15 @@ def backward(type, attr, ac, operand, inputs, output, grad):
assert attr[0] >= 0 and attr[0] <= 3, f"Invalid broadcast dim after lowering: {attr[0]}"

if attr[0] == 2 or attr[0] == 3:
ret = ac.op("reduce_sum", (grad,), (attr[0],))
ret = ac.op("reduce_sum", (grad,), (attr[0],), {"keep_dim": True})
else:
ret = ac.op(
TransposeTM.create(attr[0], -2, z_dim_slice=grad.shape[-2]),
[
grad,
],
)
ret = ac.op("reduce_sum", (ret,), (-2,))
ret = ac.op("reduce_sum", (ret,), (-2,), {"keep_dim": True})
ret = ac.op(
TransposeTM.create(attr[0], -2, z_dim_slice=ret.shape[-2]),
[
Expand Down
6 changes: 4 additions & 2 deletions forge/test/mlir/mnist/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch
from torch import nn

Expand Down Expand Up @@ -58,14 +60,14 @@ def test_mnist_training():
# Forward pass (prediction) on device
pred = tt_model(data)[0]
golden_pred = framework_model(data)
compare_with_golden(golden_pred, pred)
assert compare_with_golden(golden_pred, pred, pcc=0.95)

# Compute loss on CPU
loss = loss_fn(pred, target)
total_loss += loss.item()

golden_loss = loss_fn(golden_pred, target)
compare_with_golden(golden_loss, loss)
assert torch.allclose(loss, golden_loss, rtol=1e-2)

# Run backward pass on device
loss.backward()
Expand Down
66 changes: 65 additions & 1 deletion forge/test/mlir/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import nn

import forge
from forge.op.eval.common import compare_with_golden_pcc
from forge.op.eval.common import compare_with_golden_pcc, compare_with_golden


def test_multiple_inputs():
Expand Down Expand Up @@ -90,3 +90,67 @@ def forward(self, a):
co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])


@pytest.mark.parametrize("batch_size", [1, 2, 16, 64, 512])
@pytest.mark.parametrize("in_features", [784])
@pytest.mark.parametrize("out_features", [10])
def test_batch_size_inference(batch_size, in_features, out_features):
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(in_features, out_features)

def forward(self, x):
y = self.linear(x)
return nn.functional.softmax(y, dim=-1)

in_data = torch.rand(batch_size, in_features)
out_data = torch.randint(0, out_features, (batch_size,))

model = SimpleModel()

tt_model = forge.compile(model, sample_inputs=[torch.rand(batch_size, in_features)])

pred = tt_model(in_data)[0]
golden_pred = model(in_data)
assert compare_with_golden(golden_pred, pred, pcc=0.95) # 0.95 is the minimum value for which the test passes


@pytest.mark.parametrize("batch_size", [1, 2, 16, 64, 512])
@pytest.mark.parametrize("in_features", [784])
@pytest.mark.parametrize("out_features", [10])
def test_batch_size_training(batch_size, in_features, out_features):
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(in_features, out_features)

def forward(self, x):
y = self.linear(x)
return nn.functional.softmax(y, dim=-1)

in_data = torch.rand(batch_size, in_features)
out_data = torch.randint(0, out_features, (batch_size,))
target = nn.functional.one_hot(out_data, num_classes=out_features).float()

model = SimpleModel()

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
tt_model = forge.compile(
model, sample_inputs=[torch.rand(batch_size, in_features)], loss=loss_fn, optimizer=optimizer
)

optimizer.zero_grad()

pred = tt_model(in_data)[0]
golden_pred = model(in_data)
assert compare_with_golden(golden_pred, pred, pcc=0.95) # 0.95 is the minimum value for which the test passes

loss = loss_fn(pred, target)
golden_loss = loss_fn(golden_pred, target)
assert torch.allclose(loss, golden_loss, rtol=1e-2) # 1e-2 is the minimum value for which the test passes

loss.backward()
tt_model.backward(pred.grad)

0 comments on commit 38fadee

Please sign in to comment.