Skip to content

Commit

Permalink
optimise networking, remove bloat
Browse files Browse the repository at this point in the history
  • Loading branch information
Gary authored and Gary committed Dec 16, 2024
1 parent 9ec26a0 commit f6765d7
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 165 deletions.
113 changes: 63 additions & 50 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import shutil
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
from exo.apputil import create_animation_mp4
from collections import defaultdict

class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
Expand Down Expand Up @@ -160,6 +161,11 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
self.prev_token_lens: Dict[str, int] = {}
self.stream_tasks: Dict[str, asyncio.Task] = {}
self.default_model = default_model or "llama-3.2-1b"
self.token_queues = defaultdict(asyncio.Queue)

# Get the callback system and register our handler
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
self.token_callback.on_next(lambda _request_id, token, is_finished: asyncio.create_task(self.handle_token(_request_id, token, is_finished)))

cors = aiohttp_cors.setup(self.app)
cors_options = aiohttp_cors.ResourceOptions(
Expand Down Expand Up @@ -346,9 +352,6 @@ async def handle_post_chat_completions(self, request):
# request_id = str(uuid.uuid4())
# self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))

callback_id = f"chatgpt-api-wait-response-{request_id}"
callback = self.node.on_token.register(callback_id)

if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")

try:
Expand All @@ -367,53 +370,63 @@ async def handle_post_chat_completions(self, request):
)
await response.prepare(request)

async def stream_result(_request_id: str, token: int, is_finished: bool):
finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
if token == eos_token_id:
try:
# Stream tokens while waiting for inference to complete
while True:
token, is_finished = await asyncio.wait_for(
self.token_queues[request_id].get(),
timeout=self.response_timeout
)

finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)

if token == eos_token_id:
if is_finished:
finish_reason = "stop"
if is_finished and not finish_reason:
finish_reason = "length"

completion = generate_completion(
chat_request,
tokenizer,
prompt,
request_id,
[token],
stream,
finish_reason,
"chat.completion",
)

await response.write(f"data: {json.dumps(completion)}\n\n".encode())

if is_finished:
finish_reason = "stop"
if is_finished and not finish_reason:
finish_reason = "length"

completion = generate_completion(
chat_request,
tokenizer,
prompt,
request_id,
[token],
stream,
finish_reason,
"chat.completion",
break

await response.write_eof()
return response

except asyncio.TimeoutError:
return web.json_response({"detail": "Response generation timed out"}, status=408)

except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response(
{"detail": f"Error processing prompt: {str(e)}"},
status=500
)
if DEBUG >= 2: print(f"Streaming completion: {completion}")
try:
await response.write(f"data: {json.dumps(completion)}\n\n".encode())
except Exception as e:
if DEBUG >= 2: print(f"Error streaming completion: {e}")
if DEBUG >= 2: traceback.print_exc()

def on_result(_request_id: str, token: int, is_finished: bool):
if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, token, is_finished))

return _request_id == request_id and is_finished

_, token, _ = await callback.wait(on_result, timeout=self.response_timeout)
if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
try:
await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
except asyncio.TimeoutError:
print("WARNING: Stream task timed out. This should not happen.")
await response.write_eof()
return response
else:
_, token, _ = await callback.wait(
lambda _request_id, token, is_finished: _request_id == request_id and is_finished,
timeout=self.response_timeout,
)

finally:
# Clean up the queue for this request
if request_id in self.token_queues:
del self.token_queues[request_id]
else:
tokens = []
while True:
token, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
tokens.append(token)
if is_finished:
break
finish_reason = "length"
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
if DEBUG >= 2: print(f"Checking if end of tokens result {token=} is {eos_token_id=}")
Expand All @@ -426,9 +439,6 @@ def on_result(_request_id: str, token: int, is_finished: bool):
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
finally:
deregistered_callback = self.node.on_token.deregister(callback_id)
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")

async def handle_delete_model(self, request):
try:
Expand Down Expand Up @@ -566,6 +576,9 @@ async def handle_get_topology(self, request):
status=500
)

async def handle_token(self, request_id: str, token: int, is_finished: bool):
await self.token_queues[request_id].put((token, is_finished))

async def run(self, host: str = "0.0.0.0", port: int = 52415):
runner = web.AppRunner(self.app)
await runner.setup()
Expand Down
7 changes: 7 additions & 0 deletions exo/inference/mlx/perf_improvements.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Perf improvements

Target: 460 tok/sec
- removing sample goes from 369 -> 402
- performance degrades as we generate more tokens
- make mlx inference engien synchronous, removing thread pool executor: 402 -> 413
- remove self.on_opaque_status.trigger_all: 413 -> 418
49 changes: 21 additions & 28 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.sample_utils import top_p_sampling
from mlx_lm.sample_utils import top_p_sampling, make_sampler
import mlx.optimizers as optim
from ..inference_engine import InferenceEngine
from .sharded_utils import load_shard, get_image_from_str
Expand All @@ -10,8 +10,6 @@
from typing import Dict, Optional, Tuple
from exo.download.shard_download import ShardDownloader
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from collections import OrderedDict
from mlx_lm.models.cache import make_prompt_cache

Expand Down Expand Up @@ -40,61 +38,60 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
self.shard = None
self.shard_downloader = shard_downloader
self.executor = ThreadPoolExecutor(max_workers=1)
self.caches = OrderedDict()
self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
self.sampler = make_sampler(*self.sampler_params)

async def poll_state(self, request_id: str, max_caches=2):
if request_id in self.caches:
self.caches.move_to_end(request_id)
else:
newcache = await asyncio.get_running_loop().run_in_executor(self.executor, make_prompt_cache, self.model)
newcache = make_prompt_cache(self.model)
if len(self.caches) > max_caches:
self.caches.popitem(last=False)
self.caches[request_id] = newcache
return {"cache": self.caches[request_id]}

