Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify Parallelization Strategy to Make it More General #1988

Merged
2 changes: 1 addition & 1 deletion .github/workflows/test_fx_automatic_parallel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
config:
- name: GPU-enabled Optimum Test Suite
image: nvidia/cuda:12.4.1-devel-ubuntu22.04
gpu_target: ["nvidia-multi-gpu-l4-runners", "nvidia-multi-gpu-a10-runners"]
gpu_target: ["nvidia-multi-gpu-a10-runners"]

name: ${{ matrix.config.name }}
runs-on:
Expand Down
87 changes: 43 additions & 44 deletions optimum/fx/parallelization/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
import importlib
import os
from functools import partial
from typing import List, Union
from typing import Callable, List

import torch
from torch.fx import GraphModule
from transformers import AutoConfig

from .core import Config, ParallelExecutionCtx
from .passes import build_parallel_pass_pipeline
Expand All @@ -43,30 +44,31 @@ def parallelize_backend(


def parallelize_model(
model: Union[torch.nn.Module, str],
model: str,
parallel_ctx: ParallelExecutionCtx,
*model_args,
**kwargs,
):
) -> Callable:
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
"""
API for automatic model parallelism through Pytorch FX.

Args:
model (Union[torch.nn.Module, str]):
Model to parallelize, could either be a module or a model id on the Huggingface Hub.
parallel_ctx (ParallelExecutionCtx):
model (`str`):
Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights
of the model.
parallel_ctx (`ParallelExecutionCtx`):
Parallel execution context containing process groups the current process belongs to.
*model_args (Any):
*model_args (`Any`):
Additional postional arguments for intializing the model if a model id is passed.
revision (str, defaults to `main`):
revision (`str`, defaults to `main`):
Model revision for weights downloading if a model id is passed.
cache_dir (Optional[str], defaults to `None`):
cache_dir (`Optional[str]`, defaults to `None`):
Cache directory to store downloaded weights. Defaults to None.
local_files_only (bool, defaults to `False`):
local_files_only (`bool`, defaults to `False`):
Whether to use local files only, will avoid downloading from remote if set to `True`.
skip_load_weights (bool, defaults to `False`):
skip_load_weights (`bool`, defaults to `False`):
Whether to skip loading weights from disk to model.
**kwargs (Dict[str, Any]):
**kwargs (`Dict[str, Any]`):
Addtional keyword arguments for overriding fields in parallel config, model config and `Model.__init__`.
"""
revision = kwargs.pop("revision", "main")
Expand All @@ -80,44 +82,41 @@ def parallelize_model(
setattr(parallel_config, k, v)
kwargs.pop(k)

if isinstance(model, str):
from transformers import AutoConfig

is_local = os.path.isdir(model)
if not is_local:
hf_folder = download_model_from_hf(
model_name_or_path=model,
cache_dir=cache_dir,
revision=revision,
local_files_only=local_files_only,
skip_download_weights=skip_load_weights,
)
else:
hf_folder = model

# should be able to load config using only local files
model_config, kwargs = AutoConfig.from_pretrained(
hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs
is_local = os.path.isdir(model)
if not is_local:
hf_folder = download_model_from_hf(
model_name_or_path=model,
cache_dir=cache_dir,
revision=revision,
local_files_only=local_files_only,
skip_download_weights=skip_load_weights,
)
else:
hf_folder = model

# try getting model class info from config
model_arch = model_config.architectures
model_cls = getattr(importlib.import_module("transformers"), model_arch[0])
# should be able to load config using only local files
model_config, kwargs = AutoConfig.from_pretrained(
hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs
)

if not skip_load_weights:
parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder)
# try getting model class info from config
model_arch = model_config.architectures
model_cls = getattr(importlib.import_module("transformers"), model_arch[0])

torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None
if torch_dtype is not None:
dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)
if not skip_load_weights:
parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder)

with MetaAwareMethodsPatcher():
model = model_cls(model_config, *model_args, **kwargs)
# TODO: remove this once support training-time trace
model.eval()
torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None
if torch_dtype is not None:
dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)

