diff --git a/ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py b/ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py index a4196f64..d45c5c31 100644 --- a/ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +++ b/ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py @@ -29,10 +29,15 @@ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'), 'The path to the model checkpoint, or directory holding the checkpoint.', ) -_TFLITE_PATH = flags.DEFINE_string( - 'tflite_path', +_OUTPUT_PATH = flags.DEFINE_string( + 'output_path', '/tmp/', - 'The tflite file path to export.', + 'The path to export the tflite model.', +) +_OUTPUT_NAME_PREFIX = flags.DEFINE_string( + 'output_name_prefix', + 'gemma', + 'The prefix of the output tflite model name.', ) _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer( 'prefill_seq_lens', @@ -49,19 +54,24 @@ True, 'Whether the model should be quantized.', ) +_LORA_RANKS = flags.DEFINE_multi_integer( + 'lora_ranks', + None, + 'If set, the model will be converted with the provided list of LoRA ranks.', +) def main(_): pytorch_model = gemma1.build_2b_model( _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value ) - quant_suffix = 'q8' if _QUANTIZE.value else 'f32' - output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' converter.convert_to_tflite( pytorch_model, - tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), + output_path=_OUTPUT_PATH.value, + output_name_prefix=_OUTPUT_NAME_PREFIX.value, prefill_seq_len=_PREFILL_SEQ_LENS.value, quantize=_QUANTIZE.value, + lora_ranks=_LORA_RANKS.value, export_config=ExportConfig(), ) diff --git a/ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py b/ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py index 6cfa0ae5..c0437a62 100644 --- a/ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +++ b/ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py @@ -29,10 +29,15 @@ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'), 'The path to the model checkpoint, or directory holding the checkpoint.', ) -_TFLITE_PATH = flags.DEFINE_string( - 'tflite_path', +_OUTPUT_PATH = flags.DEFINE_string( + 'output_path', '/tmp/', - 'The tflite file path to export.', + 'The path to export the tflite model.', +) +_OUTPUT_NAME_PREFIX = flags.DEFINE_string( + 'output_name_prefix', + 'gemma2', + 'The prefix of the output tflite model name.', ) _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer( 'prefill_seq_lens', @@ -49,19 +54,24 @@ True, 'Whether the model should be quantized.', ) +_LORA_RANKS = flags.DEFINE_multi_integer( + 'lora_ranks', + None, + 'If set, the model will be converted with the provided list of LoRA ranks.', +) def main(_): pytorch_model = gemma2.build_2b_model( _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value ) - quant_suffix = 'q8' if _QUANTIZE.value else 'f32' - output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' converter.convert_to_tflite( pytorch_model, - tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), + output_path=_OUTPUT_PATH.value, + output_name_prefix=_OUTPUT_NAME_PREFIX.value, prefill_seq_len=_PREFILL_SEQ_LENS.value, quantize=_QUANTIZE.value, + lora_ranks=_LORA_RANKS.value, export_config=ExportConfig(), ) diff --git a/ai_edge_torch/generative/examples/llama/convert_to_tflite.py b/ai_edge_torch/generative/examples/llama/convert_to_tflite.py index d4a7bc92..add9e4ee 100644 --- a/ai_edge_torch/generative/examples/llama/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/llama/convert_to_tflite.py @@ -35,10 +35,15 @@ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'), 'The path to the model checkpoint, or directory holding the checkpoint.', ) -_TFLITE_PATH = flags.DEFINE_string( - 'tflite_path', +_OUTPUT_PATH = flags.DEFINE_string( + 'output_path', '/tmp/', - 'The tflite file path to export.', + 'The path to export the tflite model.', +) +_OUTPUT_NAME_PREFIX = flags.DEFINE_string( + 'output_name_prefix', + 'llama', + 'The prefix of the output tflite model name.', ) _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer( 'prefill_seq_lens', @@ -55,6 +60,11 @@ True, 'Whether the model should be quantized.', ) +_LORA_RANKS = flags.DEFINE_multi_integer( + 'lora_ranks', + None, + 'If set, the model will be converted with the provided list of LoRA ranks.', +) _BUILDER = { '1b': llama.build_1b_model, @@ -66,13 +76,13 @@ def main(_): pytorch_model = _BUILDER[_MODEL_SIZE.value]( _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value ) - quant_suffix = 'q8' if _QUANTIZE.value else 'f32' - output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' converter.convert_to_tflite( pytorch_model, - tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), + output_path=_OUTPUT_PATH.value, + output_name_prefix=_OUTPUT_NAME_PREFIX.value, prefill_seq_len=_PREFILL_SEQ_LENS.value, quantize=_QUANTIZE.value, + lora_ranks=_LORA_RANKS.value, export_config=ExportConfig(), ) diff --git a/ai_edge_torch/generative/examples/openelm/convert_to_tflite.py b/ai_edge_torch/generative/examples/openelm/convert_to_tflite.py index 20e7523c..c4f779bd 100644 --- a/ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/openelm/convert_to_tflite.py @@ -29,10 +29,15 @@ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'), 'The path to the model checkpoint, or directory holding the checkpoint.', ) -_TFLITE_PATH = flags.DEFINE_string( - 'tflite_path', +_OUTPUT_PATH = flags.DEFINE_string( + 'output_path', '/tmp/', - 'The tflite file path to export.', + 'The path to export the tflite model.', +) +_OUTPUT_NAME_PREFIX = flags.DEFINE_string( + 'output_name_prefix', + 'openelm', + 'The prefix of the output tflite model name.', ) _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer( 'prefill_seq_lens', @@ -49,22 +54,24 @@ True, 'Whether the model should be quantized.', ) +_LORA_RANKS = flags.DEFINE_multi_integer( + 'lora_ranks', + None, + 'If set, the model will be converted with the provided list of LoRA ranks.', +) def main(_): pytorch_model = openelm.build_model( _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value ) - quant_suffix = 'q8' if _QUANTIZE.value else 'f32' - output_filename = ( - f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' - ) - converter.convert_to_tflite( pytorch_model, - tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), + output_path=_OUTPUT_PATH.value, + output_name_prefix=_OUTPUT_NAME_PREFIX.value, prefill_seq_len=_PREFILL_SEQ_LENS.value, quantize=_QUANTIZE.value, + lora_ranks=_LORA_RANKS.value, export_config=ExportConfig(), ) diff --git a/ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py b/ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py index c60ed743..4414c5b3 100644 --- a/ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py @@ -34,10 +34,15 @@ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'), 'The path to the model checkpoint, or directory holding the checkpoint.', ) -_TFLITE_PATH = flags.DEFINE_string( - 'tflite_path', +_OUTPUT_PATH = flags.DEFINE_string( + 'output_path', '/tmp/', - 'The tflite file path to export.', + 'The path to export the tflite model.', +) +_OUTPUT_NAME_PREFIX = flags.DEFINE_string( + 'output_name_prefix', + 'paligemma', + 'The prefix of the output tflite model name.', ) _PREFILL_SEQ_LEN = flags.DEFINE_integer( 'prefill_seq_len', @@ -65,11 +70,10 @@ def main(_): pytorch_model = paligemma.build_model( _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value ) - quant_suffix = 'q8' if _QUANTIZE.value else 'f32' - output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' converter.convert_to_tflite( pytorch_model, - tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), + output_path=_OUTPUT_PATH.value, + output_name_prefix=_OUTPUT_NAME_PREFIX.value, prefill_seq_len=_PREFILL_SEQ_LEN.value, pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value), quantize=_QUANTIZE.value, diff --git a/ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py b/ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py index 797f06b0..7d7f88b7 100644 --- a/ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +++ b/ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py @@ -26,13 +26,18 @@ _CHECKPOINT_PATH = flags.DEFINE_string( 'checkpoint_path', - os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'), + os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'), 'The path to the model checkpoint, or directory holding the checkpoint.', ) -_TFLITE_PATH = flags.DEFINE_string( - 'tflite_path', +_OUTPUT_PATH = flags.DEFINE_string( + 'output_path', '/tmp/', - 'The tflite file path to export.', + 'The path to export the tflite model.', +) +_OUTPUT_NAME_PREFIX = flags.DEFINE_string( + 'output_name_prefix', + 'phi3', + 'The prefix of the output tflite model name.', ) _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer( 'prefill_seq_lens', @@ -49,19 +54,24 @@ True, 'Whether the model should be quantized.', ) +_LORA_RANKS = flags.DEFINE_multi_integer( + 'lora_ranks', + None, + 'If set, the model will be converted with the provided list of LoRA ranks.', +) def main(_): pytorch_model = phi3.build_model( _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value ) - quant_suffix = 'q8' if _QUANTIZE.value else 'f32' - output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' converter.convert_to_tflite( pytorch_model, - tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), + output_path=_OUTPUT_PATH.value, + output_name_prefix=_OUTPUT_NAME_PREFIX.value, prefill_seq_len=_PREFILL_SEQ_LENS.value, quantize=_QUANTIZE.value, + lora_ranks=_LORA_RANKS.value, export_config=ExportConfig(), ) diff --git a/ai_edge_torch/generative/examples/phi/convert_to_tflite.py b/ai_edge_torch/generative/examples/phi/convert_to_tflite.py index 5381203f..01121f4e 100644 --- a/ai_edge_torch/generative/examples/phi/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/phi/convert_to_tflite.py @@ -29,10 +29,15 @@ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'), 'The path to the model checkpoint, or directory holding the checkpoint.', ) -_TFLITE_PATH = flags.DEFINE_string( - 'tflite_path', +_OUTPUT_PATH = flags.DEFINE_string( + 'output_path', '/tmp/', - 'The tflite file path to export.', + 'The path to export the tflite model.', +) +_OUTPUT_NAME_PREFIX = flags.DEFINE_string( + 'output_name_prefix', + 'phi2', + 'The prefix of the output tflite model name.', ) _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer( 'prefill_seq_lens', @@ -49,19 +54,24 @@ True, 'Whether the model should be quantized.', ) +_LORA_RANKS = flags.DEFINE_multi_integer( + 'lora_ranks', + None, + 'If set, the model will be converted with the provided list of LoRA ranks.', +) def main(_): pytorch_model = phi2.build_model( _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value ) - quant_suffix = 'q8' if _QUANTIZE.value else 'f32' - output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' converter.convert_to_tflite( pytorch_model, - tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), + output_path=_OUTPUT_PATH.value, + output_name_prefix=_OUTPUT_NAME_PREFIX.value, prefill_seq_len=_PREFILL_SEQ_LENS.value, quantize=_QUANTIZE.value, + lora_ranks=_LORA_RANKS.value, export_config=ExportConfig(), ) diff --git a/ai_edge_torch/generative/examples/qwen/convert_to_tflite.py b/ai_edge_torch/generative/examples/qwen/convert_to_tflite.py index 893ab6a3..601b3f7b 100644 --- a/ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/qwen/convert_to_tflite.py @@ -35,10 +35,15 @@ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'), 'The path to the model checkpoint, or directory holding the checkpoint.', ) -_TFLITE_PATH = flags.DEFINE_string( - 'tflite_path', +_OUTPUT_PATH = flags.DEFINE_string( + 'output_path', '/tmp/', - 'The tflite file path to export.', + 'The path to export the tflite model.', +) +_OUTPUT_NAME_PREFIX = flags.DEFINE_string( + 'output_name_prefix', + 'qwen', + 'The prefix of the output tflite model name.', ) _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer( 'prefill_seq_lens', @@ -55,6 +60,12 @@ True, 'Whether the model should be quantized.', ) +_LORA_RANKS = flags.DEFINE_multi_integer( + 'lora_ranks', + None, + 'If set, the model will be converted with the provided list of LoRA ranks.', +) + _BUILDER = { '0.5b': qwen.build_0_5b_model, @@ -67,16 +78,13 @@ def main(_): pytorch_model = _BUILDER[_MODEL_SIZE.value]( _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value ) - quant_suffix = 'q8' if _QUANTIZE.value else 'f32' - model_size = _MODEL_SIZE.value.replace('.', '_') - output_filename = ( - f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' - ) converter.convert_to_tflite( pytorch_model, - tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), + output_path=_OUTPUT_PATH.value, + output_name_prefix=_OUTPUT_NAME_PREFIX.value, prefill_seq_len=_PREFILL_SEQ_LENS.value, quantize=_QUANTIZE.value, + lora_ranks=_LORA_RANKS.value, export_config=ExportConfig(), ) diff --git a/ai_edge_torch/generative/examples/smollm/convert_to_tflite.py b/ai_edge_torch/generative/examples/smollm/convert_to_tflite.py index 7a1d6e2f..b664e261 100644 --- a/ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/smollm/convert_to_tflite.py @@ -29,10 +29,15 @@ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'), 'The path to the model checkpoint, or directory holding the checkpoint.', ) -_TFLITE_PATH = flags.DEFINE_string( - 'tflite_path', +_OUTPUT_PATH = flags.DEFINE_string( + 'output_path', '/tmp/', - 'The tflite file path to export.', + 'The path to export the tflite model.', +) +_OUTPUT_NAME_PREFIX = flags.DEFINE_string( + 'output_name_prefix', + 'smollm', + 'The prefix of the output tflite model name.', ) _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer( 'prefill_seq_lens', @@ -49,20 +54,24 @@ True, 'Whether the model should be quantized.', ) +_LORA_RANKS = flags.DEFINE_multi_integer( + 'lora_ranks', + None, + 'If set, the model will be converted with the provided list of LoRA ranks.', +) def main(_): pytorch_model = smollm.build_model( _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value ) - - quant_suffix = 'q8' if _QUANTIZE.value else 'f32' - output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' converter.convert_to_tflite( pytorch_model, - tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), + output_path=_OUTPUT_PATH.value, + output_name_prefix=_OUTPUT_NAME_PREFIX.value, prefill_seq_len=_PREFILL_SEQ_LENS.value, quantize=_QUANTIZE.value, + lora_ranks=_LORA_RANKS.value, export_config=ExportConfig(), ) diff --git a/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py b/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py index ebc8e18a..eadfa960 100644 --- a/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py @@ -29,10 +29,15 @@ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'), 'The path to the model checkpoint, or directory holding the checkpoint.', ) -_TFLITE_PATH = flags.DEFINE_string( - 'tflite_path', +_OUTPUT_PATH = flags.DEFINE_string( + 'output_path', '/tmp/', - 'The tflite file path to export.', + 'The path to export the tflite model.', +) +_OUTPUT_NAME_PREFIX = flags.DEFINE_string( + 'output_name_prefix', + 'tinyllama', + 'The prefix of the output tflite model name.', ) _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer( 'prefill_seq_lens', @@ -49,21 +54,24 @@ True, 'Whether the model should be quantized.', ) +_LORA_RANKS = flags.DEFINE_multi_integer( + 'lora_ranks', + None, + 'If set, the model will be converted with the provided list of LoRA ranks.', +) def main(_): pytorch_model = tiny_llama.build_model( _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value ) - quant_suffix = 'q8' if _QUANTIZE.value else 'f32' - output_filename = ( - f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' - ) converter.convert_to_tflite( pytorch_model, - tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), + output_path=_OUTPUT_PATH.value, + output_name_prefix=_OUTPUT_NAME_PREFIX.value, prefill_seq_len=_PREFILL_SEQ_LENS.value, quantize=_QUANTIZE.value, + lora_ranks=_LORA_RANKS.value, export_config=ExportConfig(), ) diff --git a/ai_edge_torch/generative/layers/attention.py b/ai_edge_torch/generative/layers/attention.py index b52e04b4..7da8efc2 100644 --- a/ai_edge_torch/generative/layers/attention.py +++ b/ai_edge_torch/generative/layers/attention.py @@ -19,6 +19,7 @@ from ai_edge_torch.generative.layers import builder from ai_edge_torch.generative.layers import kv_cache as kv_utils +from ai_edge_torch.generative.layers import lora as lora_utils from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa import ai_edge_torch.generative.layers.model_config as cfg import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb @@ -66,6 +67,7 @@ def forward( mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, kv_cache: kv_utils.KVCacheEntry = None, + lora: Optional[lora_utils.LoRAEntry] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]: """Forward function of the TransformerBlock. @@ -75,6 +77,7 @@ def forward( mask (torch.Tensor): the optional mask tensor. input_pos (torch.Tensor): the optional input position tensor. kv_cache (KVCacheEntry): the optional kv cache entry. + lora (LoRAEntry): the optional lora entry. Returns: output activation from this transformer block, and updated kv cache (if @@ -83,7 +86,9 @@ def forward( kv = None if self.config.parallel_residual: x_norm = self.pre_atten_norm(x) - atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache) + atten_func_out = self.atten_func( + x_norm, rope, mask, input_pos, kv_cache, lora + ) if kv_cache is None: attn_out = atten_func_out else: @@ -92,7 +97,9 @@ def forward( output = x + attn_out + ff_out else: x_norm = self.pre_atten_norm(x) - atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache) + atten_func_out = self.atten_func( + x_norm, rope, mask, input_pos, kv_cache, lora + ) if kv_cache is None: attn_out = atten_func_out else: @@ -152,6 +159,7 @@ def forward( mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, kv_cache: Optional[kv_utils.KVCacheEntry] = None, + lora: Optional[lora_utils.LoRAEntry] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]: """Forward function of the CausalSelfAttention layer, which can support @@ -163,6 +171,7 @@ def forward( mask (torch.Tensor): the optional mask tensor. input_pos (torch.Tensor): the optional input position tensor. kv_cache (KVCacheEntry): The KV cache entry corresponding to this module. + lora (LoRAEntry): the optional lora entry. Returns: output activation from this self attention layer, and the updated @@ -201,6 +210,11 @@ def forward( dim=-1, ) + if lora is not None: + q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape) + k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape) + v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape) + q = self.query_norm(q) k = self.key_norm(k) @@ -218,7 +232,7 @@ def forward( kv_cache = kv_utils.update(kv_cache, input_pos, k, v) k, v = kv_cache.k_cache, kv_cache.v_cache - y = self.sdpa_func( + sdpa_out = self.sdpa_func( q, k, v, @@ -226,10 +240,13 @@ def forward( mask=mask, softcap=self.config.logit_softcap, ) - y = y.reshape(B, T, -1) + sdpa_out = sdpa_out.reshape(B, T, -1) # Compute the output projection. - y = self.output_projection(y) + y = self.output_projection(sdpa_out) + if lora is not None: + y += lora_utils.apply_lora(sdpa_out, lora.attention.output) + return y if kv_cache is None else (y, kv_cache) diff --git a/ai_edge_torch/generative/layers/lora.py b/ai_edge_torch/generative/layers/lora.py new file mode 100644 index 00000000..14c6e542 --- /dev/null +++ b/ai_edge_torch/generative/layers/lora.py @@ -0,0 +1,551 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""LoRA weights for generative models.""" + +import dataclasses +from typing import Any, Callable, List, Optional, Tuple + +from ai_edge_torch.generative.layers import model_config +import flatbuffers +import numpy as np +import safetensors +import torch +import torch.utils._pytree as pytree + +from tensorflow.lite.python import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import + +TFLITE_SCHEMA_VERSION = 3 + + +@dataclasses.dataclass +class LoRAWeight: + """LoRA weight per projection. The weights are pre-transposed.""" + + a_prime: torch.Tensor + b_prime: torch.Tensor + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, LoRAWeight): + return False + if self.a_prime.shape != other.a_prime.shape: + return False + if self.b_prime.shape != other.b_prime.shape: + return False + return torch.allclose(self.a_prime, other.a_prime) and torch.allclose( + self.b_prime, other.b_prime + ) + + +@dataclasses.dataclass +class AttentionLoRA: + """LoRA weights for attention module.""" + + query: LoRAWeight + key: LoRAWeight + value: LoRAWeight + output: LoRAWeight + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, AttentionLoRA): + return False + return ( + self.query == other.query + and self.key == other.key + and self.value == other.value + and self.output == other.output + ) + + +@dataclasses.dataclass +class LoRAEntry: + """LoRA weights for a single layer.""" + + attention: AttentionLoRA + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, LoRAEntry): + return False + return self.attention == other.attention + + +@dataclasses.dataclass +class LoRATensorNames: + """Tensor names for LoRA weights.""" + + attn_query_w_a: str + attn_query_w_b: str + + attn_key_w_a: str + attn_key_w_b: str + + attn_value_w_a: str + attn_value_w_b: str + + attn_output_w_a: str + attn_output_w_b: str + + +@dataclasses.dataclass +class LoRA: + """LoRA weights for all modules.""" + + adapters: Tuple[LoRAEntry, ...] + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, LoRA): + return False + if len(self.adapters) != len(other.adapters): + return False + return all( + adapter == other_adapter + for adapter, other_adapter in zip(self.adapters, other.adapters) + ) + + def get_rank(self) -> int: + """Returns the rank of the LoRA weights.""" + return self.adapters[0].attention.query.a_prime.shape[1] + + @classmethod + def from_safetensors( + cls, + path: str, + scale: float, + config: model_config.ModelConfig, + lora_tensor_names: LoRATensorNames, + dtype: torch.dtype = torch.float32, + ) -> "LoRA": + """Creates LoRA weights from a Hugging Face model. + + Args: + path: Path to the model. + scale: Scale factor for the LoRA weights (applied only to one of the + projections). The scaling factor depnds on the training configuration. + The common values are either `lora_alpha / rank` or `lora_alpha / + sqrt(rank)`. + config: Model configuration. + lora_tensor_names: Tensor names for the LoRA weights. + dtype: Data type of the LoRA weights. + + Returns: + LoRA weights for all modules. + """ + with safetensors.safe_open(path, framework="pt", device="cpu") as f: + adapters = [] + for i in range(config.num_layers): + attention_lora = AttentionLoRA( + query=LoRAWeight( + a_prime=f.get_tensor(lora_tensor_names.attn_query_w_a.format(i)) + .to(dtype) + .T + * scale, + b_prime=f.get_tensor(lora_tensor_names.attn_query_w_b.format(i)) + .to(dtype) + .T, + ), + key=LoRAWeight( + a_prime=f.get_tensor(lora_tensor_names.attn_key_w_a.format(i)) + .to(dtype) + .T + * scale, + b_prime=f.get_tensor(lora_tensor_names.attn_key_w_b.format(i)) + .to(dtype) + .T, + ), + value=LoRAWeight( + a_prime=f.get_tensor(lora_tensor_names.attn_value_w_a.format(i)) + .to(dtype) + .T + * scale, + b_prime=f.get_tensor(lora_tensor_names.attn_value_w_b.format(i)) + .to(dtype) + .T, + ), + output=LoRAWeight( + a_prime=f.get_tensor( + lora_tensor_names.attn_output_w_a.format(i) + ) + .to(dtype) + .T + * scale, + b_prime=f.get_tensor( + lora_tensor_names.attn_output_w_b.format(i) + ) + .to(dtype) + .T, + ), + ) + adapters.append(LoRAEntry(attention=attention_lora)) + return cls(adapters=adapters) + + @classmethod + def from_flatbuffers( + cls, + flatbuffer_model: bytearray, + dtype: torch.dtype = torch.float32, + ) -> "LoRA": + """Creates LoRA weights from FlatBuffers. + + Args: + flatbuffer_model: FlatBuffers model. + dtype: Data type of the LoRA weights. + + Returns: + LoRA weights for all modules. + """ + model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0) + model = schema_fb.ModelT.InitFromObj(model) + + flat_names = [] + tensors = [] + for tensor in model.subgraphs[0].tensors: + name = tensor.name.decode("utf-8") + assert name.startswith("lora_") + flat_names.append(name[5:]) + buffer_bytes = model.buffers[tensor.buffer].data.data.tobytes() + arr = np.frombuffer(buffer_bytes, dtype=np.float32).reshape(tensor.shape) + torch_tensor = torch.from_numpy(arr).to(dtype) + tensors.append(torch_tensor) + + return _unflatten_lora(tensors, (flat_names, [])) + + @classmethod + def zeros( + cls, + rank: int, + config: model_config.ModelConfig, + dtype: torch.dtype = torch.float32, + ) -> "LoRA": + """Creates LoRA weights with zeros. + + Args: + rank: Rank of the LoRA weights. + config: Model configuration. + dtype: Data type of the LoRA weights. + + Returns: + LoRA weights with zeros. + """ + return cls._from_tensor_generator( + tensor_generator=lambda shape, dtype: torch.zeros(shape, dtype=dtype), + rank=rank, + config=config, + dtype=dtype, + ) + + @classmethod + def random( + cls, + rank: int, + config: model_config.ModelConfig, + dtype: torch.dtype = torch.float32, + ) -> "LoRA": + """Creates LoRA weights with random values. + + Args: + rank: Rank of the LoRA weights. + config: Model configuration. + dtype: Data type of the LoRA weights. + + Returns: + LoRA weights with random values. + """ + return cls._from_tensor_generator( + tensor_generator=lambda shape, dtype: torch.randint( + low=0, high=128, size=shape, dtype=dtype + ), + rank=rank, + config=config, + dtype=dtype, + ) + + @classmethod + def _from_tensor_generator( + cls, + tensor_generator: Callable[[Tuple[int, ...], torch.dtype], torch.Tensor], + rank: int, + config: model_config.ModelConfig, + dtype: torch.dtype = torch.float32, + ) -> "LoRA": + """Creates LoRA weights from a tensor generator.""" + adapters = [] + block_config = config.block_config(0) + q_per_kv = ( + block_config.attn_config.num_heads + // block_config.attn_config.num_query_groups + ) + q_out_dim = q_per_kv * block_config.attn_config.head_dim + k_out_dim = v_out_dim = block_config.attn_config.head_dim + + for _ in range(config.num_layers): + attention_lora = AttentionLoRA( + query=LoRAWeight( + a_prime=tensor_generator((config.embedding_dim, rank), dtype), + b_prime=tensor_generator((rank, q_out_dim), dtype), + ), + key=LoRAWeight( + a_prime=tensor_generator((config.embedding_dim, rank), dtype), + b_prime=tensor_generator((rank, k_out_dim), dtype), + ), + value=LoRAWeight( + a_prime=tensor_generator((config.embedding_dim, rank), dtype), + b_prime=tensor_generator((rank, v_out_dim), dtype), + ), + output=LoRAWeight( + a_prime=tensor_generator( + ( + block_config.attn_config.num_heads + * block_config.attn_config.head_dim, + rank, + ), + dtype, + ), + b_prime=tensor_generator((rank, config.embedding_dim), dtype), + ), + ) + adapters.append(LoRAEntry(attention=attention_lora)) + return cls(adapters=adapters) + + def to_tflite(self) -> bytearray: + """Converts LoRA to FlatBuffers.""" + return _lora_to_flatbuffers(self) + + +def apply_lora( + x: torch.Tensor, + lora_weight: LoRAWeight, + shape: Optional[Tuple[int, ...]] = None, +) -> torch.Tensor: + """Applies LoRA weights to a tensor. + + Args: + x: Input tensor. + lora_weight: LoRA weight. + shape: Output shape. If None, the output shape is the same as the input + shape. + + Returns: + Output tensor. + """ + output = torch.matmul( + torch.matmul(x, lora_weight.a_prime), lora_weight.b_prime + ) + if shape is not None: + output = output.reshape(shape) + return output + + +def _flatten_attention_lora( + lora: AttentionLoRA, block_index: int +) -> Tuple[List[torch.Tensor], List[str]]: + """Flattens LoRA weights for attention module.""" + flattened = [] + flat_names = [] + flattened.append(lora.query.a_prime) + flat_names.append(f"atten_q_a_prime_weight_{block_index}") + flattened.append(lora.query.b_prime) + flat_names.append(f"atten_q_b_prime_weight_{block_index}") + flattened.append(lora.key.a_prime) + flat_names.append(f"atten_k_a_prime_weight_{block_index}") + flattened.append(lora.key.b_prime) + flat_names.append(f"atten_k_b_prime_weight_{block_index}") + flattened.append(lora.value.a_prime) + flat_names.append(f"atten_v_a_prime_weight_{block_index}") + flattened.append(lora.value.b_prime) + flat_names.append(f"atten_v_b_prime_weight_{block_index}") + flattened.append(lora.output.a_prime) + flat_names.append(f"atten_o_a_prime_weight_{block_index}") + flattened.append(lora.output.b_prime) + flat_names.append(f"atten_o_b_prime_weight_{block_index}") + return flattened, flat_names + + +def _flatten_lora(lora: LoRA) -> Tuple[List[torch.Tensor], List[Any]]: + """Flattens LoRA weights.""" + flattened = [] + flat_names = [] + none_names = [] + for i, entry in enumerate(lora.adapters): + attn_flattened, attn_flat_names = _flatten_attention_lora( + lora=entry.attention, block_index=i + ) + flattened.extend(attn_flattened) + flat_names.extend(attn_flat_names) + return flattened, [flat_names, none_names] + + +def _flatten_lora_with_keys(lora: LoRA) -> Tuple[List[Any], List[Any]]: + """Flattens LoRA weights with keys.""" + flattened, (flat_names, _) = _flatten_lora(lora) + return [ + (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened) + ], flat_names + + +def _unflatten_lora( + values: List[torch.Tensor], context: Tuple[List[str], List[Any]] +) -> LoRA: + """Unflattens LoRA object.""" + flat_names, _ = context + names_weights = list(zip(flat_names, values)) + adapters = {} + while names_weights: + name, weight = names_weights.pop(0) + block_idx = int(name.split("_")[-1]) + if block_idx not in adapters: + adapters[block_idx] = LoRAEntry( + attention=AttentionLoRA( + query=LoRAWeight( + a_prime=None, + b_prime=None, + ), + key=LoRAWeight( + a_prime=None, + b_prime=None, + ), + value=LoRAWeight( + a_prime=None, + b_prime=None, + ), + output=LoRAWeight( + a_prime=None, + b_prime=None, + ), + ) + ) + + if name.startswith("atten_"): + if "q_a_prime" in name: + adapters[block_idx].attention.query.a_prime = weight + elif "q_b_prime" in name: + adapters[block_idx].attention.query.b_prime = weight + elif "k_a_prime" in name: + adapters[block_idx].attention.key.a_prime = weight + elif "k_b_prime" in name: + adapters[block_idx].attention.key.b_prime = weight + elif "v_a_prime" in name: + adapters[block_idx].attention.value.a_prime = weight + elif "v_b_prime" in name: + adapters[block_idx].attention.value.b_prime = weight + elif "o_a_prime" in name: + adapters[block_idx].attention.output.a_prime = weight + elif "o_b_prime" in name: + adapters[block_idx].attention.output.b_prime = weight + else: + raise ValueError(f"Unsupported name: {name}") + else: + raise ValueError(f"Unsupported name: {name}") + + return LoRA(adapters=tuple(adapters[key] for key in sorted(adapters))) + + +pytree.register_pytree_node( + LoRA, + _flatten_lora, + _unflatten_lora, + flatten_with_keys_fn=_flatten_lora_with_keys, + serialized_type_name="", +) + + +def _add_buffer(builder: flatbuffers.Builder, data: np.ndarray | None) -> int: + """Adds a buffer to the FlatBuffers.""" + data_offset = None + if data is not None: + assert data.dtype == np.float32 + schema_fb.BufferStartDataVector(builder, data.size * data.itemsize) + for value in reversed(data.flatten().tolist()): + builder.PrependFloat32(value) + data_offset = builder.EndVector() + + schema_fb.BufferStart(builder) + if data is not None: + schema_fb.BufferAddData(builder, data_offset) + buffer_offset = schema_fb.BufferEnd(builder) + return buffer_offset + + +def _add_tensor( + builder: flatbuffers.Builder, + name: str, + shape: Tuple[int, ...], + buffer_idx: int, +) -> int: + """Adds a tensor to the FlatBuffers.""" + name_offset = builder.CreateString(name) + schema_fb.TensorStartShapeVector(builder, len(shape)) + for dim in reversed(shape): + builder.PrependInt32(dim) + shape_offset = builder.EndVector() + schema_fb.TensorStart(builder) + schema_fb.TensorAddName(builder, name_offset) + schema_fb.TensorAddShape(builder, shape_offset) + schema_fb.TensorAddType(builder, schema_fb.TensorType.FLOAT32) + schema_fb.TensorAddBuffer(builder, buffer_idx) + tensor_offset = schema_fb.TensorEnd(builder) + return tensor_offset + + +def _lora_to_flatbuffers(lora: LoRA) -> bytearray: + """Converts LoRA to FlatBuffers.""" + tensors, (names, _) = _flatten_lora(lora) + # Need to manually add the "lora_" prefix to the names here. The export will + # add the prefix automatically. + names = [f"lora_{name}" for name in names] + builder = flatbuffers.Builder(4096) + + # Convention to add an empty buffer in the beginning. + buffer_offsets = [_add_buffer(builder, None)] + for tensor in tensors: + buffer_offsets.append( + _add_buffer(builder, tensor.detach().type(torch.float32).numpy()) + ) + + schema_fb.ModelStartBuffersVector(builder, len(buffer_offsets)) + for buffer_offset in reversed(buffer_offsets): + builder.PrependUOffsetTRelative(buffer_offset) + buffers_offset = builder.EndVector() + + tensor_offsets = [] + for i, (name, tensor) in enumerate(zip(names, tensors)): + # Note that the zeroth buffer is empty and reserved for the convention. + tensor_offsets.append(_add_tensor(builder, name, tensor.shape, i + 1)) + + schema_fb.SubGraphStartTensorsVector(builder, len(tensor_offsets)) + for tensor_offset in reversed(tensor_offsets): + builder.PrependUOffsetTRelative(tensor_offset) + tensors_offset = builder.EndVector() + + string_offset = builder.CreateString("lora_params") + schema_fb.SubGraphStart(builder) + schema_fb.SubGraphAddName(builder, string_offset) + schema_fb.SubGraphAddTensors(builder, tensors_offset) + subgraph_offset = schema_fb.SubGraphEnd(builder) + + schema_fb.ModelStartSubgraphsVector(builder, 1) + builder.PrependUOffsetTRelative(subgraph_offset) + subgraphs_offset = builder.EndVector() + + string_offset = builder.CreateString("lora_params") + schema_fb.ModelStart(builder) + schema_fb.ModelAddVersion(builder, TFLITE_SCHEMA_VERSION) + schema_fb.ModelAddDescription(builder, string_offset) + schema_fb.ModelAddBuffers(builder, buffers_offset) + schema_fb.ModelAddSubgraphs(builder, subgraphs_offset) + model_offset = schema_fb.ModelEnd(builder) + builder.Finish(model_offset) + flatbuffer_model = builder.Output() + + return flatbuffer_model diff --git a/ai_edge_torch/generative/test/fixtures/test_lora_rank16.safetensors b/ai_edge_torch/generative/test/fixtures/test_lora_rank16.safetensors new file mode 100644 index 00000000..7711e8e5 Binary files /dev/null and b/ai_edge_torch/generative/test/fixtures/test_lora_rank16.safetensors differ diff --git a/ai_edge_torch/generative/test/test_lora.py b/ai_edge_torch/generative/test/test_lora.py new file mode 100644 index 00000000..f8c11d4d --- /dev/null +++ b/ai_edge_torch/generative/test/test_lora.py @@ -0,0 +1,147 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A suite of tests to validate LoRA utilities.""" + +from ai_edge_torch.generative.layers import lora as lora_utils +import ai_edge_torch.generative.layers.model_config as cfg +import torch +from absl.testing import absltest as googletest +from tensorflow.python.platform import resource_loader # pylint: disable=g-direct-tensorflow-import + + +class TestLora(googletest.TestCase): + """Tests for LoRA utilities.""" + + def test_safetensors_builder(self): + """Converts a safetensors file to a LoRA module.""" + + tensor_names = lora_utils.LoRATensorNames( + attn_query_w_a=( + "base_model.model.model.layers.{}.self_attn.q_proj.lora_A.weight" + ), + attn_query_w_b=( + "base_model.model.model.layers.{}.self_attn.q_proj.lora_B.weight" + ), + attn_key_w_a=( + "base_model.model.model.layers.{}.self_attn.k_proj.lora_A.weight" + ), + attn_key_w_b=( + "base_model.model.model.layers.{}.self_attn.k_proj.lora_B.weight" + ), + attn_value_w_a=( + "base_model.model.model.layers.{}.self_attn.v_proj.lora_A.weight" + ), + attn_value_w_b=( + "base_model.model.model.layers.{}.self_attn.v_proj.lora_B.weight" + ), + attn_output_w_a=( + "base_model.model.model.layers.{}.self_attn.o_proj.lora_A.weight" + ), + attn_output_w_b=( + "base_model.model.model.layers.{}.self_attn.o_proj.lora_B.weight" + ), + ) + + safetensors_file = resource_loader.get_path_to_datafile( + "fixtures/test_lora_rank16.safetensors" + ) + config = self._get_test_config( + num_layers=1, + head_dim=8, + num_query_groups=1, + kv_cache_max_len=16, + ) + lora = lora_utils.LoRA.from_safetensors( + safetensors_file, + scale=1.0, + lora_tensor_names=tensor_names, + config=config, + ) + self.assertEqual(lora.get_rank(), 16) + + def test_torch_export(self): + """Tests the export of the LoRA module.""" + + class TestModel(torch.nn.Module): + + def forward(self, x: torch.Tensor, lora: lora_utils.LoRA) -> torch.Tensor: + x += lora_utils.apply_lora(x, lora.adapters[0].attention.query) + return x + + n = 1 + head_dim = 2 + num_query_groups = 1 + key_length = 4 + config = self._get_test_config( + num_layers=n, + head_dim=head_dim, + num_query_groups=num_query_groups, + kv_cache_max_len=key_length, + ) + inputs = torch.zeros((n, 1, head_dim)) + lora = lora_utils.LoRA.zeros(rank=16, config=config) + model = TestModel() + exported_program = torch.export.export(model, (inputs, lora)) + input_specs = exported_program.graph_signature.input_specs + # 9 inputs: 1 for x, 2 for query lora, 2 for key lora, 2 for value lora, + # 2 for output lora. + self.assertLen(input_specs, 9) + self.assertEqual(input_specs[0].arg.name, "x") + self.assertEqual(input_specs[1].arg.name, "lora_atten_q_a_prime_weight_0") + self.assertEqual(input_specs[2].arg.name, "lora_atten_q_b_prime_weight_0") + self.assertEqual(input_specs[3].arg.name, "lora_atten_k_a_prime_weight_0") + self.assertEqual(input_specs[4].arg.name, "lora_atten_k_b_prime_weight_0") + self.assertEqual(input_specs[5].arg.name, "lora_atten_v_a_prime_weight_0") + self.assertEqual(input_specs[6].arg.name, "lora_atten_v_b_prime_weight_0") + self.assertEqual(input_specs[7].arg.name, "lora_atten_o_a_prime_weight_0") + self.assertEqual(input_specs[8].arg.name, "lora_atten_o_b_prime_weight_0") + + def test_lora_tflite_serialization(self): + """Tests the serialization of the LoRA module.""" + config = self._get_test_config( + num_layers=2, + head_dim=8, + num_query_groups=1, + kv_cache_max_len=16, + ) + lora = lora_utils.LoRA.random(rank=16, config=config) + flatbuffer_model = lora.to_tflite() + recovered_lora = lora_utils.LoRA.from_flatbuffers(flatbuffer_model) + self.assertEqual(lora, recovered_lora) + + def _get_test_config( + self, num_layers, head_dim, num_query_groups, kv_cache_max_len + ): + """Returns a test model config.""" + attn_config = cfg.AttentionConfig( + num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups + ) + block_config = cfg.TransformerBlockConfig( + attn_config=attn_config, ff_config=None + ) + config = cfg.ModelConfig( + kv_cache_max_len=kv_cache_max_len, + embedding_dim=head_dim, + block_configs=block_config, + num_layers=num_layers, + max_seq_len=None, + vocab_size=None, + ) + return config + + +if __name__ == "__main__": + googletest.main() diff --git a/ai_edge_torch/generative/utilities/converter.py b/ai_edge_torch/generative/utilities/converter.py index 15d86682..bbb5dd5b 100644 --- a/ai_edge_torch/generative/utilities/converter.py +++ b/ai_edge_torch/generative/utilities/converter.py @@ -15,16 +15,15 @@ """Common utility functions for model conversion.""" -from functools import partial -from typing import Any, Union - +import os +from typing import Optional, Union from ai_edge_torch._convert import converter as converter_utils +from ai_edge_torch.generative.layers import lora as lora_utils import ai_edge_torch.generative.layers.kv_cache as kv_utils import ai_edge_torch.generative.layers.model_config as cfg from ai_edge_torch.generative.quantize import quant_recipes from ai_edge_torch.generative.utilities.model_builder import ExportConfig import torch -import torch.nn as nn class ExportableModule(torch.nn.Module): @@ -41,11 +40,13 @@ def forward(self, *export_args, **export_kwargs): def convert_to_tflite( pytorch_model: torch.nn.Module, - tflite_path: str, + output_path: str, + output_name_prefix: str, prefill_seq_len: Union[int, list[int]], pixel_values_size: torch.Size = None, quantize: bool = True, config: cfg.ModelConfig = None, + lora_ranks: Optional[list[int]] = None, export_config: ExportConfig = None, ): """Converts a nn.Module model to multi-signature tflite model. @@ -79,21 +80,65 @@ def convert_to_tflite( Args: pytorch_model (torch.nn.Module): PyTorch model to convert to tflite. - tflite_path (str): The tflite file path to export. - prefill_seq_len (Union[int, list[int]]): A list of prefill lengths to - export. + output_path (str): The path to export the tflite model. + output_name_prefix (str): The prefix of the tflite model name. + prefill_seq_len (Union[int, list[int]]): The prefill sequence length to + use. If a list, the model will have multiple prefill signatures. pixel_values_size (torch.Size, optional): The size of pixel values to pass to the model. If None, the model is not expected to take pixel values. quantize (bool, optional): Whether the model should be quanized. Defaults to True. config (cfg.ModelConfig, optional): The model config used to configure KV cache. If None, it uses the config of the pytorch_model. + lora_ranks (list[int], optional): The ranks of the LORA layers. If None, + no LoRA signatures will be added. """ + # pylint: disable=protected-access + torch._dynamo.config.cache_size_limit = 64 + + config = config if config else pytorch_model.config prefill_seq_lens = ( [prefill_seq_len] if isinstance(prefill_seq_len, int) else prefill_seq_len ) + loras = [None] + if lora_ranks is not None: + for rank in lora_ranks: + lora = lora_utils.LoRA.zeros(rank, config) + loras.append(lora) + + quant_suffix = 'q8' if quantize else 'f32' + kv_size = config.kv_cache_max_len + lora_suffix = ( + '' if lora_ranks is None else f'_lora{",".join(map(str, lora_ranks))}' + ) + output_filename = ( + f'{output_name_prefix}_{quant_suffix}_ekv{kv_size}{lora_suffix}.tflite' + ) + output_file = os.path.join(output_path, output_filename) + + _export_helper( + pytorch_model, + output_file, + prefill_seq_lens, + pixel_values_size, + quantize, + config, + loras, + export_config, + ) - # Tensors used to trace the model graph during conversion. + +def _export_helper( + pytorch_model: torch.nn.Module, + output_file: str, + prefill_seq_lens: list[int], + pixel_values_size: torch.Size, + quantize: bool, + config: cfg.ModelConfig, + loras: list[None | lora_utils.LoRA], + export_config: ExportConfig, +): + """Helper function to export a model to tflite.""" prefill_tokens_list = [] prefill_input_pos_list = [] for seq_len in prefill_seq_lens: @@ -108,9 +153,7 @@ def convert_to_tflite( decode_token = torch.tensor([[0]], dtype=torch.int) decode_input_pos = torch.tensor([0], dtype=torch.int) - kv = kv_utils.KVCache.from_model_config( - config if config else pytorch_model.config - ) + kv = kv_utils.KVCache.from_model_config(config) quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None @@ -119,44 +162,55 @@ def convert_to_tflite( mod = ExportableModule(pytorch_model, export_config=export_config) converter = converter_utils.Converter() - for i in range(len(prefill_seq_lens)): - prefill_seq_len = prefill_seq_lens[i] - prefill_tokens = prefill_tokens_list[i] - prefill_input_pos = prefill_input_pos_list[i] - if i == 0 and len(prefill_seq_lens) == 1: - prefill_signature_name = 'prefill' - else: - prefill_signature_name = f'prefill_{prefill_seq_len}' - converter.add_signature( - prefill_signature_name, - mod, - sample_kwargs={ - 'tokens': prefill_tokens, - 'input_pos': prefill_input_pos, - 'kv_cache': kv, - }, - ) - if prefill_pixel_values is not None: + for lora in loras: + for i in range(len(prefill_seq_lens)): + prefill_seq_len = prefill_seq_lens[i] + prefill_tokens = prefill_tokens_list[i] + prefill_input_pos = prefill_input_pos_list[i] + if i == 0 and len(prefill_seq_lens) == 1: + prefill_signature_name = 'prefill' + else: + prefill_signature_name = f'prefill_{prefill_seq_len}' + + sample_kwargs = { + 'tokens': prefill_tokens, + 'input_pos': prefill_input_pos, + 'kv_cache': kv, + } + if lora is not None: + prefill_signature_name += f'_lora_r{lora.get_rank()}' + sample_kwargs['lora'] = lora + converter.add_signature( - prefill_signature_name + '_pixel', + prefill_signature_name, mod, - sample_kwargs={ - 'tokens': prefill_tokens, - 'input_pos': prefill_input_pos, - 'kv_cache': kv, - 'pixel_values': prefill_pixel_values, - }, + sample_kwargs=sample_kwargs, ) + if prefill_pixel_values is not None: + converter.add_signature( + prefill_signature_name + '_pixel', + mod, + sample_kwargs={ + 'tokens': prefill_tokens, + 'input_pos': prefill_input_pos, + 'kv_cache': kv, + 'pixel_values': prefill_pixel_values, + }, + ) + + sample_kwargs = { + 'tokens': decode_token, + 'input_pos': decode_input_pos, + 'kv_cache': kv, + } + if lora is not None: + sample_kwargs['lora'] = lora - converter.add_signature( - 'decode', - mod, - sample_kwargs={ - 'tokens': decode_token, - 'input_pos': decode_input_pos, - 'kv_cache': kv, - }, - ) + converter.add_signature( + 'decode' if lora is None else f'decode_lora_r{lora.get_rank()}', + mod, + sample_kwargs=sample_kwargs, + ) edge_model = converter.convert(quant_config=quant_config) - edge_model.export(tflite_path) + edge_model.export(output_file) diff --git a/ai_edge_torch/generative/utilities/model_builder.py b/ai_edge_torch/generative/utilities/model_builder.py index d259a5be..6dce5e61 100644 --- a/ai_edge_torch/generative/utilities/model_builder.py +++ b/ai_edge_torch/generative/utilities/model_builder.py @@ -22,6 +22,7 @@ from ai_edge_torch.generative.layers import attention from ai_edge_torch.generative.layers import builder from ai_edge_torch.generative.layers import kv_cache as kv_utils +from ai_edge_torch.generative.layers import lora as lora_utils import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb @@ -29,6 +30,7 @@ import torch from torch import nn + TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( ff_up_proj="model.layers.{}.mlp.up_proj", ff_down_proj="model.layers.{}.mlp.down_proj", @@ -97,6 +99,7 @@ def forward( tokens: torch.Tensor, input_pos: torch.Tensor, kv_cache: kv_utils.KVCache, + lora: Optional[lora_utils.LoRA] = None, export_config: Optional[ExportConfig] = None, ) -> dict[torch.Tensor, kv_utils.KVCache]: _, seq_len = tokens.size() @@ -118,7 +121,7 @@ def forward( ) return self.forward_with_embeds( - input_embeds, rope, mask, input_pos, kv_cache, export_config + input_embeds, rope, mask, input_pos, kv_cache, lora, export_config ) def forward_with_embeds( @@ -128,6 +131,7 @@ def forward_with_embeds( mask: torch.Tensor, input_pos: torch.Tensor, kv_cache: kv_utils.KVCache, + lora: Optional[lora_utils.LoRA] = None, export_config: Optional[ExportConfig] = None, ) -> dict[torch.Tensor, kv_utils.KVCache]: """Forwards the model with input embeddings.""" @@ -143,7 +147,8 @@ def forward_with_embeds( updated_kv_entries = [] for i, block in enumerate(self.transformer_blocks): kv_entry = kv_cache.caches[i] if kv_cache else None - x, kv_entry = block(x, rope, mask, input_pos, kv_entry) + lora_adapter = lora.adapters[i] if lora else None + x, kv_entry = block(x, rope, mask, input_pos, kv_entry, lora_adapter) if kv_entry: updated_kv_entries.append(kv_entry) updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))