Skip to content

Commit

Permalink
feat: add --grammar-support cli flag and validation error
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Feb 13, 2024
1 parent bdb9705 commit a4ff4f3
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 0 deletions.
10 changes: 10 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,11 @@ struct Args {
#[clap(long, env)]
tokenizer_config_path: Option<String>,

/// Enable outlines grammar constrained generation
/// This is a feature that allows you to generate text that follows a specific grammar.
#[clap(long, env)]
grammar_support: bool,

/// Display a lot of information about your runtime environment
#[clap(long, short, action)]
env: bool,
Expand Down Expand Up @@ -1051,6 +1056,11 @@ fn spawn_webserver(
args.model_id,
];

// Grammar support
if args.grammar_support {
router_args.push("--grammar-support".to_string());
}

// Tokenizer config path
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
router_args.push("--tokenizer-config-path".to_string());
Expand Down
4 changes: 4 additions & 0 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ struct Args {
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
grammar_support: bool,
}

#[tokio::main]
Expand Down Expand Up @@ -108,6 +110,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
grammar_support,
} = args;

// Launch Tokio runtime
Expand Down Expand Up @@ -359,6 +362,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_edge,
tokenizer_config,
messages_api_enabled,
grammar_support,
)
.await?;
Ok(())
Expand Down
2 changes: 2 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,7 @@ pub async fn run(
ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig,
messages_api_enabled: bool,
grammar_support: bool,
) -> Result<(), axum::BoxError> {
// OpenAPI documentation
#[derive(OpenApi)]
Expand Down Expand Up @@ -841,6 +842,7 @@ pub async fn run(
max_top_n_tokens,
max_input_length,
max_total_tokens,
grammar_support,
);
let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone());
Expand Down
10 changes: 10 additions & 0 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct Validation {
max_top_n_tokens: u32,
max_input_length: usize,
max_total_tokens: usize,
grammar_support: bool,
/// Channel to communicate with the background tokenization task
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
}
Expand All @@ -32,6 +33,7 @@ impl Validation {
max_top_n_tokens: u32,
max_input_length: usize,
max_total_tokens: usize,
grammar_support: bool,
) -> Self {
// If we have a fast tokenizer
let sender = if let Some(tokenizer) = tokenizer {
Expand Down Expand Up @@ -66,6 +68,7 @@ impl Validation {
max_top_n_tokens,
max_input_length,
max_total_tokens,
grammar_support,
}
}

Expand Down Expand Up @@ -293,6 +296,11 @@ impl Validation {
.validate_input(request.inputs, truncate, max_new_tokens)
.await?;

// Ensure that grammar is not set if it's not supported
if !grammar.is_empty() && !self.grammar_support {
return Err(ValidationError::Grammar);
}

let parameters = NextTokenChooserParameters {
temperature,
repetition_penalty,
Expand Down Expand Up @@ -455,6 +463,8 @@ pub enum ValidationError {
StopSequence(usize, usize),
#[error("tokenizer error {0}")]
Tokenizer(String),
#[error("grammar is not supported")]
Grammar,
}

#[cfg(test)]
Expand Down

0 comments on commit a4ff4f3

Please sign in to comment.