From 32cf973823ef2ce39bd1164fabab8969662d56d0 Mon Sep 17 00:00:00 2001 From: "Yao, Qing" Date: Mon, 21 Oct 2024 11:10:07 +0800 Subject: [PATCH] Control the concurrent number of requests in codegen acc test. Signed-off-by: Yao, Qing (cherry picked from commit f0ee48ec57783f0e7156f4d43294d1486ee1e5b3) --- .../api_evaluator.py | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/evals/evaluation/bigcode_evaluation_harness/api_evaluator.py b/evals/evaluation/bigcode_evaluation_harness/api_evaluator.py index b6faa5b1..5cf2820e 100644 --- a/evals/evaluation/bigcode_evaluation_harness/api_evaluator.py +++ b/evals/evaluation/bigcode_evaluation_harness/api_evaluator.py @@ -79,28 +79,28 @@ def parallel_generations_by_api( if codegen_url := args.codegen_url: assert "/codegen" in codegen_url, "Only OPEA codegen compatible APIs are supported" import asyncio - import os - - import requests from tqdm.asyncio import tqdm - async def get_res(prompt): - headers = {"Content-Type": "application/json"} - data = { - "messages": prompt, - "max_tokens": 2048, - "stream": False, - "temperature": args.temperature, - "top_p": args.top_p, - "top_k": args.top_k, - } - async with aiohttp.ClientSession() as session: - async with session.post(codegen_url, json=data, headers=headers, timeout=600) as response: - text = await response.text() - return text + async def get_res(prompt, semaphore): + async with semaphore: + headers = {"Content-Type": "application/json"} + data = { + "messages": prompt, + "max_tokens": 2048, + "stream": False, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": args.top_k, + } + async with aiohttp.ClientSession() as session: + async with session.post(codegen_url, json=data, headers=headers, timeout=600) as response: + text = await response.text() + return text prompts = [task.get_prompt(doc) for doc in dataset] - awaitables = [get_res(prompt=prompt) for prompt in prompts] + semaphore = asyncio.Semaphore(20) + awaitables = [get_res(prompt=prompt, semaphore=semaphore) for prompt in prompts] + responses = asyncio.run(tqdm.gather(*awaitables)) generations = [] for i, (prompt, response) in enumerate(zip(prompts, responses)):