-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify Parallelization Strategy to Make it More General #1988
Modify Parallelization Strategy to Make it More General #1988
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…parallelization_strategy
optimum/fx/parallelization/api.py
Outdated
""" | ||
API for automatic model parallelism through Pytorch FX. | ||
|
||
Args: | ||
model (Union[torch.nn.Module, str]): | ||
Model to parallelize, could either be a module or a model id on the Huggingface Hub. | ||
model (str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model (str): | |
model (`str`): |
optimum/fx/parallelization/decomp.py
Outdated
class DecompTracer(GraphAppendingTracer): | ||
def __init__(self, graph: Graph): | ||
super().__init__(graph) | ||
self.tensor_tracker = WeakTensorKeyDictionary() | ||
self.symnode_tracker = _SymNodeDict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a docstring explaining what it does.
optimum/fx/parallelization/decomp.py
Outdated
that certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific | ||
heuristic based parallelization strategy for them so that we can conveniently replace them into their parallelized counterparts | ||
in the orignal graph module. | ||
|
||
Note that the traced graph is a low-level equivalent representation of the original graph module, and is only used for | ||
parallel axis propagation and analysis, the original graph module is still used for real execution. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you group notes as follows:
Notes:
1. Certain primitive layers ....
2. The traced graph is a low-level equivalent...
optimum/fx/parallelization/decomp.py
Outdated
leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention], | ||
) -> Callable: | ||
""" | ||
API to decompose and funcitonalize a high-level graph module. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API to decompose and funcitonalize a high-level graph module. | |
API to decompose and functionalize a high-level graph module. |
optimum/fx/parallelization/decomp.py
Outdated
graph_module (GraphModule): | ||
The high-level graph module to be decomposed and functionalized. | ||
decomposition_table (Dict[torch._ops.OperatorBase, Callable], defaults to `core_aten_decompostions()`): | ||
The lookup table which maps high-level torch op to their equivalent low-level implementation. | ||
leaf_function_targets (List[Callable], defaults to `[F.scaled_dot_product_attention]`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
graph_module (GraphModule): | |
The high-level graph module to be decomposed and functionalized. | |
decomposition_table (Dict[torch._ops.OperatorBase, Callable], defaults to `core_aten_decompostions()`): | |
The lookup table which maps high-level torch op to their equivalent low-level implementation. | |
leaf_function_targets (List[Callable], defaults to `[F.scaled_dot_product_attention]`): | |
graph_module (`GraphModule`): | |
The high-level graph module to be decomposed and functionalized. | |
decomposition_table (`Dict[torch._ops.OperatorBase, Callable]`, defaults to `core_aten_decompostions()`): | |
The lookup table which maps high-level torch op to their equivalent low-level implementation. | |
leaf_function_targets (`List[Callable]`, defaults to `[F.scaled_dot_product_attention]`): |
class Registry: | ||
def __init__(self) -> None: | ||
self.mapping = {} | ||
|
||
def register(self, op_types): | ||
def wrapper(cls): | ||
if isinstance(op_types, (list, tuple)): | ||
for op_type in op_types: | ||
self.mapping[op_type] = cls | ||
else: | ||
self.mapping[op_types] = cls | ||
return cls | ||
|
||
return wrapper | ||
|
||
def is_supported(self, op_type) -> bool: | ||
return op_type in self.mapping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this registry used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for registration of parallel axis policy of different aten ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a docstring to explain that please?
def propagate(self) -> bool: | ||
arg = self.node.all_input_nodes[0] | ||
axis = self.extract_axis(arg) | ||
return [axis] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit it's not returning a bool.
If I understand properly it returns the axis that is supposed to be parallel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right, it looks up the axis of all the inputs and try inferencing the axis of the output, and will return empty list if no valid axis on which the output can be parallelized
class Registry: | ||
def __init__(self) -> None: | ||
self.mapping = {} | ||
|
||
def register(self, op_types): | ||
def wrapper(cls): | ||
if isinstance(op_types, (list, tuple)): | ||
for op_type in op_types: | ||
self.mapping[op_type] = cls | ||
else: | ||
self.mapping[op_types] = cls | ||
return cls | ||
|
||
return wrapper | ||
|
||
def is_supported(self, op_type) -> bool: | ||
return op_type in self.mapping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a docstring to explain that please?
def search(idx: int): | ||
if idx == len(nodes): | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def search(idx: int): | |
if idx == len(nodes): | |
return True | |
def search(idx: int) -> bool: | |
return idx == len(nodes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the search
here actually is a backtracking search function entailing more logic following, so we can only return True if we have reached the very last op
…parallelization_strategy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Registry class handles registration of parallel axis propagation handlers of different aten ops, to support a new | ||
aten op, you need to register the corresponding handler class by decorating it with `register` function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
Registry class handles registration of parallel axis propagation handlers of different aten ops, to support a new | |
aten op, you need to register the corresponding handler class by decorating it with `register` function. | |
Registry class handles registration of parallel axis propagation handlers of different aten ops. | |
To support a new aten op, you need to register the corresponding handler class by decorating | |
it with the `register` function. |
…parallelization_strategy
merge this for irrelevant failures |
As per title, this PR tries a more general approach rather than relying purely on human heuristics, basically it uses the following steps to search a possible parallelization strategy for a transformer model
And for the API design, we disable the support of passing custom modules and only focus on models in transformers because supporting custom models is not the priority for now.