Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sglang example #92

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions models/model_upload/llms/sglang-llama-3_2-1b-instruct/1/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import os
import subprocess
import sys
import threading
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These dependencies are using in model.py, and can be removed

from typing import Iterator

from clarifai.runners.models.model_runner import ModelRunner
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
from google.protobuf import json_format

import sglang as sgl
from transformers import AutoTokenizer

class MyRunner(ModelRunner):
"""A custom runner that loads the model and generates text using sglang inference.
"""

def load_model(self):
"""Load the model here """
os.path.join(os.path.dirname(__file__))
# if checkpoints section is in config.yaml file then checkpoints will be downloaded at this path during model upload time.
checkpoints = os.path.join(os.path.dirname(__file__), "checkpoints")
self.pipe = sgl.Engine(model_path=checkpoints)
self.tokenizer = AutoTokenizer.from_pretrained(checkpoints)

def predict(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""This is the method that will be called when the runner is run. It takes in an input and
returns an output.
"""

# TODO: Could cache the model and this conversion if the hash is the same.
model = request.model
output_info = None
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True)

outputs = []
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
output = resources_pb2.Output()

data = inp.data

# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]

temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 256)
top_p = inference_params.get("top_p", .9)

if data.text.raw != "":
prompt = data.text.raw
messages = [{"role": "user", "content": prompt}]
gen_config = dict(temperature=temperature,
max_new_tokens=max_tokens,
top_p=top_p)
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
res = self.pipe.generate(prompt, gen_config)
text = res["text"].replace("<|start_header_id|>assistant<|end_header_id|>", "")
text = text.replace("\n\n", "")
output.data.text.raw = text

output.status.code = status_code_pb2.SUCCESS
outputs.append(output)
return service_pb2.MultiOutputResponse(outputs=outputs,)

def generate(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""Example yielding a whole batch of streamed stuff back."""

# TODO: Could cache the model and this conversion if the hash is the same.
model = request.model
output_info = None
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True)

# TODO: Could cache the model and this conversion if the hash is the same.
model = request.model
output_info = None
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True)

# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
output = resources_pb2.Output()

data = inp.data

# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]

messages = []
temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 256)
top_p = inference_params.get("top_p", .9)

if data.text.raw != "":
prompt = data.text.raw
messages.append({"role": "user", "content": prompt})
kwargs = dict(
temperature=temperature,
max_new_tokens=max_tokens,
top_p=top_p,
)
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
for item in self.pipe.generate(prompt, kwargs, stream=True):
text = item['text'].replace(
"<|start_header_id|>assistant<|end_header_id|>", "")
text = text.replace("\n\n", "")
output.data.text.raw = text
output.status.code = status_code_pb2.SUCCESS
yield service_pb2.MultiOutputResponse(outputs=[output],)


def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
"""Example yielding a whole batch of streamed stuff back."""

for ri, request in enumerate(request_iterator):
output_info = None
if ri == 0: # only first request has model information.
model = request.model
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True)
# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
output = resources_pb2.Output()

data = inp.data

system_prompt = "You are a helpful assistant"

messages = [{"role": "system", "content": system_prompt}]
temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 100)
top_p = inference_params.get("top_p", 1.0)

if data.text.raw != "":
prompt = data.text.raw
messages.append({"role": "user", "content": prompt})
kwargs = dict(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True,
)
stream = self.client.chat.completions.create(**kwargs)
for chunk in stream:
if chunk.choices[0].delta.content is None:
continue
output.data.text.raw = chunk.choices[0].delta.content
output.status.code = status_code_pb2.SUCCESS
yield service_pb2.MultiOutputResponse(outputs=[output],)
22 changes: 22 additions & 0 deletions models/model_upload/llms/sglang-llama-3_2-1b-instruct/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Config file for the VLLM runner

model:
id: "sglang-llama3_2-1b-instruct"
user_id: ""
app_id: ""
model_type_id: "text-to-text"

build_info:
python_version: "3.10"

inference_compute_info:
cpu_limit: "4"
cpu_memory: "24Gi"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to reduce cpu_memory because max 16Gi is available

num_accelerators: 1
accelerator_type: ["NVIDIA-A10G"]
accelerator_memory: "24Gi"

checkpoints:
type: "huggingface"
repo_id: "meta-llama/Llama-3.2-1B-Instruct"
hf_token: "your token"
Copy link
Contributor

@luv-bansal luv-bansal Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used below requirements with dependencies versions to test locally and it worked. I think it's better to include requirements with it's versions here, because before I don't know why but I was getting error when I didn't specify dependencies versions

torch==2.4.0
tokenizers==0.20.2
transformers==4.46.2
accelerate==0.34.2
scipy==1.10.1
optimum==1.23.3
xformers==0.0.27.post2
einops==0.8.0
requests==2.32.2
packaging
ninja
protobuf==3.20.0

sglang[all]==0.3.5.post2
orjson==3.10.11
python-multipart==0.0.17

--extra-index-url https://flashinfer.ai/whl/cu121/torch2.4/
flashinfer
``

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
torch==2.4.0
tokenizers==0.20.2
transformers==4.46.2
accelerate==0.34.2
scipy==1.10.1
optimum==1.23.3
xformers==0.0.27.post2
einops==0.8.0
requests==2.32.2
packaging
ninja
protobuf==3.20.0

sglang[all]==0.3.5.post2
orjson==3.10.11
python-multipart==0.0.17

--extra-index-url https://flashinfer.ai/whl/cu121/torch2.4/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In prod and deb we have cuda 12.4, I'm not sure if it works with this cu121, need to be verified

Copy link
Contributor

@luv-bansal luv-bansal Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I tested on q22, which also has cuda 12.4 where prediction is successful but don't thing this will be a issue

flashinfer