From 303be7d20d03f7877d81a23126773faf8d62722f Mon Sep 17 00:00:00 2001 From: erikkaum Date: Wed, 31 Jul 2024 17:42:17 +0200 Subject: [PATCH] Incorporate changes from main --- backends/client/src/v2/client.rs | 1 + backends/client/src/v2/sharded_client.rs | 1 + backends/client/src/v3/client.rs | 1 + backends/client/src/v3/sharded_client.rs | 1 + backends/v3/src/queue.rs | 2 + benchmark/src/lib.rs | 3 + benchmark/src/main.rs | 7 +++ benchmark/src/table.rs | 2 + clients/python/README.md | 4 ++ docs/openapi.json | 24 ++++++++ .../models/test_no_repeat_ngram.py | 57 +++++++++++++++++++ proto/generate.proto | 2 + proto/v3/generate.proto | 2 + router/src/infer/v2/queue.rs | 2 + router/src/lib.rs | 22 +++++++ router/src/server.rs | 2 + router/src/validation.rs | 6 ++ .../utils/logits_process.py | 15 +++-- server/text_generation_server/utils/tokens.py | 32 ++++++++++- 19 files changed, 177 insertions(+), 9 deletions(-) create mode 100644 integration-tests/models/test_no_repeat_ngram.py diff --git a/backends/client/src/v2/client.rs b/backends/client/src/v2/client.rs index 9a2e6ac79f9..2fd5df1231d 100644 --- a/backends/client/src/v2/client.rs +++ b/backends/client/src/v2/client.rs @@ -143,6 +143,7 @@ impl Client { seed: 0, repetition_penalty: 1.2, frequency_penalty: 0.1, + no_repeat_ngram_size: 0, watermark: true, grammar: String::new(), grammar_type: GrammarType::None as i32, diff --git a/backends/client/src/v2/sharded_client.rs b/backends/client/src/v2/sharded_client.rs index 7b24aec3634..b53deefe25d 100644 --- a/backends/client/src/v2/sharded_client.rs +++ b/backends/client/src/v2/sharded_client.rs @@ -228,6 +228,7 @@ impl Health for ShardedClient { seed: 0, repetition_penalty: 1.0, frequency_penalty: 0.0, + no_repeat_ngram_size: 0, watermark: false, grammar: String::new(), grammar_type: GrammarType::None as i32, diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index a996b14fae8..1738cd917e6 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -166,6 +166,7 @@ impl Client { seed: 0, repetition_penalty: 1.2, frequency_penalty: 0.1, + no_repeat_ngram_size: 0, watermark: true, grammar: String::new(), grammar_type: GrammarType::None as i32, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index ae8a899b38a..2a94da8de9c 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -231,6 +231,7 @@ impl Health for ShardedClient { seed: 0, repetition_penalty: 1.0, frequency_penalty: 0.0, + no_repeat_ngram_size: 0, watermark: false, grammar: String::new(), grammar_type: GrammarType::None as i32, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 9427bd60c57..f14714a59f0 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -454,6 +454,7 @@ impl From for NextTokenChooserParameters { seed: value.seed, repetition_penalty: value.repetition_penalty, frequency_penalty: value.frequency_penalty, + no_repeat_ngram_size: value.no_repeat_ngram_size, watermark: value.watermark, grammar, grammar_type: grammar_type.into(), @@ -497,6 +498,7 @@ mod tests { seed: 0, repetition_penalty: 0.0, frequency_penalty: 0.0, + no_repeat_ngram_size: 0, watermark: false, grammar: None, }, diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index c33d64e673e..4afce9514a8 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -31,6 +31,7 @@ pub async fn run( typical_p: Option, repetition_penalty: Option, frequency_penalty: Option, + no_repeat_ngram_size: Option, watermark: bool, do_sample: bool, client: ShardedClient, @@ -44,6 +45,7 @@ pub async fn run( seed: 0, repetition_penalty: repetition_penalty.unwrap_or(1.0), frequency_penalty: frequency_penalty.unwrap_or(0.0), + no_repeat_ngram_size: no_repeat_ngram_size.unwrap_or(0), watermark, grammar: String::new(), grammar_type: GrammarType::None as i32, @@ -145,6 +147,7 @@ pub async fn run( typical_p, repetition_penalty, frequency_penalty, + no_repeat_ngram_size, watermark, do_sample, ); diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 2ee3d7c551a..3eece4eb810 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -89,6 +89,11 @@ struct Args { #[clap(long, env)] frequency_penalty: Option, + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + no_repeat_ngram_size: Option, + /// Generation parameter in case you want to specifically test/debug particular /// decoding strategies, for full doc refer to the `text-generation-server` #[clap(long, env)] @@ -125,6 +130,7 @@ fn main() -> Result<(), Box> { typical_p, repetition_penalty, frequency_penalty, + no_repeat_ngram_size, watermark, do_sample, master_shard_uds_path, @@ -196,6 +202,7 @@ fn main() -> Result<(), Box> { typical_p, repetition_penalty, frequency_penalty, + no_repeat_ngram_size, watermark, do_sample, sharded_client, diff --git a/benchmark/src/table.rs b/benchmark/src/table.rs index 1585a25f4fc..3a70f68d24e 100644 --- a/benchmark/src/table.rs +++ b/benchmark/src/table.rs @@ -16,6 +16,7 @@ pub(crate) fn parameters_table( typical_p: Option, repetition_penalty: Option, frequency_penalty: Option, + no_repeat_ngram_size: Option, watermark: bool, do_sample: bool, ) -> Table { @@ -35,6 +36,7 @@ pub(crate) fn parameters_table( builder.push_record(["Typical P", &format!("{typical_p:?}")]); builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]); builder.push_record(["Frequency Penalty", &format!("{frequency_penalty:?}")]); + builder.push_record(["No Repeat Ngram Size", &format!("{no_repeat_ngram_size:?}")]); builder.push_record(["Watermark", &watermark.to_string()]); builder.push_record(["Do Sample", &do_sample.to_string()]); diff --git a/clients/python/README.md b/clients/python/README.md index 88239aa16cc..f611a25c0b6 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -135,6 +135,10 @@ class Parameters: # Penalize new tokens based on their existing frequency in the text so far, # decreasing the model's likelihood to repeat the same line verbatim. frequency_penalty: Optional[float] + # n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the + # sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid + # generating the same n-grams in the completion. + no_repeat_ngram_size: Optional[int] # Whether to prepend the prompt to the generated text return_full_text: bool # Stop generating tokens if a member of `stop_sequences` is generated diff --git a/docs/openapi.json b/docs/openapi.json index ed9b0b961c1..f367d55a7df 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -864,6 +864,14 @@ "nullable": true, "minimum": 0 }, + "no_repeat_ngram_size": { + "type": "integer", + "format": "int32", + "description": "n-grams are groups of \"n\" consecutive words, characters, or tokens taken from a sequence of text. Given the\nsentence: \"She runs fast\", the bi-grams (n=2) would be (\"she\", \"runs\") and (\"runs\", \"fast\"). Set this to avoid\ngenerating the same n-grams in the completion.", + "example": "12", + "nullable": true, + "minimum": 0 + }, "presence_penalty": { "type": "number", "format": "float", @@ -1140,6 +1148,14 @@ "example": "mistralai/Mistral-7B-Instruct-v0.2", "nullable": true }, + "no_repeat_ngram_size": { + "type": "integer", + "format": "int32", + "description": "n-grams are groups of \"n\" consecutive words, characters, or tokens taken from a sequence of text. Given the\nsentence: \"She runs fast\", the bi-grams (n=2) would be (\"she\", \"runs\") and (\"runs\", \"fast\"). Set this to avoid\ngenerating the same n-grams in the completion.", + "example": "12", + "nullable": true, + "minimum": 0 + }, "prompt": { "$ref": "#/components/schemas/Prompt" }, @@ -1397,6 +1413,14 @@ "nullable": true, "minimum": 0 }, + "no_repeat_ngram_size": { + "type": "integer", + "format": "int32", + "description": "n-grams are groups of \"n\" consecutive words, characters, or tokens taken from a sequence of text. Given the\nsentence: \"She runs fast\", the bi-grams (n=2) would be (\"she\", \"runs\") and (\"runs\", \"fast\"). Set this to avoid\ngenerating the same n-grams in the completion.", + "example": "12", + "nullable": true, + "minimum": 0 + }, "repetition_penalty": { "type": "number", "format": "float", diff --git a/integration-tests/models/test_no_repeat_ngram.py b/integration-tests/models/test_no_repeat_ngram.py new file mode 100644 index 00000000000..281812f99d2 --- /dev/null +++ b/integration-tests/models/test_no_repeat_ngram.py @@ -0,0 +1,57 @@ +from typing import List +import pytest +import requests + +def bloom_560_handle(launcher): + with launcher("bigscience/bloom-560m") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def bloom_560(bloom_560_handle): + await bloom_560_handle.health(240) + return bloom_560_handle.client + +@pytest.mark.release +@pytest.mark.asyncio +async def test_bloom_560m(bloom_560): + + base_url = bloom_560.base_url + prompt = "The cat sat on the mat. The cat" + + repeated_2grams_control = await call_model(base_url, prompt, 0) + assert len(repeated_2grams_control) > 0, "Expected to find repeated bi-grams in control case" + + repeated_2grams_test = await call_model(base_url, prompt, 2) + assert len(repeated_2grams_test) == 0, f"Expected no repeated bi-grams, but found: {repeated_2grams_test}" + + +async def call_model(base_url, prompt, n_grams): + data = { + "inputs": prompt, + "parameters" : { + "max_new_tokens": 20, + "seed": 42, + "no_repeat_ngram_size": n_grams, + "details": True + } + } + res = requests.post(f"{base_url}/generate", json=data) + res = res.json() + + tokens = res['details']['tokens'] + token_texts = [token['text'] for token in tokens] + + # find repeated 2grams + ngrams = [tuple(token_texts[i:i+2]) for i in range(len(token_texts)-2+1)] + ngram_counts = {} + for ngram in ngrams: + if ngram in ngram_counts: + ngram_counts[ngram] += 1 + else: + ngram_counts[ngram] = 1 + + repeated = [list(ngram) for ngram, count in ngram_counts.items() if count > 1] + + return repeated + diff --git a/proto/generate.proto b/proto/generate.proto index 6351e37f2c9..68cd5fc9b0e 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -74,6 +74,8 @@ message NextTokenChooserParameters { float repetition_penalty = 7; /// frequency penalty float frequency_penalty = 9; + /// no_repeat_ngram_size + uint32 no_repeat_ngram_size = 12; /// token watermarking using "A Watermark for Large Language Models" bool watermark = 8; /// grammar (applied if not empty) diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 926c878ea44..9979ea332ac 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -95,6 +95,8 @@ message NextTokenChooserParameters { float repetition_penalty = 7; /// frequency penalty float frequency_penalty = 9; + /// no_repeat_ngram_size + uint32 no_repeat_ngram_size = 12; /// token watermarking using "A Watermark for Large Language Models" bool watermark = 8; /// grammar (applied if not empty) diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 0b51645a84b..a4e7e0f1d5b 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -377,6 +377,7 @@ impl From for NextTokenChooserParameters { seed: value.seed, repetition_penalty: value.repetition_penalty, frequency_penalty: value.frequency_penalty, + no_repeat_ngram_size: value.no_repeat_ngram_size, watermark: value.watermark, grammar, grammar_type: grammar_type.into(), @@ -420,6 +421,7 @@ mod tests { seed: 0, repetition_penalty: 0.0, frequency_penalty: 0.0, + no_repeat_ngram_size: 0, watermark: false, grammar: None, }, diff --git a/router/src/lib.rs b/router/src/lib.rs index 14bb8270d18..c12d421ba6c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -223,6 +223,13 @@ pub(crate) struct GenerateParameters { )] pub frequency_penalty: Option, + /// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the + /// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid + /// generating the same n-grams in the completion. + #[serde(default)] + #[schema(nullable = true, example = "12")] + pub no_repeat_ngram_size: Option, + /// The number of highest probability vocabulary tokens to keep for top-k-filtering. #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] @@ -327,6 +334,7 @@ fn default_parameters() -> GenerateParameters { temperature: None, repetition_penalty: None, frequency_penalty: None, + no_repeat_ngram_size: None, top_k: None, top_p: None, typical_p: None, @@ -424,6 +432,13 @@ pub struct CompletionRequest { #[schema(example = "1.0")] pub frequency_penalty: Option, + /// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the + /// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid + /// generating the same n-grams in the completion. + #[serde(default)] + #[schema(nullable = true, example = "12")] + pub no_repeat_ngram_size: Option, + /// Up to 4 sequences where the API will stop generating further tokens. #[serde(default)] #[schema(nullable = true, example = "null")] @@ -740,6 +755,13 @@ pub(crate) struct ChatRequest { #[schema(example = "1.0")] pub frequency_penalty: Option, + /// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the + /// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid + /// generating the same n-grams in the completion. + #[serde(default)] + #[schema(nullable = true, example = "12")] + pub no_repeat_ngram_size: Option, + /// UNUSED /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, diff --git a/router/src/server.rs b/router/src/server.rs index dcbaa2ada7a..2972f820f9b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -654,6 +654,7 @@ async fn completions( temperature, repetition_penalty: req.repetition_penalty, frequency_penalty: req.frequency_penalty, + no_repeat_ngram_size: req.no_repeat_ngram_size, top_k: None, top_p: req.top_p, typical_p: None, @@ -1100,6 +1101,7 @@ async fn chat_completions( temperature, repetition_penalty, frequency_penalty: req.frequency_penalty, + no_repeat_ngram_size: req.no_repeat_ngram_size, top_k: None, top_p: req.top_p, typical_p: None, diff --git a/router/src/validation.rs b/router/src/validation.rs index 3d1a4103fd7..fcb3640115c 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -195,6 +195,7 @@ impl Validation { temperature, repetition_penalty, frequency_penalty, + no_repeat_ngram_size, top_k, top_p, typical_p, @@ -238,6 +239,8 @@ impl Validation { return Err(ValidationError::FrequencyPenalty); } + let no_repeat_ngram_size = no_repeat_ngram_size.unwrap_or(0); + // Different because the proto default value is not a valid value // for the user let top_p = top_p @@ -373,6 +376,7 @@ impl Validation { temperature, repetition_penalty, frequency_penalty, + no_repeat_ngram_size, top_k, top_p, typical_p, @@ -687,6 +691,8 @@ pub struct ValidParameters { pub repetition_penalty: f32, /// / frequency penalty pub frequency_penalty: f32, + /// / no_repeat_ngram_size + pub no_repeat_ngram_size: u32, /// / token watermarking using "A Watermark for Large Language Models" pub watermark: bool, /// / grammar (applied if not empty) diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 9abd886f250..d8c7c2c6ae9 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -20,17 +20,13 @@ TypicalLogitsWarper, ) +from transformers.generation.logits_process import _calc_banned_ngram_tokens + mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None class StaticWarper: - def __init__( - self, - temperature=1.0, - top_k=None, - top_p=None, - typical_p=None, - ): + def __init__(self, temperature=1.0, top_k=None, top_p=None, typical_p=None): self.warpers = [] if temperature is not None and temperature != 1.0: @@ -84,7 +80,10 @@ def static_warper( typical_p: Optional[float], ) -> StaticWarper: return StaticWarper( - temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p + temperature=temperature, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 9ab49665a75..3766041e9c4 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -18,7 +18,11 @@ static_warper, ) from text_generation_server.utils.watermark import WatermarkLogitsProcessor -from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor +from transformers import ( + PreTrainedTokenizerBase, + RepetitionPenaltyLogitsProcessor, + NoRepeatNGramLogitsProcessor, +) class NextTokenChooser: @@ -28,6 +32,7 @@ def __init__( temperature: float = 1.0, repetition_penalty: float = 1.0, frequency_penalty: float = 0.0, + no_repeat_ngram_size: Optional[int] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, typical_p: Optional[float] = None, @@ -57,6 +62,12 @@ def __init__( if grammar != "" else None ) + + self.no_repeat_ngram_processor = ( + NoRepeatNGramLogitsProcessor(no_repeat_ngram_size) + if no_repeat_ngram_size and no_repeat_ngram_size > 0 + else None + ) self.tokenizer = tokenizer has_warpers = ( @@ -87,6 +98,8 @@ def __call__(self, input_ids, scores): scores = self.frequency_processor(input_ids, scores) if self.grammar_processor is not None: scores = self.grammar_processor(scores, self.fsm_grammar_state) + if self.no_repeat_ngram_processor is not None: + scores = self.no_repeat_ngram_processor(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -116,6 +129,7 @@ def from_pb( temperature=pb.temperature, repetition_penalty=pb.repetition_penalty, frequency_penalty=pb.frequency_penalty, + no_repeat_ngram_size=pb.no_repeat_ngram_size, top_k=pb.top_k, top_p=pb.top_p, typical_p=pb.typical_p, @@ -239,6 +253,7 @@ def __init__( temperature: List[float], repetition_penalty: List[float], frequency_penalty: List[float], + no_repeat_ngram_size: List[int], top_k: List[int], top_p: List[float], typical_p: List[float], @@ -287,6 +302,18 @@ def __init__( else None ) + self.no_repeat_ngram_processor = ( + HeterogeneousProcessorWrapper( + { + i: NoRepeatNGramLogitsProcessor(n) + for i, n in enumerate(no_repeat_ngram_size) + if n > 0 + } + ) + if any([n > 0 for n in no_repeat_ngram_size]) + else None + ) + if any(x != 1.0 for x in temperature): do_sample = [ sample or x != 1.0 for x, sample in zip(temperature, do_sample) @@ -353,6 +380,8 @@ def __call__( _scores = self.frequency_processor(input_ids, _scores) if self.grammar_processor is not None: _scores = self.grammar_processor(_scores, self.fsm_grammar_states) + if self.no_repeat_ngram_processor is not None: + _scores = self.no_repeat_ngram_processor(input_ids, _scores) for warper in self.warpers: _scores = warper(input_ids, _scores) _next_ids = self.choice(_scores) @@ -487,6 +516,7 @@ def from_pb( temperature=[pb_.temperature for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb], frequency_penalty=[pb_.frequency_penalty for pb_ in pb], + no_repeat_ngram_size=[pb_.no_repeat_ngram_size for pb_ in pb], top_k=[pb_.top_k for pb_ in pb], top_p=[pb_.top_p for pb_ in pb], typical_p=[pb_.typical_p for pb_ in pb],