Skip to content

Commit

Permalink
Add PyTorch Gemma Language Model
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634921095
Change-Id: Ib82e581bd5f36c334e14af6366050eb0f50d8e9f
  • Loading branch information
jzleibo authored and copybara-github committed May 17, 2024
1 parent 638fa44 commit 0e411ed
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 2 deletions.
4 changes: 2 additions & 2 deletions concordia/language_model/ollama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Ollama Language Model."""
"""Ollama Language Model, for models running on the local machine."""

from collections.abc import Collection, Sequence
import re
Expand All @@ -31,7 +31,7 @@ def _extract_choices(text):


class OllamaLanguageModel(language_model.LanguageModel):
"""Language Model that uses Ollama LLM models."""
"""Language Model that uses Ollama LLM models running on the local machine."""

def __init__(
self,
Expand Down
159 changes: 159 additions & 0 deletions concordia/language_model/pytorch_gemma_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pytorch Gemma Language Model, for models running on the local machine."""

from collections.abc import Collection, Sequence
import os
import re

from concordia.language_model import language_model
from concordia.utils import measurements as measurements_lib
import numpy as np
import transformers

from typing_extensions import override


def _extract_choices(text):
match = re.search(r'\(?(\w)\)', text)
if match:
return match.group(1)
return None


class PyTorchGemmaLanguageModel(language_model.LanguageModel):
"""Pytorch Language Model API, for models running on the local machine."""

def __init__(
self,
*,
# The default model is the 2 billion parameter instruction-tuned Gemma.
model_name: str = 'google/gemma-2b-it',
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
) -> None:
"""Initializes the instance.
Args:
model_name: The local language model to use. For more details,
see transformers.AutoModelForCausalLM at huggingface.
measurements: The measurements object to log usage statistics to.
channel: The channel to write the statistics to.
"""
self._model_name = model_name
self._tokenizer_name = model_name

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

self._model = transformers.GemmaForCausalLM.from_pretrained(
self._model_name)
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
self._tokenizer_name)

self._measurements = measurements
self._channel = channel

self._text_system_message = (
'You always continue sentences provided by the user and you never ' +
'repeat what the user already said.')

@override
def sample_text(
self,
prompt: str,
*,
max_tokens: int = language_model.DEFAULT_MAX_TOKENS,
max_characters: int = language_model.DEFAULT_MAX_CHARACTERS,
terminators: Collection[str] = language_model.DEFAULT_TERMINATORS,
temperature: float = language_model.DEFAULT_TEMPERATURE,
timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS,
seed: int | None = None,
) -> str:
del temperature, timeout, seed # Unused.

prompt_with_system_message = f'{self._text_system_message}\n\n{prompt}'
prompt_length = len(prompt_with_system_message)

inputs = self._tokenizer(prompt_with_system_message, return_tensors='pt')

generated_tokens = self._model.generate(
inputs.input_ids,
max_new_tokens=max_tokens,
return_dict_in_generate=True,
output_scores=True,
)

response = self._tokenizer.decode(
np.int64(generated_tokens.sequences[0]),
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
response = response[prompt_length:]

response = response[:max_characters]
# It would be better to implement terminators in the model generation, but
# this is a quick way to implement our API for now.
for terminator in terminators:
response = response[:response.find(terminator)]

if self._measurements is not None:
self._measurements.publish_datum(
self._channel, {'raw_text_length': len(response)}
)
return response

@override
def sample_choice(
self,
prompt: str,
responses: Sequence[str],
*,
seed: int | None = None,
) -> tuple[int, str, dict[str, float]]:
del seed # Unused.

inputs = self._tokenizer(prompt, return_tensors='pt')
generated_tokens = self._model.generate(
inputs.input_ids,
max_new_tokens=1,
return_dict_in_generate=True,
output_scores=True,
)
sample = self._tokenizer.batch_decode(
[np.argmax(generated_tokens.scores[0][0])],
skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0]
if len(sample) == 1:
# i.e. this would be a sample such as "a"
answer = sample
elif len(sample) == 2:
# i.e. this would be a sample such as "a)"
answer = sample[0]
else:
# extract a substring like "(a)" wherever it may be in a longer string
answer = _extract_choices(sample)
try:
idx = responses.index(answer)
print(f'sample: {sample}, response: {idx}')
except ValueError:
raise language_model.InvalidResponseError(
f'Invalid response: {answer}. '
f'LLM Input: {prompt}\nLLM Output: {sample}'
) from None

if self._measurements is not None:
self._measurements.publish_datum(self._channel, {'choices_calls': 1})
debug = {}
return idx, responses[idx], debug
120 changes: 120 additions & 0 deletions examples/pytorch_gemma_local.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "61PmYiKzWgz4"
},
"source": [
"# Illustrate how to use a PyTorch Gemma model running locally.\n",
"\n",
"Note: This will download a 2 billion parameter model from Hugging Face, so make sure you have enough space for that."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "evAcCqotVPWY"
},
"outputs": [],
"source": [
"!pip install --ignore-requires-python git+https://github.com/google-deepmind/concordia.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WvRB61fWVRpW"
},
"outputs": [],
"source": [
"!pip install -r https://raw.githubusercontent.com/google-deepmind/concordia/main/examples/requirements.txt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AVD1XXzoU5-o"
},
"outputs": [],
"source": [
"from concordia.language_model import pytorch_gemma_model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e28CPlxZViAQ"
},
"outputs": [],
"source": [
"# This will download the model from Hugging Face.\n",
"model = pytorch_gemma_model.PyTorchGemmaLanguageModel(\n",
" model_name='google/gemma-2b-it',\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DT8QPSHDV36C"
},
"outputs": [],
"source": [
"choice_response = model.sample_choice('What comes next? a,b,c,d,e,f,',\n",
" responses=['d', 'g', 'z'])\n",
"choice_response[1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qDKp1LeWWUfZ"
},
"outputs": [],
"source": [
"text_response = model.sample_text('What is the meaning of life?',\n",
" max_tokens=40)\n",
"text_response"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RMj2IZffYYh7"
},
"source": [
"```\n",
"Copyright 2023 DeepMind Technologies Limited.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License.\n",
"```"
]
}
],
"metadata": {
"colab": {
"private_outputs": true,
"provenance": [],
"toc_visible": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
'retry',
'scipy',
'termcolor',
'transformers',
'typing-extensions',
),
extras_require={
Expand Down

0 comments on commit 0e411ed

Please sign in to comment.