From 04d8bb2fcb8457297a662f108f8f2871807c4fa2 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 14 Sep 2023 12:23:25 -0400 Subject: [PATCH 1/2] fix optimizer logging --- llmfoundry/optim/adaptive_lion.py | 58 ++++++++----------------------- llmfoundry/optim/lion.py | 50 ++++++++------------------ 2 files changed, 28 insertions(+), 80 deletions(-) diff --git a/llmfoundry/optim/adaptive_lion.py b/llmfoundry/optim/adaptive_lion.py index 58c0f93ad5..e7ded752cf 100644 --- a/llmfoundry/optim/adaptive_lion.py +++ b/llmfoundry/optim/adaptive_lion.py @@ -206,28 +206,10 @@ def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): """Preprocess metrics to reduce across ranks correctly.""" - # Sort L2 norms first so they are squared before other metrics, which depend on squared values - metrics = optimizer_metrics.keys() - metrics = sorted(metrics, - key=lambda metric: 0 if 'l2_norm' in metric else 1) - for metric in metrics: - if metric.startswith('l2_norm'): - # L2 norms need to be squared, before they are reduced via summation - optimizer_metrics[metric] = optimizer_metrics[metric]**2 - elif metric.startswith('cosine'): - _, vectors, layer = tuple(metric.split('/')) - - A, B = tuple(vectors.split('_')) - - # L2 norm would've been squared in previous branch - A_rank_subset_norm = math.sqrt( - optimizer_metrics[f'l2_norm/{A}/{layer}']) - B_rank_subset_norm = math.sqrt( - optimizer_metrics[f'l2_norm/{B}/{layer}']) - - optimizer_metrics[ - metric] *= A_rank_subset_norm * B_rank_subset_norm - + # Only L2 norm metric keys are present, can skip sorting at this stage + for metric in optimizer_metrics: + # L2 norms need to be squared, before they are reduced via summation + optimizer_metrics[metric] = optimizer_metrics[metric]**2 return optimizer_metrics def report_per_parameter_metrics(self, param: torch.Tensor, name: str, @@ -287,14 +269,6 @@ class DecoupledClipLion(Optimizer): 'l2_norm/grad': lambda param, optim_state, step_tensor: torch.linalg.vector_norm( param.grad), - 'cosine/update_grad': - lambda param, optim_state, step_tensor: torch.nn.functional. - cosine_similarity( - param.grad.flatten(), step_tensor.flatten(), dim=0), - 'cosine/moment_grad': - lambda param, optim_state, step_tensor: torch.nn.functional. - cosine_similarity( - param.grad.flatten(), optim_state['exp_avg'].flatten(), dim=0), } def __init__(self, @@ -384,26 +358,22 @@ def step(self, closure: Optional[Callable] = None): return loss def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): - for metric in optimizer_metrics: + local_keys = list(optimizer_metrics.keys()) + all_gathered_keys = dist.all_gather_object(local_keys) + all_keys = set() + for keys in all_gathered_keys: + all_keys.update(keys) + + # Sort keys to ensure every rank has the same keys order + # Only L2 norm metric keys are present, can apply regular sort + all_keys = sorted(all_keys) + for metric in all_keys: if metric.startswith('l2_norm'): reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: dist.all_reduce(reduced, reduce_operation='SUM') optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced)) - elif metric.startswith('cosine'): - reduced = optimizer_metrics[metric] - if dist.get_world_size() > 1: - dist.all_reduce(reduced, reduce_operation='SUM') - - _, vectors, layer = tuple(metric.split('/')) - - A, B = tuple(vectors.split('_')) - - A_reduced_norm = optimizer_metrics[f'l2_norm/{A}/{layer}'] - B_reduced_norm = optimizer_metrics[f'l2_norm/{B}/{layer}'] - optimizer_metrics[metric] = reduced / (A_reduced_norm * - B_reduced_norm) elif metric.startswith('clipped_batches'): continue else: diff --git a/llmfoundry/optim/lion.py b/llmfoundry/optim/lion.py index cc171290b7..a26ae712e5 100644 --- a/llmfoundry/optim/lion.py +++ b/llmfoundry/optim/lion.py @@ -99,26 +99,22 @@ def step(self, closure: Optional[Callable] = None): return loss def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): - for metric in optimizer_metrics: + local_keys = list(optimizer_metrics.keys()) + all_gathered_keys = dist.all_gather_object(local_keys) + all_keys = set() + for keys in all_gathered_keys: + all_keys.update(keys) + + # Sort keys to ensure every rank has the same keys order + # Only L2 norm metric keys are present, can apply regular sort + all_keys = sorted(all_keys) + for metric in all_keys: if metric.startswith('l2_norm'): reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: dist.all_reduce(reduced, reduce_operation='SUM') optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced)) - elif metric.startswith('cosine'): - reduced = optimizer_metrics[metric] - if dist.get_world_size() > 1: - dist.all_reduce(reduced, reduce_operation='SUM') - - _, vectors, layer = tuple(metric.split('/')) - - A, B = tuple(vectors.split('_')) - - A_reduced_norm = optimizer_metrics[f'l2_norm/{A}/{layer}'] - B_reduced_norm = optimizer_metrics[f'l2_norm/{B}/{layer}'] - optimizer_metrics[metric] = reduced / (A_reduced_norm * - B_reduced_norm) else: reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: @@ -129,28 +125,10 @@ def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): """Preprocess metrics to reduce across ranks correctly.""" - # Sort L2 norms first so they are squared before other metrics, which depend on squared values - metrics = optimizer_metrics.keys() - metrics = sorted(metrics, - key=lambda metric: 0 if 'l2_norm' in metric else 1) - for metric in metrics: - if metric.startswith('l2_norm'): - # L2 norms need to be squared, before they are reduced via summation - optimizer_metrics[metric] = optimizer_metrics[metric]**2 - elif metric.startswith('cosine'): - _, vectors, layer = tuple(metric.split('/')) - - A, B = tuple(vectors.split('_')) - - # L2 norm would've been squared in previous branch - A_rank_subset_norm = math.sqrt( - optimizer_metrics[f'l2_norm/{A}/{layer}']) - B_rank_subset_norm = math.sqrt( - optimizer_metrics[f'l2_norm/{B}/{layer}']) - - optimizer_metrics[ - metric] *= A_rank_subset_norm * B_rank_subset_norm - + # Only L2 norm metric keys are present, can skip sorting at this stage + for metric in optimizer_metrics: + # L2 norms need to be squared, before they are reduced via summation + optimizer_metrics[metric] = optimizer_metrics[metric]**2 return optimizer_metrics def report_per_parameter_metrics(self, param: torch.Tensor, name: str, From e0f111dbd9f01f37cccfc02fb2ba256246113a92 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 14 Sep 2023 12:40:57 -0400 Subject: [PATCH 2/2] lint --- llmfoundry/optim/adaptive_lion.py | 2 +- llmfoundry/optim/lion.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/optim/adaptive_lion.py b/llmfoundry/optim/adaptive_lion.py index e7ded752cf..f6b74cbafe 100644 --- a/llmfoundry/optim/adaptive_lion.py +++ b/llmfoundry/optim/adaptive_lion.py @@ -209,7 +209,7 @@ def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): # Only L2 norm metric keys are present, can skip sorting at this stage for metric in optimizer_metrics: # L2 norms need to be squared, before they are reduced via summation - optimizer_metrics[metric] = optimizer_metrics[metric]**2 + optimizer_metrics[metric] = optimizer_metrics[metric]**2 return optimizer_metrics def report_per_parameter_metrics(self, param: torch.Tensor, name: str, diff --git a/llmfoundry/optim/lion.py b/llmfoundry/optim/lion.py index a26ae712e5..0caa7d2877 100644 --- a/llmfoundry/optim/lion.py +++ b/llmfoundry/optim/lion.py @@ -128,7 +128,7 @@ def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): # Only L2 norm metric keys are present, can skip sorting at this stage for metric in optimizer_metrics: # L2 norms need to be squared, before they are reduced via summation - optimizer_metrics[metric] = optimizer_metrics[metric]**2 + optimizer_metrics[metric] = optimizer_metrics[metric]**2 return optimizer_metrics def report_per_parameter_metrics(self, param: torch.Tensor, name: str,