Skip to content

Commit

Permalink
More lmoe updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
jondurbin committed Aug 23, 2023
1 parent 04b6c42 commit 3508938
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 68 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ First, download the base llama-2 model for whichever model size you want, e.g.:

Next, download the LMoE package that corresponds to that base model, e.g.: [airoboros-lmoe-7b-2.1](https://huggingface.co/jondurbin/airoboros-lmoe-7b-2.1)

*NOTE: 13b also available, 70b in progress*

Here's an example command to start the server:

```
Expand Down Expand Up @@ -102,7 +104,16 @@ curl -H 'content-type: application/json' http://127.0.0.1:8000/v1/chat/completio
}'
```

*I also tried adding vllm support, but it's not working quite right (yet) - see airoboros/lmoe/vllm.py*
I've also added an vllm-based server, but the results aren't quite as good (not sure why yet). To use it, make sure you install `vllm` and `fschat`, or `pip install airoboros[vllm]`

```
python -m airoboros.lmoe.vllm \
--model ./llama-2-7b-hf \
--lmoe-path ../airoboros-lmoe-7b-2.1 \
--router-max-samples 100 \
--port 8000 \
--host 127.0.0.1
```

## Generating instructions

Expand Down
93 changes: 55 additions & 38 deletions airoboros/lmoe/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import asyncio
import datetime
import fastapi
import glob
import os
Expand All @@ -23,6 +25,7 @@
from typing import List, Dict

warnings.filterwarnings("ignore")
MODEL_LOCK = asyncio.Lock()
MODELS = {}
ROLE_MAP = {
"user": "USER",
Expand Down Expand Up @@ -80,36 +83,8 @@ async def list_models():
}


@app.post("/v1/chat/completions")
async def chat_completions(raw_request: Request):
"""Simulate the OpenAI /v1/chat/completions endpoint.
NOTE: Parameters supported in request include:
- model: str, must be loaded from CLI args.
- messages: list[dict[str, str]]
- temperature: float
- repetition_penalty: float
- top_p: float
- top_k: int
- stop: list[str]
- max_tokens: int
Example request:
curl -s -XPOST http://127.0.0.1:8000/v1/chat/completions -H 'content-type: application/json' -d '{
"model": "airoboros-lmoe-7b-2.1",
"messages": [
{
"role": "system",
"content": "A chat.",
},
{
"role": "user",
"content": "Write a poem about Ouroboros."
}
]
}'
"""
request = ChatRequest(**await raw_request.json())
def complete_request(request):
"""Sync method to complete a request, to make sure we aren't message with model/LoRAs concurrently."""
if any(
[
getattr(request, key, 0) < 0
Expand Down Expand Up @@ -158,28 +133,30 @@ async def chat_completions(raw_request: Request):
expected == "assistant"
else:
expected == "user"
prompt = "\n".join(prompt_parts + ["ASSISTANT: "])
prompt = " ".join(prompt_parts + ["ASSISTANT: "])
logger.debug(f"Prompt:\n{prompt}")

# Validate the length of the input.
input_ids = MODELS["__tokenizer__"](prompt, return_tensors="pt")["input_ids"].to(
"cuda"
)
max_len = MODELS[request.model]["config"].max_position_embeddings
max_tokens = request.max_tokens or max_len - len(input_ids) - 1
if len(input_ids) + max_tokens > max_len:
max_tokens = request.max_tokens or max_len - len(input_ids[0]) - 1
if len(input_ids[0]) + max_tokens > max_len:
raise HTTPException(
status_code=422,
detail="Prompt length + max_tokens exceeds max model length.",
)

# Route the request to the appropriate expert (LoRA).
started_at = datetime.datetime.utcnow()
expert = MODELS[request.model]["router"].route(prompt)
model = MODELS[request.model]["model"]
loaded_expert = getattr(model, "__expert__", None)
if loaded_expert != expert:
model.set_adapter(expert)
setattr(model, "__expert__", expert)
routing_duration = (datetime.datetime.utcnow() - started_at).total_seconds()

# Update our stopping criteria.
stop_words = request.stop
Expand All @@ -196,6 +173,7 @@ async def chat_completions(raw_request: Request):
)

# Generate the response.
started_at = datetime.datetime.utcnow()
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
Expand All @@ -212,11 +190,14 @@ async def chat_completions(raw_request: Request):
.split("ASSISTANT:")[1]
.strip()
)
duration = (datetime.datetime.utcnow() - started_at).total_seconds()
request_id = f"cmpl-{uuid.uuid4()}"
return {
"id": request_id,
"object": "chat.completion",
"created": int(time.time()),
"duration": duration,
"routing_duration": routing_duration,
"model": request.model,
"expert": expert,
"choices": [
Expand All @@ -230,13 +211,47 @@ async def chat_completions(raw_request: Request):
}
],
"usage": {
"prompt_tokens": len(input_ids),
"completion_tokens": len(outputs),
"total_tokens": len(input_ids) + len(outputs),
"prompt_tokens": len(input_ids[0]),
"completion_tokens": len(outputs[0]),
"total_tokens": len(input_ids[0]) + len(outputs[0]),
},
}


