Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed Aug 20, 2024
1 parent bf99175 commit 4d9d036
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions optimum/fx/parallelization/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class DecompTracer(GraphAppendingTracer):
See https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py for more details.
"""

def __init__(self, graph: Graph):
super().__init__(graph)
self.tensor_tracker = WeakTensorKeyDictionary()
Expand All @@ -77,8 +78,8 @@ def __init__(self, graph: Graph):
class DecompositionInterpreter(Interpreter):
"""
DecompositionInterpreter takes the high-level graph module, run the iternal nodes following the topo order, and decompose
high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way.
high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way.
Notes:
- Certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific
heuristic based parallelization strategy for them so that we can conveniently replace them into their parallelized counterparts
Expand Down

0 comments on commit 4d9d036

Please sign in to comment.