forked from facebookresearch/FiD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_passage_embeddings.py
121 lines (94 loc) · 4.57 KB
/
generate_passage_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import argparse
import csv
import logging
import pickle
from pathlib import Path
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer
import src.model
import src.data
import src.util
import src.slurm
logger = logging.getLogger(__name__)
def embed_passages(opt, passages, model, tokenizer):
batch_size = opt.per_gpu_batch_size * opt.world_size
collator = src.data.TextCollator(tokenizer, args.passage_maxlength)
dataset = src.data.TextDataset(passages, title_prefix='title:', passage_prefix='context:')
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=False, num_workers=10, collate_fn=collator)
total = 0
allids, allembeddings = [], []
with torch.no_grad():
for k, (ids, text_ids, text_mask) in enumerate(tqdm(dataloader)):
embeddings = model.embed_text(
text_ids=text_ids.cuda(),
text_mask=text_mask.cuda(),
apply_mask=model.config.apply_passage_mask,
extract_cls=model.config.extract_cls,
)
embeddings = embeddings.cpu()
total += len(ids)
allids.append(ids)
allembeddings.append(embeddings)
if k % 10000 == 0:
logger.info('Encoded passages %d', total)
allembeddings = torch.cat(allembeddings, dim=0).numpy()
allids = [x for idlist in allids for x in idlist]
return allids, allembeddings
def main(opt):
logger = src.util.init_logger(is_main=True)
tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
model_class = src.model.Retriever
#model, _, _, _, _, _ = src.util.load(model_class, opt.model_path, opt)
config = src.model.RetrieverConfig.from_pretrained(pretrained_model_name_or_path='facebook/contriever')
#model = model_class.from_pretrained(opt.model_path)
if opt.model_path is None:
logging.info("Loading pretrained Contriever")
model = model_class(config, initialize_wContriever=True)
model = model.to(opt.device)
else:
model = model_class.from_pretrained(opt.model_path)
#model, optimizer, scheduler, opt_checkpoint, global_step, best_eval_loss = \
# src.util.load(model_class, opt.model_path, opt, reset_params=False)
logger.info(f"Model loaded from {opt.model_path}")
model.eval()
model = model.to(opt.device)
if not opt.no_fp16:
model = model.half()
passages = src.util.load_passages(args.passages)
shard_size = len(passages) // args.num_shards
start_idx = args.shard_id * shard_size
end_idx = start_idx + shard_size
if args.shard_id == args.num_shards-1:
end_idx = len(passages)
passages = passages[start_idx:end_idx]
logger.info(f'Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}')
allids, allembeddings = embed_passages(opt, passages, model, tokenizer)
output_path = Path(args.output_path)
save_file = output_path.parent / (output_path.name + f'_{args.shard_id:02d}')
output_path.parent.mkdir(parents=True, exist_ok=True)
logger.info(f'Saving {len(allids)} passage embeddings to {save_file}')
with open(save_file, mode='wb') as f:
pickle.dump((allids, allembeddings), f, protocol=4)
logger.info(f'Total passages processed {len(allids)}. Written to {save_file}.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--passages', type=str, default=None, help='Path to passages (.tsv file)')
parser.add_argument('--output_path', type=str, default='wikipedia_embeddings/passages', help='prefix path to save embeddings')
parser.add_argument('--shard_id', type=int, default=0, help="Id of the current shard")
parser.add_argument('--num_shards', type=int, default=1, help="Total number of shards")
parser.add_argument('--per_gpu_batch_size', type=int, default=32, help="Batch size for the passage encoder forward pass")
parser.add_argument('--passage_maxlength', type=int, default=200, help="Maximum number of tokens in a passage")
parser.add_argument('--model_path', type=str, help="path to directory containing model weights and config file")
parser.add_argument('--no_fp16', action='store_true', help="inference in fp32")
args = parser.parse_args()
src.slurm.init_distributed_mode(args)
main(args)