-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sglang example #92
base: main
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import os | ||
import subprocess | ||
import sys | ||
import threading | ||
from typing import Iterator | ||
|
||
from clarifai.runners.models.model_runner import ModelRunner | ||
from clarifai_grpc.grpc.api import resources_pb2, service_pb2 | ||
from clarifai_grpc.grpc.api.status import status_code_pb2 | ||
from google.protobuf import json_format | ||
|
||
import sglang as sgl | ||
from transformers import AutoTokenizer | ||
|
||
class MyRunner(ModelRunner): | ||
"""A custom runner that loads the model and generates text using sglang inference. | ||
""" | ||
|
||
def load_model(self): | ||
"""Load the model here """ | ||
os.path.join(os.path.dirname(__file__)) | ||
# if checkpoints section is in config.yaml file then checkpoints will be downloaded at this path during model upload time. | ||
checkpoints = os.path.join(os.path.dirname(__file__), "checkpoints") | ||
self.pipe = sgl.Engine(model_path=checkpoints) | ||
self.tokenizer = AutoTokenizer.from_pretrained(checkpoints) | ||
|
||
def predict(self, request: service_pb2.PostModelOutputsRequest | ||
) -> Iterator[service_pb2.MultiOutputResponse]: | ||
"""This is the method that will be called when the runner is run. It takes in an input and | ||
returns an output. | ||
""" | ||
|
||
# TODO: Could cache the model and this conversion if the hash is the same. | ||
model = request.model | ||
output_info = None | ||
if request.model.model_version.id != "": | ||
output_info = json_format.MessageToDict( | ||
model.model_version.output_info, preserving_proto_field_name=True) | ||
|
||
outputs = [] | ||
# TODO: parallelize this over inputs in a single request. | ||
for inp in request.inputs: | ||
output = resources_pb2.Output() | ||
|
||
data = inp.data | ||
|
||
# Optional use of output_info | ||
inference_params = {} | ||
if "params" in output_info: | ||
inference_params = output_info["params"] | ||
|
||
temperature = inference_params.get("temperature", 0.7) | ||
max_tokens = inference_params.get("max_tokens", 256) | ||
top_p = inference_params.get("top_p", .9) | ||
|
||
if data.text.raw != "": | ||
prompt = data.text.raw | ||
messages = [{"role": "user", "content": prompt}] | ||
gen_config = dict(temperature=temperature, | ||
max_new_tokens=max_tokens, | ||
top_p=top_p) | ||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) | ||
res = self.pipe.generate(prompt, gen_config) | ||
text = res["text"].replace("<|start_header_id|>assistant<|end_header_id|>", "") | ||
text = text.replace("\n\n", "") | ||
output.data.text.raw = text | ||
|
||
output.status.code = status_code_pb2.SUCCESS | ||
outputs.append(output) | ||
return service_pb2.MultiOutputResponse(outputs=outputs,) | ||
|
||
def generate(self, request: service_pb2.PostModelOutputsRequest | ||
) -> Iterator[service_pb2.MultiOutputResponse]: | ||
"""Example yielding a whole batch of streamed stuff back.""" | ||
|
||
# TODO: Could cache the model and this conversion if the hash is the same. | ||
model = request.model | ||
output_info = None | ||
if request.model.model_version.id != "": | ||
output_info = json_format.MessageToDict( | ||
model.model_version.output_info, preserving_proto_field_name=True) | ||
|
||
# TODO: Could cache the model and this conversion if the hash is the same. | ||
model = request.model | ||
output_info = None | ||
if request.model.model_version.id != "": | ||
output_info = json_format.MessageToDict( | ||
model.model_version.output_info, preserving_proto_field_name=True) | ||
|
||
# TODO: parallelize this over inputs in a single request. | ||
for inp in request.inputs: | ||
output = resources_pb2.Output() | ||
|
||
data = inp.data | ||
|
||
# Optional use of output_info | ||
inference_params = {} | ||
if "params" in output_info: | ||
inference_params = output_info["params"] | ||
|
||
messages = [] | ||
temperature = inference_params.get("temperature", 0.7) | ||
max_tokens = inference_params.get("max_tokens", 256) | ||
top_p = inference_params.get("top_p", .9) | ||
|
||
if data.text.raw != "": | ||
prompt = data.text.raw | ||
messages.append({"role": "user", "content": prompt}) | ||
kwargs = dict( | ||
temperature=temperature, | ||
max_new_tokens=max_tokens, | ||
top_p=top_p, | ||
) | ||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) | ||
for item in self.pipe.generate(prompt, kwargs, stream=True): | ||
text = item['text'].replace( | ||
"<|start_header_id|>assistant<|end_header_id|>", "") | ||
text = text.replace("\n\n", "") | ||
output.data.text.raw = text | ||
output.status.code = status_code_pb2.SUCCESS | ||
yield service_pb2.MultiOutputResponse(outputs=[output],) | ||
|
||
|
||
def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest] | ||
) -> Iterator[service_pb2.MultiOutputResponse]: | ||
"""Example yielding a whole batch of streamed stuff back.""" | ||
|
||
for ri, request in enumerate(request_iterator): | ||
output_info = None | ||
if ri == 0: # only first request has model information. | ||
model = request.model | ||
if request.model.model_version.id != "": | ||
output_info = json_format.MessageToDict( | ||
model.model_version.output_info, preserving_proto_field_name=True) | ||
# Optional use of output_info | ||
inference_params = {} | ||
if "params" in output_info: | ||
inference_params = output_info["params"] | ||
# TODO: parallelize this over inputs in a single request. | ||
for inp in request.inputs: | ||
output = resources_pb2.Output() | ||
|
||
data = inp.data | ||
|
||
system_prompt = "You are a helpful assistant" | ||
|
||
messages = [{"role": "system", "content": system_prompt}] | ||
temperature = inference_params.get("temperature", 0.7) | ||
max_tokens = inference_params.get("max_tokens", 100) | ||
top_p = inference_params.get("top_p", 1.0) | ||
|
||
if data.text.raw != "": | ||
prompt = data.text.raw | ||
messages.append({"role": "user", "content": prompt}) | ||
kwargs = dict( | ||
model=self.model, | ||
messages=messages, | ||
temperature=temperature, | ||
max_tokens=max_tokens, | ||
top_p=top_p, | ||
stream=True, | ||
) | ||
stream = self.client.chat.completions.create(**kwargs) | ||
for chunk in stream: | ||
if chunk.choices[0].delta.content is None: | ||
continue | ||
output.data.text.raw = chunk.choices[0].delta.content | ||
output.status.code = status_code_pb2.SUCCESS | ||
yield service_pb2.MultiOutputResponse(outputs=[output],) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Config file for the VLLM runner | ||
|
||
model: | ||
id: "sglang-llama3_2-1b-instruct" | ||
user_id: "" | ||
app_id: "" | ||
model_type_id: "text-to-text" | ||
|
||
build_info: | ||
python_version: "3.10" | ||
|
||
inference_compute_info: | ||
cpu_limit: "4" | ||
cpu_memory: "24Gi" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to reduce |
||
num_accelerators: 1 | ||
accelerator_type: ["NVIDIA-A10G"] | ||
accelerator_memory: "24Gi" | ||
|
||
checkpoints: | ||
type: "huggingface" | ||
repo_id: "meta-llama/Llama-3.2-1B-Instruct" | ||
hf_token: "your token" |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used below requirements with dependencies versions to test locally and it worked. I think it's better to include requirements with it's versions here, because before I don't know why but I was getting error when I didn't specify dependencies versions
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
torch==2.4.0 | ||
tokenizers==0.20.2 | ||
transformers==4.46.2 | ||
accelerate==0.34.2 | ||
scipy==1.10.1 | ||
optimum==1.23.3 | ||
xformers==0.0.27.post2 | ||
einops==0.8.0 | ||
requests==2.32.2 | ||
packaging | ||
ninja | ||
protobuf==3.20.0 | ||
|
||
sglang[all]==0.3.5.post2 | ||
orjson==3.10.11 | ||
python-multipart==0.0.17 | ||
|
||
--extra-index-url https://flashinfer.ai/whl/cu121/torch2.4/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In prod and deb we have cuda 12.4, I'm not sure if it works with this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But I tested on q22, which also has cuda 12.4 where prediction is successful but don't thing this will be a issue |
||
flashinfer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These dependencies are using in
model.py
, and can be removed