Skip to content

Commit

Permalink
SentencePiece conversion script clean-up.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697648774
  • Loading branch information
hheydary authored and copybara-github committed Nov 18, 2024
1 parent cad366d commit dc9a6eb
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions ai_edge_torch/generative/tools/tokenizer_to_sentencepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@

from absl import app
from absl import flags
import sentencepiece as spm
import sentencepiece.sentencepiece_model_pb2 as spm_model
import transformers

from sentencepiece import sentencepiece_model_pb2 as spm_model
import sentencepiece as spm

_CHECKPOINT = flags.DEFINE_string(
"checkpoint",
None,
Expand Down Expand Up @@ -123,7 +124,7 @@ def _normalize_gpt2(token: str) -> str:

def _add_token(
token: str,
id: int,
id_: int,
tokenizer: transformers.PreTrainedTokenizer,
sp_model: spm_model.ModelProto,
tokens_seen: set[str],
Expand All @@ -132,23 +133,26 @@ def _add_token(
"""Adds a token to the SentencePieceModel protobuf with a derived type."""
unk_token = tokenizer.unk_token or tokenizer.pad_token or tokenizer.eos_token
if token == unk_token:
type = spm_model.ModelProto.SentencePiece.UNKNOWN
type_ = spm_model.ModelProto.SentencePiece.UNKNOWN
elif token in tokenizer.special_tokens_map:
type = spm_model.ModelProto.SentencePiece.CONTROL
type_ = spm_model.ModelProto.SentencePiece.CONTROL
sp_model.trainer_spec.control_symbols.append(token)
elif token in tokenizer.get_added_vocab():
type = spm_model.ModelProto.SentencePiece.USER_DEFINED
type_ = spm_model.ModelProto.SentencePiece.USER_DEFINED
sp_model.trainer_spec.user_defined_symbols.append(token)
else:
type = spm_model.ModelProto.SentencePiece.NORMAL
type_ = spm_model.ModelProto.SentencePiece.NORMAL

count_type = type
normalized = _NORMALIZE_FUNCS[_NORMALIZE_TOKENS.value](token, id, tokenizer)
count_type = type_
normalized = _NORMALIZE_FUNCS[_NORMALIZE_TOKENS.value](token, id_, tokenizer)
if normalized == token:
pass
elif normalized in tokens_seen:
logging.debug(
'DUPLICATE: token "%s"(id=%d) normalized to "%s"', token, id, normalized
'DUPLICATE: token "%s"(id=%d) normalized to "%s"',
token,
id_,
normalized,
)
normalized = token
# Change only the type of counts for logging. When UNUSED is set for SPM
Expand All @@ -157,7 +161,7 @@ def _add_token(
count_type = spm_model.ModelProto.SentencePiece.Type.UNUSED
else:
tokens_seen.add(normalized)
sp_model.pieces.add(piece=normalized, score=-id, type=type)
sp_model.pieces.add(piece=normalized, score=-id_, type=type_)
counts[count_type] = counts.get(count_type, 0) + 1


Expand All @@ -176,15 +180,15 @@ def _build_spm_model_from_tokenizer(
id_to_token = {id: tk for tk, id in tokenizer.vocab.items()}
tokens_seen = set(tokenizer.vocab.keys())
counts = {}
for id in range(len(tokenizer.vocab)):
_add_token(id_to_token[id], id, tokenizer, sp_model, tokens_seen, counts)
for id_ in range(len(tokenizer.vocab)):
_add_token(id_to_token[id_], id_, tokenizer, sp_model, tokens_seen, counts)

logging.info("number of tokens: %d", len(sp_model.pieces))
for type in counts:
for type_ in counts:
logging.info(
"number of %s: %d",
spm_model.ModelProto.SentencePiece.Type.Name(type),
counts[type],
spm_model.ModelProto.SentencePiece.Type.Name(type_),
counts[type_],
)

return sp_model
Expand Down Expand Up @@ -220,7 +224,7 @@ def _verify_spm_tokenizer(
# as the token IDs encoded by the SentencePiece tokenizer.
for string in _STRINGS_TO_VERIFY.value:
ids_by_tokenizer = tokenizer.encode(string)
ids_by_spm = spm_tokenizer.encode(string)
ids_by_spm = spm_tokenizer.Encode(string)
logging.info("String to verify: %s", string)
logging.info("Token IDs by the oringal tokenizer: %s", ids_by_tokenizer)
logging.info("Token IDs by the SentencePiece tokenizer: %s", ids_by_spm)
Expand All @@ -239,7 +243,7 @@ def _verify_spm_tokenizer(
id_pair = random.sample(list(range(len(tokenizer.vocab))), 2)
string = tokenizer.decode(id_pair)
ids_by_tokenizer = tokenizer.encode(string)
ids_by_spm = spm_tokenizer.encode(string)
ids_by_spm = spm_tokenizer.Encode(string)
if not _is_same_ids(ids_by_tokenizer, ids_by_spm):
num_not_matched_strict += 1
if _is_same_ids(ids_by_tokenizer, id_pair):
Expand Down

0 comments on commit dc9a6eb

Please sign in to comment.