-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add iterate_expand and search module
- Loading branch information
Nathaniel Imel
authored and
Nathaniel Imel
committed
Nov 10, 2023
1 parent
362b43c
commit 86f2d31
Showing
8 changed files
with
178 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
"""Convenience function for building out an Atlas of publications via iterative expansion, i.e. search for similar publications.""" | ||
|
||
from .atlas import Atlas | ||
from .cartography import Cartographer | ||
|
||
def iterate_expand( | ||
atl: Atlas, | ||
crt: Cartographer, | ||
atlas_dir: str, | ||
target_size: int, | ||
max_failed_expansions: int, | ||
center: str = None, | ||
n_pubs_max: int = None, | ||
call_size: int = None, | ||
n_sources_max: int = None, | ||
record_pubs_per_update: bool = False, | ||
) -> None: | ||
"""Build out an Atlas of publications by iterating a sequence of [expand, save, project, save, track, save]. | ||
Args: | ||
atl: the Atlas to expand | ||
crt: the Cartographer to use | ||
atlas_dir: the directory where Atlas binaries will be saved/loaded from | ||
target_size: stop iterating when we reach this number of publications in the Atlas | ||
max_failed_expansions: stop iterating when we fail to add new publications after this many successive iterations. | ||
center: (if given) center the search on this publication, preferentially searching related publications. | ||
n_pubs_max: maximum number of publications allowed in the expansion. | ||
call_size: maximum number of papers to call API for in one query; if less than `len(paper_ids)`, chunking will be performed. | ||
n_sources_max: maximum number of publications (already in the atlas) to draw references and citations from. | ||
record_pubs_per_update: whether to track all the publications that exist in the resulting atlas to `self.pubs_per_update`. This should only be set to `True` when you need to later filter by degree of convergence of the atlas. | ||
""" | ||
converged = False | ||
print_progress = lambda atl: print( # view incremental progress | ||
f"Atlas has {len(atl)} publications and {len(atl.projection) if atl.projection is not None else 'None'} embeddings." | ||
) | ||
|
||
# Expansion loop | ||
failures = 0 | ||
while not converged: | ||
len_prev = len(atl) | ||
|
||
# Retrieve up to n_pubs_max citations and references. | ||
atl = crt.expand( | ||
atl, | ||
center=center, | ||
n_pubs_max=n_pubs_max, | ||
call_size=call_size, | ||
n_sources_max=n_sources_max, | ||
record_pubs_per_update=record_pubs_per_update, | ||
) | ||
print_progress(atl) | ||
atl.save(atlas_dir) | ||
|
||
# Obtain document embeddings for all new abstracts. | ||
atl = crt.project(atl, verbose=True) | ||
print_progress(atl) | ||
atl.save(atlas_dir) | ||
|
||
atl = crt.track(atl) | ||
atl.save(atlas_dir) | ||
|
||
if len_prev == len(atl): | ||
failures += 0 | ||
else: | ||
failures = 0 | ||
|
||
converged = len(atl) >= target_size or failures >= max_failed_expansions | ||
print() | ||
|
||
print(f"Expansion loop exited with atlas size {len(atl)}.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from sciterra.mapping.search import iterate_expand | ||
from sciterra.mapping.cartography import Cartographer | ||
from sciterra.librarians.s2librarian import SemanticScholarLibrarian | ||
from sciterra.vectorization.scibert import SciBERTVectorizer | ||
|
||
single_pub_bibtex_fp = "src/tests/data/single_publication.bib" | ||
|
||
atlas_dir = "atlas_tmpdir" | ||
|
||
class TestSearch: | ||
|
||
def test_search(self, tmp_path): | ||
|
||
librarian = SemanticScholarLibrarian() | ||
vectorizer = SciBERTVectorizer() | ||
crt = Cartographer(librarian, vectorizer) | ||
|
||
# Load single file from bibtex | ||
bibtex_fp = single_pub_bibtex_fp | ||
|
||
path = tmp_path / atlas_dir | ||
path.mkdir() | ||
|
||
# Construct Atlas | ||
atl = crt.bibtex_to_atlas(bibtex_fp) | ||
|
||
pub = list(atl.publications.values())[0] | ||
center = pub.identifier | ||
|
||
iterate_expand( | ||
atl=atl, | ||
crt=crt, | ||
atlas_dir=path, | ||
target_size=100, | ||
max_failed_expansions=2, | ||
center=center, | ||
n_pubs_max=10, | ||
call_size=None, | ||
n_sources_max=None, | ||
record_pubs_per_update=True, | ||
) |