From 3508938af7b364c760cc1e7286f97971ad70d1b0 Mon Sep 17 00:00:00 2001 From: Jon Durbin Date: Wed, 23 Aug 2023 13:56:52 +0000 Subject: [PATCH] More lmoe updates. --- README.md | 13 +++++- airoboros/lmoe/api.py | 93 ++++++++++++++++++++++++---------------- airoboros/lmoe/router.py | 6 +-- airoboros/lmoe/vllm.py | 92 ++++++++++++++++++++++++++++----------- 4 files changed, 136 insertions(+), 68 deletions(-) diff --git a/README.md b/README.md index 7edf528..9684957 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,8 @@ First, download the base llama-2 model for whichever model size you want, e.g.: Next, download the LMoE package that corresponds to that base model, e.g.: [airoboros-lmoe-7b-2.1](https://huggingface.co/jondurbin/airoboros-lmoe-7b-2.1) +*NOTE: 13b also available, 70b in progress* + Here's an example command to start the server: ``` @@ -102,7 +104,16 @@ curl -H 'content-type: application/json' http://127.0.0.1:8000/v1/chat/completio }' ``` -*I also tried adding vllm support, but it's not working quite right (yet) - see airoboros/lmoe/vllm.py* +I've also added an vllm-based server, but the results aren't quite as good (not sure why yet). To use it, make sure you install `vllm` and `fschat`, or `pip install airoboros[vllm]` + +``` +python -m airoboros.lmoe.vllm \ + --model ./llama-2-7b-hf \ + --lmoe-path ../airoboros-lmoe-7b-2.1 \ + --router-max-samples 100 \ + --port 8000 \ + --host 127.0.0.1 +``` ## Generating instructions diff --git a/airoboros/lmoe/api.py b/airoboros/lmoe/api.py index 9adfe80..1fa6642 100644 --- a/airoboros/lmoe/api.py +++ b/airoboros/lmoe/api.py @@ -1,4 +1,6 @@ import argparse +import asyncio +import datetime import fastapi import glob import os @@ -23,6 +25,7 @@ from typing import List, Dict warnings.filterwarnings("ignore") +MODEL_LOCK = asyncio.Lock() MODELS = {} ROLE_MAP = { "user": "USER", @@ -80,36 +83,8 @@ async def list_models(): } -@app.post("/v1/chat/completions") -async def chat_completions(raw_request: Request): - """Simulate the OpenAI /v1/chat/completions endpoint. - - NOTE: Parameters supported in request include: - - model: str, must be loaded from CLI args. - - messages: list[dict[str, str]] - - temperature: float - - repetition_penalty: float - - top_p: float - - top_k: int - - stop: list[str] - - max_tokens: int - - Example request: - curl -s -XPOST http://127.0.0.1:8000/v1/chat/completions -H 'content-type: application/json' -d '{ - "model": "airoboros-lmoe-7b-2.1", - "messages": [ - { - "role": "system", - "content": "A chat.", - }, - { - "role": "user", - "content": "Write a poem about Ouroboros." - } - ] - }' - """ - request = ChatRequest(**await raw_request.json()) +def complete_request(request): + """Sync method to complete a request, to make sure we aren't message with model/LoRAs concurrently.""" if any( [ getattr(request, key, 0) < 0 @@ -158,7 +133,7 @@ async def chat_completions(raw_request: Request): expected == "assistant" else: expected == "user" - prompt = "\n".join(prompt_parts + ["ASSISTANT: "]) + prompt = " ".join(prompt_parts + ["ASSISTANT: "]) logger.debug(f"Prompt:\n{prompt}") # Validate the length of the input. @@ -166,20 +141,22 @@ async def chat_completions(raw_request: Request): "cuda" ) max_len = MODELS[request.model]["config"].max_position_embeddings - max_tokens = request.max_tokens or max_len - len(input_ids) - 1 - if len(input_ids) + max_tokens > max_len: + max_tokens = request.max_tokens or max_len - len(input_ids[0]) - 1 + if len(input_ids[0]) + max_tokens > max_len: raise HTTPException( status_code=422, detail="Prompt length + max_tokens exceeds max model length.", ) # Route the request to the appropriate expert (LoRA). + started_at = datetime.datetime.utcnow() expert = MODELS[request.model]["router"].route(prompt) model = MODELS[request.model]["model"] loaded_expert = getattr(model, "__expert__", None) if loaded_expert != expert: model.set_adapter(expert) setattr(model, "__expert__", expert) + routing_duration = (datetime.datetime.utcnow() - started_at).total_seconds() # Update our stopping criteria. stop_words = request.stop @@ -196,6 +173,7 @@ async def chat_completions(raw_request: Request): ) # Generate the response. + started_at = datetime.datetime.utcnow() with torch.no_grad(): outputs = model.generate( input_ids=input_ids, @@ -212,11 +190,14 @@ async def chat_completions(raw_request: Request): .split("ASSISTANT:")[1] .strip() ) + duration = (datetime.datetime.utcnow() - started_at).total_seconds() request_id = f"cmpl-{uuid.uuid4()}" return { "id": request_id, "object": "chat.completion", "created": int(time.time()), + "duration": duration, + "routing_duration": routing_duration, "model": request.model, "expert": expert, "choices": [ @@ -230,13 +211,47 @@ async def chat_completions(raw_request: Request): } ], "usage": { - "prompt_tokens": len(input_ids), - "completion_tokens": len(outputs), - "total_tokens": len(input_ids) + len(outputs), + "prompt_tokens": len(input_ids[0]), + "completion_tokens": len(outputs[0]), + "total_tokens": len(input_ids[0]) + len(outputs[0]), }, } +@app.post("/v1/chat/completions") +async def chat_completions(raw_request: Request): + """Simulate the OpenAI /v1/chat/completions endpoint. + + NOTE: Parameters supported in request include: + - model: str, must be loaded from CLI args. + - messages: list[dict[str, str]] + - temperature: float + - repetition_penalty: float + - top_p: float + - top_k: int + - stop: list[str] + - max_tokens: int + + Example request: + curl -s -XPOST http://127.0.0.1:8000/v1/chat/completions -H 'content-type: application/json' -d '{ + "model": "airoboros-lmoe-7b-2.1", + "messages": [ + { + "role": "system", + "content": "A chat.", + }, + { + "role": "user", + "content": "Write a poem about Ouroboros." + } + ] + }' + """ + request = ChatRequest(**await raw_request.json()) + async with MODEL_LOCK: + return complete_request(request) + + def main(): parser = argparse.ArgumentParser( description="airoboros LMoE API server, somewhat similar to OpenAI API.", @@ -245,7 +260,7 @@ def main(): parser.add_argument("-p", "--port", type=int, default=8000, help="port number") parser.add_argument( "-k", - "--router-max-k", + "--router-k", type=int, default=20, help="k, when doing faiss approximate knn search to select expert", @@ -296,7 +311,9 @@ def main(): adapter_name="general", ), "router": Router( - input_paths=routing_paths, max_samples=args.router_max_samples + input_paths=routing_paths, + max_samples=args.router_max_samples, + k=args.router_k, ), } logger.info( diff --git a/airoboros/lmoe/router.py b/airoboros/lmoe/router.py index 57686e8..93094e4 100644 --- a/airoboros/lmoe/router.py +++ b/airoboros/lmoe/router.py @@ -25,13 +25,13 @@ def __init__( self, model_name_or_path: str = "thenlper/gte-small", input_paths: List[str] = [], - max_k: int = 50, + k: int = 50, max_samples: int = 500, ): """Constructor.""" self.model = SentenceTransformer(model_name_or_path, device="cuda") self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - self.max_k = max_k + self.k = k self.max_samples = max_samples if not input_paths: input_paths = [ @@ -71,7 +71,7 @@ def route(self, prompt: str) -> str: best_expert = None best_distance = math.inf for expert, index in self.indices.items(): - distances, _ = index.search(query_emb, k=min(index.ntotal, self.max_k)) + distances, _ = index.search(query_emb, k=min(index.ntotal, self.k)) distances = distances[0].tolist() average_distance = sum(distances) / len(distances) logger.debug(f"Average distance [{expert}]: {average_distance}") diff --git a/airoboros/lmoe/vllm.py b/airoboros/lmoe/vllm.py index fe944cb..957f09d 100644 --- a/airoboros/lmoe/vllm.py +++ b/airoboros/lmoe/vllm.py @@ -22,7 +22,6 @@ from vllm.entrypoints.openai.api_server import ( create_error_response, check_model, - get_gen_prompt, check_length, ) from vllm.entrypoints.openai.protocol import ( @@ -46,6 +45,11 @@ from airoboros.lmoe.lora import lora_merge_unmerge_state_dict TIMEOUT_KEEP_ALIVE = 5 # seconds +MODEL_LOCK = asyncio.Lock() +ROLE_MAP = { + "user": "USER", + "assistant": "ASSISTANT", +} served_model = None tokenizer = None @@ -61,29 +65,43 @@ async def show_available_models(): return ModelList(data=model_cards) -@app.post("/v1/chat/completions") -async def create_chat_completion(raw_request: Request): - """Completion API similar to OpenAI's API. +async def complete_request(raw_request, request): + """Complete a chat request, which is wrapped by asyncio lock.""" - See https://platform.openai.com/docs/api-reference/chat/create - for the API specification. This API mimics the OpenAI ChatCompletion API. - - NOTE: Currently we do not support the following features: - - function_call (Users should implement this by themselves) - - logit_bias (to be supported by vLLM engine) - """ - request = ChatCompletionRequest(**await raw_request.json()) - - # Hacky, but we'll inject a default system message, since it's - # slightly different for airoboros 2.1. + # Make sure we have a system prompt. if request.messages[0]["role"] != "system": request.messages = [{"role": "system", "content": "A chat."}] + request.messages - - logger.info(f"Received chat completion request: {request}") + logger.debug(f"Received chat completion request: {request}") + + # Build the prompt, with a bit more (very basic) validation. + prompt_parts = [] + expected = "system" + for message in request.messages: + if message["role"] == "system": + prompt_parts.append(message["content"]) + expected = "user" + elif message["role"] not in ROLE_MAP: + return create_error_response( + HTTPStatus.BAD_REQUEST, f"Invalid role found: {message['role']}" + ) + elif message["role"] != expected: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "Invalid messages structure, expected system -> [user assistant]* user", + ) + else: + prompt_parts.append( + f"{ROLE_MAP[message['role']]}: {message['content'].strip()}" + ) + if message["role"] == "user": + expected == "assistant" + else: + expected == "user" + prompt = " ".join(prompt_parts + ["ASSISTANT: "]) + logger.debug(f"Prompt:\n{prompt}") # Route the request to the appropriate expert (LoRA). - instruction = request.messages[0]["content"] + request.messages[-1]["content"] - expert = router.route(instruction) + expert = router.route(prompt) loaded_expert = getattr(engine, "__expert__", None) if loaded_expert != expert: if loaded_expert is not None: @@ -108,7 +126,6 @@ async def create_chat_completion(raw_request: Request): HTTPStatus.BAD_REQUEST, "logit_bias is not currently supported" ) - prompt = await get_gen_prompt(request).strip() + " " error_check_ret = await check_length(request, prompt) if error_check_ret is not None: return error_check_ret @@ -219,7 +236,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: for output in final_res.outputs: choice_data = ChatCompletionResponseChoice( index=output.index, - message=ChatMessage(role="assistant", content=output.text), + message=ChatMessage(role="assistant", content=output.text.strip()), finish_reason=output.finish_reason, ) choices.append(choice_data) @@ -255,6 +272,22 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: return response +@app.post("/v1/chat/completions") +async def create_chat_completion(raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI ChatCompletion API. + + NOTE: Currently we do not support the following features: + - function_call (Users should implement this by themselves) + - logit_bias (to be supported by vLLM engine) + """ + request = ChatCompletionRequest(**await raw_request.json()) + async with MODEL_LOCK: + return await complete_request(raw_request, request) + + if __name__ == "__main__": parser = argparse.ArgumentParser( description="vLLM OpenAI-Compatible RESTful API server." @@ -289,7 +322,13 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: "to use when building router index.", ) parser.add_argument( - "--lmoe-path", + "--router-k", + type=int, + default=25, + help="k, when doing faiss approximate knn search to select expert", + ) + parser.add_argument( + "--lmoe", type=str, required=True, help="Path to LMoE directory with adapters and data", @@ -318,13 +357,14 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: # Setup our router, and load all of the adapters so they # are ready to swap in/out. routing_paths = [ - str(p) - for p in glob.glob(os.path.join(args.lmoe_path, "routing_data", "*.jsonl")) + str(p) for p in glob.glob(os.path.join(args.lmoe, "routing_data", "*.jsonl")) ] - router = Router(input_paths=routing_paths, max_samples=args.router_max_samples) + router = Router( + input_paths=routing_paths, max_samples=args.router_max_samples, k=args.router_k + ) adapters = {} adapter_configs = {} - for directory in glob.glob(os.path.join(args.lmoe_path, "adapters", "*")): + for directory in glob.glob(os.path.join(args.lmoe, "adapters", "*")): adapters[str(directory).split("/")[-1]] = torch.load( os.path.join(str(directory), "adapter_model.bin"), map_location="cuda:0" )