Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

no repeat ngram size ci #2308

Closed
wants to merge 16 commits into from
1 change: 1 addition & 0 deletions backends/client/src/v2/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ impl Client {
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
no_repeat_ngram_size: 0,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
Expand Down
1 change: 1 addition & 0 deletions backends/client/src/v2/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ impl Health for ShardedClient {
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
no_repeat_ngram_size: 0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
Expand Down
1 change: 1 addition & 0 deletions backends/client/src/v3/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ impl Client {
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
no_repeat_ngram_size: 0,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
Expand Down
1 change: 1 addition & 0 deletions backends/client/src/v3/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ impl Health for ShardedClient {
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
no_repeat_ngram_size: 0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
Expand Down
1 change: 1 addition & 0 deletions backends/v3/src/client/grpc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ impl Client {
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
no_repeat_ngram_size: 0,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
Expand Down
1 change: 1 addition & 0 deletions backends/v3/src/client/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ impl Health for ShardedClient {
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
no_repeat_ngram_size: 0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
Expand Down
2 changes: 2 additions & 0 deletions backends/v3/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ impl From<ValidParameters> for NextTokenChooserParameters {
seed: value.seed,
repetition_penalty: value.repetition_penalty,
frequency_penalty: value.frequency_penalty,
no_repeat_ngram_size: value.no_repeat_ngram_size,
watermark: value.watermark,
grammar,
grammar_type: grammar_type.into(),
Expand Down Expand Up @@ -497,6 +498,7 @@ mod tests {
seed: 0,
repetition_penalty: 0.0,
frequency_penalty: 0.0,
no_repeat_ngram_size: 0,
watermark: false,
grammar: None,
},
Expand Down
3 changes: 3 additions & 0 deletions benchmark/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub async fn run(
typical_p: Option<f32>,
repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
no_repeat_ngram_size: Option<u32>,
watermark: bool,
do_sample: bool,
client: ShardedClient,
Expand All @@ -44,6 +45,7 @@ pub async fn run(
seed: 0,
repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0),
no_repeat_ngram_size: no_repeat_ngram_size.unwrap_or(0),
watermark,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
Expand Down Expand Up @@ -145,6 +147,7 @@ pub async fn run(
typical_p,
repetition_penalty,
frequency_penalty,
no_repeat_ngram_size,
watermark,
do_sample,
);
Expand Down
7 changes: 7 additions & 0 deletions benchmark/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ struct Args {
#[clap(long, env)]
frequency_penalty: Option<f32>,

/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
no_repeat_ngram_size: Option<u32>,
ErikKaum marked this conversation as resolved.
Show resolved Hide resolved

/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
Expand Down Expand Up @@ -125,6 +130,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
typical_p,
repetition_penalty,
frequency_penalty,
no_repeat_ngram_size,
watermark,
do_sample,
master_shard_uds_path,
Expand Down Expand Up @@ -196,6 +202,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
typical_p,
repetition_penalty,
frequency_penalty,
no_repeat_ngram_size,
watermark,
do_sample,
sharded_client,
Expand Down
2 changes: 2 additions & 0 deletions benchmark/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub(crate) fn parameters_table(
typical_p: Option<f32>,
repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
no_repeat_ngram_size: Option<u32>,
watermark: bool,
do_sample: bool,
) -> Table {
Expand All @@ -35,6 +36,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Typical P", &format!("{typical_p:?}")]);
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
builder.push_record(["Frequency Penalty", &format!("{frequency_penalty:?}")]);
builder.push_record(["No Repeat Ngram Size", &format!("{no_repeat_ngram_size:?}")]);
builder.push_record(["Watermark", &watermark.to_string()]);
builder.push_record(["Do Sample", &do_sample.to_string()]);

Expand Down
4 changes: 4 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ class Parameters:
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float]
# n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
# sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid
# generating the same n-grams in the completion.
no_repeat_ngram_size: Optional[int]
# Whether to prepend the prompt to the generated text
return_full_text: bool
# Stop generating tokens if a member of `stop_sequences` is generated
Expand Down
24 changes: 24 additions & 0 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,14 @@
"nullable": true,
"minimum": 0
},
"no_repeat_ngram_size": {
"type": "integer",
"format": "int32",
"description": "n-grams are groups of \"n\" consecutive words, characters, or tokens taken from a sequence of text. Given the\nsentence: \"She runs fast\", the bi-grams (n=2) would be (\"she\", \"runs\") and (\"runs\", \"fast\"). Set this to avoid\ngenerating the same n-grams in the completion.",
"example": "12",
"nullable": true,
"minimum": 0
},
"presence_penalty": {
"type": "number",
"format": "float",
Expand Down Expand Up @@ -1140,6 +1148,14 @@
"example": "mistralai/Mistral-7B-Instruct-v0.2",
"nullable": true
},
"no_repeat_ngram_size": {
"type": "integer",
"format": "int32",
"description": "n-grams are groups of \"n\" consecutive words, characters, or tokens taken from a sequence of text. Given the\nsentence: \"She runs fast\", the bi-grams (n=2) would be (\"she\", \"runs\") and (\"runs\", \"fast\"). Set this to avoid\ngenerating the same n-grams in the completion.",
"example": "12",
"nullable": true,
"minimum": 0
},
"prompt": {
"$ref": "#/components/schemas/Prompt"
},
Expand Down Expand Up @@ -1397,6 +1413,14 @@
"nullable": true,
"minimum": 0
},
"no_repeat_ngram_size": {
"type": "integer",
"format": "int32",
"description": "n-grams are groups of \"n\" consecutive words, characters, or tokens taken from a sequence of text. Given the\nsentence: \"She runs fast\", the bi-grams (n=2) would be (\"she\", \"runs\") and (\"runs\", \"fast\"). Set this to avoid\ngenerating the same n-grams in the completion.",
"example": "12",
"nullable": true,
"minimum": 0
},
"repetition_penalty": {
"type": "number",
"format": "float",
Expand Down
61 changes: 61 additions & 0 deletions integration-tests/models/test_no_repeat_ngram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
import requests


@pytest.fixture(scope="module")
def bloom_560_handle(launcher):
with launcher("bigscience/bloom-560m") as handle:
yield handle


@pytest.fixture(scope="module")
async def bloom_560(bloom_560_handle):
await bloom_560_handle.health(240)
return bloom_560_handle.client


@pytest.mark.asyncio
async def test_bloom_560m(bloom_560):

base_url = bloom_560.base_url
prompt = "The cat sat on the mat. The cat"

repeated_2grams_control = await call_model(base_url, prompt, 0)
assert (
len(repeated_2grams_control) > 0
), "Expected to find repeated bi-grams in control case"

repeated_2grams_test = await call_model(base_url, prompt, 2)
assert (
len(repeated_2grams_test) == 0
), f"Expected no repeated bi-grams, but found: {repeated_2grams_test}"


async def call_model(base_url, prompt, n_grams):
data = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 20,
"seed": 42,
"no_repeat_ngram_size": n_grams,
"details": True,
},
}
res = requests.post(f"{base_url}/generate", json=data)
res = res.json()

tokens = res["details"]["tokens"]
token_texts = [token["text"] for token in tokens]

# find repeated 2grams
ngrams = [tuple(token_texts[i : i + 2]) for i in range(len(token_texts) - 2 + 1)]
ngram_counts = {}
for ngram in ngrams:
if ngram in ngram_counts:
ngram_counts[ngram] += 1
else:
ngram_counts[ngram] = 1

repeated = [list(ngram) for ngram, count in ngram_counts.items() if count > 1]

return repeated
2 changes: 2 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ message NextTokenChooserParameters {
float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// no_repeat_ngram_size
uint32 no_repeat_ngram_size = 12;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// grammar (applied if not empty)
Expand Down
2 changes: 2 additions & 0 deletions proto/v3/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ message NextTokenChooserParameters {
float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// no_repeat_ngram_size
uint32 no_repeat_ngram_size = 12;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// grammar (applied if not empty)
Expand Down
2 changes: 2 additions & 0 deletions router/src/infer/v2/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ impl From<ValidParameters> for NextTokenChooserParameters {
seed: value.seed,
repetition_penalty: value.repetition_penalty,
frequency_penalty: value.frequency_penalty,
no_repeat_ngram_size: value.no_repeat_ngram_size,
watermark: value.watermark,
grammar,
grammar_type: grammar_type.into(),
Expand Down Expand Up @@ -420,6 +421,7 @@ mod tests {
seed: 0,
repetition_penalty: 0.0,
frequency_penalty: 0.0,
no_repeat_ngram_size: 0,
watermark: false,
grammar: None,
},
Expand Down
22 changes: 22 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ pub(crate) struct GenerateParameters {
)]
pub frequency_penalty: Option<f32>,

/// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
/// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid
/// generating the same n-grams in the completion.
#[serde(default)]
#[schema(nullable = true, example = "12")]
pub no_repeat_ngram_size: Option<u32>,

/// The number of highest probability vocabulary tokens to keep for top-k-filtering.
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
Expand Down Expand Up @@ -327,6 +334,7 @@ fn default_parameters() -> GenerateParameters {
temperature: None,
repetition_penalty: None,
frequency_penalty: None,
no_repeat_ngram_size: None,
top_k: None,
top_p: None,
typical_p: None,
Expand Down Expand Up @@ -424,6 +432,13 @@ pub struct CompletionRequest {
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,

/// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
/// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid
/// generating the same n-grams in the completion.
#[serde(default)]
#[schema(nullable = true, example = "12")]
pub no_repeat_ngram_size: Option<u32>,

/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(default)]
#[schema(nullable = true, example = "null")]
Expand Down Expand Up @@ -740,6 +755,13 @@ pub(crate) struct ChatRequest {
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,

/// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
/// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid
/// generating the same n-grams in the completion.
#[serde(default)]
#[schema(nullable = true, example = "12")]
pub no_repeat_ngram_size: Option<u32>,

/// UNUSED
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
Expand Down
2 changes: 2 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ async fn completions(
temperature,
repetition_penalty: req.repetition_penalty,
frequency_penalty: req.frequency_penalty,
no_repeat_ngram_size: req.no_repeat_ngram_size,
top_k: None,
top_p: req.top_p,
typical_p: None,
Expand Down Expand Up @@ -1100,6 +1101,7 @@ async fn chat_completions(
temperature,
repetition_penalty,
frequency_penalty: req.frequency_penalty,
no_repeat_ngram_size: req.no_repeat_ngram_size,
top_k: None,
top_p: req.top_p,
typical_p: None,
Expand Down
6 changes: 6 additions & 0 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ impl Validation {
temperature,
repetition_penalty,
frequency_penalty,
no_repeat_ngram_size,
top_k,
top_p,
typical_p,
Expand Down Expand Up @@ -238,6 +239,8 @@ impl Validation {
return Err(ValidationError::FrequencyPenalty);
}

let no_repeat_ngram_size = no_repeat_ngram_size.unwrap_or(0);

// Different because the proto default value is not a valid value
// for the user
let top_p = top_p
Expand Down Expand Up @@ -373,6 +376,7 @@ impl Validation {
temperature,
repetition_penalty,
frequency_penalty,
no_repeat_ngram_size,
top_k,
top_p,
typical_p,
Expand Down Expand Up @@ -687,6 +691,8 @@ pub struct ValidParameters {
pub repetition_penalty: f32,
/// / frequency penalty
pub frequency_penalty: f32,
/// / no_repeat_ngram_size
pub no_repeat_ngram_size: u32,
/// / token watermarking using "A Watermark for Large Language Models"
pub watermark: bool,
/// / grammar (applied if not empty)
Expand Down
Loading
Loading