Skip to content

Commit

Permalink
feat: improve update doc and add command to print router schema
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jul 2, 2024
1 parent caa4401 commit e4161a1
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 160 deletions.
13 changes: 3 additions & 10 deletions .github/workflows/autodocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,9 @@ jobs:
echo text-generation-launcher --help
python update_doc.py md --check
- name: Install Protoc
uses: arduino/setup-protoc@v1
- name: Clean unused files
run: |
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
- name: Install
run: |
make install-cpu
- name: Install router
id: install-router
run: cargo install --path router/

- name: Check that openapi schema is up-to-date
run: |
Expand Down
23 changes: 20 additions & 3 deletions router/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use axum::http::HeaderValue;
use clap::Parser;
use clap::Subcommand;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator;
Expand Down Expand Up @@ -27,6 +28,9 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,

#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
Expand Down Expand Up @@ -85,10 +89,15 @@ struct Args {
max_client_batch_size: usize,
}

#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}

#[tokio::main]
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();

// Pattern match configuration
let Args {
max_concurrent_requests,
Expand Down Expand Up @@ -119,10 +128,17 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
command,
} = args;

// Launch Tokio runtime
init_logging(otlp_endpoint, otlp_service_name, json_output);
let print_schema_command = match command {
Some(Commands::PrintSchema) => true,
None => {
// only init logging if we are not running the print schema command
init_logging(otlp_endpoint, otlp_service_name, json_output);
false
}
};

// Validate args
if max_input_tokens >= max_total_tokens {
Expand Down Expand Up @@ -388,6 +404,7 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
print_schema_command,
)
.await?;
Ok(())
Expand Down
15 changes: 11 additions & 4 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1387,10 +1387,10 @@ async fn tokenize(

/// Prometheus metrics scrape endpoint
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String))
get,
tag = "Text Generation Inference",
path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String))
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render()
Expand Down Expand Up @@ -1430,6 +1430,7 @@ pub async fn run(
messages_api_enabled: bool,
grammar_support: bool,
max_client_batch_size: usize,
print_schema_command: bool,
) -> Result<(), WebServerError> {
// OpenAPI documentation
#[derive(OpenApi)]
Expand Down Expand Up @@ -1500,6 +1501,12 @@ pub async fn run(
struct ApiDoc;

// Create state
if print_schema_command {
let api_doc = ApiDoc::openapi();
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
println!("{}", api_doc);
std::process::exit(0);
}

// Open connection, get model info and warmup
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
Expand Down
184 changes: 41 additions & 143 deletions update_doc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import subprocess
import argparse
import ast
import requests
import json
import time
import os

TEMPLATE = """
Expand Down Expand Up @@ -126,155 +124,55 @@ def check_supported_models(check: bool):
f.write(final_doc)


def start_server_and_wait():
log_file = open("/tmp/server_log.txt", "w")

process = subprocess.Popen(
["text-generation-launcher"],
stdout=log_file,
stderr=subprocess.STDOUT,
universal_newlines=True,
)
print("Server is starting...")

start_time = time.time()
while True:
try:
response = requests.get("http://127.0.0.1:3000/health")
if response.status_code == 200:
print("Server is up and running!")
return process, log_file
except requests.RequestException:
# timeout after 3 minutes (CI can be slow sometimes)
if time.time() - start_time > 180:
log_file.close()
with open("/tmp/server_log.txt", "r") as f:
print("Server log:")
print(f.read())
os.remove("/tmp/server_log.txt")
raise TimeoutError("Server didn't start within 60 seconds")
time.sleep(1)


def stop_server(process, log_file, show=False):
process.terminate()
process.wait()
log_file.close()

if show:
with open("/tmp/server_log.txt", "r") as f:
print("Server log:")
print(f.read())
os.remove("/tmp/server_log.txt")


def get_openapi_json():
response = requests.get("http://127.0.0.1:3000/api-doc/openapi.json")
# error if not 200
response.raise_for_status()
return response.json()


def update_openapi_json(new_data, filename="docs/openapi.json"):
with open(filename, "w") as f:
json.dump(new_data, f, indent=2)


def compare_openapi(old_data, new_data):
differences = []

def compare_recursive(old, new, path=""):
if isinstance(old, dict) and isinstance(new, dict):
for key in set(old.keys()) | set(new.keys()):
new_path = f"{path}.{key}" if path else key
if key not in old:
differences.append(f"Added: {new_path}")
elif key not in new:
differences.append(f"Removed: {new_path}")
else:
compare_recursive(old[key], new[key], new_path)
elif old != new:
differences.append(f"Changed: {path}")

compare_recursive(old_data, new_data)
return differences


def openapi(check: bool):
def get_openapi_schema():
try:
server_process, log_file = start_server_and_wait()

try:
new_openapi_data = get_openapi_json()

if check:
try:
with open("docs/openapi.json", "r") as f:
old_openapi_data = json.load(f)
except FileNotFoundError:
print(
"docs/openapi.json not found. Run without --check to create it."
)
return

differences = compare_openapi(old_openapi_data, new_openapi_data)

if differences:
print("The following differences were found:")
for diff in differences:
print(diff)
print(
"Please run the script without --check to update the documentation."
)
else:
print("Documentation is up to date.")
else:
update_openapi_json(new_openapi_data)
print("Documentation updated successfully.")

finally:
stop_server(server_process, log_file)

except TimeoutError as e:
print(f"Error: {e}")
raise SystemExit(1)
except requests.RequestException as e:
print(f"Error communicating with the server: {e}")
output = subprocess.check_output(["text-generation-router", "print-schema"])
return json.loads(output)
except subprocess.CalledProcessError as e:
print(f"Error running text-generation-router print-schema: {e}")
raise SystemExit(1)
except json.JSONDecodeError:
print("Error: Invalid JSON received from the server")
raise SystemExit(1)
except Exception as e:
print(f"An unexpected error occurred: {e}")
print("Error: Invalid JSON received from text-generation-router print-schema")
raise SystemExit(1)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Update documentation for text-generation-launcher"
)
subparsers = parser.add_subparsers(dest="command", required=True)
def check_openapi(check: bool):
new_openapi_data = get_openapi_schema()
filename = "docs/openapi.json"
tmp_filename = "openapi_tmp.json"

openapi_parser = subparsers.add_parser(
"openapi", help="Update OpenAPI documentation"
)
openapi_parser.add_argument(
"--check",
action="store_true",
help="Check if the OpenAPI documentation needs updating",
)
with open(tmp_filename, "w") as f:
json.dump(new_openapi_data, f, indent=2)

if check:
diff = subprocess.run(
["diff", tmp_filename, filename], capture_output=True
).stdout.decode()
os.remove(tmp_filename)

if diff:
print(diff)
raise Exception(
"OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it"
)

return True
else:
os.rename(tmp_filename, filename)
print("OpenAPI documentation updated.")
return True

md_parser = subparsers.add_parser("md", help="Update launcher and supported models")
md_parser.add_argument(
"--check",
action="store_true",
help="Check if the launcher documentation needs updating",
)

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true")

args = parser.parse_args()

if args.command == "openapi":
openapi(args.check)
elif args.command == "md":
check_cli(args.check)
check_supported_models(args.check)
check_cli(args.check)
check_supported_models(args.check)
check_openapi(args.check)


if __name__ == "__main__":
main()

0 comments on commit e4161a1

Please sign in to comment.