Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed Jul 22, 2024
1 parent 2c561d3 commit fc96b6f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 29 deletions.
7 changes: 7 additions & 0 deletions optimum/fx/parallelization/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,18 @@ def parallelize_model(
weight_map[key] = weight_file
parallel_ctx.weight_map = weight_map

torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None
if torch_dtype is not None:
dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)

with MetaAwareMethodsPatcher():
model = model_cls(model_config, *model_args, **kwargs)
# TODO: remove this once support training-time trace
model.eval()

if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

move_model_to_device(model, device=parallel_ctx.current_device)
initialize_parameter_meta(model)
backend = partial(parallelize_backend, ctx=parallel_ctx, config=parallel_config)
Expand Down
63 changes: 34 additions & 29 deletions optimum/fx/parallelization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,25 +493,26 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
new_parameters, tied_parameters = [], {}
for name, param in sorted(graph_module.named_parameters(remove_duplicate=False)):
param_meta: ParameterMeta = getattr(param, "meta")
# skip already initialized parameters
if not param_meta.need_initialize:
continue
# skip already initialized tied parameters
# skip already initialized/loaded tied parameters
if param_meta.is_tied and id(param) in tied_parameters:
new_parameters.append((name, tied_parameters[id(param)]))
continue

shape = [
shape = (
param.size(dim) // world_size if dim == param_meta.dim and param_meta.is_parallel else param.size(dim)
for dim in range(param.ndim)
]

new_param = nn.Parameter(
torch.zeros(*shape, dtype=param.dtype, device=ctx.current_device), requires_grad=param.requires_grad
)

if shape == tuple(param.size()) and param.device == ctx.current_device:
new_param = param
else:
new_param = nn.Parameter(
torch.zeros(*shape, dtype=param.dtype, device=ctx.current_device),
requires_grad=param.requires_grad,
)

# load weights if possible
for source, target in sorted(param_meta.mapping.items()):
# weights loading
if target.source in ctx.weight_map:
from safetensors import safe_open

Expand All @@ -530,29 +531,33 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf
tensor = torch.empty_like(tensor).copy_(tensor)
with torch.no_grad():
new_param.data[source_index].copy_(tensor)
continue

# initialization
if not param_meta.is_parallel or tp_rank == 0:
# initialize weight on master rank
weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu")
init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn
init_fn(weight)
weight = weight.to(ctx.current_device)
else:
weight = None
index = [
source.to_slice() if dim == param_meta.dim else slice(None, None, None)
for dim in range(param.ndim)
]
with torch.no_grad():
if param_meta.is_parallel:
scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim)
# weights initialization
if param_meta.need_initialize:
for source, target in sorted(param_meta.mapping.items()):
if target.source in ctx.weight_map:
continue
if not param_meta.is_parallel or tp_rank == 0:
# initialize weight on master rank
weight = torch.empty(*target.shape, dtype=param.dtype, device="cpu")
init_fn = param_meta.init_fn if param_meta.init_fn else config.weight_init_fn
init_fn(weight)
weight = weight.to(ctx.current_device)
else:
new_param.data[index].copy_(weight)
weight = None
index = [
source.to_slice() if dim == param_meta.dim else slice(None, None, None)
for dim in range(param.ndim)
]
with torch.no_grad():
if param_meta.is_parallel:
scatter(ctx.tp_group, weight, new_param.data[index], scatter_dim=param_meta.dim)
else:
new_param.data[index].copy_(weight)
setattr(new_param, "meta", param_meta)

new_parameters.append((name, new_param))
if id(new_param) != id(param):
new_parameters.append((name, new_param))
if param_meta.is_tied:
tied_parameters[id(param)] = new_param

Expand Down

0 comments on commit fc96b6f

Please sign in to comment.