diff --git a/.github/workflows/autodocs.yaml b/.github/workflows/autodocs.yaml index 502feb8b30a..e04ed21175e 100644 --- a/.github/workflows/autodocs.yaml +++ b/.github/workflows/autodocs.yaml @@ -20,16 +20,9 @@ jobs: echo text-generation-launcher --help python update_doc.py md --check - - name: Install Protoc - uses: arduino/setup-protoc@v1 - - name: Clean unused files - run: | - sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android - sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET - - - name: Install - run: | - make install-cpu + - name: Install router + id: install-router + run: cargo install --path router/ - name: Check that openapi schema is up-to-date run: | diff --git a/router/src/main.rs b/router/src/main.rs index 8618f57eb38..21cd66496da 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,5 +1,6 @@ use axum::http::HeaderValue; use clap::Parser; +use clap::Subcommand; use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::{Cache, Repo, RepoType}; use opentelemetry::sdk::propagation::TraceContextPropagator; @@ -27,6 +28,9 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { + #[command(subcommand)] + command: Option, + #[clap(default_value = "128", long, env)] max_concurrent_requests: usize, #[clap(default_value = "2", long, env)] @@ -85,10 +89,15 @@ struct Args { max_client_batch_size: usize, } +#[derive(Debug, Subcommand)] +enum Commands { + PrintSchema, +} + #[tokio::main] async fn main() -> Result<(), RouterError> { - // Get args let args = Args::parse(); + // Pattern match configuration let Args { max_concurrent_requests, @@ -119,10 +128,17 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, + command, } = args; - // Launch Tokio runtime - init_logging(otlp_endpoint, otlp_service_name, json_output); + let print_schema_command = match command { + Some(Commands::PrintSchema) => true, + None => { + // only init logging if we are not running the print schema command + init_logging(otlp_endpoint, otlp_service_name, json_output); + false + } + }; // Validate args if max_input_tokens >= max_total_tokens { @@ -388,6 +404,7 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, + print_schema_command, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index d24774f96c3..9be6a35cd8b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1387,10 +1387,10 @@ async fn tokenize( /// Prometheus metrics scrape endpoint #[utoipa::path( -get, -tag = "Text Generation Inference", -path = "/metrics", -responses((status = 200, description = "Prometheus Metrics", body = String)) + get, + tag = "Text Generation Inference", + path = "/metrics", + responses((status = 200, description = "Prometheus Metrics", body = String)) )] async fn metrics(prom_handle: Extension) -> String { prom_handle.render() @@ -1430,6 +1430,7 @@ pub async fn run( messages_api_enabled: bool, grammar_support: bool, max_client_batch_size: usize, + print_schema_command: bool, ) -> Result<(), WebServerError> { // OpenAPI documentation #[derive(OpenApi)] @@ -1500,6 +1501,12 @@ pub async fn run( struct ApiDoc; // Create state + if print_schema_command { + let api_doc = ApiDoc::openapi(); + let api_doc = serde_json::to_string_pretty(&api_doc).unwrap(); + println!("{}", api_doc); + std::process::exit(0); + } // Open connection, get model info and warmup let (scheduler, health_ext, shard_info, max_batch_total_tokens): ( diff --git a/update_doc.py b/update_doc.py index f4ac12908d2..df8f0471466 100644 --- a/update_doc.py +++ b/update_doc.py @@ -1,9 +1,7 @@ import subprocess import argparse import ast -import requests import json -import time import os TEMPLATE = """ @@ -126,155 +124,55 @@ def check_supported_models(check: bool): f.write(final_doc) -def start_server_and_wait(): - log_file = open("/tmp/server_log.txt", "w") - - process = subprocess.Popen( - ["text-generation-launcher"], - stdout=log_file, - stderr=subprocess.STDOUT, - universal_newlines=True, - ) - print("Server is starting...") - - start_time = time.time() - while True: - try: - response = requests.get("http://127.0.0.1:3000/health") - if response.status_code == 200: - print("Server is up and running!") - return process, log_file - except requests.RequestException: - # timeout after 3 minutes (CI can be slow sometimes) - if time.time() - start_time > 180: - log_file.close() - with open("/tmp/server_log.txt", "r") as f: - print("Server log:") - print(f.read()) - os.remove("/tmp/server_log.txt") - raise TimeoutError("Server didn't start within 60 seconds") - time.sleep(1) - - -def stop_server(process, log_file, show=False): - process.terminate() - process.wait() - log_file.close() - - if show: - with open("/tmp/server_log.txt", "r") as f: - print("Server log:") - print(f.read()) - os.remove("/tmp/server_log.txt") - - -def get_openapi_json(): - response = requests.get("http://127.0.0.1:3000/api-doc/openapi.json") - # error if not 200 - response.raise_for_status() - return response.json() - - -def update_openapi_json(new_data, filename="docs/openapi.json"): - with open(filename, "w") as f: - json.dump(new_data, f, indent=2) - - -def compare_openapi(old_data, new_data): - differences = [] - - def compare_recursive(old, new, path=""): - if isinstance(old, dict) and isinstance(new, dict): - for key in set(old.keys()) | set(new.keys()): - new_path = f"{path}.{key}" if path else key - if key not in old: - differences.append(f"Added: {new_path}") - elif key not in new: - differences.append(f"Removed: {new_path}") - else: - compare_recursive(old[key], new[key], new_path) - elif old != new: - differences.append(f"Changed: {path}") - - compare_recursive(old_data, new_data) - return differences - - -def openapi(check: bool): +def get_openapi_schema(): try: - server_process, log_file = start_server_and_wait() - - try: - new_openapi_data = get_openapi_json() - - if check: - try: - with open("docs/openapi.json", "r") as f: - old_openapi_data = json.load(f) - except FileNotFoundError: - print( - "docs/openapi.json not found. Run without --check to create it." - ) - return - - differences = compare_openapi(old_openapi_data, new_openapi_data) - - if differences: - print("The following differences were found:") - for diff in differences: - print(diff) - print( - "Please run the script without --check to update the documentation." - ) - else: - print("Documentation is up to date.") - else: - update_openapi_json(new_openapi_data) - print("Documentation updated successfully.") - - finally: - stop_server(server_process, log_file) - - except TimeoutError as e: - print(f"Error: {e}") - raise SystemExit(1) - except requests.RequestException as e: - print(f"Error communicating with the server: {e}") + output = subprocess.check_output(["text-generation-router", "print-schema"]) + return json.loads(output) + except subprocess.CalledProcessError as e: + print(f"Error running text-generation-router print-schema: {e}") raise SystemExit(1) except json.JSONDecodeError: - print("Error: Invalid JSON received from the server") - raise SystemExit(1) - except Exception as e: - print(f"An unexpected error occurred: {e}") + print("Error: Invalid JSON received from text-generation-router print-schema") raise SystemExit(1) -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Update documentation for text-generation-launcher" - ) - subparsers = parser.add_subparsers(dest="command", required=True) +def check_openapi(check: bool): + new_openapi_data = get_openapi_schema() + filename = "docs/openapi.json" + tmp_filename = "openapi_tmp.json" - openapi_parser = subparsers.add_parser( - "openapi", help="Update OpenAPI documentation" - ) - openapi_parser.add_argument( - "--check", - action="store_true", - help="Check if the OpenAPI documentation needs updating", - ) + with open(tmp_filename, "w") as f: + json.dump(new_openapi_data, f, indent=2) + + if check: + diff = subprocess.run( + ["diff", tmp_filename, filename], capture_output=True + ).stdout.decode() + os.remove(tmp_filename) + + if diff: + print(diff) + raise Exception( + "OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it" + ) + + return True + else: + os.rename(tmp_filename, filename) + print("OpenAPI documentation updated.") + return True - md_parser = subparsers.add_parser("md", help="Update launcher and supported models") - md_parser.add_argument( - "--check", - action="store_true", - help="Check if the launcher documentation needs updating", - ) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--check", action="store_true") args = parser.parse_args() - if args.command == "openapi": - openapi(args.check) - elif args.command == "md": - check_cli(args.check) - check_supported_models(args.check) + check_cli(args.check) + check_supported_models(args.check) + check_openapi(args.check) + + +if __name__ == "__main__": + main()