Skip to content

Commit

Permalink
Serving LLMs
Browse files Browse the repository at this point in the history
  • Loading branch information
truskovskiyk committed Sep 18, 2024
1 parent 89daf7b commit 27498e3
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 12 deletions.
5 changes: 5 additions & 0 deletions module-3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,8 @@ python generative-api/pipeline_phi3.py ./data/test.json

- https://github.com/microsoft/nni
- https://github.com/autogluon/autogluon


## Updated design doc

[Google doc](https://docs.google.com/document/d/1vkjE5QohSkxkcWCWahciqR43K4RjCjXMpixx3hoYjXo/edit?usp=sharing)
2 changes: 1 addition & 1 deletion module-3/generative-example/run_training_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"WANDB_PROJECT": os.getenv("WANDB_PROJECT"),
"WANDB_API_KEY": os.getenv("WANDB_API_KEY"),
}
custom_image = Image.from_registry("ghcr.io/kyryl-opens-ml/generative-example:pr-11").env(env)
custom_image = Image.from_registry("ghcr.io/kyryl-opens-ml/generative-example:main").env(env)


@app.function(image=custom_image, gpu="A100", timeout=10 * 60 * 60)
Expand Down
7 changes: 6 additions & 1 deletion module-4/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,9 @@ dagster dev -f dagster_pipelines/text2sql_pipeline.py -p 3000 -h 0.0.0.0
### References:

- [Introducing Asset Checks](https://dagster.io/blog/dagster-asset-checks)
- [Anomaly Detection](https://dagster.io/glossary/anomaly-detection)
- [Anomaly Detection](https://dagster.io/glossary/anomaly-detection)


## Updated design doc

[Google doc](https://docs.google.com/document/d/1j9-RFCrLRQy54TsywHxvje56EuntAbUbSlw_POsWl5Q/edit?usp=sharing)
31 changes: 21 additions & 10 deletions module-5/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,6 @@ make run_pytriton
```


# LLMs


- https://github.com/vllm-project/vllm
- https://github.com/huggingface/text-generation-inference
- https://github.com/predibase/lorax
- https://github.com/triton-inference-server/vllm_backend
- https://github.com/ray-project/ray-llm


# KServe

Install KServe
Expand All @@ -115,3 +105,24 @@ Call API
```
curl -v -H "Host: custom-model.default.example.com" -H "Content-Type: application/json" "http://localhost:8080/v1/models/custom-model:predict" -d @data-samples/kserve-input.json
```


# Serving LLMs via vLLM


```
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
vllm serve microsoft/Phi-3-mini-4k-instruct --dtype auto --max-model-len 512 --enable-lora --gpu-memory-utilization 0.8
vllm serve microsoft/Phi-3-mini-4k-instruct --enable-lora \
--lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/
```

## Updated design doc

[Google doc](https://docs.google.com/document/d/1ZCnnsnHHiDkc3FgK2XBVur9W7nkDA7SKoPd1pGa-irQ/edit?usp=sharing)
Empty file.
88 changes: 88 additions & 0 deletions module-5/serving-llm/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from pathlib import Path
import wandb
import requests
import json


BASE_URL = "http://localhost:8000/v1"

def load_from_registry(model_name: str, model_path: Path):
with wandb.init() as run:
artifact = run.use_artifact(model_name, type="model")
artifact_dir = artifact.download(root=model_path)
print(f"{artifact_dir}")


def list_of_models():
url = f"{BASE_URL}/models"
response = requests.get(url)
models = response.json()
print(json.dumps(models, indent=4))

def load_adapter(lora_name: str, lora_path: str):

lora_name = "sql-test"
lora_path = "data/sql-adapter/"

url = f"{BASE_URL}/load_lora_adapter"
payload = {
"lora_name": lora_name,
"lora_path": lora_path
}
response = requests.post(url, json=payload)
print(response)

def unload_adapter(lora_name: str):
url = f"{BASE_URL}/unload_lora_adapter"
payload = {
"lora_name": lora_name
}
headers = {"Content-Type": "application/json"}
response = requests.post(url, headers=headers, json=payload)
result = response.json()
print(json.dumps(result, indent=4))

def test_client(model: str, prompt: str, max_tokens: int = 7, temperature: float = 0.0):
prompt = "test"
max_tokens: int = 7
temperature: float = 0.0
# model = "microsoft/Phi-3-mini-4k-instruct"
model = "sql-test"
url = f"{BASE_URL}/completions"
payload = {
"model": model,
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature
}
response = requests.post(url, json=payload)
completion = response.json()
print(json.dumps(completion, indent=4))

def run_inference_on_json(json_file: Path):
url = f"{BASE_URL}/completions"
with open(json_file, 'r') as f:
payload = json.load(f)
headers = {"Content-Type": "application/json"}
response = requests.post(url, headers=headers, json=payload)
completion = response.json()
print(json.dumps(completion, indent=4))




def cli():
app = typer.Typer()
app.command()(load_from_registry)
app.command()(list_of_models)
app.command()(load_adapter)
app.command()(unload_adapter)
app.command()(test_client)
app.command()(upload_to_registry)
app.command()(run_inference_on_json)
app.command()(run_evaluate_on_json)
app()

if __name__ == "__main__":
cli()

0 comments on commit 27498e3

Please sign in to comment.