From 56e93f9940c0f8afab0b265e5e7c93d134612c56 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 24 Jul 2024 17:05:57 -0700 Subject: [PATCH 01/10] try if statement --- llmfoundry/callbacks/hf_checkpointer.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 7127d37f40..fbed36d1f0 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -435,6 +435,27 @@ def _save_checkpoint(self, state: State, logger: Logger): cpu_offload = True + # def dtensor_to_tensor_hook( + # module: nn.Module, + # state_dict: Dict[str, Any], + # prefix: str, + # *args: Any, + # ) -> Dict[str, Any]: + # dtensor_fqns = [] + # for fqn in state_dict.keys(): + # tensor = state_dict[fqn] + # if isinstance(tensor, DTensor): + # dtensor_fqns.append(fqn) + # tensor = tensor.full_tensor() # type: ignore + # if dist.get_global_rank() == 0: + # if cpu_offload: + # tensor = tensor.cpu() + # state_dict[fqn] = tensor + # if dist.get_global_rank() != 0: + # for fqn in dtensor_fqns: + # del state_dict[fqn] + # return state_dict + # Add hook to move tensors to cpu to avoid CUDA OOM def tensor_hook( module: nn.Module, @@ -465,7 +486,8 @@ def tensor_hook( hooks = [] for _, module in state_dict_model.named_modules(): - hooks.append(module._register_state_dict_hook(tensor_hook),) + if isinstance(module, FSDP): + hooks.append(module._register_state_dict_hook(tensor_hook),) state_dict = get_model_state_dict( state_dict_model, From aea32d6edb967c28425d8a9889b0f42f63fd741c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 24 Jul 2024 17:24:29 -0700 Subject: [PATCH 02/10] test --- llmfoundry/callbacks/hf_checkpointer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index fbed36d1f0..0a5eb236de 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -483,6 +483,8 @@ def tensor_hook( if dist.get_global_rank() != 0: state_dict = {} return state_dict + + assert False hooks = [] for _, module in state_dict_model.named_modules(): From e4c9691c2dd98fe76d684246b532dc9f83c18e9e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 24 Jul 2024 17:29:58 -0700 Subject: [PATCH 03/10] put back --- llmfoundry/callbacks/hf_checkpointer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 0a5eb236de..fbed36d1f0 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -483,8 +483,6 @@ def tensor_hook( if dist.get_global_rank() != 0: state_dict = {} return state_dict - - assert False hooks = [] for _, module in state_dict_model.named_modules(): From 3de4088677e47f5c19f9587ad10c68e043f7d377 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 24 Jul 2024 17:33:54 -0700 Subject: [PATCH 04/10] go back --- llmfoundry/callbacks/hf_checkpointer.py | 77 +++++++++++++++---------- 1 file changed, 45 insertions(+), 32 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index fbed36d1f0..94f1b6e47f 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -435,29 +435,7 @@ def _save_checkpoint(self, state: State, logger: Logger): cpu_offload = True - # def dtensor_to_tensor_hook( - # module: nn.Module, - # state_dict: Dict[str, Any], - # prefix: str, - # *args: Any, - # ) -> Dict[str, Any]: - # dtensor_fqns = [] - # for fqn in state_dict.keys(): - # tensor = state_dict[fqn] - # if isinstance(tensor, DTensor): - # dtensor_fqns.append(fqn) - # tensor = tensor.full_tensor() # type: ignore - # if dist.get_global_rank() == 0: - # if cpu_offload: - # tensor = tensor.cpu() - # state_dict[fqn] = tensor - # if dist.get_global_rank() != 0: - # for fqn in dtensor_fqns: - # del state_dict[fqn] - # return state_dict - - # Add hook to move tensors to cpu to avoid CUDA OOM - def tensor_hook( + def dtensor_to_tensor_hook( module: nn.Module, state_dict: Dict[str, Any], prefix: str, @@ -470,24 +448,59 @@ def tensor_hook( dtensor_fqns.append(fqn) tensor = tensor.full_tensor() # type: ignore if dist.get_global_rank() == 0: - # Offload any DTensors to CPU if cpu_offload: tensor = tensor.cpu() state_dict[fqn] = tensor - else: - state_dict[fqn] = None - # Convert the state dict to the requested precision - if isinstance(tensor, torch.Tensor): - state_dict[fqn] = tensor.to(dtype=self.dtype) - del tensor if dist.get_global_rank() != 0: - state_dict = {} + for fqn in dtensor_fqns: + del state_dict[fqn] return state_dict + + # def tensor_dtype_hook( + # module: nn.Module, + # state_dict: Dict[str, Any], + # prefix: str, + # *args: Any, + # ) -> Dict[str, Any]: + # for fqn in state_dict.keys(): + # tensor = state_dict[fqn] + # if isinstance(tensor, torch.Tensor): + # state_dict[fqn] = tensor.to(dtype=self.dtype) + # del tensor + # return state_dict + + # # Add hook to move tensors to cpu to avoid CUDA OOM + # def tensor_hook( + # module: nn.Module, + # state_dict: Dict[str, Any], + # prefix: str, + # *args: Any, + # ) -> Dict[str, Any]: + # dtensor_fqns = [] + # for fqn in state_dict.keys(): + # tensor = state_dict[fqn] + # if isinstance(tensor, DTensor): + # dtensor_fqns.append(fqn) + # tensor = tensor.full_tensor() # type: ignore + # if dist.get_global_rank() == 0: + # # Offload any DTensors to CPU + # if cpu_offload: + # tensor = tensor.cpu() + # state_dict[fqn] = tensor + # else: + # state_dict[fqn] = None + # # Convert the state dict to the requested precision + # if isinstance(tensor, torch.Tensor): + # state_dict[fqn] = tensor.to(dtype=self.dtype) + # del tensor + # if dist.get_global_rank() != 0: + # state_dict = {} + # return state_dict hooks = [] for _, module in state_dict_model.named_modules(): if isinstance(module, FSDP): - hooks.append(module._register_state_dict_hook(tensor_hook),) + hooks.append(module._register_state_dict_hook(dtensor_to_tensor_hook),) state_dict = get_model_state_dict( state_dict_model, From bff6de3ecf75a035fd0b1c0e6742fd133d592fcd Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 24 Jul 2024 17:42:39 -0700 Subject: [PATCH 05/10] try --- llmfoundry/callbacks/hf_checkpointer.py | 88 ++++++++++++------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 94f1b6e47f..b5c3726854 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -435,42 +435,7 @@ def _save_checkpoint(self, state: State, logger: Logger): cpu_offload = True - def dtensor_to_tensor_hook( - module: nn.Module, - state_dict: Dict[str, Any], - prefix: str, - *args: Any, - ) -> Dict[str, Any]: - dtensor_fqns = [] - for fqn in state_dict.keys(): - tensor = state_dict[fqn] - if isinstance(tensor, DTensor): - dtensor_fqns.append(fqn) - tensor = tensor.full_tensor() # type: ignore - if dist.get_global_rank() == 0: - if cpu_offload: - tensor = tensor.cpu() - state_dict[fqn] = tensor - if dist.get_global_rank() != 0: - for fqn in dtensor_fqns: - del state_dict[fqn] - return state_dict - - # def tensor_dtype_hook( - # module: nn.Module, - # state_dict: Dict[str, Any], - # prefix: str, - # *args: Any, - # ) -> Dict[str, Any]: - # for fqn in state_dict.keys(): - # tensor = state_dict[fqn] - # if isinstance(tensor, torch.Tensor): - # state_dict[fqn] = tensor.to(dtype=self.dtype) - # del tensor - # return state_dict - - # # Add hook to move tensors to cpu to avoid CUDA OOM - # def tensor_hook( + # def dtensor_to_tensor_hook( # module: nn.Module, # state_dict: Dict[str, Any], # prefix: str, @@ -483,24 +448,59 @@ def dtensor_to_tensor_hook( # dtensor_fqns.append(fqn) # tensor = tensor.full_tensor() # type: ignore # if dist.get_global_rank() == 0: - # # Offload any DTensors to CPU # if cpu_offload: # tensor = tensor.cpu() # state_dict[fqn] = tensor - # else: - # state_dict[fqn] = None - # # Convert the state dict to the requested precision + # if dist.get_global_rank() != 0: + # for fqn in dtensor_fqns: + # del state_dict[fqn] + # return state_dict + + # def tensor_dtype_hook( + # module: nn.Module, + # state_dict: Dict[str, Any], + # prefix: str, + # *args: Any, + # ) -> Dict[str, Any]: + # for fqn in state_dict.keys(): + # tensor = state_dict[fqn] # if isinstance(tensor, torch.Tensor): # state_dict[fqn] = tensor.to(dtype=self.dtype) # del tensor - # if dist.get_global_rank() != 0: - # state_dict = {} # return state_dict + # Add hook to move tensors to cpu to avoid CUDA OOM + def tensor_hook( + module: nn.Module, + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> Dict[str, Any]: + dtensor_fqns = [] + for fqn in state_dict.keys(): + tensor = state_dict[fqn] + if isinstance(tensor, DTensor): + dtensor_fqns.append(fqn) + tensor = tensor.full_tensor() # type: ignore + if dist.get_global_rank() == 0: + # Offload any DTensors to CPU + if cpu_offload: + tensor = tensor.cpu() + state_dict[fqn] = tensor + # Convert the state dict to the requested precision + if isinstance(tensor, torch.Tensor): + state_dict[fqn] = tensor.to(dtype=self.dtype) + del tensor + if dist.get_global_rank() != 0: + for fqn in dtensor_fqns: + del state_dict[fqn] + return state_dict + hooks = [] for _, module in state_dict_model.named_modules(): - if isinstance(module, FSDP): - hooks.append(module._register_state_dict_hook(dtensor_to_tensor_hook),) + # if isinstance(module, FSDP): + hooks.append(module._register_state_dict_hook(tensor_hook),) + state_dict = get_model_state_dict( state_dict_model, From 04bbfb91c20ea894eda293aab646d1dd0fdb98f4 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 24 Jul 2024 18:08:59 -0700 Subject: [PATCH 06/10] try --- llmfoundry/callbacks/hf_checkpointer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index b5c3726854..6ec12de517 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -487,19 +487,20 @@ def tensor_hook( if cpu_offload: tensor = tensor.cpu() state_dict[fqn] = tensor + else: + state_dict[fqn] = None # Convert the state dict to the requested precision - if isinstance(tensor, torch.Tensor): - state_dict[fqn] = tensor.to(dtype=self.dtype) + # if isinstance(tensor, torch.Tensor): + # state_dict[fqn] = tensor.to(dtype=self.dtype) del tensor if dist.get_global_rank() != 0: - for fqn in dtensor_fqns: - del state_dict[fqn] + state_dict = {} return state_dict hooks = [] for _, module in state_dict_model.named_modules(): - # if isinstance(module, FSDP): - hooks.append(module._register_state_dict_hook(tensor_hook),) + if isinstance(module, FSDP): + hooks.append(module._register_state_dict_hook(tensor_hook),) state_dict = get_model_state_dict( From 96e5194ba9089b655ab1db164d620b6662af3c35 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 24 Jul 2024 18:22:01 -0700 Subject: [PATCH 07/10] try --- llmfoundry/callbacks/hf_checkpointer.py | 42 ++++++++++++++++++++----- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 6ec12de517..e8866b9f93 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -470,6 +470,33 @@ def _save_checkpoint(self, state: State, logger: Logger): # return state_dict # Add hook to move tensors to cpu to avoid CUDA OOM + # def tensor_hook( + # module: nn.Module, + # state_dict: Dict[str, Any], + # prefix: str, + # *args: Any, + # ) -> Dict[str, Any]: + # dtensor_fqns = [] + # for fqn in state_dict.keys(): + # tensor = state_dict[fqn] + # if isinstance(tensor, DTensor): + # dtensor_fqns.append(fqn) + # tensor = tensor.full_tensor() # type: ignore + # if dist.get_global_rank() == 0: + # # Offload any DTensors to CPU + # if cpu_offload: + # tensor = tensor.cpu() + # tensor = tensor.to(dtype=self.dtype) + # state_dict[fqn] = tensor + # else: + # state_dict[fqn] = None + # elif isinstance(tensor, torch.Tensor): + # state_dict[fqn] = tensor.to(dtype=self.dtype) + # del tensor + # if dist.get_global_rank() != 0: + # state_dict = {} + # return state_dict + def tensor_hook( module: nn.Module, state_dict: Dict[str, Any], @@ -483,18 +510,17 @@ def tensor_hook( dtensor_fqns.append(fqn) tensor = tensor.full_tensor() # type: ignore if dist.get_global_rank() == 0: - # Offload any DTensors to CPU if cpu_offload: tensor = tensor.cpu() state_dict[fqn] = tensor - else: - state_dict[fqn] = None - # Convert the state dict to the requested precision - # if isinstance(tensor, torch.Tensor): - # state_dict[fqn] = tensor.to(dtype=self.dtype) - del tensor if dist.get_global_rank() != 0: - state_dict = {} + for fqn in dtensor_fqns: + del state_dict[fqn] + + for fqn in state_dict.keys(): + if isinstance(state_dict[fqn], torch.Tensor): + state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) + return state_dict hooks = [] From 710a3ed60b7450a0e797305c8d83a7d3e1b5be22 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 24 Jul 2024 18:34:30 -0700 Subject: [PATCH 08/10] cleaner --- llmfoundry/callbacks/hf_checkpointer.py | 66 ++++++++++++------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index e8866b9f93..bad9084235 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -470,33 +470,6 @@ def _save_checkpoint(self, state: State, logger: Logger): # return state_dict # Add hook to move tensors to cpu to avoid CUDA OOM - # def tensor_hook( - # module: nn.Module, - # state_dict: Dict[str, Any], - # prefix: str, - # *args: Any, - # ) -> Dict[str, Any]: - # dtensor_fqns = [] - # for fqn in state_dict.keys(): - # tensor = state_dict[fqn] - # if isinstance(tensor, DTensor): - # dtensor_fqns.append(fqn) - # tensor = tensor.full_tensor() # type: ignore - # if dist.get_global_rank() == 0: - # # Offload any DTensors to CPU - # if cpu_offload: - # tensor = tensor.cpu() - # tensor = tensor.to(dtype=self.dtype) - # state_dict[fqn] = tensor - # else: - # state_dict[fqn] = None - # elif isinstance(tensor, torch.Tensor): - # state_dict[fqn] = tensor.to(dtype=self.dtype) - # del tensor - # if dist.get_global_rank() != 0: - # state_dict = {} - # return state_dict - def tensor_hook( module: nn.Module, state_dict: Dict[str, Any], @@ -510,18 +483,45 @@ def tensor_hook( dtensor_fqns.append(fqn) tensor = tensor.full_tensor() # type: ignore if dist.get_global_rank() == 0: + # Offload any DTensors to CPU if cpu_offload: tensor = tensor.cpu() + tensor = tensor.to(dtype=self.dtype) state_dict[fqn] = tensor + else: + state_dict[fqn] = None + elif isinstance(tensor, torch.Tensor): + state_dict[fqn] = tensor.to(dtype=self.dtype) + del tensor if dist.get_global_rank() != 0: - for fqn in dtensor_fqns: - del state_dict[fqn] + state_dict = {} + return state_dict + + # def tensor_hook( + # module: nn.Module, + # state_dict: Dict[str, Any], + # prefix: str, + # *args: Any, + # ) -> Dict[str, Any]: + # dtensor_fqns = [] + # for fqn in state_dict.keys(): + # tensor = state_dict[fqn] + # if isinstance(tensor, DTensor): + # dtensor_fqns.append(fqn) + # tensor = tensor.full_tensor() # type: ignore + # if dist.get_global_rank() == 0: + # if cpu_offload: + # tensor = tensor.cpu() + # state_dict[fqn] = tensor + # if dist.get_global_rank() != 0: + # for fqn in dtensor_fqns: + # del state_dict[fqn] - for fqn in state_dict.keys(): - if isinstance(state_dict[fqn], torch.Tensor): - state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) + # for fqn in state_dict.keys(): + # if isinstance(state_dict[fqn], torch.Tensor): + # state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) - return state_dict + # return state_dict hooks = [] for _, module in state_dict_model.named_modules(): From b7c44efb09915c0c674db49c0326d5e2e074161e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 24 Jul 2024 18:39:56 -0700 Subject: [PATCH 09/10] clean up --- llmfoundry/callbacks/hf_checkpointer.py | 67 ++----------------------- 1 file changed, 3 insertions(+), 64 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index bad9084235..2f858dd186 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -435,40 +435,6 @@ def _save_checkpoint(self, state: State, logger: Logger): cpu_offload = True - # def dtensor_to_tensor_hook( - # module: nn.Module, - # state_dict: Dict[str, Any], - # prefix: str, - # *args: Any, - # ) -> Dict[str, Any]: - # dtensor_fqns = [] - # for fqn in state_dict.keys(): - # tensor = state_dict[fqn] - # if isinstance(tensor, DTensor): - # dtensor_fqns.append(fqn) - # tensor = tensor.full_tensor() # type: ignore - # if dist.get_global_rank() == 0: - # if cpu_offload: - # tensor = tensor.cpu() - # state_dict[fqn] = tensor - # if dist.get_global_rank() != 0: - # for fqn in dtensor_fqns: - # del state_dict[fqn] - # return state_dict - - # def tensor_dtype_hook( - # module: nn.Module, - # state_dict: Dict[str, Any], - # prefix: str, - # *args: Any, - # ) -> Dict[str, Any]: - # for fqn in state_dict.keys(): - # tensor = state_dict[fqn] - # if isinstance(tensor, torch.Tensor): - # state_dict[fqn] = tensor.to(dtype=self.dtype) - # del tensor - # return state_dict - # Add hook to move tensors to cpu to avoid CUDA OOM def tensor_hook( module: nn.Module, @@ -486,49 +452,22 @@ def tensor_hook( # Offload any DTensors to CPU if cpu_offload: tensor = tensor.cpu() - tensor = tensor.to(dtype=self.dtype) state_dict[fqn] = tensor else: state_dict[fqn] = None - elif isinstance(tensor, torch.Tensor): - state_dict[fqn] = tensor.to(dtype=self.dtype) + + if isinstance(state_dict[fqn], torch.Tensor): + state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) del tensor if dist.get_global_rank() != 0: state_dict = {} return state_dict - - # def tensor_hook( - # module: nn.Module, - # state_dict: Dict[str, Any], - # prefix: str, - # *args: Any, - # ) -> Dict[str, Any]: - # dtensor_fqns = [] - # for fqn in state_dict.keys(): - # tensor = state_dict[fqn] - # if isinstance(tensor, DTensor): - # dtensor_fqns.append(fqn) - # tensor = tensor.full_tensor() # type: ignore - # if dist.get_global_rank() == 0: - # if cpu_offload: - # tensor = tensor.cpu() - # state_dict[fqn] = tensor - # if dist.get_global_rank() != 0: - # for fqn in dtensor_fqns: - # del state_dict[fqn] - - # for fqn in state_dict.keys(): - # if isinstance(state_dict[fqn], torch.Tensor): - # state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) - - # return state_dict hooks = [] for _, module in state_dict_model.named_modules(): if isinstance(module, FSDP): hooks.append(module._register_state_dict_hook(tensor_hook),) - state_dict = get_model_state_dict( state_dict_model, options=StateDictOptions( From 244f8e1d1f640b6e39a29c41086b8187a55c3bef Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 24 Jul 2024 19:00:46 -0700 Subject: [PATCH 10/10] oops --- llmfoundry/callbacks/hf_checkpointer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 2f858dd186..35508cc0c7 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -465,8 +465,7 @@ def tensor_hook( hooks = [] for _, module in state_dict_model.named_modules(): - if isinstance(module, FSDP): - hooks.append(module._register_state_dict_hook(tensor_hook),) + hooks.append(module._register_state_dict_hook(tensor_hook),) state_dict = get_model_state_dict( state_dict_model,