-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 634921095 Change-Id: Ib82e581bd5f36c334e14af6366050eb0f50d8e9f
- Loading branch information
1 parent
638fa44
commit 0e411ed
Showing
4 changed files
with
282 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright 2024 DeepMind Technologies Limited. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Pytorch Gemma Language Model, for models running on the local machine.""" | ||
|
||
from collections.abc import Collection, Sequence | ||
import os | ||
import re | ||
|
||
from concordia.language_model import language_model | ||
from concordia.utils import measurements as measurements_lib | ||
import numpy as np | ||
import transformers | ||
|
||
from typing_extensions import override | ||
|
||
|
||
def _extract_choices(text): | ||
match = re.search(r'\(?(\w)\)', text) | ||
if match: | ||
return match.group(1) | ||
return None | ||
|
||
|
||
class PyTorchGemmaLanguageModel(language_model.LanguageModel): | ||
"""Pytorch Language Model API, for models running on the local machine.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
# The default model is the 2 billion parameter instruction-tuned Gemma. | ||
model_name: str = 'google/gemma-2b-it', | ||
measurements: measurements_lib.Measurements | None = None, | ||
channel: str = language_model.DEFAULT_STATS_CHANNEL, | ||
) -> None: | ||
"""Initializes the instance. | ||
Args: | ||
model_name: The local language model to use. For more details, | ||
see transformers.AutoModelForCausalLM at huggingface. | ||
measurements: The measurements object to log usage statistics to. | ||
channel: The channel to write the statistics to. | ||
""" | ||
self._model_name = model_name | ||
self._tokenizer_name = model_name | ||
|
||
os.environ['TOKENIZERS_PARALLELISM'] = 'false' | ||
|
||
self._model = transformers.GemmaForCausalLM.from_pretrained( | ||
self._model_name) | ||
self._tokenizer = transformers.AutoTokenizer.from_pretrained( | ||
self._tokenizer_name) | ||
|
||
self._measurements = measurements | ||
self._channel = channel | ||
|
||
self._text_system_message = ( | ||
'You always continue sentences provided by the user and you never ' + | ||
'repeat what the user already said.') | ||
|
||
@override | ||
def sample_text( | ||
self, | ||
prompt: str, | ||
*, | ||
max_tokens: int = language_model.DEFAULT_MAX_TOKENS, | ||
max_characters: int = language_model.DEFAULT_MAX_CHARACTERS, | ||
terminators: Collection[str] = language_model.DEFAULT_TERMINATORS, | ||
temperature: float = language_model.DEFAULT_TEMPERATURE, | ||
timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS, | ||
seed: int | None = None, | ||
) -> str: | ||
del temperature, timeout, seed # Unused. | ||
|
||
prompt_with_system_message = f'{self._text_system_message}\n\n{prompt}' | ||
prompt_length = len(prompt_with_system_message) | ||
|
||
inputs = self._tokenizer(prompt_with_system_message, return_tensors='pt') | ||
|
||
generated_tokens = self._model.generate( | ||
inputs.input_ids, | ||
max_new_tokens=max_tokens, | ||
return_dict_in_generate=True, | ||
output_scores=True, | ||
) | ||
|
||
response = self._tokenizer.decode( | ||
np.int64(generated_tokens.sequences[0]), | ||
skip_special_tokens=True, | ||
clean_up_tokenization_spaces=False, | ||
) | ||
response = response[prompt_length:] | ||
|
||
response = response[:max_characters] | ||
# It would be better to implement terminators in the model generation, but | ||
# this is a quick way to implement our API for now. | ||
for terminator in terminators: | ||
response = response[:response.find(terminator)] | ||
|
||
if self._measurements is not None: | ||
self._measurements.publish_datum( | ||
self._channel, {'raw_text_length': len(response)} | ||
) | ||
return response | ||
|
||
@override | ||
def sample_choice( | ||
self, | ||
prompt: str, | ||
responses: Sequence[str], | ||
*, | ||
seed: int | None = None, | ||
) -> tuple[int, str, dict[str, float]]: | ||
del seed # Unused. | ||
|
||
inputs = self._tokenizer(prompt, return_tensors='pt') | ||
generated_tokens = self._model.generate( | ||
inputs.input_ids, | ||
max_new_tokens=1, | ||
return_dict_in_generate=True, | ||
output_scores=True, | ||
) | ||
sample = self._tokenizer.batch_decode( | ||
[np.argmax(generated_tokens.scores[0][0])], | ||
skip_special_tokens=True, | ||
clean_up_tokenization_spaces=False)[0] | ||
if len(sample) == 1: | ||
# i.e. this would be a sample such as "a" | ||
answer = sample | ||
elif len(sample) == 2: | ||
# i.e. this would be a sample such as "a)" | ||
answer = sample[0] | ||
else: | ||
# extract a substring like "(a)" wherever it may be in a longer string | ||
answer = _extract_choices(sample) | ||
try: | ||
idx = responses.index(answer) | ||
print(f'sample: {sample}, response: {idx}') | ||
except ValueError: | ||
raise language_model.InvalidResponseError( | ||
f'Invalid response: {answer}. ' | ||
f'LLM Input: {prompt}\nLLM Output: {sample}' | ||
) from None | ||
|
||
if self._measurements is not None: | ||
self._measurements.publish_datum(self._channel, {'choices_calls': 1}) | ||
debug = {} | ||
return idx, responses[idx], debug |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "61PmYiKzWgz4" | ||
}, | ||
"source": [ | ||
"# Illustrate how to use a PyTorch Gemma model running locally.\n", | ||
"\n", | ||
"Note: This will download a 2 billion parameter model from Hugging Face, so make sure you have enough space for that." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "evAcCqotVPWY" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install --ignore-requires-python git+https://github.com/google-deepmind/concordia.git" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "WvRB61fWVRpW" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install -r https://raw.githubusercontent.com/google-deepmind/concordia/main/examples/requirements.txt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "AVD1XXzoU5-o" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from concordia.language_model import pytorch_gemma_model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "e28CPlxZViAQ" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# This will download the model from Hugging Face.\n", | ||
"model = pytorch_gemma_model.PyTorchGemmaLanguageModel(\n", | ||
" model_name='google/gemma-2b-it',\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "DT8QPSHDV36C" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"choice_response = model.sample_choice('What comes next? a,b,c,d,e,f,',\n", | ||
" responses=['d', 'g', 'z'])\n", | ||
"choice_response[1]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "qDKp1LeWWUfZ" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"text_response = model.sample_text('What is the meaning of life?',\n", | ||
" max_tokens=40)\n", | ||
"text_response" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "RMj2IZffYYh7" | ||
}, | ||
"source": [ | ||
"```\n", | ||
"Copyright 2023 DeepMind Technologies Limited.\n", | ||
"\n", | ||
"Licensed under the Apache License, Version 2.0 (the \"License\");\n", | ||
"you may not use this file except in compliance with the License.\n", | ||
"You may obtain a copy of the License at\n", | ||
"\n", | ||
" https://www.apache.org/licenses/LICENSE-2.0\n", | ||
"\n", | ||
"Unless required by applicable law or agreed to in writing, software\n", | ||
"distributed under the License is distributed on an \"AS IS\" BASIS,\n", | ||
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", | ||
"See the License for the specific language governing permissions and\n", | ||
"limitations under the License.\n", | ||
"```" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"private_outputs": true, | ||
"provenance": [], | ||
"toc_visible": true | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,6 +65,7 @@ | |
'retry', | ||
'scipy', | ||
'termcolor', | ||
'transformers', | ||
'typing-extensions', | ||
), | ||
extras_require={ | ||
|