From 2251f5b16f808f2d4c0cebc3486b7a42439140c2 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 19 Feb 2024 16:56:09 +0100 Subject: [PATCH 1/6] fix(router): fix openapi and add jsonschema validation --- Cargo.lock | 157 +++++++++++++++++++++++++++++++++++++++++++ docs/openapi.json | 51 ++++++++++++++ router/Cargo.toml | 1 + router/src/lib.rs | 12 +++- router/src/server.rs | 3 +- 5 files changed, 222 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f81517310f3..ccccdf3c9d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,7 +24,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff" dependencies = [ "cfg-if", + "getrandom", "once_cell", + "serde", "version_check", "zerocopy", ] @@ -265,6 +267,21 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -716,6 +733,16 @@ dependencies = [ "cc", ] +[[package]] +name = "fancy-regex" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" +dependencies = [ + "bit-set", + "regex", +] + [[package]] name = "fastrand" version = "2.0.1" @@ -780,6 +807,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fraction" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3027ae1df8d41b4bed2241c8fdad4acc1e7af60c8e17743534b545e77182d678" +dependencies = [ + "lazy_static", + "num", +] + [[package]] name = "futures" version = "0.3.30" @@ -895,8 +932,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -1181,6 +1220,15 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +[[package]] +name = "iso8601" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "924e5d73ea28f59011fec52a0d12185d496a9b075d360657aed2a5707f701153" +dependencies = [ + "nom", +] + [[package]] name = "itertools" version = "0.10.5" @@ -1223,6 +1271,36 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonschema" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a071f4f7efc9a9118dfb627a0a94ef247986e1ab8606a4c806ae2b3aa3b6978" +dependencies = [ + "ahash", + "anyhow", + "base64 0.21.7", + "bytecount", + "clap", + "fancy-regex", + "fraction", + "getrandom", + "iso8601", + "itoa", + "memchr", + "num-cmp", + "once_cell", + "parking_lot", + "percent-encoding", + "regex", + "reqwest", + "serde", + "serde_json", + "time", + "url", + "uuid", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -1574,12 +1652,84 @@ dependencies = [ "winapi", ] +[[package]] +name = "num" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-cmp" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa" + +[[package]] +name = "num-complex" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.18" @@ -2874,6 +3024,7 @@ dependencies = [ "futures-util", "hf-hub", "init-tracing-opentelemetry", + "jsonschema", "metrics", "metrics-exporter-prometheus", "minijinja", @@ -3530,6 +3681,12 @@ dependencies = [ "zip", ] +[[package]] +name = "uuid" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" + [[package]] name = "valuable" version = "0.1.0" diff --git a/docs/openapi.json b/docs/openapi.json index d72d32e9064..fad01aec0c1 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1022,6 +1022,57 @@ } } }, + "GrammarType": { + "oneOf": [ + { + "type": "object", + "required": [ + "type", + "value" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "json" + ] + }, + "value": { + "type": "string", + "description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions.", + "example": { + "properties": { + "location": { + "type": "string" + } + } + } + } + } + }, + { + "type": "object", + "required": [ + "type", + "value" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "regex" + ] + }, + "value": { + "type": "string" + } + } + } + ], + "discriminator": { + "propertyName": "type" + } + }, "Info": { "type": "object", "required": [ diff --git a/router/Cargo.toml b/router/Cargo.toml index 1a7ceb7095b..97a46ad1c6d 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -22,6 +22,7 @@ text-generation-client = { path = "client" } clap = { version = "4.4.5", features = ["derive", "env"] } futures = "0.3.28" hf-hub = { version = "0.3.0", features = ["tokio"] } +jsonschema = "0.17.1" metrics = "0.21.1" metrics-exporter-prometheus = { version = "0.12.1", features = [] } nohash-hasher = "0.2.0" diff --git a/router/src/lib.rs b/router/src/lib.rs index 8c7ca74b7cd..e1217fa722e 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -46,6 +46,7 @@ impl HubTokenizerConfig { } mod json_object_or_string_to_string { + use jsonschema::JSONSchema; use serde::{Deserialize, Deserializer}; use serde_json::Value; @@ -57,6 +58,10 @@ mod json_object_or_string_to_string { { let value = Value::deserialize(deserializer)?; + JSONSchema::options() + .compile(&value) + .map_err(|e| serde::de::Error::custom(format!("invalid JSONSchema: {e}")))?; + match value { Value::String(s) => Ok(s), // Safely handle serialization and return an error if it fails @@ -70,13 +75,18 @@ mod json_object_or_string_to_string { } } -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, ToSchema)] #[serde(tag = "type", content = "value")] pub(crate) enum GrammarType { + /// A string that represents a [JSON Schema](https://json-schema.org/). + /// + /// JSON Schema is a declarative language that allows to annotate JSON documents + /// with types and descriptions. #[serde( rename = "json", deserialize_with = "json_object_or_string_to_string::deserialize" )] + #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] Json(String), #[serde(rename = "regex")] Regex(String), diff --git a/router/src/server.rs b/router/src/server.rs index 0fc76916328..1849298fbaf 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -5,7 +5,7 @@ use crate::validation::ValidationError; use crate::{ BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta, ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse, - FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, + FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation, }; @@ -800,6 +800,7 @@ pub async fn run( Info, CompatGenerateRequest, GenerateRequest, + GrammarType, ChatRequest, Message, ChatCompletionChoice, From 0533e67ea62aa63ff9563f66b19b3e1bcdc5000e Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 19 Feb 2024 17:33:23 +0100 Subject: [PATCH 2/6] use Draft202012 --- router/Cargo.toml | 2 +- router/src/lib.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/router/Cargo.toml b/router/Cargo.toml index 97a46ad1c6d..8f13dc33578 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -22,7 +22,7 @@ text-generation-client = { path = "client" } clap = { version = "4.4.5", features = ["derive", "env"] } futures = "0.3.28" hf-hub = { version = "0.3.0", features = ["tokio"] } -jsonschema = "0.17.1" +jsonschema = { version = "0.17.1", features = ["draft202012"] } metrics = "0.21.1" metrics-exporter-prometheus = { version = "0.12.1", features = [] } nohash-hasher = "0.2.0" diff --git a/router/src/lib.rs b/router/src/lib.rs index e1217fa722e..c1da657272a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -46,7 +46,7 @@ impl HubTokenizerConfig { } mod json_object_or_string_to_string { - use jsonschema::JSONSchema; + use jsonschema::{Draft, JSONSchema}; use serde::{Deserialize, Deserializer}; use serde_json::Value; @@ -59,6 +59,7 @@ mod json_object_or_string_to_string { let value = Value::deserialize(deserializer)?; JSONSchema::options() + .with_draft(Draft::Draft202012) .compile(&value) .map_err(|e| serde::de::Error::custom(format!("invalid JSONSchema: {e}")))?; From bada345055b66f11afb7631df5db6e8b7a8c6cd5 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 20 Feb 2024 14:55:30 +0100 Subject: [PATCH 3/6] ??? --- integration-tests/models/test_flash_awq.py | 3 -- .../models/test_flash_awq_sharded.py | 2 - integration-tests/models/test_flash_medusa.py | 3 -- .../models/test_flash_mistral.py | 3 -- integration-tests/models/test_flash_phi.py | 3 -- .../models/test_flash_starcoder_gptq.py | 3 -- .../models/test_grammar_llama.py | 5 --- integration-tests/models/test_mamba.py | 3 -- router/src/lib.rs | 38 +------------------ router/src/validation.rs | 15 +++++++- server/text_generation_server/utils/tokens.py | 5 +-- 11 files changed, 18 insertions(+), 65 deletions(-) diff --git a/integration-tests/models/test_flash_awq.py b/integration-tests/models/test_flash_awq.py index 62a95f48791..ead918c32bc 100644 --- a/integration-tests/models/test_flash_awq.py +++ b/integration-tests/models/test_flash_awq.py @@ -18,7 +18,6 @@ async def flash_llama_awq(flash_llama_awq_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq(flash_llama_awq, response_snapshot): response = await flash_llama_awq.generate( "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True @@ -33,7 +32,6 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): response = await flash_llama_awq.generate( "What is Deep Learning?", @@ -55,7 +53,6 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot): responses = await generate_load( flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4 diff --git a/integration-tests/models/test_flash_awq_sharded.py b/integration-tests/models/test_flash_awq_sharded.py index 1c687fc9915..a83614acdfb 100644 --- a/integration-tests/models/test_flash_awq_sharded.py +++ b/integration-tests/models/test_flash_awq_sharded.py @@ -18,7 +18,6 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): response = await flash_llama_awq_sharded.generate( "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True @@ -33,7 +32,6 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_load_sharded( flash_llama_awq_sharded, generate_load, response_snapshot ): diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py index a0ce057092f..e0cc1039451 100644 --- a/integration-tests/models/test_flash_medusa.py +++ b/integration-tests/models/test_flash_medusa.py @@ -14,7 +14,6 @@ async def flash_medusa(flash_medusa_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_medusa_simple(flash_medusa, response_snapshot): response = await flash_medusa.generate( "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True @@ -25,7 +24,6 @@ async def test_flash_medusa_simple(flash_medusa, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_medusa_all_params(flash_medusa, response_snapshot): response = await flash_medusa.generate( "What is Deep Learning?", @@ -48,7 +46,6 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): responses = await generate_load( flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4 diff --git a/integration-tests/models/test_flash_mistral.py b/integration-tests/models/test_flash_mistral.py index ace3328b1c7..52b51928dfb 100644 --- a/integration-tests/models/test_flash_mistral.py +++ b/integration-tests/models/test_flash_mistral.py @@ -14,7 +14,6 @@ async def flash_mistral(flash_mistral_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_mistral(flash_mistral, response_snapshot): response = await flash_mistral.generate( "Test request", max_new_tokens=10, decoder_input_details=True @@ -26,7 +25,6 @@ async def test_flash_mistral(flash_mistral, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_mistral_all_params(flash_mistral, response_snapshot): response = await flash_mistral.generate( "Test request", @@ -49,7 +47,6 @@ async def test_flash_mistral_all_params(flash_mistral, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot): responses = await generate_load( flash_mistral, "Test request", max_new_tokens=10, n=4 diff --git a/integration-tests/models/test_flash_phi.py b/integration-tests/models/test_flash_phi.py index 0987b3a1b95..9d6ca56693d 100644 --- a/integration-tests/models/test_flash_phi.py +++ b/integration-tests/models/test_flash_phi.py @@ -14,7 +14,6 @@ async def flash_phi(flash_phi_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_phi(flash_phi, response_snapshot): response = await flash_phi.generate( "Test request", max_new_tokens=10, decoder_input_details=True @@ -26,7 +25,6 @@ async def test_flash_phi(flash_phi, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_phi_all_params(flash_phi, response_snapshot): response = await flash_phi.generate( "Test request", @@ -50,7 +48,6 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_phi_load(flash_phi, generate_load, response_snapshot): responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4) diff --git a/integration-tests/models/test_flash_starcoder_gptq.py b/integration-tests/models/test_flash_starcoder_gptq.py index 5e448d5532a..329158b7813 100644 --- a/integration-tests/models/test_flash_starcoder_gptq.py +++ b/integration-tests/models/test_flash_starcoder_gptq.py @@ -14,7 +14,6 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot): response = await flash_starcoder_gptq.generate( "def geometric_mean(L: List[float]):", @@ -26,7 +25,6 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap @pytest.mark.asyncio -@pytest.mark.private async def test_flash_starcoder_gptq_default_params( flash_starcoder_gptq, generous_response_snapshot ): @@ -43,7 +41,6 @@ async def test_flash_starcoder_gptq_default_params( @pytest.mark.asyncio -@pytest.mark.private async def test_flash_starcoder_gptq_load( flash_starcoder_gptq, generate_load, generous_response_snapshot ): diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py index f068496c180..ba123999a09 100644 --- a/integration-tests/models/test_grammar_llama.py +++ b/integration-tests/models/test_grammar_llama.py @@ -19,7 +19,6 @@ async def flash_llama_grammar(flash_llama_grammar_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): response = await flash_llama_grammar.generate( "Test request", max_new_tokens=10, decoder_input_details=True @@ -30,7 +29,6 @@ async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot): response = await flash_llama_grammar.generate( "Whats Googles DNS", @@ -49,7 +47,6 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot) @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): response = await flash_llama_grammar.generate( "info: david holtz like trees and has two cats. ", @@ -98,7 +95,6 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_load( flash_llama_grammar, generate_load, response_snapshot ): @@ -130,7 +126,6 @@ async def test_flash_llama_grammar_load( # this is the same as the above test, but only fires off a single request # this is only to ensure that the parallel and single inference produce the same result @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_single_load_instance( flash_llama_grammar, generate_load, response_snapshot ): diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index 5ec2ec31113..bf3701b4db4 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -14,7 +14,6 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_mamba(fused_kernel_mamba, response_snapshot): response = await fused_kernel_mamba.generate( "What is Deep Learning?", max_new_tokens=10 @@ -26,7 +25,6 @@ async def test_mamba(fused_kernel_mamba, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): response = await fused_kernel_mamba.generate( "blue, red, yellow, ", @@ -53,7 +51,6 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_mamba_load( fused_kernel_mamba, generate_load, generous_response_snapshot ): diff --git a/router/src/lib.rs b/router/src/lib.rs index c1da657272a..7bf51f5db71 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -45,37 +45,6 @@ impl HubTokenizerConfig { } } -mod json_object_or_string_to_string { - use jsonschema::{Draft, JSONSchema}; - use serde::{Deserialize, Deserializer}; - use serde_json::Value; - - // A custom deserializer that treats both strings and objects as strings. - // This provides flexibility with input formats for the 'grammar' field. - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; - - JSONSchema::options() - .with_draft(Draft::Draft202012) - .compile(&value) - .map_err(|e| serde::de::Error::custom(format!("invalid JSONSchema: {e}")))?; - - match value { - Value::String(s) => Ok(s), - // Safely handle serialization and return an error if it fails - Value::Object(o) => { - serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string())) - } - _ => Err(serde::de::Error::custom( - "expected string or object for grammar", - )), - } - } -} - #[derive(Clone, Debug, Deserialize, ToSchema)] #[serde(tag = "type", content = "value")] pub(crate) enum GrammarType { @@ -83,12 +52,9 @@ pub(crate) enum GrammarType { /// /// JSON Schema is a declarative language that allows to annotate JSON documents /// with types and descriptions. - #[serde( - rename = "json", - deserialize_with = "json_object_or_string_to_string::deserialize" - )] + #[serde(rename = "json")] #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] - Json(String), + Json(serde_json::Value), #[serde(rename = "regex")] Regex(String), } diff --git a/router/src/validation.rs b/router/src/validation.rs index bf85b12f137..f350d15e820 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -314,7 +314,18 @@ impl Validation { } match grammar { // currently both are handled the same way since compilation is done in Python - GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()), + GrammarType::Json(json) => { + // JSONSchema::options() + // .with_draft(Draft::Draft202012) + // .compile(&json) + // .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; + + ( + serde_json::to_string(&json) + .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, + ProtoGrammarType::Json.into(), + ) + } GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), } } @@ -486,6 +497,8 @@ pub enum ValidationError { Tokenizer(String), #[error("grammar is not supported")] Grammar, + #[error("grammar is not a valid JSONSchema: {0}")] + InvalidGrammar(String), } #[cfg(test)] diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 72c6c21c7b5..32789850077 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -328,7 +328,6 @@ def __call__( scores = scores.view(B, S, -1) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) - mask = torch.full((scores.shape[-1],), -math.inf, device=self.device) for j in range(S): _scores = scores[:, j] @@ -338,10 +337,10 @@ def __call__( _scores = self.repetition_processor(input_ids, _scores) if self.frequency_processor is not None: _scores = self.frequency_processor(input_ids, _scores) - for warper in self.warpers: - _scores = warper(input_ids, _scores) if self.grammar_processor is not None: _scores = self.grammar_processor(_scores, self.fsm_grammar_states) + for warper in self.warpers: + _scores = warper(input_ids, _scores) _next_ids = self.choice(_scores) scores[:, j] = _scores next_ids[:, j] = _next_ids From 19dd0a8fde5596a008370ca0d66ba9449d1dba0a Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 20 Feb 2024 15:16:02 +0100 Subject: [PATCH 4/6] parse string to json --- router/src/validation.rs | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/router/src/validation.rs b/router/src/validation.rs index f350d15e820..204dbf92a6a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,7 +1,9 @@ /// Payload validation logic use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{GenerateParameters, GenerateRequest, GrammarType}; +use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; +use serde_json::Value; use text_generation_client::{ GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, }; @@ -313,14 +315,24 @@ impl Validation { return Err(ValidationError::Grammar); } match grammar { - // currently both are handled the same way since compilation is done in Python GrammarType::Json(json) => { - // JSONSchema::options() - // .with_draft(Draft::Draft202012) - // .compile(&json) - // .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; + let json = match json { + // if value is a string, we need to parse it again to make sure its + // a valid json + Value::String(s) => serde_json::from_str(&s) + .map_err(|e| ValidationError::InvalidGrammar(e.to_string())), + Value::Object(_) => Ok(json), + _ => Err(ValidationError::Grammar), + }?; + + // Check if the json is a valid JSONSchema + JSONSchema::options() + .with_draft(Draft::Draft202012) + .compile(&json) + .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; ( + // Serialize json to string serde_json::to_string(&json) .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, ProtoGrammarType::Json.into(), @@ -497,7 +509,7 @@ pub enum ValidationError { Tokenizer(String), #[error("grammar is not supported")] Grammar, - #[error("grammar is not a valid JSONSchema: {0}")] + #[error("grammar is not valid: {0}")] InvalidGrammar(String), } From befca7a9a36e7e6c77e939f23ed84d10bf6ebe69 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:19:12 +0100 Subject: [PATCH 5/6] fix snapshot --- .github/workflows/tests.yaml | 2 +- .../test_flash_llama_grammar_json.json | 92 +++++++++---------- .../models/test_grammar_llama.py | 2 +- 3 files changed, 48 insertions(+), 48 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 5b19eb8c02b..96010e8a97d 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -17,7 +17,7 @@ concurrency: jobs: run_tests: - runs-on: ubuntu-latest + runs-on: large-runner env: SCCACHE_GHA_ENABLED: "on" diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json index 7b12b1587a2..d7fb620d4c6 100644 --- a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json @@ -136,128 +136,128 @@ "text": "\",\"" }, { - "id": 4230, - "logprob": -0.020492554, + "id": 29882, + "logprob": -0.08862305, "special": false, - "text": "last" + "text": "h" }, { - "id": 1170, - "logprob": -0.0013818741, + "id": 711, + "logprob": -0.66259766, "special": false, - "text": "Name" + "text": "ob" }, { - "id": 4710, - "logprob": -0.0067749023, + "id": 1609, + "logprob": -5.51939e-05, "special": false, - "text": "\":\"" + "text": "by" }, { - "id": 29950, - "logprob": -0.11578369, + "id": 4710, + "logprob": -0.23120117, "special": false, - "text": "H" + "text": "\":\"" }, { - "id": 14339, - "logprob": -0.004131317, + "id": 29911, + "logprob": -2.3730469, "special": false, - "text": "olt" + "text": "T" }, { - "id": 29920, - "logprob": -0.0033359528, + "id": 11003, + "logprob": -0.032104492, "special": false, - "text": "z" + "text": "rees" }, { "id": 3284, - "logprob": -0.20471191, + "logprob": -0.22021484, "special": false, "text": "\",\"" }, { - "id": 29882, - "logprob": -0.0069274902, + "id": 4230, + "logprob": -0.06726074, "special": false, - "text": "h" + "text": "last" }, { - "id": 20838, - "logprob": -0.19580078, + "id": 1170, + "logprob": -0.003501892, "special": false, - "text": "obb" + "text": "Name" }, { - "id": 29891, - "logprob": -2.2649765e-06, + "id": 4710, + "logprob": -0.0045661926, "special": false, - "text": "y" + "text": "\":\"" }, { - "id": 4710, - "logprob": -0.32080078, + "id": 29950, + "logprob": -0.12512207, "special": false, - "text": "\":\"" + "text": "H" }, { - "id": 29911, - "logprob": -2.1035156, + "id": 14339, + "logprob": -0.009552002, "special": false, - "text": "T" + "text": "olt" }, { - "id": 11003, - "logprob": -0.020767212, + "id": 29920, + "logprob": -0.00042438507, "special": false, - "text": "rees" + "text": "z" }, { "id": 3284, - "logprob": -0.6010742, + "logprob": -0.11651611, "special": false, "text": "\",\"" }, { "id": 29876, - "logprob": -0.57666016, + "logprob": -0.29736328, "special": false, "text": "n" }, { "id": 398, - "logprob": -0.0061073303, + "logprob": -0.003030777, "special": false, "text": "um" }, { "id": 29907, - "logprob": -0.45703125, + "logprob": -0.3774414, "special": false, "text": "C" }, { "id": 1446, - "logprob": -0.0002872944, + "logprob": -0.0003130436, "special": false, "text": "ats" }, { "id": 1115, - "logprob": -0.0021018982, + "logprob": -0.0021514893, "special": false, "text": "\":" }, { "id": 29906, - "logprob": -0.08996582, + "logprob": -0.071899414, "special": false, "text": "2" }, { "id": 29913, - "logprob": -0.021697998, + "logprob": -0.018997192, "special": false, "text": "}" }, @@ -270,5 +270,5 @@ ], "top_tokens": null }, - "generated_text": "{\"firstName\":\"David\",\"lastName\":\"Holtz\",\"hobby\":\"Trees\",\"numCats\":2}" + "generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}" } diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py index ba123999a09..585d0656c3c 100644 --- a/integration-tests/models/test_grammar_llama.py +++ b/integration-tests/models/test_grammar_llama.py @@ -89,7 +89,7 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): assert response.details.generated_tokens == 30 assert ( response.generated_text - == '{"firstName":"David","lastName":"Holtz","hobby":"Trees","numCats":2}' + == '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}' ) assert response == response_snapshot From f8b54b78d77eb4051ecf58f00e892d89174ff6d0 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:27:35 +0100 Subject: [PATCH 6/6] remove unused files --- .github/workflows/tests.yaml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 96010e8a97d..29ff6d4545e 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -17,7 +17,7 @@ concurrency: jobs: run_tests: - runs-on: large-runner + runs-on: ubuntu-latest env: SCCACHE_GHA_ENABLED: "on" @@ -41,6 +41,10 @@ jobs: components: rustfmt, clippy - name: Install Protoc uses: arduino/setup-protoc@v1 + - name: Clean unused files + run: | + sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android + sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET - name: Install sccache run: | curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache