Skip to content

Commit

Permalink
Refactor verifier so that verification can return boolean value indic…
Browse files Browse the repository at this point in the history
…ating if verification succeeds.

PiperOrigin-RevId: 695933836
  • Loading branch information
ai-edge-bot authored and copybara-github committed Nov 13, 2024
1 parent 78d1c7b commit 956cfc9
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 21 deletions.
15 changes: 1 addition & 14 deletions ai_edge_torch/generative/examples/gemma/verify_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@

"""Verifies the reauthored Gemma2 model."""

import logging
from absl import app
from absl import flags
from ai_edge_torch.generative.examples.gemma import gemma2
from ai_edge_torch.generative.examples.gemma import verify_util
import kagglehub

Expand All @@ -38,18 +36,7 @@
def main(_):
checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")

logging.info("Building the reauthored model from: %s", checkpoint)
reauthored_model = gemma2.build_2b_model(checkpoint)

verify_util.verify_reauthored_gemma_model(
checkpoint=checkpoint,
variant="2b-v2",
reauthored_model=reauthored_model,
generate_prompts=_PROMPTS.value,
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
max_new_tokens=_MAX_NEW_TOKENS.value,
atol=1e-04,
)
verify_util.verify_gemma2(checkpoint, _PROMPTS.value, _MAX_NEW_TOKENS.value)


if __name__ == "__main__":
Expand Down
31 changes: 28 additions & 3 deletions ai_edge_torch/generative/examples/gemma/verify_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
from typing import List, Tuple

from ai_edge_torch.generative.examples.gemma import gemma2
import ai_edge_torch.generative.layers.attention_utils as attn_utils
from ai_edge_torch.generative.utilities import verifier
from gemma import config as gemma_config
Expand Down Expand Up @@ -109,8 +110,11 @@ def verify_reauthored_gemma_model(
max_new_tokens: int = 20,
rtol: float = 1e-05,
atol: float = 1e-05,
):
"""Verifies the reauthored Gemma model against the original model."""
) -> bool:
"""Verifies the reauthored Gemma model against the original model.
Returns True if the verification passes, False otherwise.
"""
config = gemma_config.get_model_config(variant)
config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
# Use float32 to be compatible with the reauthored model.
Expand All @@ -120,7 +124,7 @@ def verify_reauthored_gemma_model(
original_model = gemma_model.GemmaForCausalLM(config).eval()
original_model.load_weights(os.path.join(checkpoint, weight_filename))

verifier.verify_reauthored_model(
return verifier.verify_reauthored_model(
original_model=GemmaWrapper(original_model),
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
Expand All @@ -130,3 +134,24 @@ def verify_reauthored_gemma_model(
rtol=rtol,
atol=atol,
)


def verify_gemma2(
gemma2_model_path: str, prompts: List[str], max_new_tokens: int
) -> bool:
"""Verifies the reauthored Gemma2 model.
Return True if the verification passes, False otherwise.
"""
logging.info("Building the reauthored model from: %s", gemma2_model_path)
reauthored_model = gemma2.build_2b_model(gemma2_model_path)

return verify_reauthored_gemma_model(
checkpoint=gemma2_model_path,
variant="2b-v2",
reauthored_model=reauthored_model,
generate_prompts=prompts,
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
max_new_tokens=max_new_tokens,
atol=1e-04,
)
20 changes: 16 additions & 4 deletions ai_edge_torch/generative/utilities/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def verify_reauthored_model(
rtol: float = 1e-05,
atol: float = 1e-05,
continue_on_failure: bool = False,
):
) -> bool:
"""Verifies the reauthored model against the original model.
It verifies the reauthored model with two methods:
Expand All @@ -237,7 +237,8 @@ def verify_reauthored_model(
2. It compares the answer generated by the original and the reauthored model
with a prompt.
It prints out "PASS" or "FAILED" to the console.
It prints out "PASS" or "FAILED" to the console. It returns True if all
verification passes, False otherwise.
Args:
original_model (ModelWrapper): The original model.
Expand All @@ -253,6 +254,8 @@ def verify_reauthored_model(
continue_on_failure (bool): If True, it continues to verify the next prompt
or input IDs even if a previous one fails.
"""
failure_count = 0

for input_ids in forward_input_ids:
logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
try:
Expand All @@ -261,8 +264,9 @@ def verify_reauthored_model(
)
except AssertionError as e:
logging.error("*** FAILED *** verify with input IDs: %s", input_ids)
failure_count += 1
if not continue_on_failure:
raise e
return False
else:
logging.info("*** PASSED *** verify with input IDs: %s", input_ids)

Expand All @@ -274,7 +278,15 @@ def verify_reauthored_model(
)
except AssertionError as e:
logging.error("*** FAILED *** verify with prompts: %s", prompts)
failure_count += 1
if not continue_on_failure:
raise e
return False
else:
logging.info("*** PASSED *** verify with prompts: %s", prompts)

if failure_count == 0:
logging.info("*** PASSED *** verify_reauthored_model")
return True
else:
logging.error("*** FAILED *** verify_reauthored_model")
return False

0 comments on commit 956cfc9

Please sign in to comment.