Skip to content

Commit

Permalink
Add export_model_with_tokenizer to Text Classifier API.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 567744604
  • Loading branch information
MediaPipe Team authored and copybara-github committed Sep 22, 2023
1 parent 9d85141 commit 573fdad
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 5 deletions.
21 changes: 20 additions & 1 deletion mediapipe/model_maker/python/text/text_classifier/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,23 @@ py_test(
deps = [":dataset"],
)

py_library(
name = "model_with_tokenizer",
srcs = ["model_with_tokenizer.py"],
)

py_test(
name = "model_with_tokenizer_test",
srcs = ["model_with_tokenizer_test.py"],
tags = ["requires-net:external"],
deps = [
":bert_tokenizer",
":model_spec",
":model_with_tokenizer",
"//mediapipe/model_maker/python/core/utils:hub_loader",
],
)

py_library(
name = "bert_tokenizer",
srcs = ["bert_tokenizer.py"],
Expand Down Expand Up @@ -145,10 +162,12 @@ py_library(
name = "text_classifier",
srcs = ["text_classifier.py"],
deps = [
":bert_tokenizer",
":dataset",
":hyperparameters",
":model_options",
":model_spec",
":model_with_tokenizer",
":preprocessor",
":text_classifier_options",
"//mediapipe/model_maker/python/core/data:dataset",
Expand All @@ -165,7 +184,7 @@ py_library(

py_test(
name = "text_classifier_test",
size = "large",
size = "enormous",
srcs = ["text_classifier_test.py"],
data = [
"//mediapipe/model_maker/python/text/text_classifier/testdata",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
self._seq_len = seq_len

def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
"""Processes one input_tensor example.
Args:
input_tensor: A tensor with shape (1, None) of a utf-8 encoded string.
Returns:
A dictionary of lists all with shape (1, self._seq_len) containing the
keys "input_word_ids", "input_type_ids", and "input_mask".
"""
tokens = self._tokenizer.tokenize(input_tensor.numpy()[0].decode("utf-8"))
tokens = tokens[0 : (self._seq_len - 2)] # account for [CLS] and [SEP]
tokens.insert(0, "[CLS]")
Expand Down Expand Up @@ -96,7 +105,18 @@ def __init__(self, vocab_file: str, do_lower_case: bool, seq_len: int):
self._sep_id = vocab.index("[SEP]")
self._pad_id = vocab.index("[PAD]")

def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
def process_fn(self, input_tensor: tf.Tensor) -> Mapping[str, tf.Tensor]:
"""Tensor implementation of the process function.
This implementation can be used within a model graph directly since it
takes in tensors and outputs tensors.
Args:
input_tensor: Input string tensor
Returns:
Dictionary of tf.Tensors.
"""
input_ids = self._tokenizer.tokenize(input_tensor).flat_values
input_ids = input_ids[: (self._seq_len - 2)]
input_ids = tf.concat(
Expand All @@ -112,7 +132,20 @@ def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
input_type_ids = tf.zeros(self._seq_len, dtype=tf.int32)
input_mask = tf.cast(input_ids != self._pad_id, dtype=tf.int32)
return {
"input_word_ids": input_ids.numpy().tolist(),
"input_type_ids": input_type_ids.numpy().tolist(),
"input_mask": input_mask.numpy().tolist(),
"input_word_ids": input_ids,
"input_type_ids": input_type_ids,
"input_mask": input_mask,
}

def process(self, input_tensor: tf.Tensor) -> Mapping[str, Sequence[int]]:
"""Processes one input_tensor example.
Args:
input_tensor: A tensor with shape (1, None) of a utf-8 encoded string.
Returns:
A dictionary of lists all with shape (1, self._seq_len) containing the
keys "input_word_ids", "input_type_ids", and "input_mask".
"""
result = self.process_fn(input_tensor)
return {k: v.numpy().tolist() for k, v in result.items()}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text classifier export module library."""
import tensorflow as tf


class ModelWithTokenizer(tf.keras.Model):
"""A model with the tokenizer included in graph for exporting to TFLite."""

def __init__(self, tokenizer, model):
super().__init__()
self._tokenizer = tokenizer
self._model = model

@tf.function(
input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.string, name="input")
]
)
def call(self, input_tensor):
x = self._tokenizer.process_fn(input_tensor)
x = {k: tf.expand_dims(v, axis=0) for k, v in x.items()}
x = self._model(x)
return x
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
from unittest import mock as unittest_mock

import tensorflow as tf
import tensorflow_hub

from mediapipe.model_maker.python.core.utils import hub_loader
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import model_with_tokenizer


class BertTokenizerTest(tf.test.TestCase):
_SEQ_LEN = 128

def setUp(self):
super().setUp()
# Mock tempfile.gettempdir() to be unique for each test to avoid race
# condition when downloading model since these tests may run in parallel.
mock_gettempdir = unittest_mock.patch.object(
tempfile,
"gettempdir",
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
self._ms = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value()
self._tokenizer = self._create_tokenizer()
self._model = self._create_model()

def _create_tokenizer(self):
vocab_file = os.path.join(
tensorflow_hub.resolve(self._ms.get_path()), "assets", "vocab.txt"
)
return bert_tokenizer.BertFastTokenizer(vocab_file, True, self._SEQ_LEN)

def _create_model(self):
encoder_inputs = dict(
input_word_ids=tf.keras.layers.Input(
shape=(self._SEQ_LEN,),
dtype=tf.int32,
name="input_word_ids",
),
input_mask=tf.keras.layers.Input(
shape=(self._SEQ_LEN,),
dtype=tf.int32,
name="input_mask",
),
input_type_ids=tf.keras.layers.Input(
shape=(self._SEQ_LEN,),
dtype=tf.int32,
name="input_type_ids",
),
)
renamed_inputs = dict(
input_ids=encoder_inputs["input_word_ids"],
input_mask=encoder_inputs["input_mask"],
segment_ids=encoder_inputs["input_type_ids"],
)
encoder = hub_loader.HubKerasLayerV1V2(
self._ms.get_path(),
signature="tokens",
output_key="pooled_output",
trainable=True,
)
pooled_output = encoder(renamed_inputs)

output = tf.keras.layers.Dropout(rate=0.1)(pooled_output)
initializer = tf.keras.initializers.TruncatedNormal(stddev=0.02)
output = tf.keras.layers.Dense(
2,
kernel_initializer=initializer,
name="output",
activation="softmax",
dtype=tf.float32,
)(output)
return tf.keras.Model(inputs=encoder_inputs, outputs=output)

def test_model_with_tokenizer(self):
model = model_with_tokenizer.ModelWithTokenizer(
self._tokenizer, self._model
)
output = model(tf.constant(["Example input".encode("utf-8")]))
self.assertAllEqual(output.shape, (1, 2))
self.assertEqual(tf.reduce_sum(output), 1)


if __name__ == "__main__":
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,10 @@ def create_int_feature(values):
tfrecord_cache_files=tfrecord_cache_files,
)

@property
def tokenizer(self) -> bert_tokenizer.BertTokenizer:
return self._tokenizer


TextClassifierPreprocessor = Union[
BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
from mediapipe.model_maker.python.core.utils import metrics
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.text.text_classifier import bert_tokenizer
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import model_with_tokenizer
from mediapipe.model_maker.python.text.text_classifier import preprocessor
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
Expand Down Expand Up @@ -620,3 +622,56 @@ def _get_metadata_writer(self, tflite_model: bytearray, vocab_filepath: str):
ids_name=self._model_spec.tflite_input_name["ids"],
mask_name=self._model_spec.tflite_input_name["mask"],
segment_name=self._model_spec.tflite_input_name["segment_ids"])

def export_model_with_tokenizer(
self,
model_name: str = "model_with_tokenizer.tflite",
quantization_config: Optional[quantization.QuantizationConfig] = None,
):
"""Converts and saves the model to a TFLite file with the tokenizer.
Note that unlike the export_model method, this export method will include
a FastBertTokenizer in the TFLite graph. The resulting TFLite will not have
metadata information to use with MediaPipe Tasks, but can be run directly
using TFLite Inference: https://www.tensorflow.org/lite/guide/inference
For more information on the tokenizer, see:
https://www.tensorflow.org/text/api_docs/python/text/FastBertTokenizer
Args:
model_name: File name to save TFLite model with tokenizer. The full export
path is {self._hparams.export_dir}/{model_name}.
quantization_config: The configuration for model quantization.
"""
tf.io.gfile.makedirs(self._hparams.export_dir)
tflite_file = os.path.join(self._hparams.export_dir, model_name)
if (
self._hparams.tokenizer
!= bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER
):
print(
f"WARNING: This model was trained with {self._hparams.tokenizer} "
"tokenizer, but the exported model with tokenizer will have a "
f"{bert_tokenizer.SupportedBertTokenizers.FAST_BERT_TOKENIZER} "
"tokenizer."
)
tokenizer = bert_tokenizer.BertFastTokenizer(
vocab_file=self._text_preprocessor.get_vocab_file(),
do_lower_case=self._model_spec.do_lower_case,
seq_len=self._model_options.seq_len,
)
else:
tokenizer = self._text_preprocessor.tokenizer

model = model_with_tokenizer.ModelWithTokenizer(tokenizer, self._model)
model(tf.constant(["Example input data".encode("utf-8")])) # build model
saved_model_file = os.path.join(
self._hparams.export_dir, "saved_model_with_tokenizer"
)
model.save(saved_model_file)
tflite_model = model_util.convert_to_tflite_from_file(
saved_model_file,
quantization_config=quantization_config,
allow_custom_ops=True,
)
model_util.save_tflite(tflite_model, tflite_file)
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ def test_create_and_train_bert(self, supported_model):
output_metadata_file, self._BERT_CLASSIFIER_JSON_FILE, shallow=False
)
)
bert_classifier.export_model_with_tokenizer()
output_tflite_with_tokenizer_file = os.path.join(
options.hparams.export_dir, 'model_with_tokenizer.tflite'
)
self.assertTrue(os.path.exists(output_tflite_with_tokenizer_file))
self.assertGreater(os.path.getsize(output_tflite_with_tokenizer_file), 0)

def test_label_mismatch(self):
options = text_classifier.TextClassifierOptions(
Expand Down

0 comments on commit 573fdad

Please sign in to comment.