Skip to content

Commit

Permalink
ran basic experiment with convergence filter
Browse files Browse the repository at this point in the history
  • Loading branch information
Nathaniel Imel authored and Nathaniel Imel committed Nov 11, 2023
1 parent f68981b commit 04b2f41
Show file tree
Hide file tree
Showing 4 changed files with 6,662 additions and 37 deletions.
1,493 changes: 1,493 additions & 0 deletions src/examples/scratch/convergence_scratch.ipynb

Large diffs are not rendered by default.

48 changes: 11 additions & 37 deletions src/examples/scratch/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from examples.scratch import util

from sciterra.mapping.atlas import Atlas
from sciterra.mapping.cartography import Cartographer
from sciterra.mapping.cartography import Cartographer, iterate_expand
from sciterra.librarians import ADSLibrarian, SemanticScholarLibrarian
from sciterra.librarians import ADSLibrarian
from sciterra.vectorization.scibert import SciBERTVectorizer
Expand Down Expand Up @@ -50,44 +50,18 @@ def main(args):
print(f"Initializing atlas.")
atl = atl_center

converged = False
print_progress = lambda atl: print( # view incremental progress
f"Atlas has {len(atl)} publications and {len(atl.projection) if atl.projection is not None else 'None'} embeddings."
iterate_expand(
atl=atl,
crt=crt,
atlas_dir=atlas_dir,
target_size=target,
max_failed_expansions=max_failures,
center=center,
n_pubs_max=n_pubs_max,
call_size=call_size,
record_pubs_per_update=True,
)

# Expansion loop
failures = 0
while not converged:
len_prev = len(atl)

# Retrieve up to n_pubs_max citations and references.
atl = crt.expand(
atl,
center=center,
n_pubs_max=n_pubs_max,
call_size=call_size,
record_pubs_per_update=True,
)
print_progress(atl)
atl.save(atlas_dir)

# Obtain document embeddings for all new abstracts.
atl = crt.project(atl, verbose=True)
print_progress(atl)
atl.save(atlas_dir)

atl = crt.track(atl)

if len_prev == len(atl):
failures += 0
else:
failures = 0

converged = len(atl) >= target or failures >= max_failures
print()

print(f"Expansion loop exited with atlas size {len(atl)}.")


if __name__ == "__main__":
args = util.get_args()
Expand Down
35 changes: 35 additions & 0 deletions src/examples/scratch/run_convergence_scratch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
import pandas as pd
import plotnine as pn
import matplotlib.pyplot as plt

from sciterra.mapping.atlas import Atlas
from sciterra.mapping.cartography import Cartographer
from sciterra.vectorization.scibert import SciBERTVectorizer

atlas_dir = "/Users/nathanielimel/uci/projects/sciterra/src/examples/scratch/outputs/atlas_s2-11-10-23_centered_hafenetal"

def main():

atl = Atlas.load(atlas_dir)

kernels = atl.history['kernel_size']

con_d = 3
kernel_size = 10
converged_filter = kernels[:, -con_d] >= kernel_size
ids = np.array(atl.projection.index_to_identifier)
converged_pub_ids = ids[converged_filter]

crt = Cartographer(vectorizer=SciBERTVectorizer())

measurements = crt.measure_topography(
atl,
ids=converged_pub_ids,
metrics=["density", "edginess"],
kernel_size=10,
)


if __name__ == "__main__":
main()
Loading

0 comments on commit 04b2f41

Please sign in to comment.