From a1cd0bdc03ef50a7120d63318c0d8f5372e8cc86 Mon Sep 17 00:00:00 2001 From: Nathaniel Imel Date: Sat, 30 Dec 2023 15:26:27 -0800 Subject: [PATCH] run black --- src/sciterra/mapping/cartography.py | 33 +++++++++++++++++------------ src/sciterra/mapping/tracing.py | 2 +- src/tests/test_cartography.py | 6 +++++- 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/sciterra/mapping/cartography.py b/src/sciterra/mapping/cartography.py index a1dab43..f272a6c 100644 --- a/src/sciterra/mapping/cartography.py +++ b/src/sciterra/mapping/cartography.py @@ -57,22 +57,27 @@ def batch_cospsi_matrix(embeddings: np.ndarray) -> np.ndarray: # Helper function for filtering def pub_has_attributes( - pub: Publication, + pub: Publication, attributes: list[str], - ) -> bool: +) -> bool: """Return True if a publication has all `attributes`. - - Args: - attributes: the list of attributes to check are not `None` for each publication from the atlas. + + Args: + attributes: the list of attributes to check are not `None` for each publication from the atlas. """ - return pub is not None and all([getattr(pub, attr) is not None for attr in attributes]) + return pub is not None and all( + [getattr(pub, attr) is not None for attr in attributes] + ) + def pub_has_fields_of_study( pub: Publication, fields_of_study: list[str], ) -> bool: """Return true if any of `pub.fields_of_study` are in passed `fields_of_study`.""" - return pub is not None and any([field in fields_of_study for field in pub.fields_of_study]) + return pub is not None and any( + [field in fields_of_study for field in pub.fields_of_study] + ) ############################################################################## @@ -348,10 +353,12 @@ def expand( def filter_by_func( self, atl: Atlas, - require_func: Callable[[Publication], bool] = lambda pub: pub_has_attributes(pub, attributes=[ - "abstract", - "publication_date", - "fields_of_study", + require_func: Callable[[Publication], bool] = lambda pub: pub_has_attributes( + pub, + attributes=[ + "abstract", + "publication_date", + "fields_of_study", ], ), record_pubs_per_update=False, @@ -371,9 +378,7 @@ def filter_by_func( """ # Filter publications invalid_pubs = { - id: pub - for id, pub in atl.publications.items() - if not require_func(pub) + id: pub for id, pub in atl.publications.items() if not require_func(pub) } # Do not update if unnecessary diff --git a/src/sciterra/mapping/tracing.py b/src/sciterra/mapping/tracing.py index 3f7b4b1..e2bce87 100644 --- a/src/sciterra/mapping/tracing.py +++ b/src/sciterra/mapping/tracing.py @@ -128,7 +128,7 @@ def __init__( vectorizer_name: a str name of a vectorizer, one of `vectorization.vectorizers.keys()`, e.g. 'BOW' or 'SciBERT'. - vectorizer_kwargs: keyword args propogated to a Vectorizer initialization; if values are `None` they will be omitted + vectorizer_kwargs: keyword args propogated to a Vectorizer initialization; if values are `None` they will be omitted """ ###################################################################### # Initialize cartography tools diff --git a/src/tests/test_cartography.py b/src/tests/test_cartography.py index 824d904..e5ff61b 100644 --- a/src/tests/test_cartography.py +++ b/src/tests/test_cartography.py @@ -7,7 +7,11 @@ from datetime import datetime from sciterra.mapping.atlas import Atlas -from sciterra.mapping.cartography import Cartographer, pub_has_attributes, pub_has_fields_of_study +from sciterra.mapping.cartography import ( + Cartographer, + pub_has_attributes, + pub_has_fields_of_study, +) from sciterra.librarians.s2librarian import SemanticScholarLibrarian from sciterra.mapping.publication import Publication from sciterra.vectorization import SciBERTVectorizer, Word2VecVectorizer