diff --git a/pyproject.toml b/pyproject.toml index 343c7d1..73d6952 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "refchecker" -version = "0.2.11" +version = "0.2.12" description = "RefChecker provides automatic checking pipeline for detecting fine-grained hallucinations generated by Large Language Models." authors = [ "Xiangkun Hu ", diff --git a/refchecker/checker/checker_base.py b/refchecker/checker/checker_base.py index 6d7f360..fd8170e 100644 --- a/refchecker/checker/checker_base.py +++ b/refchecker/checker/checker_base.py @@ -49,6 +49,9 @@ def check( merge_psg: bool = True, is_joint: bool = False, joint_check_num: int = 5, + sagemaker_client=None, + sagemaker_params=None, + sagemaker_get_response_func=None, custom_llm_api_func=None, **kwargs ): @@ -95,6 +98,9 @@ def check( questions=batch_questions, is_joint=True, joint_check_num=joint_check_num, + sagemaker_client=sagemaker_client, + sagemaker_params=sagemaker_params, + sagemaker_get_response_func=sagemaker_get_response_func, custom_llm_api_func=custom_llm_api_func, **kwargs ) @@ -135,6 +141,9 @@ def check( responses=[inp[2] for inp in input_flattened], questions=[inp[3] for inp in input_flattened], is_joint=False, + sagemaker_client=sagemaker_client, + sagemaker_params=sagemaker_params, + sagemaker_get_response_func=sagemaker_get_response_func, custom_llm_api_func=custom_llm_api_func, ) diff --git a/refchecker/checker/llm_checker.py b/refchecker/checker/llm_checker.py index 3dd43e2..7fda80b 100644 --- a/refchecker/checker/llm_checker.py +++ b/refchecker/checker/llm_checker.py @@ -46,6 +46,9 @@ def _check( questions: List[str] = None, is_joint: bool = False, joint_check_num: int = 5, + sagemaker_client=None, + sagemaker_params=None, + sagemaker_get_response_func=None, custom_llm_api_func=None, **kwargs ): @@ -125,6 +128,9 @@ def _check( model=self.model, max_new_tokens=joint_check_num * 10 + 100, api_base=self.api_base, + sagemaker_client=sagemaker_client, + sagemaker_params=sagemaker_params, + sagemaker_get_response_func=sagemaker_get_response_func, custom_llm_api_func=custom_llm_api_func, **kwargs ) @@ -204,6 +210,9 @@ def _check( model=self.model, max_new_tokens=10, api_base=self.api_base, + sagemaker_client=sagemaker_client, + sagemaker_params=sagemaker_params, + sagemaker_get_response_func=sagemaker_get_response_func, custom_llm_api_func=custom_llm_api_func, **kwargs ) diff --git a/refchecker/extractor/extractor_base.py b/refchecker/extractor/extractor_base.py index 5c62a65..372811d 100644 --- a/refchecker/extractor/extractor_base.py +++ b/refchecker/extractor/extractor_base.py @@ -16,6 +16,9 @@ def extract( batch_responses, batch_questions=None, max_new_tokens=500, + sagemaker_client=None, + sagemaker_params=None, + sagemaker_get_response_func=None, custom_llm_api_func=None, **kwargs ): @@ -24,6 +27,9 @@ def extract( batch_responses=batch_responses, batch_questions=batch_questions, max_new_tokens=max_new_tokens, + sagemaker_client=sagemaker_client, + sagemaker_params=sagemaker_params, + sagemaker_get_response_func=sagemaker_get_response_func, custom_llm_api_func=custom_llm_api_func, **kwargs ) @@ -32,6 +38,9 @@ def extract( batch_responses=batch_responses, batch_questions=batch_questions, max_new_tokens=max_new_tokens, + sagemaker_client=sagemaker_client, + sagemaker_params=sagemaker_params, + sagemaker_get_response_func=sagemaker_get_response_func, custom_llm_api_func=custom_llm_api_func, **kwargs ) @@ -42,6 +51,9 @@ def extract_claim_triplets( batch_responses, batch_questions=None, max_new_tokens=500, + sagemaker_client=None, + sagemaker_params=None, + sagemaker_get_response_func=None, custom_llm_api_func=None, **kwargs ): @@ -52,6 +64,9 @@ def extract_subsentence_claims( batch_responses, batch_questions=None, max_new_tokens=500, + sagemaker_client=None, + sagemaker_params=None, + sagemaker_get_response_func=None, custom_llm_api_func=None, **kwargs ): diff --git a/refchecker/extractor/llm_extractor.py b/refchecker/extractor/llm_extractor.py index b1f770b..f7b0433 100644 --- a/refchecker/extractor/llm_extractor.py +++ b/refchecker/extractor/llm_extractor.py @@ -30,6 +30,9 @@ def extract_subsentence_claims( batch_responses, batch_questions=None, max_new_tokens=500, + sagemaker_client=None, + sagemaker_params=None, + sagemaker_get_response_func=None, custom_llm_api_func=None, **kwargs ): @@ -73,6 +76,9 @@ def extract_subsentence_claims( n_choices=1, max_new_tokens=max_new_tokens, api_base=self.api_base, + sagemaker_client=sagemaker_client, + sagemaker_params=sagemaker_params, + sagemaker_get_response_func=sagemaker_get_response_func, custom_llm_api_func=custom_llm_api_func, **kwargs ) @@ -99,6 +105,9 @@ def extract_claim_triplets( batch_responses, batch_questions=None, max_new_tokens=500, + sagemaker_client=None, + sagemaker_params=None, + sagemaker_get_response_func=None, custom_llm_api_func=None, **kwargs ): @@ -145,6 +154,9 @@ def extract_claim_triplets( n_choices=1, max_new_tokens=max_new_tokens, api_base=self.api_base, + sagemaker_client=sagemaker_client, + sagemaker_params=sagemaker_params, + sagemaker_get_response_func=sagemaker_get_response_func, custom_llm_api_func=custom_llm_api_func, **kwargs ) diff --git a/refchecker/utils.py b/refchecker/utils.py index 36283c7..3418d83 100644 --- a/refchecker/utils.py +++ b/refchecker/utils.py @@ -70,6 +70,9 @@ def get_model_batch_response( n_choices=1, max_new_tokens=500, api_base=None, + sagemaker_client=None, + sagemaker_params=None, + sagemaker_get_response_func=None, custom_llm_api_func=None, **kwargs ): @@ -98,7 +101,36 @@ def get_model_batch_response( if not prompts or len(prompts) == 0: raise ValueError("Invalid input.") - if custom_llm_api_func is not None: + if sagemaker_client is not None: + parameters = { + "max_new_tokens": max_new_tokens, + "temperature": temperature + } + if sagemaker_params is not None: + for k, v in sagemaker_params.items(): + if k in parameters: + parameters[k] = v + response_list = [] + for prompt in prompts: + r = sagemaker_client.invoke_endpoint( + EndpointName=model, + Body=json.dumps( + { + "inputs": prompt, + "parameters": parameters, + } + ), + ContentType="application/json", + ) + if sagemaker_get_response_func is not None: + response = sagemaker_get_response_func(r) + else: + r = json.loads(r['Body'].read().decode('utf8')) + response = r['outputs'][0] + response_list.append(response) + return response_list + + elif custom_llm_api_func is not None: return custom_llm_api_func(prompts) else: message_list = []