-
Notifications
You must be signed in to change notification settings - Fork 0
/
openai_utils.py
218 lines (203 loc) · 7.97 KB
/
openai_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""Tools to generate from OpenAI prompts."""
import asyncio
import logging
import os
from typing import Any
import aiolimiter
import openai
import openai.error
from aiohttp import ClientSession
from tqdm.asyncio import tqdm_asyncio
from zeno_build.models import lm_config
from zeno_build.prompts import chat_prompt
async def _throttled_openai_completion_acreate(
engine: str,
prompt: str,
temperature: float,
max_tokens: int,
top_p: float,
limiter: aiolimiter.AsyncLimiter,
) -> dict[str, Any]:
async with limiter:
for _ in range(3):
try:
return await openai.Completion.acreate(
engine=engine,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
except openai.error.RateLimitError:
logging.warning(
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
)
await asyncio.sleep(10)
except asyncio.exceptions.TimeoutError:
logging.warning("OpenAI API timeout. Sleeping for 10 seconds.")
await asyncio.sleep(10)
except openai.error.InvalidRequestError:
logging.warning("OpenAI API Invalid Request: Prompt was filtered")
return {
"choices": [
{"message": {"content": "Invalid Request: Prompt was filtered"}}
]
}
except openai.error.APIConnectionError:
logging.warning(
"OpenAI API Connection Error: Error Communicating with OpenAI"
)
await asyncio.sleep(10)
except openai.error.Timeout:
logging.warning("OpenAI APITimeout Error: OpenAI Timeout")
await asyncio.sleep(10)
except openai.error.ServiceUnavailableError as e:
logging.warning(f"OpenAI service unavailable error: {e}")
await asyncio.sleep(10)
except openai.error.APIError as e:
logging.warning(f"OpenAI API error: {e}")
await asyncio.sleep(10)
return {"choices": [{"message": {"content": ""}}]}
async def generate_from_openai_completion(
full_contexts: list[chat_prompt.ChatMessages],
prompt_template: chat_prompt.ChatMessages,
model_config: lm_config.LMConfig,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
requests_per_minute: int = 150,
) -> list[str]:
"""Generate from OpenAI Completion API.
Args:
full_contexts: List of full contexts to generate from.
prompt_template: Prompt template to use.
model_config: Model configuration.
temperature: Temperature to use.
max_tokens: Maximum number of tokens to generate.
top_p: Top p to use.
context_length: Length of context to use.
requests_per_minute: Number of requests per minute to allow.
Returns:
List of generated responses.
"""
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
openai.api_key = os.environ["OPENAI_API_KEY"]
openai.aiosession.set(ClientSession())
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
async_responses = [
_throttled_openai_completion_acreate(
engine=model_config.model,
prompt=prompt_template.to_text_prompt(
full_context=full_context.limit_length(context_length),
name_replacements=model_config.name_replacements,
),
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
limiter=limiter,
)
for full_context in full_contexts
]
responses = await tqdm_asyncio.gather(*async_responses)
# Note: will never be none because it's set, but mypy doesn't know that.
await openai.aiosession.get().close() # type: ignore
return [x["choices"][0]["text"] for x in responses]
async def _throttled_openai_chat_completion_acreate(
model: str,
messages: list[dict[str, str]],
temperature: float,
max_tokens: int,
top_p: float,
limiter: aiolimiter.AsyncLimiter,
) -> dict[str, Any]:
async with limiter:
for _ in range(3):
try:
return await openai.ChatCompletion.acreate(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
except openai.error.RateLimitError:
logging.warning(
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
)
await asyncio.sleep(10)
except asyncio.exceptions.TimeoutError:
logging.warning("OpenAI API timeout. Sleeping for 10 seconds.")
await asyncio.sleep(10)
except openai.error.InvalidRequestError:
logging.warning("OpenAI API Invalid Request: Prompt was filtered")
return {
"choices": [
{"message": {"content": "Invalid Request: Prompt was filtered"}}
]
}
except openai.error.APIConnectionError:
logging.warning(
"OpenAI API Connection Error: Error Communicating with OpenAI"
)
await asyncio.sleep(10)
except openai.error.Timeout:
logging.warning("OpenAI APITimeout Error: OpenAI Timeout")
await asyncio.sleep(10)
except openai.error.ServiceUnavailableError as e:
logging.warning(f"OpenAI service unavailable error: {e}")
await asyncio.sleep(10)
except openai.error.APIError as e:
logging.warning(f"OpenAI API error: {e}")
await asyncio.sleep(10)
return {"choices": [{"message": {"content": ""}}]}
async def generate_from_openai_chat_completion(
full_contexts: list[chat_prompt.ChatMessages],
prompt_template: chat_prompt.ChatMessages,
model_config: lm_config.LMConfig,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
requests_per_minute: int = 150,
) -> list[str]:
"""Generate from OpenAI Chat Completion API.
Args:
full_contexts: List of full contexts to generate from.
prompt_template: Prompt template to use.
model_config: Model configuration.
temperature: Temperature to use.
max_tokens: Maximum number of tokens to generate.
top_p: Top p to use.
context_length: Length of context to use.
requests_per_minute: Number of requests per minute to allow.
Returns:
List of generated responses.
"""
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
openai.api_key = os.environ["OPENAI_API_KEY"]
openai.aiosession.set(ClientSession())
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
async_responses = [
_throttled_openai_chat_completion_acreate(
model=model_config.model,
messages=prompt_template.to_openai_chat_completion_messages(
full_context=full_context.limit_length(context_length),
),
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
limiter=limiter,
)
for full_context in full_contexts
]
responses = await tqdm_asyncio.gather(*async_responses)
# Note: will never be none because it's set, but mypy doesn't know that.
await openai.aiosession.get().close() # type: ignore
return [x["choices"][0]["message"]["content"] for x in responses]