From 5a25efa38ce719310168e4ed8eb6047f01856036 Mon Sep 17 00:00:00 2001 From: Yishuang Pang Date: Fri, 6 Sep 2024 17:25:19 -0700 Subject: [PATCH] Add a flag in the converter config for generating fake weights. When it is set to true, all weights will be filled with zeros. PiperOrigin-RevId: 671939022 --- .../python/genai/converter/converter_base.py | 6 +++- .../python/genai/converter/llm_converter.py | 6 +++- .../genai/converter/weight_bins_writer.py | 10 +++++- .../converter/weight_bins_writer_test.py | 33 +++++++++++++++++++ 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/python/genai/converter/converter_base.py b/mediapipe/tasks/python/genai/converter/converter_base.py index d13a923988..b759b12d56 100644 --- a/mediapipe/tasks/python/genai/converter/converter_base.py +++ b/mediapipe/tasks/python/genai/converter/converter_base.py @@ -171,5 +171,9 @@ def __init__(self, output_dir: str, backend: str): os.mkdir(self._output_dir) self._backend = backend - def write_variables(self, variables: Dict[str, Tuple[np.ndarray, bool]]): + def write_variables( + self, + variables: Dict[str, Tuple[np.ndarray, bool]], + use_fake_values: bool = False, + ): raise NotImplementedError("The write_variables method is not implemented.") diff --git a/mediapipe/tasks/python/genai/converter/llm_converter.py b/mediapipe/tasks/python/genai/converter/llm_converter.py index fc177a1005..7d1c372919 100644 --- a/mediapipe/tasks/python/genai/converter/llm_converter.py +++ b/mediapipe/tasks/python/genai/converter/llm_converter.py @@ -48,6 +48,8 @@ class ConversionConfig(object): lora_output_tflite_file: A string indicating the name of the generated tflite file for the LoRA weight. Only applicable when the lora_rank is not zero. + use_fake_weights: Whether to use fake weights. If set to True, the weights + will be filled with zeros. """ def __init__( @@ -69,6 +71,7 @@ def __init__( lora_ckpt: Optional[str] = None, lora_rank: Optional[int] = None, lora_output_tflite_file: Optional[str] = None, + use_fake_weights: bool = False, ): self.input_ckpt = input_ckpt self.ckpt_format = ckpt_format @@ -87,6 +90,7 @@ def __init__( self.combine_file_only = combine_file_only self.vocab_model_file = vocab_model_file self.obfuscate = obfuscate + self.use_fake_weights = use_fake_weights if output_tflite_file: parent_dir = os.path.dirname(output_tflite_file) if not os.path.isdir(parent_dir): @@ -291,7 +295,7 @@ def maybe_quantize_and_write_tensors_to_bins( output_dir=config.output_dir, backend=config.backend, ) - writer.write_variables(quantized_tensors) + writer.write_variables(quantized_tensors, config.use_fake_weights) del quantized_tensors del writer diff --git a/mediapipe/tasks/python/genai/converter/weight_bins_writer.py b/mediapipe/tasks/python/genai/converter/weight_bins_writer.py index c0340f77e1..40bff63460 100644 --- a/mediapipe/tasks/python/genai/converter/weight_bins_writer.py +++ b/mediapipe/tasks/python/genai/converter/weight_bins_writer.py @@ -50,17 +50,25 @@ def get_weight_info(self, var_name: str, weight: np.ndarray) -> str: shape_str = '_'.join(map(str, weight.shape)) return f'mdl_vars.{var_name}.{dtype_str}.{shape_str}\n' - def write_variables(self, variables: Dict[str, Tuple[np.ndarray, bool]]): + def write_variables( + self, + variables: Dict[str, Tuple[np.ndarray, bool]], + use_fake_values: bool = False, + ): """Writes variable to the binary files. One for each layer. Args: variables: A dictionary that maps from the target variable names to the quantized tensor values along with a boolean that indicates whether to pack the values (only applicable for the 4-bit quantized tensors). + use_fake_values: Whether to use fake values for the weights. If set to + True, the weights will be filled with zeros. """ weights_info = [] for var_name, value in variables.items(): output = value[0] + if use_fake_values: + output.fill(0) if value[1]: # Squeeze the tensor to make sure it is a 1D array for packing. output = np.expand_dims(np.ravel(output), axis=-1) diff --git a/mediapipe/tasks/python/genai/converter/weight_bins_writer_test.py b/mediapipe/tasks/python/genai/converter/weight_bins_writer_test.py index 401a055282..81e7701eef 100644 --- a/mediapipe/tasks/python/genai/converter/weight_bins_writer_test.py +++ b/mediapipe/tasks/python/genai/converter/weight_bins_writer_test.py @@ -57,6 +57,39 @@ def test_load_to_actions(self): ) self.assertEqual(file_size, 6 * 4) + @parameterized.named_parameters( + ( + 'real_weights', + np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32), + False, + ), + ( + 'fake_weights', + np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32), + True, + ), + ) + def test_write_variables(self, var_values, use_fake_values): + output_dir = os.path.join(flags.FLAGS.test_tmpdir, 'output_dir') + writer = weight_bins_writer.WeightBinsWriter( + output_dir=output_dir, backend='gpu' + ) + variables = { + 'mdl_vars.params.lm.softmax.logits_ffn.linear.w': ( + var_values, + False, + ), + } + writer.write_variables(variables, use_fake_values=use_fake_values) + with open( + os.path.join(output_dir, 'params.lm.softmax.logits_ffn.linear.w'), 'rb' + ) as f: + data = np.frombuffer(f.read(), dtype=np.float32).reshape(var_values.shape) + expected_values = var_values if not use_fake_values else np.zeros_like( + var_values + ) + self.assertTrue(np.all(data == expected_values)) + if __name__ == '__main__': absltest.main()