Skip to content

Commit

Permalink
Control the concurrent number of requests in codegen acc test. (#169)
Browse files Browse the repository at this point in the history
* Control the concurrent number of requests
 in codegen acc test.

Signed-off-by: Yao, Qing <[email protected]>
(cherry picked from commit f0ee48e)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
yao531441 and pre-commit-ci[bot] authored Oct 21, 2024
1 parent ffa65dc commit 84e077e
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions evals/evaluation/bigcode_evaluation_harness/api_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,28 +79,29 @@ def parallel_generations_by_api(
if codegen_url := args.codegen_url:
assert "/codegen" in codegen_url, "Only OPEA codegen compatible APIs are supported"
import asyncio
import os

import requests
from tqdm.asyncio import tqdm

async def get_res(prompt):
headers = {"Content-Type": "application/json"}
data = {
"messages": prompt,
"max_tokens": 2048,
"stream": False,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k,
}
async with aiohttp.ClientSession() as session:
async with session.post(codegen_url, json=data, headers=headers, timeout=600) as response:
text = await response.text()
return text
async def get_res(prompt, semaphore):
async with semaphore:
headers = {"Content-Type": "application/json"}
data = {
"messages": prompt,
"max_tokens": 2048,
"stream": False,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k,
}
async with aiohttp.ClientSession() as session:
async with session.post(codegen_url, json=data, headers=headers, timeout=600) as response:
text = await response.text()
return text

prompts = [task.get_prompt(doc) for doc in dataset]
awaitables = [get_res(prompt=prompt) for prompt in prompts]
semaphore = asyncio.Semaphore(20)
awaitables = [get_res(prompt=prompt, semaphore=semaphore) for prompt in prompts]

responses = asyncio.run(tqdm.gather(*awaitables))
generations = []
for i, (prompt, response) in enumerate(zip(prompts, responses)):
Expand Down

0 comments on commit 84e077e

Please sign in to comment.