Skip to content

Commit

Permalink
run black
Browse files Browse the repository at this point in the history
  • Loading branch information
Nathaniel Imel authored and Nathaniel Imel committed Nov 11, 2023
1 parent b835a22 commit 4beca1f
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 33 deletions.
14 changes: 7 additions & 7 deletions src/examples/scratch/run_convergence_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,27 @@

atlas_dir = "/Users/nathanielimel/uci/projects/sciterra/src/examples/scratch/outputs/atlas_s2-11-10-23_centered_hafenetal"

def main():

atl = Atlas.load(atlas_dir)
def main():
atl = Atlas.load(atlas_dir)

kernels = atl.history['kernel_size']
kernels = atl.history["kernel_size"]

con_d = 3
kernel_size = 10
converged_filter = kernels[:, -con_d] >= kernel_size
ids = np.array(atl.projection.index_to_identifier)
converged_pub_ids = ids[converged_filter]

crt = Cartographer(vectorizer=SciBERTVectorizer())
crt = Cartographer(vectorizer=SciBERTVectorizer())

measurements = crt.measure_topography(
atl,
atl,
ids=converged_pub_ids,
metrics=["density", "edginess"],
metrics=["density", "edginess"],
kernel_size=10,
)


if __name__ == "__main__":
main()
main()
2 changes: 1 addition & 1 deletion src/sciterra/mapping/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __getitem__(self, identifier: str) -> Publication:
if identifier in self.publications:
return self.publications[identifier]
raise ValueError(f"Identifier {identifier} not in Atlas.")

def ids(self) -> list[str]:
"""Get a list of all the publication identifiers in the Atlas."""
return list(self.publications.keys())
Expand Down
31 changes: 15 additions & 16 deletions src/sciterra/mapping/cartography.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ def project(self, atl: Atlas, **kwargs) -> Atlas:
for id, pub in atl_filtered.publications.items()
if id in merged_projection.identifier_to_index
}
assert not set(atl_filtered.ids()) - set(
embedded_publications.keys()
)
assert not set(atl_filtered.ids()) - set(embedded_publications.keys())

# Overwrite atlas data
atl_filtered.publications = embedded_publications
Expand Down Expand Up @@ -309,15 +307,15 @@ def filter_by_attributes(
"abstract",
"publication_date",
],
record_pubs_per_update = False,
record_pubs_per_update=False,
**kwargs,
) -> Atlas:
"""Update an atlas by dropping publications (and corresponding data in projection) when certain fields are empty.
Args:
atl: the Atlas containing publications to filter
attributes: the list of attributes to filter publications from the atlas IF any of items are None for a publication. For example, if attributes = ["abstract"], then all publications `pub` such that `pub.abstract is None` is True will be removed from the atlas, along with the corresponding data in the projection.
attributes: the list of attributes to filter publications from the atlas IF any of items are None for a publication. For example, if attributes = ["abstract"], then all publications `pub` such that `pub.abstract is None` is True will be removed from the atlas, along with the corresponding data in the projection.
record_pubs_per_update: whether to track all the publications that exist in the resulting atlas to `self.pubs_per_update`. This should only be set to `True` when you need to later filter by degree of convergence of the atlas. This is an important parameter because `self.filter` is called in `self.project`, which typically is called after `self.expand`, where we pass in the same parameter.
Expand All @@ -341,7 +339,7 @@ def filter_by_attributes(
self.pubs_per_update[-1] = atl_filtered.ids()

return atl_filtered

def filter_by_ids(
self,
atl: Atlas,
Expand All @@ -357,15 +355,19 @@ def filter_by_ids(
drop_ids: the list of publications to filter; all publications in `atl` matching one of these ids will be removed.
"""

