Skip to content

Commit

Permalink
修复对线程数为6的限制
Browse files Browse the repository at this point in the history
  • Loading branch information
shell-nlp committed May 19, 2024
1 parent 0d04621 commit f67149b
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 7 deletions.
2 changes: 1 addition & 1 deletion gpt_server/model_worker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def get_worker(
controller_addr: str = "http://localhost:21001",
worker_id: str = str(uuid.uuid4())[:8],
model_names: List[str] = [""],
limit_worker_concurrency: int = 6,
limit_worker_concurrency: int = 100,
conv_template: str = None, # type: ignore
):
worker = cls(
Expand Down
6 changes: 3 additions & 3 deletions gpt_server/model_worker/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,18 @@ async def generate_stream_gate(self, params):
add_generation_prompt=True,
chat_template=self.other_config["chat_template"],
)
logger.info(text)
input_ids = self.tokenizer([text], return_tensors="pt").input_ids
elif model_type == "qwen2":
logger.info("正在使用qwen-2.0 !")
text = self.tokenizer.apply_chat_template(
conversation=messages, tokenize=False, add_generation_prompt=True
)
logger.info(text)
input_ids = self.tokenizer([text], return_tensors="pt").input_ids
prompt = self.tokenizer.decode(input_ids.tolist()[0])
logger.info(prompt)
# ---------------添加额外的参数------------------------
params["messages"] = messages
params["prompt"] = prompt
params["prompt"] = text
params["stop"].extend(self.stop)
params["stop_words_ids"] = self.stop_words_ids
params["input_ids"] = input_ids
Expand Down
12 changes: 12 additions & 0 deletions gpt_server/script/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ models:
workers:
- gpus:
- 1
# - gpus:
# - 3

- llama3: #自定义的模型名称
alias: null # 别名 例如 gpt4,gpt3
Expand Down Expand Up @@ -106,6 +108,16 @@ models:
workers:
- gpus:
- 2
- bge-base-zh:
alias: null # 别名
enable: true # false true
model_name_or_path: /home/dev/model/Xorbits/bge-base-zh-v1___5/
model_type: embedding
work_mode: hf
device: gpu # gpu / cpu
workers:
- gpus:
- 2

- acge_text_embedding:
alias: null # 别名
Expand Down
14 changes: 11 additions & 3 deletions tests/test_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,18 @@ def send_request(results, i, prefill_times):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--server-address", type=str, default="http://localhost:8082")

parser.add_argument("--model-name", type=str, default="qwen")
parser.add_argument("--max-new-tokens", type=int, default=2048)
parser.add_argument("--n-thread", type=int, default=8)
parser.add_argument("--n-thread", type=int, default=20)
parser.add_argument("--test-dispatch", action="store_true")
args = parser.parse_args()

main(args)
threads = []
for i in range(1):
t = threading.Thread(target=main, args=(args,))
t.start()
threads.append(t)
time.sleep(1)
for t in threads:
t.join()
# main(args)

0 comments on commit f67149b

Please sign in to comment.