Skip to content

Commit

Permalink
[training] add torch gradient hooks (#421)
Browse files Browse the repository at this point in the history
Currently, the loss and optimizer parts are ran on the CPU. In order to
run backward pass (on the device) we need to know the gradients coming
from the `loss.backward()`. Before this change, the gradients from the
backward part of the loss function were manually passed to the backward
pass by the user.

This change adds appropriate torch hooks which are triggered every time
a gradient is computed. The hook will then tie the calculated gradient
on the CPU to our runtime, so that it can be passed by the forge runtime
to the backward program when needed.

To discern between graph outputs which are used/defined by the user and
those which are used by us (to pass information between different
graphs/programs), output type enum is introduced. In this space, there
is still a lot of refactoring/redesigning to be done, once we get a
firmer grasp of what we want/need.

Closes #364
  • Loading branch information
pilkicTT authored Oct 21, 2024
1 parent 94a20e4 commit 0b9b8dc
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 29 deletions.
13 changes: 13 additions & 0 deletions forge/csrc/graph_lib/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,19 @@ std::vector<std::string> Graph::get_ordered_output_names() const
return ordered_outputs;
}

std::vector<std::string> Graph::get_ordered_external_output_names() const
{
std::vector<std::string> ordered_outputs;
for (auto output_node_id : this->ordered_module_output_node_ids_)
{
if (this->node_by_id(output_node_id)->as<OutputNode>()->output_type() == OutputType::External)
{
ordered_outputs.push_back(this->node_by_id(output_node_id)->name());
}
}
return ordered_outputs;
}

bool Graph::contains_nodes_of_epoch_type(NodeEpochType node_epoch_type) const
{
// Cache if it starts getting slow?
Expand Down
1 change: 1 addition & 0 deletions forge/csrc/graph_lib/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ class Graph
std::vector<std::string> get_ordered_input_names() const;
std::vector<std::string> get_ordered_intermediate_names() const;
std::vector<std::string> get_ordered_output_names() const;
std::vector<std::string> get_ordered_external_output_names() const;
std::vector<std::string> get_ordered_input_gradient_names() const;
std::vector<std::string> get_ordered_output_gradient_names() const;
std::vector<unsigned int> get_ordered_input_subgraph_indices() const;
Expand Down
16 changes: 15 additions & 1 deletion forge/csrc/graph_lib/node_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ class ConstantInputNode : public InputNode
bool equivalent(const ConstantInputNode *other) const;
};

enum class OutputType
{
// Internal is used for outputs that are not exposed directly to the user, but used and handled internally for i/o
// between different graphs, e.g. passing intermediate values from forward graph to the backward graph.
Internal,

// Outputs which are defined by/exposed to the user, e.g. result of the forward pass.
External,
};

