Skip to content

Commit

Permalink
Add BUILD rules and update copy.bara.sky for toy models.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678471108
  • Loading branch information
haozha111 authored and copybara-github committed Sep 25, 2024
1 parent c0f0b63 commit 9ae6590
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 87 deletions.
105 changes: 105 additions & 0 deletions ai_edge_torch/generative/examples/test_models/convert_toy_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# A toy example which has a single-layer transformer block.
from absl import app
import ai_edge_torch
from ai_edge_torch import lowertools
from ai_edge_torch.generative.examples.test_models import toy_model
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import torch

KV_CACHE_MAX_LEN = 100


def convert_toy_model(_) -> None:
"""Converts a toy model to tflite."""
model = toy_model.ToySingleLayerModel(toy_model.get_model_config())
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
print('running an inference')
print(
model.forward(
idx,
input_pos,
)
)

# Convert model to tflite.
print('converting model to tflite')
edge_model = ai_edge_torch.convert(
model,
(
idx,
input_pos,
),
)
edge_model.export('/tmp/toy_model.tflite')


def _export_stablehlo_mlir(model, args):
ep = torch.export.export(model, args)
return lowertools.exported_program_to_mlir_text(ep)


def convert_toy_model_with_kv_cache(_) -> None:
"""Converts a toy model with kv cache to tflite."""
dump_mlir = False

config = toy_model_with_kv_cache.get_model_config()
model = toy_model_with_kv_cache.ToyModelWithKVCache(config)
model.eval()
print('running an inference')
kv = kv_utils.KVCache.from_model_config(config)

tokens, input_pos = toy_model_with_kv_cache.get_sample_prefill_inputs()
decode_token, decode_input_pos = (
toy_model_with_kv_cache.get_sample_decode_inputs()
)
print(model.forward(tokens, input_pos, kv))

if dump_mlir:
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
f.write(mlir_text)

# Convert model to tflite with 2 signatures (prefill + decode).
print('converting toy model to tflite with 2 signatures (prefill + decode)')
edge_model = (
ai_edge_torch.signature(
'prefill',
model,
sample_kwargs={
'tokens': tokens,
'input_pos': input_pos,
'kv_cache': kv,
},
)
.signature(
'decode',
model,
sample_kwargs={
'tokens': decode_token,
'input_pos': decode_input_pos,
'kv_cache': kv,
},
)
.convert()
)
edge_model.export('/tmp/toy_external_kv_cache.tflite')


if __name__ == '__main__':
app.run(convert_toy_model)
33 changes: 2 additions & 31 deletions ai_edge_torch/generative/examples/test_models/toy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
# A toy example which has a single-layer transformer block.
from typing import Tuple

import ai_edge_torch
from ai_edge_torch.generative.layers import builder
from ai_edge_torch.generative.layers.attention import TransformerBlock
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.builder as builder
import ai_edge_torch.generative.layers.model_config as cfg
import torch
import torch.nn as nn
from torch import nn

RoPECache = Tuple[torch.Tensor, torch.Tensor]
KV_CACHE_MAX_LEN = 100
Expand Down Expand Up @@ -149,31 +148,3 @@ def get_model_config() -> cfg.ModelConfig:
final_norm_config=norm_config,
)
return config


def define_and_run() -> None:
model = ToySingleLayerModel(get_model_config())
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
print('running an inference')
print(
model.forward(
idx,
input_pos,
)
)

# Convert model to tflite.
print('converting model to tflite')
edge_model = ai_edge_torch.convert(
model,
(
idx,
input_pos,
),
)
edge_model.export('/tmp/toy_model.tflite')


if __name__ == '__main__':
define_and_run()
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@

from typing import Tuple

import ai_edge_torch
from ai_edge_torch import lowertools
from absl import app
from ai_edge_torch.generative.layers import attention
from ai_edge_torch.generative.layers import builder
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
import torch
import torch.nn as nn
from torch import nn

RoPECache = Tuple[torch.Tensor, torch.Tensor]

Expand Down Expand Up @@ -87,11 +86,6 @@ def forward(
return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}


def _export_stablehlo_mlir(model, args):
ep = torch.export.export(model, args)
return lowertools.exported_program_to_mlir_text(ep)


def get_model_config() -> cfg.ModelConfig:
attn_config = cfg.AttentionConfig(
num_heads=32,
Expand Down Expand Up @@ -133,51 +127,3 @@ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
tokens = torch.tensor([[1]], dtype=torch.int)
input_pos = torch.tensor([10])
return tokens, input_pos


def define_and_run() -> None:
dump_mlir = False

config = get_model_config()
model = ToyModelWithExternalKV(config)
model.eval()
print('running an inference')
kv = kv_utils.KVCache.from_model_config(config)

tokens, input_pos = get_sample_prefill_inputs()
decode_token, decode_input_pos = get_sample_decode_inputs()
print(model.forward(tokens, input_pos, kv))

if dump_mlir:
mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
f.write(mlir_text)

# Convert model to tflite with 2 signatures (prefill + decode).
print('converting toy model to tflite with 2 signatures (prefill + decode)')
edge_model = (
ai_edge_torch.signature(
'prefill',
model,
sample_kwargs={
'tokens': tokens,
'input_pos': input_pos,
'kv_cache': kv,
},
)
.signature(
'decode',
model,
sample_kwargs={
'tokens': decode_token,
'input_pos': decode_input_pos,
'kv_cache': kv,
},
)
.convert()
)
edge_model.export('/tmp/toy_external_kv_cache.tflite')


if __name__ == '__main__':
define_and_run()

0 comments on commit 9ae6590

Please sign in to comment.