From 23f8574ee540e7e4b16ed6537098630ba1c404f4 Mon Sep 17 00:00:00 2001 From: Longjie Zheng <32992656+zhenglongjiepheonix@users.noreply.github.com> Date: Thu, 29 Aug 2024 11:45:19 -0400 Subject: [PATCH] Add Param Cache For Recompilation (#2000) add param cache --- optimum/fx/parallelization/core.py | 6 ++++++ optimum/fx/parallelization/passes.py | 15 ++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index cba7d454441..1d13b00b468 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -125,6 +125,11 @@ class ParallelExecutionCtx: because we have to make sure we don't initiate new parameters and replace original ones when recompilation happens in training process. + - param_cache (`Dict[str, nn.Parameter]`): + Cache which keeps record of newly created parameters. Similar to `parallel_layer_cache`, we + need to make sure all the newly created parameters in the first compilation will still be used + when recompilation happens. + - weight_map (`Dict[str, str]`): Mapping between parameter names and their locations on disk, useful when loading weights from disk. @@ -140,6 +145,7 @@ class ParallelExecutionCtx: current_device: torch.device example_inputs: List[Any] = field(default_factory=list) parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict) + param_cache: Dict[str, nn.Parameter] = field(default_factory=dict) weight_map: Dict[str, str] = field(default_factory=dict) last_optimized_graph_module: Optional[GraphModule] = None compile_times: int = 0 diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 1b25e9e1233..379b027d400 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -480,18 +480,21 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf class InitializeOrLoadWeightsPass(PassBase): """ - Make weights loading/initialization a seperate pass for cleaner logic and easier extensibility. This - pass will only run once in the very first compilation step. + Weights loading and intialization pass, will initialize parameters on current rank and load weights from disk + if necessary. """ - need_rerun_when_recompile = False - def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: world_size = dist.get_world_size(ctx.tp_group) tp_rank = dist.get_rank(ctx.tp_group) - new_parameters, tied_parameters = [], {} + new_parameters, tied_parameters, param_cache = [], {}, ctx.param_cache for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)): + # skip initializing new params when recompilation happens + if name in param_cache: + new_parameters.append((name, param_cache[name])) + continue + param_meta: ParameterMeta = getattr(param, "meta") # skip already initialized/loaded tied parameters if param_meta.is_tied and id(param) in tied_parameters: @@ -569,6 +572,8 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf else: parent_mod = graph_module field = name + if name not in param_cache: + param_cache[name] = new_param setattr(parent_mod, field, new_param) return graph_module