if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
with MetaAwareMethodsPatcher():
model = model_cls(model_config, *model_args, **kwargs)
# TODO: remove this once support training-time trace
model.eval()

if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

move_model_to_device(model, device=parallel_ctx.current_device)
initialize_parameter_meta(model)
Expand Down
5 changes: 5 additions & 0 deletions optimum/fx/parallelization/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,13 @@ class Config:
- weight_init_fn (`Callable`, defaults to `partial(nn.init.normal_, std=0.02)`)
Initialization function of weights in `nn.Linear` and `nn.Embedding` layers,
if not provided weights loading path.

- enable_sequence_parallel (`bool`, defaults to `False`):
Whether to enable Megatron-style sequence parallelism in searching parallelization
strategies.
"""

lint_and_recompile: bool = True
clean_markers_after_all_passes: bool = True
weight_init_fn: Callable = partial(nn.init.normal_, std=0.02)
enable_sequence_parallel: bool = False
225 changes: 225 additions & 0 deletions optimum/fx/parallelization/decomp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from typing import Callable, Dict, List

import torch
import torch.nn.functional as F
import torch.utils._pytree as pytree
from torch import SymBool, SymFloat, SymInt
from torch._decomp import core_aten_decompositions
from torch._functorch._aot_autograd.functional_utils import from_fun, to_fun
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode, disable_functional_mode
from torch.fx import Graph, GraphModule, Interpreter, Proxy, traceback
from torch.fx.experimental.proxy_tensor import (
ProxyTorchDispatchMode,
_ProxyTensor,
_SymNodeDict,
decompose,
disable_proxy_modes_tracing,
fetch_object_proxy,
fetch_sym_proxy,
get_proxy_slot,
track_tensor_tree,
)
from torch.fx.proxy import GraphAppendingTracer
from torch.utils.weak import WeakTensorKeyDictionary


def is_leaf_module(m):
return (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) and not isinstance(
m, torch.nn.Sequential
)


@contextlib.contextmanager
def trace_decomp_origin():
creat_node = Graph.create_node

def create_node_(*args, **kwargs):
node = creat_node(*args, **kwargs)
node.meta["traced_from"] = traceback.get_current_meta()["from_node"]
return node

try:
Graph.create_node = create_node_
yield
finally:
Graph.create_node = creat_node
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved


class DecompTracer(GraphAppendingTracer):
"""
DecompTracer is a tracer class which works together with `DecompositionInterpreter`, it keeps track of tensors and their
corresponding proxy objects during execution process. When invoked with `create_proxy`, it creates a node in the containing
graph and associate the output tensor of the node with the created proxy.

See https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py for more details.
"""

def __init__(self, graph: Graph):
super().__init__(graph)
self.tensor_tracker = WeakTensorKeyDictionary()
self.symnode_tracker = _SymNodeDict()


class DecompositionInterpreter(Interpreter):
"""
DecompositionInterpreter takes the high-level graph module, run the iternal nodes following the topo order, and decompose
high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way.

Notes:
- Certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific
heuristic based parallelization strategy for them so that we can conveniently replace them into their parallelized counterparts
in the orignal graph module.

