Skip to content

Commit

Permalink
Add LoRA support to AI Edge Transformers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713026247
  • Loading branch information
hheydary authored and copybara-github committed Jan 7, 2025
1 parent b183411 commit 603e8ea
Show file tree
Hide file tree
Showing 16 changed files with 1,009 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'gemma',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -49,19 +54,24 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)


def main(_):
pytorch_model = gemma1.build_2b_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'gemma2',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -49,19 +54,24 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)


def main(_):
pytorch_model = gemma2.build_2b_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
22 changes: 16 additions & 6 deletions ai_edge_torch/generative/examples/llama/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'llama',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -55,6 +60,11 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)

_BUILDER = {
'1b': llama.build_1b_model,
Expand All @@ -66,13 +76,13 @@ def main(_):
pytorch_model = _BUILDER[_MODEL_SIZE.value](
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
25 changes: 16 additions & 9 deletions ai_edge_torch/generative/examples/openelm/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'openelm',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -49,22 +54,24 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)


def main(_):
pytorch_model = openelm.build_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = (
f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
)

converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
17 changes: 11 additions & 6 deletions ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'paligemma',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
Expand Down Expand Up @@ -73,11 +78,11 @@ def main(_):
version=int(_VERSION.value),
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'

converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
prefill_seq_len=_PREFILL_SEQ_LEN.value,
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
quantize=_QUANTIZE.value,
Expand Down
24 changes: 17 additions & 7 deletions ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'phi3',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -49,19 +54,24 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)


def main(_):
pytorch_model = phi3.build_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
22 changes: 16 additions & 6 deletions ai_edge_torch/generative/examples/phi/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'phi2',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -49,19 +54,24 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)


def main(_):
pytorch_model = phi2.build_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
26 changes: 17 additions & 9 deletions ai_edge_torch/generative/examples/qwen/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'qwen',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -55,6 +60,12 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)


_BUILDER = {
'0.5b': qwen.build_0_5b_model,
Expand All @@ -67,16 +78,13 @@ def main(_):
pytorch_model = _BUILDER[_MODEL_SIZE.value](
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
model_size = _MODEL_SIZE.value.replace('.', '_')
output_filename = (
f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
)
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
Loading

0 comments on commit 603e8ea

Please sign in to comment.