Skip to content

Commit

Permalink
[train] fix param tensor references (#906)
Browse files Browse the repository at this point in the history
Remove the `.detach()` call when storing the tensor which doesn't have
the consteval trace. This makes the tensor in our runtime, and its
analagous tensor in torch, share the same reference. So, any changes
in the underlaying data will be reflected in both our runtime and torch.

Before this change, the tensors were not referencing the same object,
but they still shared the data - that is the reason why the optimizer on
cpu (torch) worked even before. However, if we would change the
underlaying data of our tensor (via `our_tensor.data = new_data`), as is
the case with running optimizer on the device, this would not reflect on
the original tensor (in torch).

Add assertions to make sure we are sharing the same tensor with torch.

Also, fixes an issue with the registering the hooks for multiple
gradients (lambda was taking the gradient id by ref instead of by
value).
  • Loading branch information
pilkicTT authored Dec 16, 2024
1 parent 79e186f commit 492308b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 20 deletions.
23 changes: 20 additions & 3 deletions forge/forge/compiled_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,18 @@ def __call__(self, *inputs: AnyTensor) -> List[torch.Tensor]:
logger.info("Converting inputs and parameters to PyTorch tensors...")
inputs_and_parameters = to_pt_tensors(inputs_and_parameters)

if self.training() and isinstance(self.framework_module, PyTorchModule):
for name, param in self.framework_module.module.named_parameters():
if param.requires_grad:
our_tensor = self.fwd_compiled_graph_state.get_parameter_tensor(name)

# NOTE: for parameters that require gradients, we want to share the same tensor with the PyTorch module.
# This is because we want to be able to optimize the parameters both on the device (through our runtime)
# and via the torch optimizers. So this ensures that whichever side updates the parameter value, the other side can see the change.
#
# This could change in the future, but for now ensure that our premise is correct.
assert param is our_tensor

logger.info(
f"Running model {self.framework_module.get_name()} {self.fwd_compiled_graph_state.graph.get_name()} on device..."
)
Expand All @@ -272,19 +284,21 @@ def __call__(self, *inputs: AnyTensor) -> List[torch.Tensor]:
self.outputs[output_name] = output
model_outputs.append(output)

if self.fwd_compiled_graph_state.graph.training():
if self.training():
# For executing loss and its backward graph on CPU, we need to tell torch to compute gradients.
for idx, output in enumerate(model_outputs):
output.requires_grad = True
output.register_hook(lambda grad: self.tie_grad_fn(idx, grad))
# NOTE: the default idx parameter for the lambda is used to capture the idx by value. Otherwise, the lambda
# would capture the idx by reference, and all the lambdas would have the same idx value.
output.register_hook(lambda grad, idx=idx: self.tie_grad_fn(idx, grad))

return model_outputs

def forward(self, *inputs: AnyTensor) -> List[torch.Tensor]:
return self(inputs)

def backward(self) -> List[torch.Tensor]:
assert self.fwd_compiled_graph_state.graph.training(), "Model not compiled for training."
assert self.training(), "Model not compiled for training."
assert self.bwd_compiled_graph_state is not None, "Backward graph should be present for training."
consts_and_params = [
*self.bwd_compiled_graph_state.get_ordered_constant_tensors(),
Expand Down Expand Up @@ -342,3 +356,6 @@ def backward(self) -> List[torch.Tensor]:
self.attached_module.backward()

return grads

def training(self) -> bool:
return self.fwd_compiled_graph_state.graph.training()
26 changes: 9 additions & 17 deletions forge/forge/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,12 +1378,9 @@ def compare_tensors(t0, t1):
def const_eval_tensor(inputs, consteval_trace, input_name, is_forge=True):
contains_recorded_operations = consteval_trace[input_name]
if contains_recorded_operations:
value = consteval_input(
consteval_trace,
input_name,
inputs,
is_forge,
)
value = detach_tensors(
[consteval_input(consteval_trace, input_name, inputs, is_forge)], fix_non_contiguos=True
)[0]
else:
value = pad_pytorch_tensor_to_forge(inputs[input_name], []) if is_forge else inputs[input_name]
# cast if necessary
Expand Down Expand Up @@ -1426,17 +1423,12 @@ def get_post_const_eval_tensors(
constant_nodes, device_constant_and_parameters, consteval_trace, input_name, is_forge
)

post_const_eval_constants[input_name] = detach_tensors(
[
const_eval_tensor(
inputs,
consteval_trace,
input_name,
is_forge,
)
],
fix_non_contiguos=True,
)[0]
post_const_eval_constants[input_name] = const_eval_tensor(
inputs,
consteval_trace,
input_name,
is_forge,
)

return post_const_eval_constants

Expand Down

0 comments on commit 492308b

Please sign in to comment.