diff --git a/src/examples/scratch/main.py b/src/examples/scratch/main.py index e341875..24014ec 100644 --- a/src/examples/scratch/main.py +++ b/src/examples/scratch/main.py @@ -66,6 +66,7 @@ def main(args): center=center, n_pubs_max=n_pubs_max, call_size=call_size, + record_pubs_per_update=True, ) print_progress(atl) atl.save(atlas_dir) @@ -75,6 +76,8 @@ def main(args): print_progress(atl) atl.save(atlas_dir) + atl = crt.track(atl) + if len_prev == len(atl): failures += 0 else: diff --git a/src/sciterra/librarians/s2librarian.py b/src/sciterra/librarians/s2librarian.py index 7cd5870..fb3bbdc 100644 --- a/src/sciterra/librarians/s2librarian.py +++ b/src/sciterra/librarians/s2librarian.py @@ -135,6 +135,8 @@ def get_publications( return [] total = len(paper_ids) + if call_size is None: + call_size = CALL_SIZE chunked_ids = chunk_ids(paper_ids, call_size=call_size) if None in paper_ids: diff --git a/src/sciterra/mapping/atlas.py b/src/sciterra/mapping/atlas.py index abb505b..4808e81 100644 --- a/src/sciterra/mapping/atlas.py +++ b/src/sciterra/mapping/atlas.py @@ -16,11 +16,22 @@ class Atlas: + + """Data structure for storing publications. + + `self.projection`: the Projection object containing the embeddings of all publications and their mapping to str identifiers. + + `self.bad_ids`: a list of identifiers that have failed for some reason or other during an expansion, and will be excluded from subsequent expansions. + + `self.history`: dict of the form {'pubs_per_update': list[list[str]], 'kernel_size': np.ndarray of ints of shape `(num_pubs, last_update)` where last_update <= the total number of expansions performed.} + """ + def __init__( self, publications: list[Publication], projection: Projection = None, bad_ids: set[str] = set(), + history: dict[str, Any] = dict(), ) -> None: if not isinstance(publications, list): raise ValueError @@ -31,6 +42,8 @@ def __init__( self.bad_ids = bad_ids + self.history = history + ###################################################################### # Lookup ###################################################################### @@ -64,7 +77,12 @@ def save( return attributes = { - k: getattr(self, k) for k in ["publications", "projection", "bad_ids"] + k: getattr(self, k) for k in [ + "publications", + "projection", + "bad_ids", + "history", + ] } for attribute in attributes: @@ -91,10 +109,15 @@ def load( Warnings cannot be silenced. Args: - atlas_dirpath: file with vocab, assumed output from `save_to_file` + atlas_dirpath: directory where .pkl binaries will be read from """ - attributes = {k: None for k in ["publications", "projection", "bad_ids"]} + attributes = {k: None for k in [ + "publications", + "projection", + "bad_ids", + "history", + ]} for attribute in attributes: fn = f"{attribute}.pkl" fp = os.path.join(atlas_dirpath, fn) diff --git a/src/sciterra/mapping/cartography.py b/src/sciterra/mapping/cartography.py index 244b108..c211d1c 100644 --- a/src/sciterra/mapping/cartography.py +++ b/src/sciterra/mapping/cartography.py @@ -366,6 +366,29 @@ def filter( # Record Atlas history ######################################################################## + def track( + self, + atl: Atlas, + pubs: list[str] = None, + pubs_per_update: list[list[str]] = None, + ) -> Atlas: + """Overwrite the data associated with tracking degree of convergence of publications in an atlas over multiple expansions. N.B.: the atlas must be fully projected, or else `converged_kernel_size` will raise a KeyError. + + Args: + atl: the Atlas that will be updated by overwriting `Atlas.history` + """ + self.record_update_history(pubs, pubs_per_update) + kernel_size = self.converged_kernel_size(atl) + atl.history = { + 'pubs_per_update': self.pubs_per_update if pubs_per_update is None else pubs_per_update, + 'kernel_size': kernel_size, + } + return atl + + ######################################################################## + # Record Atlas history + ######################################################################## + def record_update_history( self, pubs: list[str] = None, diff --git a/src/sciterra/mapping/search.py b/src/sciterra/mapping/search.py new file mode 100644 index 0000000..3665dd4 --- /dev/null +++ b/src/sciterra/mapping/search.py @@ -0,0 +1,80 @@ +"""Convenience function for building out an Atlas of publications via iterative expansion, i.e. search for similar publications.""" + +from .atlas import Atlas +from .cartography import Cartographer + +def iterate_expand( + atl: Atlas, + crt: Cartographer, + atlas_dir: str, + target_size: int, + max_failed_expansions: int, + center: str = None, + n_pubs_max: int = None, + call_size: int = None, + n_sources_max: int = None, + record_pubs_per_update: bool = False, +) -> None: + """Build out an Atlas of publications by iterating a sequence of [expand, save, project, save, track, save]. + + Args: + atl: the Atlas to expand + + crt: the Cartographer to use + + atlas_dir: the directory where Atlas binaries will be saved/loaded from + + target_size: stop iterating when we reach this number of publications in the Atlas + + max_failed_expansions: stop iterating when we fail to add new publications after this many successive iterations. + + center: (if given) center the search on this publication, preferentially searching related publications. + + n_pubs_max: maximum number of publications allowed in the expansion. + + call_size: maximum number of papers to call API for in one query; if less than `len(paper_ids)`, chunking will be performed. + + n_sources_max: maximum number of publications (already in the atlas) to draw references and citations from. + + record_pubs_per_update: whether to track all the publications that exist in the resulting atlas to `self.pubs_per_update`. This should only be set to `True` when you need to later filter by degree of convergence of the atlas. + + """ + converged = False + print_progress = lambda atl: print( # view incremental progress + f"Atlas has {len(atl)} publications and {len(atl.projection) if atl.projection is not None else 'None'} embeddings." + ) + + # Expansion loop + failures = 0 + while not converged: + len_prev = len(atl) + + # Retrieve up to n_pubs_max citations and references. + atl = crt.expand( + atl, + center=center, + n_pubs_max=n_pubs_max, + call_size=call_size, + n_sources_max=n_sources_max, + record_pubs_per_update=record_pubs_per_update, + ) + print_progress(atl) + atl.save(atlas_dir) + + # Obtain document embeddings for all new abstracts. + atl = crt.project(atl, verbose=True) + print_progress(atl) + atl.save(atlas_dir) + + atl = crt.track(atl) + atl.save(atlas_dir) + + if len_prev == len(atl): + failures += 0 + else: + failures = 0 + + converged = len(atl) >= target_size or failures >= max_failed_expansions + print() + + print(f"Expansion loop exited with atlas size {len(atl)}.") diff --git a/src/sciterra/misc/utils.py b/src/sciterra/misc/utils.py index efaad4f..be02960 100644 --- a/src/sciterra/misc/utils.py +++ b/src/sciterra/misc/utils.py @@ -107,10 +107,10 @@ def wrapped_fn(*args, **kwargs): return _keep_trying -def chunk_ids(ids: list[str], call_size=2000): +def chunk_ids(ids: list[str], call_size): """Helper function to chunk bibcodes or paperIds into smaller sublists if appropriate.""" # Break into chunks - assert ( # TODO: this seems like an irrelevant copypasta since we use SearchQuery + assert ( # TODO: this seems like an irrelevant copypasta since we use SearchQuery call_size <= 2000 ), "Max number of calls ExportQuery can handle at a time is 2000." if len(ids) > call_size: diff --git a/src/tests/test_cartography.py b/src/tests/test_cartography.py index da0c7c7..ef602a2 100644 --- a/src/tests/test_cartography.py +++ b/src/tests/test_cartography.py @@ -475,7 +475,7 @@ def test_pubs_per_update_expand_consistent(self, tmp_path): ) assert len(TestConvergence.crt.pubs_per_update) == num_expansions - breakpoint() + TestConvergence.crt.record_update_history() # need to project all pubs before kernel calculations! diff --git a/src/tests/test_search.py b/src/tests/test_search.py new file mode 100644 index 0000000..5fcef93 --- /dev/null +++ b/src/tests/test_search.py @@ -0,0 +1,41 @@ +from sciterra.mapping.search import iterate_expand +from sciterra.mapping.cartography import Cartographer +from sciterra.librarians.s2librarian import SemanticScholarLibrarian +from sciterra.vectorization.scibert import SciBERTVectorizer + +single_pub_bibtex_fp = "src/tests/data/single_publication.bib" + +atlas_dir = "atlas_tmpdir" + +class TestSearch: + + def test_search(self, tmp_path): + + librarian = SemanticScholarLibrarian() + vectorizer = SciBERTVectorizer() + crt = Cartographer(librarian, vectorizer) + + # Load single file from bibtex + bibtex_fp = single_pub_bibtex_fp + + path = tmp_path / atlas_dir + path.mkdir() + + # Construct Atlas + atl = crt.bibtex_to_atlas(bibtex_fp) + + pub = list(atl.publications.values())[0] + center = pub.identifier + + iterate_expand( + atl=atl, + crt=crt, + atlas_dir=path, + target_size=100, + max_failed_expansions=2, + center=center, + n_pubs_max=10, + call_size=None, + n_sources_max=None, + record_pubs_per_update=True, + )