Skip to content

Commit

Permalink
address previous test failures involving batch_size = None and missin…
Browse files Browse the repository at this point in the history
…g librarian_kwargs
  • Loading branch information
Nathaniel Imel authored and Nathaniel Imel committed Feb 6, 2024
1 parent c082257 commit 5327254
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 7 deletions.
6 changes: 3 additions & 3 deletions src/sciterra/mapping/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def __init__(
atlas_dir: str,
atlas_center_bibtex: str,
librarian_name: str,
librarian_kwargs: str,
vectorizer_name: str,
librarian_kwargs: dict = None,
vectorizer_kwargs: dict = None,
) -> None:
"""Convenience wrapper data structure for tracked expansions, by aligning the history of a Cartographer with an Atlas.
Expand All @@ -127,10 +127,10 @@ def __init__(
librarian_name: a str name of a librarian, one of `librarians.librarians.keys()`, e.g. 'S2' or 'ADS'.
librarian_kwargs: keyword args propogated to a Librarian initialization; if values are `None` they will be omitted
vectorizer_name: a str name of a vectorizer, one of `vectorization.vectorizers.keys()`, e.g. 'BOW' or 'SciBERT'.
librarian_kwargs: keyword args propogated to a Librarian initialization; if values are `None` they will be omitted
vectorizer_kwargs: keyword args propogated to a Vectorizer initialization; if values are `None` they will be omitted
"""
######################################################################
Expand Down
2 changes: 1 addition & 1 deletion src/sciterra/vectorization/bow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self.embedding_dim = len(self.vocabulary)
self.count_vectorizer = CountVectorizer(vocabulary=self.vocabulary)

def embed_documents(self, docs: list[str]) -> dict[str, ndarray]:
def embed_documents(self, docs: list[str], **kwargs) -> dict[str, ndarray]:
"""Embed a list of documents (raw text) into bow document vectors using scikit-learn's CountVectorizer.
Args:
Expand Down
9 changes: 8 additions & 1 deletion src/sciterra/vectorization/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
# This is the hidden dimension size
EMBEDDING_DIM = 768

# Default is small, otherwise memory limits become a problem
BATCH_SIZE = 8


class GPT2Vectorizer(Vectorizer):
def __init__(self, device="cuda", **kwargs) -> None:
Expand Down Expand Up @@ -48,7 +51,9 @@ def __init__(self, device="cuda", **kwargs) -> None:
super().__init__()

def embed_documents(
self, docs: list[str], batch_size: int = 64
self,
docs: list[str],
batch_size: int = BATCH_SIZE,
) -> dict[str, np.ndarray]:
"""Embed a list of documents (raw text) into GPT-2 vectors, by batching.
Expand All @@ -59,6 +64,8 @@ def embed_documents(
a numpy array of shape `(num_documents, embedding_dim)`
"""
if batch_size is None:
batch_size = BATCH_SIZE

embeddings = []

Expand Down
8 changes: 7 additions & 1 deletion src/sciterra/vectorization/sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
EMBEDDING_DIM = 384
MAX_SEQ_LENGTH = 256

BATCH_SIZE = 64


class SBERTVectorizer(Vectorizer):
def __init__(self, device="cuda", **kwargs) -> None:
Expand All @@ -38,7 +40,9 @@ def __init__(self, device="cuda", **kwargs) -> None:
self.model.eval()
super().__init__()

def embed_documents(self, docs: list[str], batch_size: int = 64) -> np.ndarray:
def embed_documents(
self, docs: list[str], batch_size: int = BATCH_SIZE
) -> np.ndarray:
"""Embed a list of documents (raw text) into SBERT vectors, by batching.
Args:
Expand All @@ -47,6 +51,8 @@ def embed_documents(self, docs: list[str], batch_size: int = 64) -> np.ndarray:
Returns:
a numpy array of shape `(num_documents, 384)`
"""
if batch_size is None:
batch_size = BATCH_SIZE

embeddings = []

Expand Down
6 changes: 5 additions & 1 deletion src/sciterra/vectorization/scibert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
MODEL_PATH = "allenai/scibert_scivocab_uncased"
EMBEDDING_DIM = 768

BATCH_SIZE = 64


class SciBERTVectorizer(Vectorizer):
def __init__(self, device="cuda", **kwargs) -> None:
Expand Down Expand Up @@ -53,7 +55,7 @@ def __init__(self, device="cuda", **kwargs) -> None:
super().__init__()

def embed_documents(
self, docs: list[str], batch_size: int = 64
self, docs: list[str], batch_size: int = BATCH_SIZE
) -> dict[str, np.ndarray]:
"""Embed a list of documents (raw text) into SciBERT vectors, by batching.
Expand All @@ -64,6 +66,8 @@ def embed_documents(
a numpy array of shape `(num_documents, 768)`
"""
if batch_size is None:
batch_size = BATCH_SIZE

embeddings = []

Expand Down
Binary file modified src/tests/data/models/word2vec_model_example.model
Binary file not shown.

0 comments on commit 5327254

Please sign in to comment.