diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index 35be5b54d40..1490848a6e4 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -136,11 +136,18 @@ def parallelize_model( weight_map[key] = weight_file parallel_ctx.weight_map = weight_map + torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None + if torch_dtype is not None: + dtype_orig = model_cls._set_default_torch_dtype(torch_dtype) + with MetaAwareMethodsPatcher(): model = model_cls(model_config, *model_args, **kwargs) # TODO: remove this once support training-time trace model.eval() + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + move_model_to_device(model, device=parallel_ctx.current_device) initialize_parameter_meta(model) backend = partial(parallelize_backend, ctx=parallel_ctx, config=parallel_config) diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index cb4d6cc2e1f..6546ce622d0 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -493,25 +493,26 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf new_parameters, tied_parameters = [], {} for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)): param_meta: ParameterMeta = getattr(param, "meta") - # skip already initialized parameters - if not param_meta.need_initialize: - continue - # skip already initialized tied parameters + # skip already initialized/loaded tied parameters if param_meta.is_tied and id(param) in tied_parameters: new_parameters.append((name, tied_parameters[id(param)])) continue - shape = [ + shape = ( param.size(dim) // world_size if dim == param_meta.dim and param_meta.is_parallel else param.size(dim) for dim in range(param.ndim) - ] - - new_param = nn.Parameter( - torch.zeros(*shape, dtype=param.dtype, device=ctx.current_device), requires_grad=param.requires_grad ) + if shape == tuple(param.size()) and param.device == ctx.current_device: + new_param = param + else: + new_param = nn.Parameter( + torch.zeros(*shape, dtype=param.dtype, device=ctx.current_device), + requires_grad=param.requires_grad, + ) + + # load weights if possible for source, target in sorted(param_meta.mapping.items()): - # weights loading if target.source in ctx.weight_map: from safetensors import safe_open @@ -530,29 +531,33 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf tensor = torch.empty_like(tensor).copy_(tensor) with torch.no_grad(): new_param.data[source_index].copy_(tensor) - continue - # initialization - if not param_meta.is_parallel or tp_rank == 0: - # initialize weight on master rank - weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu") - init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn - init_fn(weight) - weight = weight.to(ctx.current_device) - else: - weight = None - index = [ - source.to_slice() if dim == param_meta.dim else slice(None, None, None) - for dim in range(param.ndim) - ] - with torch.no_grad(): - if param_meta.is_parallel: - scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim) + # weights initialization + if param_meta.need_initialize: + for source, target in sorted(param_meta.mapping.items()): + if target.source in ctx.weight_map: + continue + if not param_meta.is_parallel or tp_rank == 0: + # initialize weight on master rank + weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu") + init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn + init_fn(weight) + weight = weight.to(ctx.current_device) else: - new_param.data[index].copy_(weight) + weight = None + index = [ + source.to_slice() if dim == param_meta.dim else slice(None, None, None) + for dim in range(param.ndim) + ] + with torch.no_grad(): + if param_meta.is_parallel: + scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim) + else: + new_param.data[index].copy_(weight) setattr(new_param, "meta", param_meta) - new_parameters.append((name, new_param)) + if id(new_param) != id(param): + new_parameters.append((name, new_param)) if param_meta.is_tied: tied_parameters[id(param)] = new_param