From 4d9d036d10f896e0a4871514f850839e7061e8c3 Mon Sep 17 00:00:00 2001 From: Longjie Zheng Date: Tue, 20 Aug 2024 21:54:41 +0200 Subject: [PATCH] format --- optimum/fx/parallelization/decomp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optimum/fx/parallelization/decomp.py b/optimum/fx/parallelization/decomp.py index 7ba18f43438..26258d451bf 100644 --- a/optimum/fx/parallelization/decomp.py +++ b/optimum/fx/parallelization/decomp.py @@ -68,6 +68,7 @@ class DecompTracer(GraphAppendingTracer): See https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py for more details. """ + def __init__(self, graph: Graph): super().__init__(graph) self.tensor_tracker = WeakTensorKeyDictionary() @@ -77,8 +78,8 @@ def __init__(self, graph: Graph): class DecompositionInterpreter(Interpreter): """ DecompositionInterpreter takes the high-level graph module, run the iternal nodes following the topo order, and decompose - high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. - + high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. + Notes: - Certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific heuristic based parallelization strategy for them so that we can conveniently replace them into their parallelized counterparts