@app.post("/v1/chat/completions")
async def chat_completions(raw_request: Request):
"""Simulate the OpenAI /v1/chat/completions endpoint.
NOTE: Parameters supported in request include:
- model: str, must be loaded from CLI args.
- messages: list[dict[str, str]]
- temperature: float
- repetition_penalty: float
- top_p: float
- top_k: int
- stop: list[str]
- max_tokens: int
Example request:
curl -s -XPOST http://127.0.0.1:8000/v1/chat/completions -H 'content-type: application/json' -d '{
"model": "airoboros-lmoe-7b-2.1",
"messages": [
{
"role": "system",
"content": "A chat.",
},
{
"role": "user",
"content": "Write a poem about Ouroboros."
}
]
}'
"""
request = ChatRequest(**await raw_request.json())
async with MODEL_LOCK:
return complete_request(request)


def main():
parser = argparse.ArgumentParser(
description="airoboros LMoE API server, somewhat similar to OpenAI API.",
Expand All @@ -245,7 +260,7 @@ def main():
parser.add_argument("-p", "--port", type=int, default=8000, help="port number")
parser.add_argument(
"-k",
"--router-max-k",
"--router-k",
type=int,
default=20,
help="k, when doing faiss approximate knn search to select expert",
Expand Down Expand Up @@ -296,7 +311,9 @@ def main():
adapter_name="general",
),
"router": Router(
input_paths=routing_paths, max_samples=args.router_max_samples
input_paths=routing_paths,
max_samples=args.router_max_samples,
k=args.router_k,
),
}
logger.info(
Expand Down
6 changes: 3 additions & 3 deletions airoboros/lmoe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ def __init__(
self,
model_name_or_path: str = "thenlper/gte-small",
input_paths: List[str] = [],
max_k: int = 50,
k: int = 50,
max_samples: int = 500,
):
"""Constructor."""
self.model = SentenceTransformer(model_name_or_path, device="cuda")
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.max_k = max_k
self.k = k
self.max_samples = max_samples
if not input_paths:
input_paths = [
Expand Down Expand Up @@ -71,7 +71,7 @@ def route(self, prompt: str) -> str:
best_expert = None
best_distance = math.inf
for expert, index in self.indices.items():
distances, _ = index.search(query_emb, k=min(index.ntotal, self.max_k))
distances, _ = index.search(query_emb, k=min(index.ntotal, self.k))
distances = distances[0].tolist()
average_distance = sum(distances) / len(distances)
logger.debug(f"Average distance [{expert}]: {average_distance}")
Expand Down
92 changes: 66 additions & 26 deletions airoboros/lmoe/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from vllm.entrypoints.openai.api_server import (
create_error_response,
check_model,
get_gen_prompt,
check_length,
)
from vllm.entrypoints.openai.protocol import (
Expand All @@ -46,6 +45,11 @@
from airoboros.lmoe.lora import lora_merge_unmerge_state_dict

TIMEOUT_KEEP_ALIVE = 5 # seconds
MODEL_LOCK = asyncio.Lock()
ROLE_MAP = {
"user": "USER",
"assistant": "ASSISTANT",
}

served_model = None
tokenizer = None
Expand All @@ -61,29 +65,43 @@ async def show_available_models():
return ModelList(data=model_cards)


@app.post("/v1/chat/completions")
async def create_chat_completion(raw_request: Request):
"""Completion API similar to OpenAI's API.
async def complete_request(raw_request, request):
"""Complete a chat request, which is wrapped by asyncio lock."""

See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
request = ChatCompletionRequest(**await raw_request.json())

# Hacky, but we'll inject a default system message, since it's
# slightly different for airoboros 2.1.
# Make sure we have a system prompt.
if request.messages[0]["role"] != "system":
request.messages = [{"role": "system", "content": "A chat."}] + request.messages

logger.info(f"Received chat completion request: {request}")
logger.debug(f"Received chat completion request: {request}")

# Build the prompt, with a bit more (very basic) validation.
prompt_parts = []
expected = "system"
for message in request.messages:
if message["role"] == "system":
prompt_parts.append(message["content"])
expected = "user"
elif message["role"] not in ROLE_MAP:
return create_error_response(
HTTPStatus.BAD_REQUEST, f"Invalid role found: {message['role']}"
)
elif message["role"] != expected:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"Invalid messages structure, expected system -> [user assistant]* user",
)
else:
prompt_parts.append(
f"{ROLE_MAP[message['role']]}: {message['content'].strip()}"
)
if message["role"] == "user":
expected == "assistant"
else:
expected == "user"
prompt = " ".join(prompt_parts + ["ASSISTANT: "])
logger.debug(f"Prompt:\n{prompt}")

# Route the request to the appropriate expert (LoRA).
instruction = request.messages[0]["content"] + request.messages[-1]["content"]
expert = router.route(instruction)
expert = router.route(prompt)
loaded_expert = getattr(engine, "__expert__", None)
if loaded_expert != expert:
if loaded_expert is not None:
Expand All @@ -108,7 +126,6 @@ async def create_chat_completion(raw_request: Request):
HTTPStatus.BAD_REQUEST, "logit_bias is not currently supported"
)

prompt = await get_gen_prompt(request).strip() + " "
error_check_ret = await check_length(request, prompt)
if error_check_ret is not None:
return error_check_ret
Expand Down Expand Up @@ -219,7 +236,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
for output in final_res.outputs:
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role="assistant", content=output.text),
message=ChatMessage(role="assistant", content=output.text.strip()),
finish_reason=output.finish_reason,
)
choices.append(choice_data)
Expand Down Expand Up @@ -255,6 +272,22 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
return response


@app.post("/v1/chat/completions")
async def create_chat_completion(raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
request = ChatCompletionRequest(**await raw_request.json())
async with MODEL_LOCK:
return await complete_request(raw_request, request)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server."
Expand Down Expand Up @@ -289,7 +322,13 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
"to use when building router index.",
)
parser.add_argument(
"--lmoe-path",
"--router-k",
type=int,
default=25,
help="k, when doing faiss approximate knn search to select expert",
)
parser.add_argument(
"--lmoe",
type=str,
required=True,
help="Path to LMoE directory with adapters and data",
Expand Down Expand Up @@ -318,13 +357,14 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
# Setup our router, and load all of the adapters so they
# are ready to swap in/out.
routing_paths = [
str(p)
for p in glob.glob(os.path.join(args.lmoe_path, "routing_data", "*.jsonl"))
str(p) for p in glob.glob(os.path.join(args.lmoe, "routing_data", "*.jsonl"))
]
router = Router(input_paths=routing_paths, max_samples=args.router_max_samples)
router = Router(
input_paths=routing_paths, max_samples=args.router_max_samples, k=args.router_k
)
adapters = {}
adapter_configs = {}
for directory in glob.glob(os.path.join(args.lmoe_path, "adapters", "*")):
for directory in glob.glob(os.path.join(args.lmoe, "adapters", "*")):
adapters[str(directory).split("/")[-1]] = torch.load(
os.path.join(str(directory), "adapter_model.bin"), map_location="cuda:0"
)
Expand Down

0 comments on commit 3508938

Please sign in to comment.