Skip to content

Commit

Permalink
keep parameters for sagemaker
Browse files Browse the repository at this point in the history
  • Loading branch information
HuXiangkun committed Sep 18, 2024
1 parent c9d6f46 commit e358997
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "refchecker"
version = "0.2.11"
version = "0.2.12"
description = "RefChecker provides automatic checking pipeline for detecting fine-grained hallucinations generated by Large Language Models."
authors = [
"Xiangkun Hu <[email protected]>",
Expand Down
9 changes: 9 additions & 0 deletions refchecker/checker/checker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def check(
merge_psg: bool = True,
is_joint: bool = False,
joint_check_num: int = 5,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
Expand Down Expand Up @@ -95,6 +98,9 @@ def check(
questions=batch_questions,
is_joint=True,
joint_check_num=joint_check_num,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)
Expand Down Expand Up @@ -135,6 +141,9 @@ def check(
responses=[inp[2] for inp in input_flattened],
questions=[inp[3] for inp in input_flattened],
is_joint=False,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
)

Expand Down
9 changes: 9 additions & 0 deletions refchecker/checker/llm_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def _check(
questions: List[str] = None,
is_joint: bool = False,
joint_check_num: int = 5,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
Expand Down Expand Up @@ -125,6 +128,9 @@ def _check(
model=self.model,
max_new_tokens=joint_check_num * 10 + 100,
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)
Expand Down Expand Up @@ -204,6 +210,9 @@ def _check(
model=self.model,
max_new_tokens=10,
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)
Expand Down
15 changes: 15 additions & 0 deletions refchecker/extractor/extractor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def extract(
batch_responses,
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
Expand All @@ -24,6 +27,9 @@ def extract(
batch_responses=batch_responses,
batch_questions=batch_questions,
max_new_tokens=max_new_tokens,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)
Expand All @@ -32,6 +38,9 @@ def extract(
batch_responses=batch_responses,
batch_questions=batch_questions,
max_new_tokens=max_new_tokens,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)
Expand All @@ -42,6 +51,9 @@ def extract_claim_triplets(
batch_responses,
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
Expand All @@ -52,6 +64,9 @@ def extract_subsentence_claims(
batch_responses,
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
Expand Down
12 changes: 12 additions & 0 deletions refchecker/extractor/llm_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def extract_subsentence_claims(
batch_responses,
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
Expand Down Expand Up @@ -73,6 +76,9 @@ def extract_subsentence_claims(
n_choices=1,
max_new_tokens=max_new_tokens,
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)
Expand All @@ -99,6 +105,9 @@ def extract_claim_triplets(
batch_responses,
batch_questions=None,
max_new_tokens=500,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
Expand Down Expand Up @@ -145,6 +154,9 @@ def extract_claim_triplets(
n_choices=1,
max_new_tokens=max_new_tokens,
api_base=self.api_base,
sagemaker_client=sagemaker_client,
sagemaker_params=sagemaker_params,
sagemaker_get_response_func=sagemaker_get_response_func,
custom_llm_api_func=custom_llm_api_func,
**kwargs
)
Expand Down
34 changes: 33 additions & 1 deletion refchecker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def get_model_batch_response(
n_choices=1,
max_new_tokens=500,
api_base=None,
sagemaker_client=None,
sagemaker_params=None,
sagemaker_get_response_func=None,
custom_llm_api_func=None,
**kwargs
):
Expand Down Expand Up @@ -98,7 +101,36 @@ def get_model_batch_response(
if not prompts or len(prompts) == 0:
raise ValueError("Invalid input.")

if custom_llm_api_func is not None:
if sagemaker_client is not None:
parameters = {
"max_new_tokens": max_new_tokens,
"temperature": temperature
}
if sagemaker_params is not None:
for k, v in sagemaker_params.items():
if k in parameters:
parameters[k] = v
response_list = []
for prompt in prompts:
r = sagemaker_client.invoke_endpoint(
EndpointName=model,
Body=json.dumps(
{
"inputs": prompt,
"parameters": parameters,
}
),
ContentType="application/json",
)
if sagemaker_get_response_func is not None:
response = sagemaker_get_response_func(r)
else:
r = json.loads(r['Body'].read().decode('utf8'))
response = r['outputs'][0]
response_list.append(response)
return response_list

elif custom_llm_api_func is not None:
return custom_llm_api_func(prompts)
else:
message_list = []
Expand Down

0 comments on commit e358997

Please sign in to comment.