-
Notifications
You must be signed in to change notification settings - Fork 1
/
1_gen_origin_response.py
38 lines (30 loc) · 1.46 KB
/
1_gen_origin_response.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from easyjailbreak.selector.RandomSelector import RandomSelectPolicy
from easyjailbreak.datasets import JailbreakDataset, Instance
from easyjailbreak.seed import SeedTemplate
from easyjailbreak.mutation.rule import Translate
from easyjailbreak.models import from_pretrained
from sentence_transformers import SentenceTransformer, util
from easyjailbreak.metrics.Evaluator.Evaluator_ClassificationGetScore import EvaluatorClassificationGetScore
import torch
import json
from transformers import HfArgumentParser, AutoTokenizer
from tools.inference_models import GenerationArguments, get_inference_model
from tools.generate_response import MyGenerationArguments, generate_response, generate_response_light
prompt_list = list()
json_list = list()
with open('./data/0_original_prompt.json', 'r', encoding='utf-8') as file:
for line in file:
if line.strip() == "":
continue
data = json.loads(line.strip())
json_list.append(data)
prompt_list.append(data["prompt_in_en"])
parser = HfArgumentParser(MyGenerationArguments)
generation_args = parser.parse_args_into_dataclasses()[0]
model = get_inference_model(generation_args)
response_list = generate_response_light(model, prompt_list)
with open('./data/1_original_prompt_with_response.json', 'a') as json_file:
for i in range(len(response_list)):
json_list[i]["response"] = response_list[i]
json.dump(json_list[i], json_file, ensure_ascii=False)
json_file.write('\n')