Skip to content

Commit

Permalink
run black
Browse files Browse the repository at this point in the history
  • Loading branch information
Nathaniel Imel authored and Nathaniel Imel committed Dec 30, 2023
1 parent e154ff5 commit a1cd0bd
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 16 deletions.
33 changes: 19 additions & 14 deletions src/sciterra/mapping/cartography.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,27 @@ def batch_cospsi_matrix(embeddings: np.ndarray) -> np.ndarray:

# Helper function for filtering
def pub_has_attributes(
pub: Publication,
pub: Publication,
attributes: list[str],
) -> bool:
) -> bool:
"""Return True if a publication has all `attributes`.
Args:
attributes: the list of attributes to check are not `None` for each publication from the atlas.
Args:
attributes: the list of attributes to check are not `None` for each publication from the atlas.
"""
return pub is not None and all([getattr(pub, attr) is not None for attr in attributes])
return pub is not None and all(
[getattr(pub, attr) is not None for attr in attributes]
)


def pub_has_fields_of_study(
pub: Publication,
fields_of_study: list[str],
) -> bool:
"""Return true if any of `pub.fields_of_study` are in passed `fields_of_study`."""
return pub is not None and any([field in fields_of_study for field in pub.fields_of_study])
return pub is not None and any(
[field in fields_of_study for field in pub.fields_of_study]
)


##############################################################################
Expand Down Expand Up @@ -348,10 +353,12 @@ def expand(
def filter_by_func(
self,
atl: Atlas,
require_func: Callable[[Publication], bool] = lambda pub: pub_has_attributes(pub, attributes=[
"abstract",
"publication_date",
"fields_of_study",
require_func: Callable[[Publication], bool] = lambda pub: pub_has_attributes(
pub,
attributes=[
"abstract",
"publication_date",
"fields_of_study",
],
),
record_pubs_per_update=False,
Expand All @@ -371,9 +378,7 @@ def filter_by_func(
"""
# Filter publications
invalid_pubs = {
id: pub
for id, pub in atl.publications.items()
if not require_func(pub)
id: pub for id, pub in atl.publications.items() if not require_func(pub)
}

# Do not update if unnecessary
Expand Down
2 changes: 1 addition & 1 deletion src/sciterra/mapping/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
vectorizer_name: a str name of a vectorizer, one of `vectorization.vectorizers.keys()`, e.g. 'BOW' or 'SciBERT'.
vectorizer_kwargs: keyword args propogated to a Vectorizer initialization; if values are `None` they will be omitted
vectorizer_kwargs: keyword args propogated to a Vectorizer initialization; if values are `None` they will be omitted
"""
######################################################################
# Initialize cartography tools
Expand Down
6 changes: 5 additions & 1 deletion src/tests/test_cartography.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from datetime import datetime

from sciterra.mapping.atlas import Atlas
from sciterra.mapping.cartography import Cartographer, pub_has_attributes, pub_has_fields_of_study
from sciterra.mapping.cartography import (
Cartographer,
pub_has_attributes,
pub_has_fields_of_study,
)
from sciterra.librarians.s2librarian import SemanticScholarLibrarian
from sciterra.mapping.publication import Publication
from sciterra.vectorization import SciBERTVectorizer, Word2VecVectorizer
Expand Down

0 comments on commit a1cd0bd

Please sign in to comment.