if all(x is not None for x in [keep_ids, drop_ids]):
raise ValueError("You must pass exactly one of `keep_ids` or `drop_ids`, but both had a value that was not `None`.")
raise ValueError(
"You must pass exactly one of `keep_ids` or `drop_ids`, but both had a value that was not `None`."
)
if keep_ids is not None:
filter_ids = set([id for id in atl.ids() if id not in keep_ids])
elif drop_ids is not None:
filter_ids = set(drop_ids)
else:
raise ValueError("You must pass exactly one of `keep_ids` or `drop_ids`, but both had value `None`.")
raise ValueError(
"You must pass exactly one of `keep_ids` or `drop_ids`, but both had value `None`."
)

# Keep track of the bad identifiers to skip them in future expansions
new_bad_ids = atl.bad_ids.union(filter_ids)
Expand Down Expand Up @@ -410,7 +412,6 @@ def filter_by_ids(

return atl_filtered


########################################################################
# Record Atlas history
########################################################################
Expand Down Expand Up @@ -581,14 +582,12 @@ def measure_topography(
ids = atl.ids()
else:
ids = list(ids)

if not ids:
raise Exception("No publications to measure topography of.")

# Get publication dates, for filtering
dates = np.array([
atl[identifier].publication_date for identifier in ids
])
dates = np.array([atl[identifier].publication_date for identifier in ids])

# Get pairwise cosine similarities for ids
embeddings = atl.projection.identifiers_to_embeddings(ids)
Expand Down Expand Up @@ -720,8 +719,8 @@ def iterate_expand(

# Obtain document embeddings for all new abstracts.
atl = crt.project(
atl,
verbose=True,
atl,
verbose=True,
record_pubs_per_update=record_pubs_per_update,
)
print_progress(atl)
Expand Down
10 changes: 6 additions & 4 deletions src/sciterra/vectorization/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ def identifiers_to_embeddings(self, identifiers: list[str]) -> np.ndarray:
return self.embeddings[self.identifiers_to_indices(identifiers)]

# def identifier_to_embedding(self, identifier: str) -> np.ndarray:
# """Retrieve the document embedding of a Publication."""
# return self.embeddings[self.identifier_to_index[identifier]]
# """Retrieve the document embedding of a Publication."""
# return self.embeddings[self.identifier_to_index[identifier]]

def identifiers_to_indices(self, identifiers: list[str]) -> np.ndarray:
"""Retrieve the embedding indices for a list of identifiers."""
return np.array([self.identifier_to_index[identifier] for identifier in identifiers])
return np.array(
[self.identifier_to_index[identifier] for identifier in identifiers]
)

def __len__(self) -> int:
return len(self.identifier_to_index)
Expand Down
9 changes: 4 additions & 5 deletions src/tests/test_cartography.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def test_project_correct_number(self, tmp_path):
# 1. Simulate first part of project
# 'only project publications that have abstracts'
atl_filtered = TestS2SBProjection.crt.filter_by_attributes(
atl_exp_double, attributes=["abstract"],
atl_exp_double,
attributes=["abstract"],
)

# 'get only embeddings for publications not already projected in atlas'
Expand Down Expand Up @@ -337,7 +338,7 @@ class TestTopography:

def test_measure_topography_full(self):
bibtex_fp = ten_pub_bibtex_fp
atl = TestTopography.crt.bibtex_to_atlas(bibtex_fp)
atl = TestTopography.crt.bibtex_to_atlas(bibtex_fp)
atl = TestTopography.crt.project(atl)
metrics = [
"density",
Expand Down Expand Up @@ -395,7 +396,7 @@ def test_measure_topography_realistic(self):
ids=ids,
metrics=metrics,
)
assert measurements.shape == tuple((len(atl), len(metrics)))
assert measurements.shape == tuple((len(atl), len(metrics)))


class TestConvergence:
Expand Down Expand Up @@ -585,5 +586,3 @@ def test_iterate_expand(self, tmp_path):
n_sources_max=None,
record_pubs_per_update=True,
)


0 comments on commit 4beca1f

Please sign in to comment.