diff --git a/forge/forge/compiled_graph_state.py b/forge/forge/compiled_graph_state.py index b138f12e5..ff82b527a 100644 --- a/forge/forge/compiled_graph_state.py +++ b/forge/forge/compiled_graph_state.py @@ -255,6 +255,18 @@ def __call__(self, *inputs: AnyTensor) -> List[torch.Tensor]: logger.info("Converting inputs and parameters to PyTorch tensors...") inputs_and_parameters = to_pt_tensors(inputs_and_parameters) + if self.training() and isinstance(self.framework_module, PyTorchModule): + for name, param in self.framework_module.module.named_parameters(): + if param.requires_grad: + our_tensor = self.fwd_compiled_graph_state.get_parameter_tensor(name) + + # NOTE: for parameters that require gradients, we want to share the same tensor with the PyTorch module. + # This is because we want to be able to optimize the parameters both on the device (through our runtime) + # and via the torch optimizers. So this ensures that whichever side updates the parameter value, the other side can see the change. + # + # This could change in the future, but for now ensure that our premise is correct. + assert param is our_tensor + logger.info( f"Running model {self.framework_module.get_name()} {self.fwd_compiled_graph_state.graph.get_name()} on device..." ) @@ -272,11 +284,13 @@ def __call__(self, *inputs: AnyTensor) -> List[torch.Tensor]: self.outputs[output_name] = output model_outputs.append(output) - if self.fwd_compiled_graph_state.graph.training(): + if self.training(): # For executing loss and its backward graph on CPU, we need to tell torch to compute gradients. for idx, output in enumerate(model_outputs): output.requires_grad = True - output.register_hook(lambda grad: self.tie_grad_fn(idx, grad)) + # NOTE: the default idx parameter for the lambda is used to capture the idx by value. Otherwise, the lambda + # would capture the idx by reference, and all the lambdas would have the same idx value. + output.register_hook(lambda grad, idx=idx: self.tie_grad_fn(idx, grad)) return model_outputs @@ -284,7 +298,7 @@ def forward(self, *inputs: AnyTensor) -> List[torch.Tensor]: return self(inputs) def backward(self) -> List[torch.Tensor]: - assert self.fwd_compiled_graph_state.graph.training(), "Model not compiled for training." + assert self.training(), "Model not compiled for training." assert self.bwd_compiled_graph_state is not None, "Backward graph should be present for training." consts_and_params = [ *self.bwd_compiled_graph_state.get_ordered_constant_tensors(), @@ -342,3 +356,6 @@ def backward(self) -> List[torch.Tensor]: self.attached_module.backward() return grads + + def training(self) -> bool: + return self.fwd_compiled_graph_state.graph.training() diff --git a/forge/forge/tensor.py b/forge/forge/tensor.py index 3efbc4aeb..8c8b657e6 100644 --- a/forge/forge/tensor.py +++ b/forge/forge/tensor.py @@ -1378,12 +1378,9 @@ def compare_tensors(t0, t1): def const_eval_tensor(inputs, consteval_trace, input_name, is_forge=True): contains_recorded_operations = consteval_trace[input_name] if contains_recorded_operations: - value = consteval_input( - consteval_trace, - input_name, - inputs, - is_forge, - ) + value = detach_tensors( + [consteval_input(consteval_trace, input_name, inputs, is_forge)], fix_non_contiguos=True + )[0] else: value = pad_pytorch_tensor_to_forge(inputs[input_name], []) if is_forge else inputs[input_name] # cast if necessary @@ -1426,17 +1423,12 @@ def get_post_const_eval_tensors( constant_nodes, device_constant_and_parameters, consteval_trace, input_name, is_forge ) - post_const_eval_constants[input_name] = detach_tensors( - [ - const_eval_tensor( - inputs, - consteval_trace, - input_name, - is_forge, - ) - ], - fix_non_contiguos=True, - )[0] + post_const_eval_constants[input_name] = const_eval_tensor( + inputs, + consteval_trace, + input_name, + is_forge, + ) return post_const_eval_constants