-
Notifications
You must be signed in to change notification settings - Fork 372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Low performance from unnecessary permutations #936
Comments
Hi @jonpryai, thanks for flagging this. It does seem like at least one of the permutes could be redundant. But without a minimal repro, it's hard to determine whether they should be removed and whether we need a pass to handle this case. Do you mind sharing details on how to reproduce this? Thanks! |
I use this to compile
Profiling the above is difficult because of all the compiling and profiling. I locate the test.so in /tmp and copy it to the current dir.
|
Example onnx section: |
For security reasons, I'm unable to download external files. I hope you understand. It'll be easier to reproduce if you can share the model graph that AIT dumps automatically via dump_graph_debug_str_to_file. You'll need to set the following environment variable: Once you do that, could you share the contents of |
|
It does seem like either permute2 or permute4 can be removed here. It'll be easier to remove permute_2 imo. And sorry for the delay, but this is what I believe we need:
Here's some pointers:
Lmk if there's any questions there. |
I am not very familiar with the code, so I could be wrong. But my first impression looking at this is while the optimizer is able to look at different orderings, NHWC and NCHW for the conv2d, for some reason it is married to NCHW for the elementwise, and maybe doesn't take into account the permutation cost. I think that both permute_2 and permute_4 can be removed. There's also 2 copies of permute_4 that yield exactly the same tensor. What is happening here is:
which is the same thing as |
Ah I see, both permutes can definitely be removed in that case. And I'm not sure which pass introduces them in the first place. Do you still have the dumped graphs in your directory? We can see which pass adds the permutes by looking at the {passname}_pseudo_code.txt. |
They are present in everything except toposort_pseudo_code.txt. So bind_constants pass is causing it? |
Actually, that's not true. It's even in the toposort, just the nodes haven't been annotated yet.
Is it possible these nodes are being inserted by fxt2ai? |
It could be fx2ait but it may also be onnx2torch. I'm curious if replicating the model in Pytorch then using fx2ait will give us the same graph. If not, then I assume it's onnx2torch. |
You're right, the permutes are being added in fx2ait. The result from each conv2d is being permuted via ait_nhwc2nchw (here). AIT does that because PyTorch takes channel-first tensors for conv, maxpool, etc., whereas, AIT takes channel-last tensors. A potential workaround is to add a permute after each Conv2D? cc: @chenyang78 |
Is it possible to just make all the elementwise ops also do the permutation, then we will end up with a graph that is like toNCHW -> conv2d -> toNHWC -> toHCHW -> elementWise -> to NHWC then the remove permutations pass will find the redundant permutes |
It sounds like that could work. But would it be possible to try this?
|
@jonpryai hi, have you solved the problem? |
@xmfbit No not really. I am just trying to quickly see what the inference performance of a model would be with AITemplate. I'm wondering if instead of an onnx model, an FX graph may work correctly? Otherwise it may actually be easier to write the code to create an AITemplate model instead of trying to fix fxt2ait. Trying to import a typical dla34 model gives a good example of the issues. |
I'm using fx2ait to load an onnx graph. After optimization, the results are not good.
BS: 1, PT Eager time per iter: 0.01654841552734375ms, PT Eager QPS: 60.43, FX2AIT time per iter: 0.024108586425781252ms, FX2AIT Eager QPS: 41.48, Speedup: 0.69
Let alone compared to tensorRt. I profiled the optimized graph and found:
61.9 11,952,884,520 64,800 184,458.1 123,999.0 18,112 1,017,919 216,853.6 void ::PermuteKernel<(unsigned long)4, (unsigned long)2, int>(::PermuteKernelPara…
Analyzing this in nsys, i see what is happening is that the graph is consistently doing:
permute -> element wise addition -> permute.
These permutations don't do anything because the element wise operator doesn't care about the ordering.
How to fix?
The text was updated successfully, but these errors were encountered: