Skip to content

Commit

Permalink
Incorporate changes from main
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikKaum committed Jul 31, 2024
1 parent 7451041 commit 303be7d
Show file tree
Hide file tree
Showing 19 changed files with 177 additions and 9 deletions.
1 change: 1 addition & 0 deletions backends/client/src/v2/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ impl Client {
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
no_repeat_ngram_size: 0,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
Expand Down
1 change: 1 addition & 0 deletions backends/client/src/v2/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ impl Health for ShardedClient {
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
no_repeat_ngram_size: 0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
Expand Down
1 change: 1 addition & 0 deletions backends/client/src/v3/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ impl Client {
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
no_repeat_ngram_size: 0,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
Expand Down
1 change: 1 addition & 0 deletions backends/client/src/v3/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ impl Health for ShardedClient {
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
no_repeat_ngram_size: 0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
Expand Down
2 changes: 2 additions & 0 deletions backends/v3/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ impl From<ValidParameters> for NextTokenChooserParameters {
seed: value.seed,
repetition_penalty: value.repetition_penalty,
frequency_penalty: value.frequency_penalty,
no_repeat_ngram_size: value.no_repeat_ngram_size,
watermark: value.watermark,
grammar,
grammar_type: grammar_type.into(),
Expand Down Expand Up @@ -497,6 +498,7 @@ mod tests {
seed: 0,
repetition_penalty: 0.0,
frequency_penalty: 0.0,
no_repeat_ngram_size: 0,
watermark: false,
grammar: None,
},
Expand Down
3 changes: 3 additions & 0 deletions benchmark/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub async fn run(
typical_p: Option<f32>,
repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
no_repeat_ngram_size: Option<u32>,
watermark: bool,
do_sample: bool,
client: ShardedClient,
Expand All @@ -44,6 +45,7 @@ pub async fn run(
seed: 0,
repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0),
no_repeat_ngram_size: no_repeat_ngram_size.unwrap_or(0),
watermark,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
Expand Down Expand Up @@ -145,6 +147,7 @@ pub async fn run(
typical_p,
repetition_penalty,
frequency_penalty,
no_repeat_ngram_size,
watermark,
do_sample,
);
Expand Down
7 changes: 7 additions & 0 deletions benchmark/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ struct Args {
#[clap(long, env)]
frequency_penalty: Option<f32>,

/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
no_repeat_ngram_size: Option<u32>,

/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
Expand Down Expand Up @@ -125,6 +130,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
typical_p,
repetition_penalty,
frequency_penalty,
no_repeat_ngram_size,
watermark,
do_sample,
master_shard_uds_path,
Expand Down Expand Up @@ -196,6 +202,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
typical_p,
repetition_penalty,
frequency_penalty,
no_repeat_ngram_size,
watermark,
do_sample,
sharded_client,
Expand Down
2 changes: 2 additions & 0 deletions benchmark/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub(crate) fn parameters_table(
typical_p: Option<f32>,
repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
no_repeat_ngram_size: Option<u32>,
watermark: bool,
do_sample: bool,
) -> Table {
Expand All @@ -35,6 +36,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Typical P", &format!("{typical_p:?}")]);
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
builder.push_record(["Frequency Penalty", &format!("{frequency_penalty:?}")]);
builder.push_record(["No Repeat Ngram Size", &format!("{no_repeat_ngram_size:?}")]);
builder.push_record(["Watermark", &watermark.to_string()]);
builder.push_record(["Do Sample", &do_sample.to_string()]);

Expand Down
4 changes: 4 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ class Parameters:
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float]
# n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
# sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid
# generating the same n-grams in the completion.
no_repeat_ngram_size: Optional[int]
# Whether to prepend the prompt to the generated text
return_full_text: bool
# Stop generating tokens if a member of `stop_sequences` is generated
Expand Down
24 changes: 24 additions & 0 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,14 @@
"nullable": true,
"minimum": 0
},
"no_repeat_ngram_size": {
"type": "integer",
"format": "int32",
"description": "n-grams are groups of \"n\" consecutive words, characters, or tokens taken from a sequence of text. Given the\nsentence: \"She runs fast\", the bi-grams (n=2) would be (\"she\", \"runs\") and (\"runs\", \"fast\"). Set this to avoid\ngenerating the same n-grams in the completion.",
"example": "12",
"nullable": true,
"minimum": 0
},
"presence_penalty": {
"type": "number",
"format": "float",
Expand Down Expand Up @@ -1140,6 +1148,14 @@
"example": "mistralai/Mistral-7B-Instruct-v0.2",
"nullable": true
},
"no_repeat_ngram_size": {
"type": "integer",
"format": "int32",
"description": "n-grams are groups of \"n\" consecutive words, characters, or tokens taken from a sequence of text. Given the\nsentence: \"She runs fast\", the bi-grams (n=2) would be (\"she\", \"runs\") and (\"runs\", \"fast\"). Set this to avoid\ngenerating the same n-grams in the completion.",
"example": "12",
"nullable": true,
"minimum": 0
},
"prompt": {
"$ref": "#/components/schemas/Prompt"
},
Expand Down Expand Up @@ -1397,6 +1413,14 @@
"nullable": true,
"minimum": 0
},
"no_repeat_ngram_size": {
"type": "integer",
"format": "int32",
"description": "n-grams are groups of \"n\" consecutive words, characters, or tokens taken from a sequence of text. Given the\nsentence: \"She runs fast\", the bi-grams (n=2) would be (\"she\", \"runs\") and (\"runs\", \"fast\"). Set this to avoid\ngenerating the same n-grams in the completion.",
"example": "12",
"nullable": true,
"minimum": 0
},
"repetition_penalty": {
"type": "number",
"format": "float",
Expand Down
57 changes: 57 additions & 0 deletions integration-tests/models/test_no_repeat_ngram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import List
import pytest
import requests

def bloom_560_handle(launcher):
with launcher("bigscience/bloom-560m") as handle:
yield handle


@pytest.fixture(scope="module")
async def bloom_560(bloom_560_handle):
await bloom_560_handle.health(240)
return bloom_560_handle.client

@pytest.mark.release
@pytest.mark.asyncio
async def test_bloom_560m(bloom_560):

base_url = bloom_560.base_url
prompt = "The cat sat on the mat. The cat"

repeated_2grams_control = await call_model(base_url, prompt, 0)
assert len(repeated_2grams_control) > 0, "Expected to find repeated bi-grams in control case"

repeated_2grams_test = await call_model(base_url, prompt, 2)
assert len(repeated_2grams_test) == 0, f"Expected no repeated bi-grams, but found: {repeated_2grams_test}"


async def call_model(base_url, prompt, n_grams):
data = {
"inputs": prompt,
"parameters" : {
"max_new_tokens": 20,
"seed": 42,
"no_repeat_ngram_size": n_grams,
"details": True
}
}
res = requests.post(f"{base_url}/generate", json=data)
res = res.json()

tokens = res['details']['tokens']
token_texts = [token['text'] for token in tokens]

# find repeated 2grams
ngrams = [tuple(token_texts[i:i+2]) for i in range(len(token_texts)-2+1)]
ngram_counts = {}
for ngram in ngrams:
if ngram in ngram_counts:
ngram_counts[ngram] += 1
else:
ngram_counts[ngram] = 1

repeated = [list(ngram) for ngram, count in ngram_counts.items() if count > 1]

return repeated

2 changes: 2 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ message NextTokenChooserParameters {
float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// no_repeat_ngram_size
uint32 no_repeat_ngram_size = 12;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// grammar (applied if not empty)
Expand Down
2 changes: 2 additions & 0 deletions proto/v3/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ message NextTokenChooserParameters {
float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// no_repeat_ngram_size
uint32 no_repeat_ngram_size = 12;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// grammar (applied if not empty)
Expand Down
2 changes: 2 additions & 0 deletions router/src/infer/v2/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ impl From<ValidParameters> for NextTokenChooserParameters {
seed: value.seed,
repetition_penalty: value.repetition_penalty,
frequency_penalty: value.frequency_penalty,
no_repeat_ngram_size: value.no_repeat_ngram_size,
watermark: value.watermark,
grammar,
grammar_type: grammar_type.into(),
Expand Down Expand Up @@ -420,6 +421,7 @@ mod tests {
seed: 0,
repetition_penalty: 0.0,
frequency_penalty: 0.0,
no_repeat_ngram_size: 0,
watermark: false,
grammar: None,
},
Expand Down
22 changes: 22 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ pub(crate) struct GenerateParameters {
)]
pub frequency_penalty: Option<f32>,

/// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
/// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid
/// generating the same n-grams in the completion.
#[serde(default)]
#[schema(nullable = true, example = "12")]
pub no_repeat_ngram_size: Option<u32>,

/// The number of highest probability vocabulary tokens to keep for top-k-filtering.
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
Expand Down Expand Up @@ -327,6 +334,7 @@ fn default_parameters() -> GenerateParameters {
temperature: None,
repetition_penalty: None,
frequency_penalty: None,
no_repeat_ngram_size: None,
top_k: None,
top_p: None,
typical_p: None,
Expand Down Expand Up @@ -424,6 +432,13 @@ pub struct CompletionRequest {
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,

/// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
/// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid
/// generating the same n-grams in the completion.
#[serde(default)]
#[schema(nullable = true, example = "12")]
pub no_repeat_ngram_size: Option<u32>,

/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(default)]
#[schema(nullable = true, example = "null")]
Expand Down Expand Up @@ -740,6 +755,13 @@ pub(crate) struct ChatRequest {
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,

/// n-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
/// sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). Set this to avoid
/// generating the same n-grams in the completion.
#[serde(default)]
#[schema(nullable = true, example = "12")]
pub no_repeat_ngram_size: Option<u32>,

/// UNUSED
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
Expand Down
2 changes: 2 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ async fn completions(
temperature,
repetition_penalty: req.repetition_penalty,
frequency_penalty: req.frequency_penalty,
no_repeat_ngram_size: req.no_repeat_ngram_size,
top_k: None,
top_p: req.top_p,
typical_p: None,
Expand Down Expand Up @@ -1100,6 +1101,7 @@ async fn chat_completions(
temperature,
repetition_penalty,
frequency_penalty: req.frequency_penalty,
no_repeat_ngram_size: req.no_repeat_ngram_size,
top_k: None,
top_p: req.top_p,
typical_p: None,
Expand Down
6 changes: 6 additions & 0 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ impl Validation {
temperature,
repetition_penalty,
frequency_penalty,
no_repeat_ngram_size,
top_k,
top_p,
typical_p,
Expand Down Expand Up @@ -238,6 +239,8 @@ impl Validation {
return Err(ValidationError::FrequencyPenalty);
}

let no_repeat_ngram_size = no_repeat_ngram_size.unwrap_or(0);

// Different because the proto default value is not a valid value
// for the user
let top_p = top_p
Expand Down Expand Up @@ -373,6 +376,7 @@ impl Validation {
temperature,
repetition_penalty,
frequency_penalty,
no_repeat_ngram_size,
top_k,
top_p,
typical_p,
Expand Down Expand Up @@ -687,6 +691,8 @@ pub struct ValidParameters {
pub repetition_penalty: f32,
/// / frequency penalty
pub frequency_penalty: f32,
/// / no_repeat_ngram_size
pub no_repeat_ngram_size: u32,
/// / token watermarking using "A Watermark for Large Language Models"
pub watermark: bool,
/// / grammar (applied if not empty)
Expand Down
15 changes: 7 additions & 8 deletions server/text_generation_server/utils/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,13 @@
TypicalLogitsWarper,
)

from transformers.generation.logits_process import _calc_banned_ngram_tokens

mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None


class StaticWarper:
def __init__(
self,
temperature=1.0,
top_k=None,
top_p=None,
typical_p=None,
):
def __init__(self, temperature=1.0, top_k=None, top_p=None, typical_p=None):
self.warpers = []

if temperature is not None and temperature != 1.0:
Expand Down Expand Up @@ -84,7 +80,10 @@ def static_warper(
typical_p: Optional[float],
) -> StaticWarper:
return StaticWarper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
temperature=temperature,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
)


Expand Down
Loading

0 comments on commit 303be7d

Please sign in to comment.