Skip to content

Commit

Permalink
Add LoRA support to AI Edge Transformers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704425190
  • Loading branch information
hheydary authored and copybara-github committed Dec 20, 2024
1 parent 9d387ec commit 1672654
Show file tree
Hide file tree
Showing 16 changed files with 984 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'gemma',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -49,19 +54,24 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)


def main(_):
pytorch_model = gemma1.build_2b_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'gemma2',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -49,19 +54,24 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)


def main(_):
pytorch_model = gemma2.build_2b_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
22 changes: 16 additions & 6 deletions ai_edge_torch/generative/examples/llama/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'llama',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -55,6 +60,11 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)

_BUILDER = {
'1b': llama.build_1b_model,
Expand All @@ -66,13 +76,13 @@ def main(_):
pytorch_model = _BUILDER[_MODEL_SIZE.value](
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
25 changes: 16 additions & 9 deletions ai_edge_torch/generative/examples/openelm/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'openelm',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -49,22 +54,24 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)


def main(_):
pytorch_model = openelm.build_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = (
f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
)

converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
16 changes: 10 additions & 6 deletions ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'paligemma',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
'prefill_seq_len',
Expand Down Expand Up @@ -65,11 +70,10 @@ def main(_):
pytorch_model = paligemma.build_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LEN.value,
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
quantize=_QUANTIZE.value,
Expand Down
24 changes: 17 additions & 7 deletions ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@

_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'phi3',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -49,19 +54,24 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)


def main(_):
pytorch_model = phi3.build_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
22 changes: 16 additions & 6 deletions ai_edge_torch/generative/examples/phi/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'phi2',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -49,19 +54,24 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)


def main(_):
pytorch_model = phi2.build_model(
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
26 changes: 17 additions & 9 deletions ai_edge_torch/generative/examples/qwen/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,15 @@
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
'The path to the model checkpoint, or directory holding the checkpoint.',
)
_TFLITE_PATH = flags.DEFINE_string(
'tflite_path',
_OUTPUT_PATH = flags.DEFINE_string(
'output_path',
'/tmp/',
'The tflite file path to export.',
'The path to export the tflite model.',
)
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
'output_name_prefix',
'qwen',
'The prefix of the output tflite model name.',
)
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
'prefill_seq_lens',
Expand All @@ -55,6 +60,12 @@
True,
'Whether the model should be quantized.',
)
_LORA_RANKS = flags.DEFINE_multi_integer(
'lora_ranks',
None,
'If set, the model will be converted with the provided list of LoRA ranks.',
)


_BUILDER = {
'0.5b': qwen.build_0_5b_model,
Expand All @@ -67,16 +78,13 @@ def main(_):
pytorch_model = _BUILDER[_MODEL_SIZE.value](
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
model_size = _MODEL_SIZE.value.replace('.', '_')
output_filename = (
f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
)
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
output_path=_OUTPUT_PATH.value,
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
prefill_seq_len=_PREFILL_SEQ_LENS.value,
quantize=_QUANTIZE.value,
lora_ranks=_LORA_RANKS.value,
export_config=ExportConfig(),
)

Expand Down
Loading

0 comments on commit 1672654

Please sign in to comment.