Skip to content

Commit

Permalink
Collator (#3)
Browse files Browse the repository at this point in the history
* keys_to_ignore_at_inference

* collator class

* newline at end

* apply black formatting
  • Loading branch information
nbroad1881 authored Oct 21, 2022
1 parent ee9588a commit 332483e
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/strideformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .model import Strideformer
from .config import StrideformerConfig
from .pipeline import Pipeline
from .collator import StrideformerCollator
from .pipeline import Pipeline

60 changes: 60 additions & 0 deletions src/strideformer/collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from itertools import chain
from typing import List, Dict

import torch
from transformers import PreTrainedTokenizerFast


class StrideformerCollator:
def __init__(self, tokenizer: PreTrainedTokenizerFast, max_chunks: int = 128):
"""
Loads a collator designed for Strideformer.
Args:
tokenizer (`PreTrainedTokenizerFast`):
The tokenizer that corresponds with the first model in Strideformer.
max_chunks (`int`, *optional*, defaults to 128):
The maximum number of chunks that can be passed to the first model.
This is to limit OOM errors.
"""
self.tokenizer = tokenizer
self.max_chunks = max_chunks

def __call__(self, features: List[Dict]):
"""
Put features in a format that the model can use.
Args:
features (`List[Dict]`):
The list will be as long as the batch size specified
passed to the DataLoader.ffffffffffffffffffffffffffff
Each element of features will have keys: input_ids, attention_mask, labels
input_ids will be of shape [num_chunks, sequence_length]
attention_mask will be of shape [num_chunks, sequence_length]
label will be a single value if this is single_label_classification or regression
It will be a list if multi_label_classification
Returns:
(dict): input_ids, attention_mask, labels to be passed to the model.
"""

label_key = "label" if "label" in features[0] else "labels"

ids = list(chain(*[x["input_ids"] for x in features]))
mask = list(chain(*[x["attention_mask"] for x in features]))
labels = [x[label_key] for x in features]

longest_seq = max([len(x) for x in ids])

ids = [x + [self.tokenizer.pad_token_id] * (longest_seq - len(x)) for x in ids]
mask = [x + [0] * (longest_seq - len(x)) for x in mask]

return {
"input_ids": torch.tensor(ids, dtype=torch.long)[: self.max_chunks, :],
"attention_mask": torch.tensor(mask, dtype=torch.long)[
: self.max_chunks, :
],
"labels": torch.tensor(labels, dtype=torch.long),
}
13 changes: 9 additions & 4 deletions src/strideformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class StrideformerConfig(PretrainedConfig):
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the second model.
layer_norm_eps (`float`, *optional*, defaults to 1e-7):
layer_norm_eps (`float`, *optional*, defaults to 1e-7):
The epsilon value in LayerNorm
num_labels (`int`, *optional*, defaults to 2):
The number of labels for the classifier.
Expand All @@ -48,11 +48,16 @@ class StrideformerConfig(PretrainedConfig):
>>> model = Strideformer(config)
```"""
model_type: str = "strideformer"
keys_to_ignore_at_inference: List = []
keys_to_ignore_at_inference: List = [
"first_model_hidden_states",
"second_model_hidden_states",
]

def __init__(
self,
first_model_name_or_path: Optional[str] = "sentence-transformers/all-MiniLM-L6-v2",
first_model_name_or_path: Optional[
str
] = "sentence-transformers/all-MiniLM-L6-v2",
freeze_first_model: Optional[bool] = True,
max_chunks: Optional[int] = 64,
hidden_size: Optional[int] = 384,
Expand Down Expand Up @@ -82,4 +87,4 @@ def __init__(
super().__init__(
num_labels=num_labels,
**kwargs,
)
)

0 comments on commit 332483e

Please sign in to comment.