-
Notifications
You must be signed in to change notification settings - Fork 0
/
vllm_intercept.py
61 lines (50 loc) · 1.97 KB
/
vllm_intercept.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# vllm_intercept.py
import openai
import os
# Set the API base to your local server
openai.api_base = os.environ.get("OPENAI_API_BASE", "http://localhost:8000/v1")
# Set the API key to a dummy value (since you're using a local server)
openai.api_key = os.environ.get("OPENAI_API_KEY", "sk-dummy")
# Save the original ChatCompletion.create method
original_chat_completion_create = openai.ChatCompletion.create
def chat_completion_create_intercept(*args, **kwargs):
"""
Intercepts calls to openai.ChatCompletion.create and redirects them to the local vLLM server.
"""
# Map the model name to the local model
model = kwargs.get("model", "")
if model == "gpt-3.5-turbo":
kwargs["model"] = "facebook/opt-125m"
else:
kwargs["model"] = model # Use the specified model
# Convert 'messages' to a single 'prompt'
messages = kwargs.get("messages", [])
prompt = ""
for message in messages:
role = message.get('role', '')
content = message.get('content', '')
if role == 'system':
prompt += f"{content}\n"
elif role == 'user':
prompt += f"User: {content}\n"
elif role == 'assistant':
prompt += f"Assistant: {content}\n"
prompt += "Assistant:"
# Remove 'messages' and add 'prompt'
kwargs.pop('messages', None)
kwargs['prompt'] = prompt
# Call the Completion endpoint instead
response = openai.Completion.create(*args, **kwargs)
# Modify the response to match the ChatCompletion format
for choice in response.get('choices', []):
text = choice.pop('text', '')
choice['message'] = {
'role': 'assistant',
'content': text.strip()
}
# Remove keys that are not part of the ChatCompletion response
choice.pop('logprobs', None)
choice.pop('finish_reason', None)
return response
# Override the ChatCompletion.create method
openai.ChatCompletion.create = chat_completion_create_intercept