async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
y = mx.array(x)
logits = y[:, -1, :]
out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
return out
async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
if (temp, top_p, 0.0, 1) != self.sampler_params:
self.sampler_params = (temp, top_p, 0.0, 1)
self.sampler = make_sampler(*self.sampler_params)
logits = mx.array(x)
logits = logits[:, -1, :]
logprobs = logits - mx.logsumexp(logits, keepdims=True)
return np.asarray(self.sampler(logprobs), dtype=int)

async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
await self.ensure_shard(shard)
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
return np.array(tokens)
tokens = self.tokenizer.encode(prompt)
return np.asarray(tokens)

async def decode(self, shard: Shard, tokens) -> str:
await self.ensure_shard(shard)
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
return tokens
return self.tokenizer.decode(tokens)

async def save_checkpoint(self, shard: Shard, path: str):
await self.ensure_shard(shard)
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.save_weights, path)
self.model.save_weights(path)

async def load_checkpoint(self, shard: Shard, path: str):
await self.ensure_shard(shard)
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
self.model.load_weights(path)

async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
await self.ensure_shard(shard)
loop = asyncio.get_running_loop()
state = await self.poll_state(request_id)
x = mx.array(input_data)
output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, **state)))
output_data = np.array(self.model(x, **state), copy=False)
return output_data

async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
await self.ensure_shard(shard)
await self.save_session('loss', loss_fns[loss])
loop = asyncio.get_running_loop()
#print(f"evaluate in <- {inputs}")
x = mx.array(inputs)
y = mx.array(targets)
l = mx.array(lengths)
score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
#print(f"evaluate out -> {score}")
score = self.session['loss'](self.model, x, y, l)
return score

async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
Expand Down Expand Up @@ -130,7 +127,7 @@ def train_step(inp, tar, lng):
layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
#print(layers[0])

return score, np.array(layers[0]['input_layernorm'])
return score, np.array(layers[0]['input_layernorm'], copy=False)

async def ensure_shard(self, shard: Shard):
if self.shard == shard:
Expand All @@ -139,11 +136,7 @@ async def ensure_shard(self, shard: Shard):
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)

if self.shard != shard:

def load_shard_wrapper():
return asyncio.run(load_shard(model_path, shard))

model_shard, self.tokenizer = await asyncio.get_running_loop().run_in_executor(self.executor, load_shard_wrapper)
model_shard, self.tokenizer = await load_shard(model_path, shard)
self.shard = shard
self.model = model_shard
self.caches = OrderedDict()
Expand Down
12 changes: 8 additions & 4 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
default_model=args.default_model
)
node.on_token.register("update_topology_viz").on_next(
lambda req_id, token, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode([token])) if topology_viz and hasattr(inference_engine, "tokenizer") else None
)
# node.on_token.register("update_topology_viz").on_next(
# lambda req_id, token, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode([token])) if topology_viz and hasattr(inference_engine, "tokenizer") else None
# )

def preemptively_start_download(request_id: str, opaque_status: str):
try:
Expand Down Expand Up @@ -200,7 +200,11 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
print(f"Processing prompt: {prompt}")
await node.process_prompt(shard, prompt, request_id=request_id)

_, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
tokens = []
def on_token(_request_id, _token, _is_finished):
tokens.append(_token)
return _request_id == request_id and _is_finished
await callback.wait(on_token, timeout=300)

print("\nGenerated response:")
print(tokenizer.decode(tokens))
Expand Down
18 changes: 4 additions & 14 deletions exo/networking/grpc/grpc_peer_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def health_check(self) -> bool:
traceback.print_exc()
return False

async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> None:
request = node_service_pb2.PromptRequest(
prompt=prompt,
shard=node_service_pb2.Shard(
Expand All @@ -82,14 +82,9 @@ async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str]
),
request_id=request_id,
)
response = await self.stub.SendPrompt(request)
await self.stub.SendPrompt(request)

if not response.tensor_data or not response.shape or not response.dtype:
return None

return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)

async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> None:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
Expand All @@ -100,12 +95,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Option
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
request_id=request_id,
)
response = await self.stub.SendTensor(request)

if not response.tensor_data or not response.shape or not response.dtype:
return None

return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
await self.stub.SendTensor(request)

async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.ExampleRequest(
Expand Down
15 changes: 6 additions & 9 deletions exo/networking/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ async def SendPrompt(self, request, context):
)
prompt = request.prompt
request_id = request.request_id
result = await self.node.process_prompt(shard, prompt, request_id)
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
await self.node.process_prompt(shard, prompt, request_id)
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=}")
return node_service_pb2.Empty()

async def SendTensor(self, request, context):
shard = Shard(
Expand All @@ -64,11 +63,9 @@ async def SendTensor(self, request, context):
)
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
request_id = request.request_id

result = await self.node.process_tensor(shard, tensor, request_id)
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
await self.node.process_tensor(shard, tensor, request_id)
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=}")
return node_service_pb2.Empty()

async def SendExample(self, request, context):
shard = Shard(
Expand Down
4 changes: 2 additions & 2 deletions exo/networking/grpc/node_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ syntax = "proto3";
package node_service;

service NodeService {
rpc SendPrompt (PromptRequest) returns (Tensor) {}
rpc SendTensor (TensorRequest) returns (Tensor) {}
rpc SendPrompt (PromptRequest) returns (Empty) {}
rpc SendTensor (TensorRequest) returns (Empty) {}
rpc SendExample (ExampleRequest) returns (Loss) {}
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
rpc SendNewToken (SendNewTokenRequest) returns (Empty) {}
Expand Down
Loading

0 comments on commit f6765d7

Please sign in to comment.