Skip to content

Commit

Permalink
add iterate_expand and search module
Browse files Browse the repository at this point in the history
  • Loading branch information
Nathaniel Imel authored and Nathaniel Imel committed Nov 10, 2023
1 parent 362b43c commit 86f2d31
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/examples/scratch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def main(args):
center=center,
n_pubs_max=n_pubs_max,
call_size=call_size,
record_pubs_per_update=True,
)
print_progress(atl)
atl.save(atlas_dir)
Expand All @@ -75,6 +76,8 @@ def main(args):
print_progress(atl)
atl.save(atlas_dir)

atl = crt.track(atl)

if len_prev == len(atl):
failures += 0
else:
Expand Down
2 changes: 2 additions & 0 deletions src/sciterra/librarians/s2librarian.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def get_publications(
return []

total = len(paper_ids)
if call_size is None:
call_size = CALL_SIZE
chunked_ids = chunk_ids(paper_ids, call_size=call_size)

if None in paper_ids:
Expand Down
29 changes: 26 additions & 3 deletions src/sciterra/mapping/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,22 @@


class Atlas:

"""Data structure for storing publications.
`self.projection`: the Projection object containing the embeddings of all publications and their mapping to str identifiers.
`self.bad_ids`: a list of identifiers that have failed for some reason or other during an expansion, and will be excluded from subsequent expansions.
`self.history`: dict of the form {'pubs_per_update': list[list[str]], 'kernel_size': np.ndarray of ints of shape `(num_pubs, last_update)` where last_update <= the total number of expansions performed.}
"""

def __init__(
self,
publications: list[Publication],
projection: Projection = None,
bad_ids: set[str] = set(),
history: dict[str, Any] = dict(),
) -> None:
if not isinstance(publications, list):
raise ValueError
Expand All @@ -31,6 +42,8 @@ def __init__(

self.bad_ids = bad_ids

self.history = history

######################################################################
# Lookup ######################################################################

Expand Down Expand Up @@ -64,7 +77,12 @@ def save(
return

attributes = {
k: getattr(self, k) for k in ["publications", "projection", "bad_ids"]
k: getattr(self, k) for k in [
"publications",
"projection",
"bad_ids",
"history",
]
}

for attribute in attributes:
Expand All @@ -91,10 +109,15 @@ def load(
Warnings cannot be silenced.
Args:
atlas_dirpath: file with vocab, assumed output from `save_to_file`
atlas_dirpath: directory where .pkl binaries will be read from
"""
attributes = {k: None for k in ["publications", "projection", "bad_ids"]}
attributes = {k: None for k in [
"publications",
"projection",
"bad_ids",
"history",
]}
for attribute in attributes:
fn = f"{attribute}.pkl"
fp = os.path.join(atlas_dirpath, fn)
Expand Down
23 changes: 23 additions & 0 deletions src/sciterra/mapping/cartography.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,29 @@ def filter(
# Record Atlas history
########################################################################

def track(
self,
atl: Atlas,
pubs: list[str] = None,
pubs_per_update: list[list[str]] = None,
) -> Atlas:
"""Overwrite the data associated with tracking degree of convergence of publications in an atlas over multiple expansions. N.B.: the atlas must be fully projected, or else `converged_kernel_size` will raise a KeyError.
Args:
atl: the Atlas that will be updated by overwriting `Atlas.history`
"""
self.record_update_history(pubs, pubs_per_update)
kernel_size = self.converged_kernel_size(atl)
atl.history = {
'pubs_per_update': self.pubs_per_update if pubs_per_update is None else pubs_per_update,
'kernel_size': kernel_size,
}
return atl

########################################################################
# Record Atlas history
########################################################################

def record_update_history(
self,
pubs: list[str] = None,
Expand Down
80 changes: 80 additions & 0 deletions src/sciterra/mapping/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Convenience function for building out an Atlas of publications via iterative expansion, i.e. search for similar publications."""

from .atlas import Atlas
from .cartography import Cartographer

def iterate_expand(
atl: Atlas,
crt: Cartographer,
atlas_dir: str,
target_size: int,
max_failed_expansions: int,
center: str = None,
n_pubs_max: int = None,
call_size: int = None,
n_sources_max: int = None,
record_pubs_per_update: bool = False,
) -> None:
"""Build out an Atlas of publications by iterating a sequence of [expand, save, project, save, track, save].
Args:
atl: the Atlas to expand
crt: the Cartographer to use
atlas_dir: the directory where Atlas binaries will be saved/loaded from
target_size: stop iterating when we reach this number of publications in the Atlas
max_failed_expansions: stop iterating when we fail to add new publications after this many successive iterations.
center: (if given) center the search on this publication, preferentially searching related publications.
n_pubs_max: maximum number of publications allowed in the expansion.
call_size: maximum number of papers to call API for in one query; if less than `len(paper_ids)`, chunking will be performed.
n_sources_max: maximum number of publications (already in the atlas) to draw references and citations from.
record_pubs_per_update: whether to track all the publications that exist in the resulting atlas to `self.pubs_per_update`. This should only be set to `True` when you need to later filter by degree of convergence of the atlas.
"""
converged = False
print_progress = lambda atl: print( # view incremental progress
f"Atlas has {len(atl)} publications and {len(atl.projection) if atl.projection is not None else 'None'} embeddings."
)

# Expansion loop
failures = 0
while not converged:
len_prev = len(atl)

# Retrieve up to n_pubs_max citations and references.
atl = crt.expand(
atl,
center=center,
n_pubs_max=n_pubs_max,
call_size=call_size,
n_sources_max=n_sources_max,
record_pubs_per_update=record_pubs_per_update,
)
print_progress(atl)
atl.save(atlas_dir)

# Obtain document embeddings for all new abstracts.
atl = crt.project(atl, verbose=True)
print_progress(atl)
atl.save(atlas_dir)

atl = crt.track(atl)
atl.save(atlas_dir)

if len_prev == len(atl):
failures += 0
else:
failures = 0

converged = len(atl) >= target_size or failures >= max_failed_expansions
print()

print(f"Expansion loop exited with atlas size {len(atl)}.")
4 changes: 2 additions & 2 deletions src/sciterra/misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def wrapped_fn(*args, **kwargs):
return _keep_trying


def chunk_ids(ids: list[str], call_size=2000):
def chunk_ids(ids: list[str], call_size):
"""Helper function to chunk bibcodes or paperIds into smaller sublists if appropriate."""
# Break into chunks
assert ( # TODO: this seems like an irrelevant copypasta since we use SearchQuery
assert ( # TODO: this seems like an irrelevant copypasta since we use SearchQuery
call_size <= 2000
), "Max number of calls ExportQuery can handle at a time is 2000."
if len(ids) > call_size:
Expand Down
2 changes: 1 addition & 1 deletion src/tests/test_cartography.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def test_pubs_per_update_expand_consistent(self, tmp_path):
)

assert len(TestConvergence.crt.pubs_per_update) == num_expansions
breakpoint()

TestConvergence.crt.record_update_history()

# need to project all pubs before kernel calculations!
Expand Down
41 changes: 41 additions & 0 deletions src/tests/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from sciterra.mapping.search import iterate_expand
from sciterra.mapping.cartography import Cartographer
from sciterra.librarians.s2librarian import SemanticScholarLibrarian
from sciterra.vectorization.scibert import SciBERTVectorizer

single_pub_bibtex_fp = "src/tests/data/single_publication.bib"

atlas_dir = "atlas_tmpdir"

class TestSearch:

def test_search(self, tmp_path):

librarian = SemanticScholarLibrarian()
vectorizer = SciBERTVectorizer()
crt = Cartographer(librarian, vectorizer)

# Load single file from bibtex
bibtex_fp = single_pub_bibtex_fp

path = tmp_path / atlas_dir
path.mkdir()

# Construct Atlas
atl = crt.bibtex_to_atlas(bibtex_fp)

pub = list(atl.publications.values())[0]
center = pub.identifier

iterate_expand(
atl=atl,
crt=crt,
atlas_dir=path,
target_size=100,
max_failed_expansions=2,
center=center,
n_pubs_max=10,
call_size=None,
n_sources_max=None,
record_pubs_per_update=True,
)

0 comments on commit 86f2d31

Please sign in to comment.