-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_sentiment.py
57 lines (42 loc) · 1.61 KB
/
test_sentiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import json
import os
import sklearn.metrics as metrics
import torch
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizerFast
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def create_arg_parser():
parser = argparse.ArgumentParser(
description='Test sentiment classification on Singapore Hansard using XLM-RoBERTa model')
parser.add_argument('json_path', type=str,
help='Path of JSON file containing Singapore Hansard data.')
parser.add_argument('model_name_or_dir', type=str,
help='Name or directory of model.')
return parser
def main(
json_path,
model_name_or_dir):
tokenizer = XLMRobertaTokenizerFast.from_pretrained(model_name_or_dir)
model = XLMRobertaForSequenceClassification.from_pretrained(model_name_or_dir).to(DEVICE)
model.eval()
with open(json_path) as json_file:
data = json.load(json_file)
predictions = []
labels = []
for pair in data:
sentence = pair['text']
label = pair['sentiment']
inputs = tokenizer.encode(
sentence, padding=False, truncation=True, return_tensors='pt').to(DEVICE)
with torch.no_grad():
output = model(inputs).logits
prediction = torch.argmax(output, dim=-1)[0].item()
predictions.append(prediction)
labels.append(label)
print(metrics.classification_report(labels, predictions, digits=6))
if __name__ == '__main__':
parser = create_arg_parser()
args = parser.parse_args()
main(
args.json_path,
args.model_name_or_dir)