diff --git a/router/src/validation.rs b/router/src/validation.rs index f350d15e820..204dbf92a6a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,7 +1,9 @@ /// Payload validation logic use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{GenerateParameters, GenerateRequest, GrammarType}; +use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; +use serde_json::Value; use text_generation_client::{ GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, }; @@ -313,14 +315,24 @@ impl Validation { return Err(ValidationError::Grammar); } match grammar { - // currently both are handled the same way since compilation is done in Python GrammarType::Json(json) => { - // JSONSchema::options() - // .with_draft(Draft::Draft202012) - // .compile(&json) - // .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; + let json = match json { + // if value is a string, we need to parse it again to make sure its + // a valid json + Value::String(s) => serde_json::from_str(&s) + .map_err(|e| ValidationError::InvalidGrammar(e.to_string())), + Value::Object(_) => Ok(json), + _ => Err(ValidationError::Grammar), + }?; + + // Check if the json is a valid JSONSchema + JSONSchema::options() + .with_draft(Draft::Draft202012) + .compile(&json) + .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; ( + // Serialize json to string serde_json::to_string(&json) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, ProtoGrammarType::Json.into(), @@ -497,7 +509,7 @@ pub enum ValidationError { Tokenizer(String), #[error("grammar is not supported")] Grammar, - #[error("grammar is not a valid JSONSchema: {0}")] + #[error("grammar is not valid: {0}")] InvalidGrammar(String), }