-
Notifications
You must be signed in to change notification settings - Fork 123
/
main.py
262 lines (229 loc) · 8.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import json
import configparser
import os
import time
from openai import OpenAI
from requests import Session
from typing import TypeVar, Generator
import io
from retry import retry
from tqdm import tqdm
from arxiv_scraper import get_papers_from_arxiv_rss_api
from filter_papers import filter_by_author, filter_by_gpt
from parse_json_to_md import render_md_string
from push_to_slack import push_to_slack
from arxiv_scraper import EnhancedJSONEncoder
T = TypeVar("T")
def batched(items: list[T], batch_size: int) -> list[T]:
# takes a list and returns a list of list with batch_size
return [items[i : i + batch_size] for i in range(0, len(items), batch_size)]
def argsort(seq):
# native python version of an 'argsort'
# http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
return sorted(range(len(seq)), key=seq.__getitem__)
def get_paper_batch(
session: Session,
ids: list[str],
S2_API_KEY: str,
fields: str = "paperId,title",
**kwargs,
) -> list[dict]:
# gets a batch of papers. taken from the sem scholar example.
params = {
"fields": fields,
**kwargs,
}
if S2_API_KEY is None:
headers = {}
else:
headers = {
"X-API-KEY": S2_API_KEY,
}
body = {
"ids": ids,
}
# https://api.semanticscholar.org/api-docs/graph#tag/Paper-Data/operation/post_graph_get_papers
with session.post(
"https://api.semanticscholar.org/graph/v1/paper/batch",
params=params,
headers=headers,
json=body,
) as response:
response.raise_for_status()
return response.json()
def get_author_batch(
session: Session,
ids: list[str],
S2_API_KEY: str,
fields: str = "name,hIndex,citationCount",
**kwargs,
) -> list[dict]:
# gets a batch of authors. analogous to author batch
params = {
"fields": fields,
**kwargs,
}
if S2_API_KEY is None:
headers = {}
else:
headers = {
"X-API-KEY": S2_API_KEY,
}
body = {
"ids": ids,
}
with session.post(
"https://api.semanticscholar.org/graph/v1/author/batch",
params=params,
headers=headers,
json=body,
) as response:
response.raise_for_status()
return response.json()
@retry(tries=3, delay=2.0)
def get_one_author(session, author: str, S2_API_KEY: str) -> str:
# query the right endpoint https://api.semanticscholar.org/graph/v1/author/search?query=adam+smith
params = {"query": author, "fields": "authorId,name,hIndex", "limit": "10"}
if S2_API_KEY is None:
headers = {}
else:
headers = {
"X-API-KEY": S2_API_KEY,
}
with session.get(
"https://api.semanticscholar.org/graph/v1/author/search",
params=params,
headers=headers,
) as response:
# try catch for errors
try:
response.raise_for_status()
response_json = response.json()
if len(response_json["data"]) >= 1:
return response_json["data"]
else:
return None
except Exception as ex:
print("exception happened" + str(ex))
return None
def get_papers(
ids: list[str], S2_API_KEY: str, batch_size: int = 100, **kwargs
) -> Generator[dict, None, None]:
# gets all papers, doing batching to avoid hitting the max paper limit.
# use a session to reuse the same TCP connection
with Session() as session:
# take advantage of S2 batch paper endpoint
for ids_batch in batched(ids, batch_size=batch_size):
yield from get_paper_batch(session, ids_batch, S2_API_KEY, **kwargs)
def get_authors(
all_authors: list[str], S2_API_KEY: str, batch_size: int = 100, **kwargs
):
# first get the list of all author ids by querying by author names
author_metadata_dict = {}
with Session() as session:
for author in tqdm(all_authors):
auth_map = get_one_author(session, author, S2_API_KEY)
if auth_map is not None:
author_metadata_dict[author] = auth_map
# add a 20ms wait time to avoid rate limiting
# otherwise, semantic scholar aggressively rate limits, so do 1s
if S2_API_KEY is not None:
time.sleep(0.02)
else:
time.sleep(1.0)
return author_metadata_dict
def get_papers_from_arxiv(config):
area_list = config["FILTERING"]["arxiv_category"].split(",")
paper_set = set()
for area in area_list:
papers = get_papers_from_arxiv_rss_api(area.strip(), config)
paper_set.update(set(papers))
if config["OUTPUT"].getboolean("debug_messages"):
print("Number of papers:" + str(len(paper_set)))
return paper_set
def parse_authors(lines):
# parse the comma-separated author list, ignoring lines that are empty and starting with #
author_ids = []
authors = []
for line in lines:
if line.startswith("#"):
continue
if not line.strip():
continue
author_split = line.split(",")
author_ids.append(author_split[1].strip())
authors.append(author_split[0].strip())
return authors, author_ids
if __name__ == "__main__":
# now load config.ini
config = configparser.ConfigParser()
config.read("configs/config.ini")
S2_API_KEY = os.environ.get("S2_KEY")
OAI_KEY = os.environ.get("OAI_KEY")
if OAI_KEY is None:
raise ValueError(
"OpenAI key is not set - please set OAI_KEY to your OpenAI key"
)
openai_client = OpenAI(api_key=OAI_KEY)
# load the author list
with io.open("configs/authors.txt", "r") as fopen:
author_names, author_ids = parse_authors(fopen.readlines())
author_id_set = set(author_ids)
papers = list(get_papers_from_arxiv(config))
# dump all papers for debugging
all_authors = set()
for paper in papers:
all_authors.update(set(paper.authors))
if config["OUTPUT"].getboolean("debug_messages"):
print("Getting author info for " + str(len(all_authors)) + " authors")
all_authors = get_authors(list(all_authors), S2_API_KEY)
if config["OUTPUT"].getboolean("dump_debug_file"):
with open(
config["OUTPUT"]["output_path"] + "papers.debug.json", "w"
) as outfile:
json.dump(papers, outfile, cls=EnhancedJSONEncoder, indent=4)
with open(
config["OUTPUT"]["output_path"] + "all_authors.debug.json", "w"
) as outfile:
json.dump(all_authors, outfile, cls=EnhancedJSONEncoder, indent=4)
with open(
config["OUTPUT"]["output_path"] + "author_id_set.debug.json", "w"
) as outfile:
json.dump(list(author_id_set), outfile, cls=EnhancedJSONEncoder, indent=4)
selected_papers, all_papers, sort_dict = filter_by_author(
all_authors, papers, author_id_set, config
)
filter_by_gpt(
all_authors,
papers,
config,
openai_client,
all_papers,
selected_papers,
sort_dict,
)
# sort the papers by relevance and novelty
keys = list(sort_dict.keys())
values = list(sort_dict.values())
sorted_keys = [keys[idx] for idx in argsort(values)[::-1]]
selected_papers = {key: selected_papers[key] for key in sorted_keys}
if config["OUTPUT"].getboolean("debug_messages"):
print(sort_dict)
print(selected_papers)
# pick endpoints and push the summaries
if len(papers) > 0:
if config["OUTPUT"].getboolean("dump_json"):
with open(config["OUTPUT"]["output_path"] + "output.json", "w") as outfile:
json.dump(selected_papers, outfile, indent=4)
if config["OUTPUT"].getboolean("dump_md"):
with open(config["OUTPUT"]["output_path"] + "output.md", "w") as f:
f.write(render_md_string(selected_papers))
# only push to slack for non-empty dicts
if config["OUTPUT"].getboolean("push_to_slack"):
SLACK_KEY = os.environ.get("SLACK_KEY")
if SLACK_KEY is None:
print(
"Warning: push_to_slack is true, but SLACK_KEY is not set - not pushing to slack"
)
else:
push_to_slack(selected_papers)