Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mocked http request tests #1395

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
hf-hub = "0.3.1"
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
tower = "0.4.13"

[build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
Expand Down
5 changes: 5 additions & 0 deletions router/client/src/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ impl ShardedClient {
Self { clients }
}

/// Create a new ShardedClient with no shards. Used for testing
pub fn empty() -> Self {
Self { clients: vec![] }
}

/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result<Self> {
Expand Down
4 changes: 4 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>,

// useful when testing the router in isolation
skip_generation: Option<bool>,
}

fn default_max_new_tokens() -> Option<u32> {
Expand All @@ -162,6 +165,7 @@ fn default_parameters() -> GenerateParameters {
decoder_input_details: false,
seed: None,
top_n_tokens: None,
skip_generation: None,
}
}

Expand Down
221 changes: 221 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ async fn generate(

let details: bool = req.parameters.details || req.parameters.decoder_input_details;

// Early return if skip_generation is set
if req.parameters.skip_generation.unwrap_or(false) {
let response = GenerateResponse {
generated_text: req.inputs.clone(),
details: None,
};
return Ok((HeaderMap::new(), Json(response)));
}

// Inference
let (response, best_of_responses) = match req.parameters.best_of {
Some(best_of) if best_of > 1 => {
Expand Down Expand Up @@ -838,3 +847,215 @@ impl From<InferError> for Event {
.unwrap()
}
}

#[cfg(test)]
mod tests {
use super::*;
use axum::body::HttpBody;
use axum::{
body::Body,
http::{self, Request, StatusCode},
};
use serde_json::json;
use tower::util::ServiceExt;

/// Build the router for testing purposes
async fn build_router() -> Router<(), axum::body::Body> {
// Set dummy values for testing
let validation_workers = 1;
let tokenizer = None;
let waiting_served_ratio = 1.0;
let max_batch_prefill_tokens = 1;
let max_batch_total_tokens = 1;
let max_concurrent_requests = 1;
let max_waiting_tokens = 1;
let requires_padding = false;
let allow_origin = None;
let max_best_of = 1;
let max_stop_sequences = 1;
let max_input_length = 1024;
let max_total_tokens = 2048;
let max_top_n_tokens = 5;

// Create an empty client
let shardless_client = ShardedClient::empty();

// Create validation and inference
let validation = Validation::new(
validation_workers,
tokenizer,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);

// Create shard info
let shard_info = ShardInfo {
dtype: "demo".to_string(),
device_type: "none".to_string(),
window_size: Some(1),
speculate: 0,
requires_padding,
};

// Create model info
let model_info = HubModelInfo {
model_id: "test".to_string(),
sha: None,
pipeline_tag: None,
};

// Setup extension
let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(shardless_client.clone(), generation_health.clone());

// Build the Infer struct with the dummy values
let infer = Infer::new(
shardless_client,
validation,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_concurrent_requests,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
);

// CORS layer
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
let cors_layer = CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_headers([http::header::CONTENT_TYPE])
.allow_origin(allow_origin);

// Endpoint info
let info = Info {
model_id: model_info.model_id,
model_sha: model_info.sha,
model_dtype: shard_info.dtype,
model_device_type: shard_info.device_type,
model_pipeline_tag: model_info.pipeline_tag,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_length,
max_total_tokens,
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
validation_workers,
version: env!("CARGO_PKG_VERSION"),
sha: option_env!("VERGEN_GIT_SHA"),
docker_label: option_env!("DOCKER_LABEL"),
};

let compat_return_full_text = true;

// Create router
let app: Router<(), Body> = Router::new()
// removed the swagger ui for testing
// Base routes
.route("/", post(compat_generate))
.route("/info", get(get_model_info))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
// Base Health route
.route("/health", get(health))
// Inference API health route
.route("/", get(health))
// AWS Sagemaker health route
.route("/ping", get(health))
// Prometheus metrics route
.route("/metrics", get(metrics))
.layer(Extension(info))
.layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
// removed the prometheus layer for testing
.layer(OtelAxumLayer::default())
.layer(cors_layer);

app
}

#[tokio::test]
async fn test_echo_inputs_when_skip_generation() {
let app = build_router().await;

let request_body = json!({
"inputs": "Hello world!",
"parameters": {
"stream": false,
// skip generation is needed for testing to avoid
// requests to non-existing client shards
"skip_generation": true
}
});
// `Router` implements `tower::Service<Request<Body>>` so we can
// call it like any tower service, no need to run an HTTP server.
let response = app
.oneshot(
Request::builder()
.uri("/generate")
.method(Method::POST)
.header(http::header::CONTENT_TYPE, "application/json")
.body(axum::body::Body::from(request_body.to_string()))
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::OK);

let body = response.into_body().collect().await.unwrap().to_bytes();
let utf8_body = std::str::from_utf8(&body[..]).unwrap();

let expected_response_body = json!({
"generated_text": "Hello world!"
});
assert_eq!(utf8_body, expected_response_body.to_string());
}

#[tokio::test]
async fn test_return_json_error_on_empty_inputs() {
let app = build_router().await;

let request_body = json!({
"inputs": "",
"parameters": {
"stream": false,
/* we do not need to skip_generation here because the validation will fail when trying to generate */
}
});

let response = app
.oneshot(
Request::builder()
.uri("/generate")
.method(Method::POST)
.header(http::header::CONTENT_TYPE, "application/json")
.body(axum::body::Body::from(request_body.to_string()))
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);

let body = response.into_body().collect().await.unwrap().to_bytes();
let utf8_body = std::str::from_utf8(&body[..]).unwrap();

let expected_response_body = json!({
"error":"Input validation error: `inputs` cannot be empty",
"error_type":"validation"
});
assert_eq!(utf8_body, expected_response_body.to_string());
}
}
Loading