-
Notifications
You must be signed in to change notification settings - Fork 2
/
construct_graph.py
145 lines (139 loc) · 6.69 KB
/
construct_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import pandas as pd
import pickle
import argparse
from tqdm import tqdm, trange
from torch_geometric.data import Data
import torch
import os
from transformers import XLMRobertaTokenizer,XLMRobertaModel
import numpy as np
def node_id_mapping(path,df_frac=1):
"""
Function to map keywords and titles to node ids
and save the mapping dictionary to node_id_map.pkl
Args:
path:str - Path to the dataset tsv
df_frac:float - Percentage of data to use (for large files)
Returns:
Dictionary of {keyword/title: id}
"""
df = pd.read_csv(path,sep="\t")
df = df.sample(frac=df_frac)
all_nodes = set(df["keyword"].values)
all_nodes.update(set(df["title"].values))
del(df)
node_id_map = dict(zip(all_nodes,range(len(all_nodes))))
pickle.dump(node_id_map, open("node_id_map.pkl","wb"))
del(all_nodes)
return node_id_map
def make_semantic_node_embeddings(path, df_frac=1):
"""
Function to save semantic embeddings of keywords and titles
in batches of 1000 to the directory node_embeds/
Args:
path:str - Path to the dataset tsv
df_frac:float - Percentage of data to use (for large files)
"""
if os.path.isfile('./node_id_map.pkl'):
node_id_map = pickle.load(open("node_id_map.pkl","rb"))
else:
node_id_map = node_id_mapping(path,df_frac)
print("Loading Tokenizer")
model_name = "xlm-roberta-base"
tokenizer = XLMRobertaTokenizer.from_pretrained(model_name, do_lower_case=True)
model = XLMRobertaModel.from_pretrained(model_name)
print("Initialize node_embeds")
node_keys = list(node_id_map.keys())
print("Number of nodes",len(node_keys))
batch_size = 1000
print("Making semantic node embeddings")
for ind in trange(1+len(node_keys)//batch_size):
raw_inputs = node_keys[ind*batch_size:(ind+1)*batch_size]
if len(raw_inputs) == 0: break
inputs = tokenizer(raw_inputs, add_special_tokens=True,padding='max_length',
max_length=24,truncation=True, return_tensors="pt")
outputs = torch.mean(model(**inputs).last_hidden_state,dim=1)
np.save(open(f"node_embeds/node_embed_{ind}","wb"),outputs.detach().numpy())
return
def get_embedding(text, node_id_map, root_node_path="./node_embeds/"):
"""
Function to retrieve the semantic embedding of a
text keyword/title using the node_id_map and saved
semantic embeddings.
Args:
text:str - Text of the title/keyword
node_id_map:dict - Map of the text to the node id
root_node_path:str - Directory where the semantic embeddings are saved
Returns:
feature_embed:np.array - semantic embedding for text
"""
node_id = node_id_map.get(text,0)
feature_embed = np.load(root_node_path+f"node_embed_{node_id//1000}")[node_id%1000]
return feature_embed
def construct_graph(path,df_frac=1, threshold=100):
"""
Function to get the 2-hop neighborhood of all keywords and titles.
The neighborhood is saved in a pickle file: two_hop_ngbrs.pkl
The function gives priority to the 1-hop neighborhood.
Args:
path:str - Path to the dataset tsv
df_frac:float - Percentage of data to use (for large files)
threshold:int - Maximum size of the neighborhood to be considered
"""
df = pd.read_csv(path,sep="\t")
df = df.sample(frac=df_frac,random_state=42)
qa_df = df.groupby('keyword')['title'].apply(list).reset_index(name="neighbours")
query_asin_map = dict(zip(qa_df["keyword"],qa_df.get("neighbours",[])))
aq_df = df.groupby('title')['keyword'].apply(list).reset_index(name="neighbours")
asin_query_map = dict(zip(aq_df["title"],aq_df.get("neighbours",[])))
del(df)
two_hop_ngbrs = {}
node_id_map = pickle.load(open("node_id_map.pkl","rb"))
root_node_path = "./node_embeds/"
print("Constructing query neighborhood graph")
for _,row in tqdm(qa_df.iterrows(),total=len(qa_df)):
r, c, feature = [],[],[]
neighbours = set(list(filter(lambda i:type(i) is str, list(row["neighbours"]))))
feature.append(get_embedding(row["keyword"], node_id_map, root_node_path))
for ngbr in list(neighbours)[:threshold]:
feature.append(get_embedding(ngbr, node_id_map, root_node_path))
if ngbr in asin_query_map:
neighbours.update(set(filter(lambda i:type(i) is str, list(aq_df.get(ngbr,[])))))
neighborhood = list(neighbours)[:threshold]
for _,neighbor in enumerate(neighborhood):
feature.append(get_embedding(neighbor, node_id_map, root_node_path))
r.append(0)
c.append(_+1)
two_hop_ngbrs[row["keyword"]] = Data(x=torch.tensor(feature,dtype=float),edge_index=torch.tensor([r,c],dtype=int))
print("Constructing asin neighborhood graph")
for _,row in tqdm(aq_df.iterrows(),total=len(aq_df)):
r, c, feature = [],[],[]
neighbours = set(list(filter(lambda i:type(i) is str, list(row["neighbours"]))))
feature.append(get_embedding(row["title"], node_id_map, root_node_path))
for ngbr in list(neighbours)[:threshold]:
if ngbr in query_asin_map:
neighbours.update(set(filter(lambda i:type(i) is str, list(aq_df.get(ngbr,[])))))
neighborhood = list(neighbours)[:threshold]
for _,neighbor in enumerate(neighborhood):
feature.append(get_embedding(neighbor, node_id_map, root_node_path))
r.append(0)
c.append(_+1)
two_hop_ngbrs[row["title"]] = Data(x=torch.tensor(feature,dtype=float),edge_index=torch.tensor([r,c],dtype=int))
pickle.dump(two_hop_ngbrs,open("two_hop_ngbrs.pkl","wb"))
if __name__=="__main__":
"""
Main script to be called for the graph neighborhood construction for the SMLM model.
make_semantic_node_embeddings - Creates the semantic embeddings for the nodes.
construct_graph - Constructs the graph neighborhood using the 'Data' class of the torch-geometric library.
Files created:
node_id_map.pkl: map of the titles/keywords to unique ids.
node_embeds/: directory with semantic embeddings for the titles/keywords.
two_hop_ngbrs.pkl: final graph with the two hop neighborhood for all titles and keywords
"""
parser = argparse.ArgumentParser(description='Construct graph from query-asin dataset')
parser.add_argument('--path', metavar='P', type=str, help='path to the query-asin dataset')
parser.add_argument('--frac', metavar='F', type=float, help='incase we need to only process a fraction of the dataset')
parser.add_argument('--threshold', metavar='T', type=int, help='threshold on the size of neighborhood')
args = parser.parse_args()
make_semantic_node_embeddings(args.path,df_frac=1)
construct_graph(args.path,df_frac=args.frac,threshold=args.threshold)