class OutputNode : public QueueNode
{
protected:
Expand All @@ -315,24 +325,28 @@ class OutputNode : public QueueNode
// The golden info is needed if we fractured the output and need to reconstruct it for golden comparison
std::optional<int> partial_datacopy_golden_output_index;
std::vector<OpType> golden_transforms;
OutputType output_type_;

public:
OutputNode(std::string name) :
QueueNode(name, QueueNodeType::Output, NodeType::kOutput),
requires_grad_(false),
is_loss_output_(false),
is_intermediate_(false),
untilize_(true)
untilize_(true),
output_type_(OutputType::External)
{
}
bool requires_grad() const { return requires_grad_; }
bool is_loss_output() const { return is_loss_output_; }
bool is_intermediate() const { return is_intermediate_; }
bool untilize() const { return untilize_; }
OutputType output_type() const { return output_type_; }
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
void set_loss_output() { is_loss_output_ = true; }
void set_intermediate(bool intermediate) { is_intermediate_ = intermediate; }
void set_untilize(bool should_untilize) { untilize_ = should_untilize; }
void set_output_type(OutputType output_type) { output_type_ = output_type; }
virtual std::unique_ptr<Node> clone(std::string const &name = "") const override;

void set_runtime_tensor_transform(RuntimeTensorTransform transform) { this->runtime_tensor_transform = transform; }
Expand Down
1 change: 1 addition & 0 deletions forge/csrc/graph_lib/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ void GraphModule(py::module &m_graph)
.def("get_ordered_input_names", &Graph::get_ordered_input_names)
.def("get_ordered_intermediate_names", &Graph::get_ordered_intermediate_names)
.def("get_ordered_output_names", &Graph::get_ordered_output_names)
.def("get_ordered_external_output_names", &Graph::get_ordered_external_output_names)
.def("get_ordered_target_names", &Graph::get_ordered_target_names)
.def("get_ordered_intermediate_names", &Graph::get_ordered_intermediate_names)
.def("get_ordered_input_gradient_names", &Graph::get_ordered_input_gradient_names)
Expand Down
2 changes: 2 additions & 0 deletions forge/csrc/passes/split_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ std::unique_ptr<Graph> extract_forward_graph(const Graph *graph, const std::vect
intermediate_output->set_intermediate(true);
intermediate_output->set_shape(node->shape());
intermediate_output->set_output_df(node->output_df());
intermediate_output->set_output_type(graphlib::OutputType::Internal);

auto intermediate_output_node = fwd_graph->add_node(std::move(intermediate_output), 0 /*subgraph_id=*/);
fwd_graph->add_edge(fwd_graph->get_node_by_name(node->name()), intermediate_output_node);
Expand Down Expand Up @@ -228,6 +229,7 @@ std::unique_ptr<Graph> extract_backward_graph(
auto operand = graph->data_operands(queue_node)[0];

auto output_node = graphlib::create_node<graphlib::OutputNode>(queue_node->name() + "_grad_accumulator");
output_node->set_output_type(graphlib::OutputType::Internal);
output_node->set_shape(queue_node->shape());
output_node->set_output_df(queue_node->output_df());
auto grad_out = bwd_graph->add_node(std::move(output_node), 0 /*subgraph_id=*/);
Expand Down
2 changes: 0 additions & 2 deletions forge/forge/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,6 @@ def forge_compile_from_context(context: CompileContext) -> CompiledModel:
bwd_compiled_graph_state,
context.compiled_binary,
context.modules[0],
loss_module=context.loss_module,
optimizer=context.optimizer,
)

logger.info("Compilation completed.")
Expand Down
67 changes: 46 additions & 21 deletions forge/forge/compiled_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class CompiledGraphState:
graph: Graph
ordered_input_names: List[str]
ordered_output_names: List[str]
ordered_external_output_names: List[str]
ordered_target_names: List[str]
ordered_constant_node_names: List[str]
ordered_parameter_node_names: List[str]
Expand All @@ -86,8 +87,10 @@ class CompiledGraphState:
def from_compiled_graph(module: Module, graph: Graph) -> "CompiledGraphState":
ordered_input_names = graph.get_ordered_input_names()
ordered_output_names = graph.get_ordered_output_names()
ordered_external_output_names = graph.get_ordered_external_output_names()
ordered_target_names = graph.get_ordered_target_names()
ordered_intermediate_names = graph.get_ordered_intermediate_names()
ordered_output_requires_grad = graph.get_ordered_output_requires_grad()
ordered_constant_node_names = [constant_node.name for constant_node in graph.get_constant_nodes()]
ordered_parameter_node_names = [parameter_node.name for parameter_node in graph.get_parameter_nodes()]

Expand Down Expand Up @@ -122,6 +125,7 @@ def from_compiled_graph(module: Module, graph: Graph) -> "CompiledGraphState":
graph=graph,
ordered_input_names=ordered_input_names,
ordered_output_names=ordered_output_names,
ordered_external_output_names=ordered_external_output_names,
ordered_target_names=ordered_target_names,
ordered_constant_node_names=ordered_constant_node_names,
ordered_parameter_node_names=ordered_parameter_node_names,
Expand Down Expand Up @@ -166,35 +170,54 @@ class ProgramId(IntEnum):

class CompiledModel:
"""
Callable object for running inference on the compiled model.
Callable object for running the compiled model on the device(s).
If the model is compiled for inference, only forward pass can be executed.
In case of training - forward, backward, loss and optimizer steps can be executed - depending on which of these
is compiled for the device, and which are set up to be ran separately on the CPU.
"""

fwd_compiled_graph_state: CompiledGraphState
bwd_compiled_graph_state: Optional[CompiledGraphState]

# Compiled flatbuffer binary composed of programs which execute compiled graphs (e.g., forward, backward, etc.)
compiled_binary: Binary

inputs: List[torch.Tensor]
outputs: Dict[str, torch.Tensor]
intermediates: List[torch.Tensor]

# Original user-defined module.
framework_module: AnyModule
loss_module: Optional[Module]
optimizer: Optional[torch.optim.Optimizer]

# Gradients coming from the loss.backward() - only used when loss is computed on CPU.
loss_grad: List[Optional[torch.Tensor]]

def __init__(
self,
fwd_compiled_graph_state: CompiledGraphState,
bwd_compiled_graph_state: CompiledGraphState,
bwd_compiled_graph_state: Optional[CompiledGraphState],
compiled_binary: Binary,
framework_module: AnyModule,
loss_module: Optional[Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
):
self.fwd_compiled_graph_state = fwd_compiled_graph_state
self.bwd_compiled_graph_state = bwd_compiled_graph_state
self.compiled_binary = compiled_binary
self.inputs = []
self.framework_module = framework_module
self.loss_module = loss_module
self.optimizer = optimizer
self.intermediates = []
self.loss_grad = [None] * len(fwd_compiled_graph_state.ordered_external_output_names)
self.outputs = {}

def tie_grad_fn(self, grad_id: int, grad: torch.Tensor):
"""
Hook function to tie the gradients produced by torch as inputs to the backward pass which will be ran on the
TT device.
NOTE: Should be used only when loss is computed on CPU (outside of our runtime).
"""
assert len(self.loss_grad) > grad_id, "More gradients than expected."
self.loss_grad[grad_id] = grad

def __call__(self, *inputs: AnyTensor) -> List[torch.Tensor]:
"""
Expand Down Expand Up @@ -222,45 +245,47 @@ def __call__(self, *inputs: AnyTensor) -> List[torch.Tensor]:
inputs_and_parameters = to_pt_tensors(inputs_and_parameters)

logger.info(f"Running model {self.fwd_compiled_graph_state.graph.get_name()} on device...")
model_outputs = run_binary(self.compiled_binary, int(ProgramId.FORWARD), inputs_and_parameters)
all_outputs = run_binary(self.compiled_binary, int(ProgramId.FORWARD), inputs_and_parameters)

self.intermediates = []

# The model_outputs will contain outputs that we need to return to the user, i.e. external outputs.
model_outputs = []
for idx, output_name in enumerate(self.fwd_compiled_graph_state.ordered_output_names):
output = all_outputs[idx]
if output_name in self.fwd_compiled_graph_state.ordered_intermediate_names:
self.intermediates.append(model_outputs[idx])

self.outputs = {}
self.outputs[self.fwd_compiled_graph_state.ordered_output_names[0]] = model_outputs[0]

model_outputs = [model_outputs[0]]
self.intermediates.append(output)
if output_name in self.fwd_compiled_graph_state.ordered_external_output_names:
self.outputs[output_name] = output
model_outputs.append(output)

if self.fwd_compiled_graph_state.graph.training():
# For executing loss and its backward graph on CPU, we need to tell torch to compute gradients.
for output in model_outputs:
for idx, output in enumerate(model_outputs):
output.requires_grad = True
output.register_hook(lambda grad: self.tie_grad_fn(idx, grad))

return model_outputs

def forward(self, *inputs: AnyTensor) -> List[torch.Tensor]:
return self(inputs)

def backward(self, loss_grad: torch.Tensor) -> List[torch.Tensor]:
def backward(self) -> List[torch.Tensor]:
assert self.fwd_compiled_graph_state.graph.training(), "Model not compiled for training."
assert self.bwd_compiled_graph_state is not None, "Backward graph should be present for training."
consts_and_params = [
*self.bwd_compiled_graph_state.get_ordered_constant_tensors(),
*self.bwd_compiled_graph_state.get_ordered_parameter_tensors(),
]

# Make a list from gradients passed from loss function.
if not isinstance(loss_grad, list):
loss_grad = [loss_grad]
for grad in self.loss_grad:
assert grad is not None, "Gradients not provided for backward pass."

logger.info(f"Running backward pass on model {self.bwd_compiled_graph_state.graph.get_name()} on device...")
grads = run_binary(
self.compiled_binary,
int(ProgramId.BACKWARD),
[*loss_grad, *self.intermediates, *self.inputs, *consts_and_params],
[*self.loss_grad, *self.intermediates, *self.inputs, *consts_and_params],
)

for name, param in self.framework_module.module.named_parameters():
Expand Down
2 changes: 1 addition & 1 deletion forge/test/mlir/mnist/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_mnist_training():
# Run backward pass on device
loss.backward()

tt_model.backward(pred.grad)
tt_model.backward()

if batch_idx >= limit_num_batches:
break
Expand Down
2 changes: 1 addition & 1 deletion forge/test/mlir/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,4 @@ def forward(self, x):
assert torch.allclose(loss, golden_loss, rtol=1e-2) # 1e-2 is the minimum value for which the test passes

loss.backward()
tt_model.backward(pred.grad)
tt_model.backward()
6 changes: 3 additions & 3 deletions forge/test/mlir/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import forge
import forge.config
from forge.op.eval.common import compare_with_golden_pcc
from forge.op.eval.common import compare_with_golden


def test_torch_training():
Expand Down Expand Up @@ -39,7 +39,7 @@ def forward(self, x):
output = tt_model(inputs)

output = [co.to("cpu") for co in output]
assert compare_with_golden_pcc(golden=golden, calculated=output[0], pcc=0.99)
assert compare_with_golden(golden=golden, calculated=output[0])

optimizer.zero_grad()

Expand All @@ -53,7 +53,7 @@ def forward(self, x):

loss_grad = output[0].grad
assert loss_grad is not None
grad = tt_model.backward(loss_grad)
grad = tt_model.backward()

# HACK to run the optimizer step
# i'm not sure what's the right way to tie the torch optimizer to our params,
Expand Down

0 comments on commit 0b9b8dc

Please sign in to comment.