Skip to content

Commit

Permalink
Update documentation to reflect use of odml torch.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 706025583
  • Loading branch information
haozha111 authored and copybara-github committed Dec 13, 2024
1 parent b19fa25 commit cd66b8f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions ai_edge_torch/generative/doc/system_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ The user journey of the Edge Generative API begins from the authoring library. T
During our initial investigation, we found that although there are many open-source PyTorch implementation of common LLM models (e.g. Llama), it's pretty difficult to convert those models from source program to the TF Lite format, and achieve good performance. Those difficulties might include:

1) **Not able to be exported by Torch Dynamo**. The PyTorch 2.0 compiler uses Torch Dynamo under-the-hood to perform graph capturing (tracing a PyTorch `nn.Module` and convert it to a full graph representation with Aten operations). For an arbitrary PyTorch model, dynamo may fail during full graph export. Re-writing the model to be dynamo exportable usually takes some non-trivial work and may need a deep understanding of the compiler stack.
2) **Not able to be converted by AI edge torch**. AI edge torch converter utilizes TorchXLA, StableHLO and TF Lite converter to convert the Aten graph to TF Lite model format. During this process, it may fail if the source FX graph has certain graph patterns / ops that are not supported by the converter. As the conversion will go through multiple stages, it's very difficult to triage and fix the conversion issue (either fixing the model itself or converter stack) for most users with little compiler knowledge.
2) **Not able to be converted by AI edge torch**. AI edge torch converter utilizes ODML Torch, StableHLO and TF Lite converter to convert the Aten graph to TF Lite model format. During this process, it may fail if the source FX graph has certain graph patterns / ops that are not supported by the converter. As the conversion will go through multiple stages, it's very difficult to triage and fix the conversion issue (either fixing the model itself or converter stack) for most users with little compiler knowledge.
3) **Difficult to achieve good OOB performance**. Even you may get lucky and successfully make it through the TF Lite model format, there is also no guarantee that the model will run performantly on device, and be able to leverage the on-device accelerators such as XNNPack, GPU or NPUs. For example, performance critical layers such as KV Cache and SDPA need to be handlded specifically in the converter and runtime to ensure it can be accelerated.
4) **Applying quantization consistently is difficult**. It's difficult to specify quantization receipes in a consistent way for an arbitrary PyTorch generative AI model. For ODML, we need to carefully co-design our custom building blocks and AI Edge quantizer to ensure they work smoothly together, and is easy to configure with custom quantization receipes without noticiable model quality drop.
5) **Many OSS models are not designed / optimized for mobile execution**. For example, those implementations may contain distributed training logic, CUDA dependencies, and tensor parallel optimizations for fast training. Models usually need to be rewritten to remove those pieces before exporting to mobile.
Expand Down Expand Up @@ -83,7 +83,7 @@ TODO: fill in this part.

With AI edge torch generative API, it adopts the same PyTorch to TF Lite conversion flow as traditional models. On a high level, it involves the following steps:
1) PyTorch compiler (Dynamo). We leverage [Torch Dynamo](https://pytorch.org/docs/stable/torch.compiler_deepdive.html) which is the official graph compiler for PyTorch 2.0. It performs graph capturing and exports the PyTorch `nn.Module` to an FX graph with Aten operations.
2) TorchXLA. For the moment, we are using [Torch/XLA](https://github.com/pytorch/xla) to compile the FX graph to Stable HLO graph. The converted graph will be stored as a Tensorflow SavedModel.
2) ODML torch. For the moment, we are using [ODML torch](https://github.com/google-ai-edge/ai-edge-torch/tree/main/ai_edge_torch/odml_torch) to compile the FX graph to Stable HLO graph. The converted graph will be stored as a Tensorflow SavedModel.
3) TF Lite MLIR converter. The TF Lite converter will consume the SavedModel with StableHLO ops, and further lower it down to TF Lite ops.

To convert a several billion parameters Generative AI model, it usually takes a lot of CPU and RAM resources on your machine, as the converter does all kinds of graph optimizations and transformations to ensure the converted model is highly optimized. Going forward, we will continue to improve the converter infrastructure to reduce its system overhead.
Expand All @@ -100,7 +100,7 @@ For code examples, please refer to [this](https://github.com/google-ai-edge/ai-e

### Composite op lowering via high-level function boundary

To address the performance issues of LLM, we identified that the model's forward pass is usually bottlenecked by a few key operations, such as KV cache or scaled dot product attention. As the converter traces the whole model and performs graph lowering through Aten/CoreAten/StableHLO etc, there are chances that certain op groups are not properly fused together, resulting in bad runtime performance. For example, we leverage PyTorch's [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) to implement attention computation, however it will be converted to less efficient form when converting directly (which leads to unncessary FLOPs). The core problem here is that the converter doesn't know the boundary of a `SDPA` op so it can't apply any optimizations on it. We use TorchXLA's high-level function boundary API to mark the boundary of the `SDPA` operation, and then lower it to TF Lite custom op. For example, see how HLFB is applied inside the `SDPA` python implementation:
To address the performance issues of LLM, we identified that the model's forward pass is usually bottlenecked by a few key operations, such as KV cache or scaled dot product attention. As the converter traces the whole model and performs graph lowering through Aten/CoreAten/StableHLO etc, there are chances that certain op groups are not properly fused together, resulting in bad runtime performance. For example, we leverage PyTorch's [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) to implement attention computation, however it will be converted to less efficient form when converting directly (which leads to unncessary FLOPs). The core problem here is that the converter doesn't know the boundary of a `SDPA` op so it can't apply any optimizations on it. We use ODML torch's high-level function boundary API to mark the boundary of the `SDPA` operation, and then lower it to TF Lite custom op. For example, see how HLFB is applied inside the `SDPA` python implementation:
https://github.com/google-ai-edge/ai-edge-torch/blob/9b06abe7c645f1d784804eb7f63f93458f358ba1/ai_edge_torch/generative/layers/scaled_dot_product_attention.py#L69-L117

As a result, everything in between the `builder.mark_inputs` and `builder.mark_outputs` call will be wrapped inside a StableHLO composite op named `odml.scaled_dot_product_attention`, and map to TF Lite's [SDPA custom op](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/genai/sdpa.cc) by the converter. During runtime, the delegate will match and replace the `odml.scaled_dot_product_attention` to the most efficient kernels on that delegate backend.
Expand Down

0 comments on commit cd66b8f

Please sign in to comment.