- The traced graph is a low-level equivalent representation of the original graph module, and is only used for
parallel axis propagation and analysis, the original graph module is still used for real execution.
"""

def __init__(
self, module: GraphModule, new_graph: Graph, decomposition_table=None, leaf_function_targets=None, **kwargs
):
super().__init__(module, **kwargs)
self.new_graph = new_graph
self.tracer = DecompTracer(new_graph)

self.decomposition_table = decomposition_table
if self.decomposition_table is None:
self.decomposition_table = {}

self.leaf_function_targets = leaf_function_targets
if self.leaf_function_targets is None:
self.leaf_function_targets = []

self.fun_mode = FunctionalTensorMode()
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")

def placeholder(self, target, args, kwargs):
out = super().placeholder(target, args, kwargs)
out = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), out)
proxy = self.tracer.create_proxy("placeholder", target, args, kwargs)

with disable_proxy_modes_tracing():
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)

out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out)
return out

def call_function(self, target, args, kwargs):
if target in self.leaf_function_targets:
args = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), args)
kwargs = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), kwargs)

with disable_proxy_modes_tracing(), disable_functional_mode():
out = target(*args, **kwargs)

args, kwargs = pytree.tree_map_only((torch.Tensor,), fetch_object_proxy(self.tracer), (args, kwargs))
proxy_args, proxy_kwargs = pytree.tree_map_only(
(SymInt, SymFloat, SymBool),
fetch_sym_proxy(self.tracer),
pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (args, kwargs)),
)
proxy = self.tracer.create_proxy("call_function", target, proxy_args, proxy_kwargs)

with disable_proxy_modes_tracing():
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)

out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out)
return out

return super().call_function(target, args, kwargs)

def call_module(self, target, args, kwargs):
assert isinstance(target, str)
submod = self.fetch_attr(target)
if not is_leaf_module(submod):
return super().call_module(target, args, kwargs)

args = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), args)
kwargs = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), kwargs)

with disable_proxy_modes_tracing(), disable_functional_mode():
out = submod(*args, **kwargs)

args, kwargs = pytree.tree_map_only((torch.Tensor,), fetch_object_proxy(self.tracer), (args, kwargs))
proxy_args, proxy_kwargs = pytree.tree_map_only(
(SymInt, SymFloat, SymBool),
fetch_sym_proxy(self.tracer),
pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (args, kwargs)),
)
proxy = self.tracer.create_proxy("call_module", target, proxy_args, proxy_kwargs)

with disable_proxy_modes_tracing():
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)

out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out)
return out

def get_attr(self, target, args, kwargs):
out = super().get_attr(target, args, kwargs)
proxy = Proxy(self.new_graph.get_attr(target), self.tracer)
with disable_proxy_modes_tracing():
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
return out

def output(self, target, args, kwargs):
args = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), args)
kwargs = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), kwargs)
out = super().output(target, args, kwargs)

def unwrap(e):
return get_proxy_slot(e, self.tracer, e, lambda x: x.proxy.node)

self.new_graph.output(pytree.tree_map(unwrap, out))
return out

def run(self, *args, **kwargs):
with self.fun_mode:
args = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), args)
kwargs = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), kwargs)
with traceback.preserve_node_meta(), trace_decomp_origin(), decompose(self.decomposition_table), self.mode:
return super().run(*args, **kwargs)


def decompose_and_functionalize(
graph_module: GraphModule,
decomposition_table: Dict[torch._ops.OperatorBase, Callable] = core_aten_decompositions(),
leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention],
) -> Callable:
"""
API to decompose and functionalize a high-level graph module.

Args:
graph_module (`GraphModule`):
The high-level graph module to be decomposed and functionalized.
decomposition_table (`Dict[torch._ops.OperatorBase, Callable]`, defaults to `core_aten_decompostions()`):
The lookup table which maps high-level torch op to their equivalent low-level implementation.
leaf_function_targets (`List[Callable]`, defaults to `[F.scaled_dot_product_attention]`):
Functions which will not be traced through for convenience, `F.scaled_dot_product_attention` is
treated as a leaf function by default so that we don't have to deal with all detailed version of
sdpas in the traced graph.

Returns:
Callable: a wrapper which returns the traced low-level graph when called with concrete arguments.
"""
new_graph = Graph(owning_module=graph_module)
interp = DecompositionInterpreter(graph_module, new_graph, decomposition_table, leaf_function_targets)

def wrapper(*args, **kwargs):
interp.run(*args, **kwargs)
return new_graph

return wrapper
15 changes: 15 additions & 0 deletions optimum/fx/parallelization/op_registry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .op_handlers import REGISTRY, FallbackParallelAxisPropagateHandler
Loading
Loading