Skip to content

Commit

Permalink
Add a flag in the converter config for generating fake weights. When …
Browse files Browse the repository at this point in the history
…it is set to true, all weights will be filled with zeros.

PiperOrigin-RevId: 671939022
  • Loading branch information
yishuangP authored and copybara-github committed Sep 7, 2024
1 parent 9cff42d commit 5a25efa
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 3 deletions.
6 changes: 5 additions & 1 deletion mediapipe/tasks/python/genai/converter/converter_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,9 @@ def __init__(self, output_dir: str, backend: str):
os.mkdir(self._output_dir)
self._backend = backend

def write_variables(self, variables: Dict[str, Tuple[np.ndarray, bool]]):
def write_variables(
self,
variables: Dict[str, Tuple[np.ndarray, bool]],
use_fake_values: bool = False,
):
raise NotImplementedError("The write_variables method is not implemented.")
6 changes: 5 additions & 1 deletion mediapipe/tasks/python/genai/converter/llm_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class ConversionConfig(object):
lora_output_tflite_file: A string indicating the name of the generated
tflite file for the LoRA weight. Only applicable when the lora_rank is not
zero.
use_fake_weights: Whether to use fake weights. If set to True, the weights
will be filled with zeros.
"""

def __init__(
Expand All @@ -69,6 +71,7 @@ def __init__(
lora_ckpt: Optional[str] = None,
lora_rank: Optional[int] = None,
lora_output_tflite_file: Optional[str] = None,
use_fake_weights: bool = False,
):
self.input_ckpt = input_ckpt
self.ckpt_format = ckpt_format
Expand All @@ -87,6 +90,7 @@ def __init__(
self.combine_file_only = combine_file_only
self.vocab_model_file = vocab_model_file
self.obfuscate = obfuscate
self.use_fake_weights = use_fake_weights
if output_tflite_file:
parent_dir = os.path.dirname(output_tflite_file)
if not os.path.isdir(parent_dir):
Expand Down Expand Up @@ -291,7 +295,7 @@ def maybe_quantize_and_write_tensors_to_bins(
output_dir=config.output_dir,
backend=config.backend,
)
writer.write_variables(quantized_tensors)
writer.write_variables(quantized_tensors, config.use_fake_weights)
del quantized_tensors
del writer

Expand Down
10 changes: 9 additions & 1 deletion mediapipe/tasks/python/genai/converter/weight_bins_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,25 @@ def get_weight_info(self, var_name: str, weight: np.ndarray) -> str:
shape_str = '_'.join(map(str, weight.shape))
return f'mdl_vars.{var_name}.{dtype_str}.{shape_str}\n'

def write_variables(self, variables: Dict[str, Tuple[np.ndarray, bool]]):
def write_variables(
self,
variables: Dict[str, Tuple[np.ndarray, bool]],
use_fake_values: bool = False,
):
"""Writes variable to the binary files. One for each layer.
Args:
variables: A dictionary that maps from the target variable names to the
quantized tensor values along with a boolean that indicates whether to
pack the values (only applicable for the 4-bit quantized tensors).
use_fake_values: Whether to use fake values for the weights. If set to
True, the weights will be filled with zeros.
"""
weights_info = []
for var_name, value in variables.items():
output = value[0]
if use_fake_values:
output.fill(0)
if value[1]:
# Squeeze the tensor to make sure it is a 1D array for packing.
output = np.expand_dims(np.ravel(output), axis=-1)
Expand Down
33 changes: 33 additions & 0 deletions mediapipe/tasks/python/genai/converter/weight_bins_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,39 @@ def test_load_to_actions(self):
)
self.assertEqual(file_size, 6 * 4)

@parameterized.named_parameters(
(
'real_weights',
np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32),
False,
),
(
'fake_weights',
np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32),
True,
),
)
def test_write_variables(self, var_values, use_fake_values):
output_dir = os.path.join(flags.FLAGS.test_tmpdir, 'output_dir')
writer = weight_bins_writer.WeightBinsWriter(
output_dir=output_dir, backend='gpu'
)
variables = {
'mdl_vars.params.lm.softmax.logits_ffn.linear.w': (
var_values,
False,
),
}
writer.write_variables(variables, use_fake_values=use_fake_values)
with open(
os.path.join(output_dir, 'params.lm.softmax.logits_ffn.linear.w'), 'rb'
) as f:
data = np.frombuffer(f.read(), dtype=np.float32).reshape(var_values.shape)
expected_values = var_values if not use_fake_values else np.zeros_like(
var_values
)
self.assertTrue(np.all(data == expected_values))


if __name__ == '__main__':
absltest.main()

0 comments on commit 5a25efa

Please sign in to comment.