Skip to content

Commit

Permalink
adjust generic document types (#359)
Browse files Browse the repository at this point in the history
* adjust generic document types: rename annotation layers to reflect the annotation type and also reflect that in the document type names

* fix for adjust generic document types

* adjust the default values for annotation layer related parameters in the taskmodules
  • Loading branch information
ArneBinder authored Oct 19, 2023
1 parent a052082 commit 421f618
Show file tree
Hide file tree
Showing 18 changed files with 69 additions and 52 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ print("val docs: ", len(val_docs))
# Create a PIE taskmodule.
task_module = TransformerSpanClassificationTaskModule(
tokenizer_name_or_path=model_name,
entity_annotation="entities",
max_length=128,
)

Expand Down
2 changes: 2 additions & 0 deletions examples/predict/re_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def main():

taskmodule = TransformerSeq2SeqTaskModule(
tokenizer_name_or_path=model_name_or_path,
entity_annotation="entities",
relation_annotation="relations",
max_input_length=128,
max_target_length=128,
)
Expand Down
1 change: 1 addition & 0 deletions examples/train/span_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def main():

task_module = TransformerSpanClassificationTaskModule(
tokenizer_name_or_path=model_name,
entity_annotation="entities",
max_length=128,
)

Expand Down
42 changes: 21 additions & 21 deletions src/pytorch_ie/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class TextDocumentWithMultiLabel(DocumentWithMultiLabel, TextBasedDocument):

@dataclasses.dataclass
class TextDocumentWithLabeledPartitions(TextBasedDocument):
partitions: AnnotationList[LabeledSpan] = annotation_field(target="text")
labeled_partitions: AnnotationList[LabeledSpan] = annotation_field(target="text")


@dataclasses.dataclass
Expand All @@ -68,59 +68,59 @@ class TextDocumentWithSentences(TextBasedDocument):


@dataclasses.dataclass
class TextDocumentWithEntities(TextBasedDocument):
entities: AnnotationList[Span] = annotation_field(target="text")
class TextDocumentWithSpans(TextBasedDocument):
spans: AnnotationList[Span] = annotation_field(target="text")


@dataclasses.dataclass
class TextDocumentWithLabeledEntities(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
class TextDocumentWithLabeledSpans(TextBasedDocument):
labeled_spans: AnnotationList[LabeledSpan] = annotation_field(target="text")


@dataclasses.dataclass
class TextDocumentWithLabeledEntitiesAndLabeledPartitions(
TextDocumentWithLabeledEntities, TextDocumentWithLabeledPartitions
class TextDocumentWithLabeledSpansAndLabeledPartitions(
TextDocumentWithLabeledSpans, TextDocumentWithLabeledPartitions
):
pass


@dataclasses.dataclass
class TextDocumentWithLabeledEntitiesAndSentences(
TextDocumentWithLabeledEntities, TextDocumentWithSentences
class TextDocumentWithLabeledSpansAndSentences(
TextDocumentWithLabeledSpans, TextDocumentWithSentences
):
pass


@dataclasses.dataclass
class TextDocumentWithLabeledEntitiesAndRelations(TextDocumentWithLabeledEntities):
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
class TextDocumentWithLabeledSpansAndBinaryRelations(TextDocumentWithLabeledSpans):
binary_relations: AnnotationList[BinaryRelation] = annotation_field(target="labeled_spans")


@dataclasses.dataclass
class TextDocumentWithLabeledEntitiesRelationsAndLabeledPartitions(
TextDocumentWithLabeledEntitiesAndLabeledPartitions,
TextDocumentWithLabeledEntitiesAndRelations,
class TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
TextDocumentWithLabeledSpansAndLabeledPartitions,
TextDocumentWithLabeledSpansAndBinaryRelations,
TextDocumentWithLabeledPartitions,
):
pass


@dataclasses.dataclass
class TextDocumentWithEntitiesAndRelations(TextDocumentWithEntities):
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
class TextDocumentWithSpansAndBinaryRelations(TextDocumentWithSpans):
binary_relations: AnnotationList[BinaryRelation] = annotation_field(target="spans")


@dataclasses.dataclass
class TextDocumentWithEntitiesAndLabeledPartitions(
TextDocumentWithEntities, TextDocumentWithLabeledPartitions
class TextDocumentWithSpansAndLabeledPartitions(
TextDocumentWithSpans, TextDocumentWithLabeledPartitions
):
pass


@dataclasses.dataclass
class TextDocumentWithEntitiesRelationsAndLabeledPartitions(
TextDocumentWithEntitiesAndLabeledPartitions,
TextDocumentWithEntitiesAndRelations,
class TextDocumentWithSpansBinaryRelationsAndLabeledPartitions(
TextDocumentWithSpansAndLabeledPartitions,
TextDocumentWithSpansAndBinaryRelations,
TextDocumentWithLabeledPartitions,
):
pass
10 changes: 5 additions & 5 deletions src/pytorch_ie/taskmodules/transformer_re_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
from pytorch_ie.core import AnnotationList, Document, TaskEncoding, TaskModule
from pytorch_ie.documents import (
TextDocument,
TextDocumentWithLabeledEntitiesAndRelations,
TextDocumentWithLabeledEntitiesRelationsAndLabeledPartitions,
TextDocumentWithLabeledSpansAndBinaryRelations,
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
)
from pytorch_ie.models.transformer_text_classification import ModelOutputType, ModelStepInputType
from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize
Expand Down Expand Up @@ -154,7 +154,7 @@ def __init__(
tokenizer_name_or_path: str,
# this is deprecated, the target of the relation layer already specifies the entity layer
entity_annotation: Optional[str] = None,
relation_annotation: str = "relations",
relation_annotation: str = "binary_relations",
create_relation_candidates: bool = False,
partition_annotation: Optional[str] = None,
none_label: str = "no_relation",
Expand Down Expand Up @@ -216,9 +216,9 @@ def __init__(
@property
def document_type(self) -> Optional[Type[TextDocument]]:
if self.partition_annotation is not None:
return TextDocumentWithLabeledEntitiesRelationsAndLabeledPartitions
return TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
else:
return TextDocumentWithLabeledEntitiesAndRelations
return TextDocumentWithLabeledSpansAndBinaryRelations

def get_relation_layer(self, document: Document) -> AnnotationList[BinaryRelation]:
return document[self.relation_annotation]
Expand Down
8 changes: 4 additions & 4 deletions src/pytorch_ie/taskmodules/transformer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.core import Annotation, TaskEncoding, TaskModule
from pytorch_ie.documents import TextDocument, TextDocumentWithLabeledEntitiesAndRelations
from pytorch_ie.documents import TextDocument, TextDocumentWithLabeledSpansAndBinaryRelations
from pytorch_ie.models.transformer_seq2seq import ModelOutputType, ModelStepInputType

InputEncodingType: TypeAlias = Dict[str, Sequence[int]]
Expand All @@ -42,13 +42,13 @@
@TaskModule.register()
class TransformerSeq2SeqTaskModule(TaskModuleType):

DOCUMENT_TYPE = TextDocumentWithLabeledEntitiesAndRelations
DOCUMENT_TYPE = TextDocumentWithLabeledSpansAndBinaryRelations

def __init__(
self,
tokenizer_name_or_path: str,
entity_annotation: str = "entities",
relation_annotation: str = "relations",
entity_annotation: str = "labeled_spans",
relation_annotation: str = "binary_relations",
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = True,
max_input_length: Optional[int] = None,
Expand Down
12 changes: 6 additions & 6 deletions src/pytorch_ie/taskmodules/transformer_span_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from pytorch_ie.core import TaskEncoding, TaskModule
from pytorch_ie.documents import (
TextDocument,
TextDocumentWithLabeledEntities,
TextDocumentWithLabeledEntitiesAndLabeledPartitions,
TextDocumentWithLabeledEntitiesAndSentences,
TextDocumentWithLabeledSpans,
TextDocumentWithLabeledSpansAndLabeledPartitions,
TextDocumentWithLabeledSpansAndSentences,
)
from pytorch_ie.models.transformer_span_classification import ModelOutputType, ModelStepInputType

Expand Down Expand Up @@ -56,7 +56,7 @@ class TransformerSpanClassificationTaskModule(TaskModuleType):
def __init__(
self,
tokenizer_name_or_path: str,
entity_annotation: str = "entities",
entity_annotation: str = "labeled_spans",
single_sentence: bool = False,
sentence_annotation: str = "sentences",
padding: Union[bool, str, PaddingStrategy] = True,
Expand Down Expand Up @@ -92,9 +92,9 @@ def __init__(
@property
def document_type(self) -> TypeAlias:
if self.single_sentence:
return TextDocumentWithLabeledEntitiesAndSentences
return TextDocumentWithLabeledSpansAndSentences
else:
return TextDocumentWithLabeledEntities
return TextDocumentWithLabeledSpans

def _config(self) -> Dict[str, Any]:
config = super()._config()
Expand Down
10 changes: 5 additions & 5 deletions src/pytorch_ie/taskmodules/transformer_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from pytorch_ie.core import TaskEncoding, TaskModule
from pytorch_ie.documents import (
TextDocument,
TextDocumentWithLabeledEntities,
TextDocumentWithLabeledEntitiesAndLabeledPartitions,
TextDocumentWithLabeledSpans,
TextDocumentWithLabeledSpansAndLabeledPartitions,
)
from pytorch_ie.models.transformer_token_classification import ModelOutputType, ModelStepInputType
from pytorch_ie.utils.span import (
Expand Down Expand Up @@ -62,7 +62,7 @@ class TransformerTokenClassificationTaskModule(TaskModuleType):
def __init__(
self,
tokenizer_name_or_path: str,
entity_annotation: str = "entities",
entity_annotation: str = "labeled_spans",
partition_annotation: Optional[str] = None,
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
Expand Down Expand Up @@ -97,9 +97,9 @@ def __init__(
@property
def document_type(self) -> Type[TextDocument]:
if self.partition_annotation is not None:
return TextDocumentWithLabeledEntitiesAndLabeledPartitions
return TextDocumentWithLabeledSpansAndLabeledPartitions
else:
return TextDocumentWithLabeledEntities
return TextDocumentWithLabeledSpans

def _config(self) -> Dict[str, Any]:
config = super()._config()
Expand Down
8 changes: 4 additions & 4 deletions tests/data/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pytorch_ie.annotations import LabeledSpan, Span
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.data.builder import PieDatasetBuilder
from pytorch_ie.documents import TextBasedDocument, TextDocumentWithEntities
from pytorch_ie.documents import TextBasedDocument, TextDocumentWithSpans
from tests import FIXTURES_ROOT

DATASETS_ROOT = FIXTURES_ROOT / "builder" / "datasets"
Expand Down Expand Up @@ -174,10 +174,10 @@ class ExampleDocumentWithSimpleSpans(TextBasedDocument):


def convert_example_document_to_example_document_with_simple_spans(
document: TextDocumentWithEntities,
document: TextDocumentWithSpans,
) -> ExampleDocumentWithSimpleSpans:
result = ExampleDocumentWithSimpleSpans(text=document.text, spans=document.entities)
for entity in document.entities:
result = ExampleDocumentWithSimpleSpans(text=document.text, spans=document.spans)
for entity in document.spans:
result.spans.append(Span(start=entity.start, end=entity.end))
return result

Expand Down
3 changes: 2 additions & 1 deletion tests/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
def taskmodule():
tokenizer_name_or_path = "bert-base-cased"
taskmodule = TransformerSpanClassificationTaskModule(
tokenizer_name_or_path=tokenizer_name_or_path
tokenizer_name_or_path=tokenizer_name_or_path,
entity_annotation="entities",
)
return taskmodule

Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_transformer_span_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
def taskmodule():
tokenizer_name_or_path = "bert-base-cased"
taskmodule = TransformerSpanClassificationTaskModule(
tokenizer_name_or_path=tokenizer_name_or_path
tokenizer_name_or_path=tokenizer_name_or_path,
entity_annotation="entities",
)
return taskmodule

Expand Down
2 changes: 2 additions & 0 deletions tests/pipeline/test_re_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def test_re_generative():

taskmodule = TransformerSeq2SeqTaskModule(
tokenizer_name_or_path=model_name_or_path,
entity_annotation="entities",
relation_annotation="relations",
max_input_length=128,
max_target_length=128,
)
Expand Down
4 changes: 3 additions & 1 deletion tests/taskmodules/test_transformer_re_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def cfg(request):
def taskmodule(cfg):
tokenizer_name_or_path = "bert-base-cased"
taskmodule = TransformerRETextClassificationTaskModule(
tokenizer_name_or_path=tokenizer_name_or_path, **cfg
tokenizer_name_or_path=tokenizer_name_or_path, relation_annotation="relations", **cfg
)
assert not taskmodule.is_from_pretrained

Expand Down Expand Up @@ -605,6 +605,7 @@ def test_encode_with_partition(documents):
tokenizer_name_or_path = "bert-base-cased"
taskmodule = TransformerRETextClassificationTaskModule(
tokenizer_name_or_path=tokenizer_name_or_path,
relation_annotation="relations",
partition_annotation="sentences",
)
assert not taskmodule.is_from_pretrained
Expand Down Expand Up @@ -700,6 +701,7 @@ def test_encode_with_windowing(documents):
tokenizer_name_or_path = "bert-base-cased"
taskmodule = TransformerRETextClassificationTaskModule(
tokenizer_name_or_path=tokenizer_name_or_path,
relation_annotation="relations",
max_window=12,
)
assert not taskmodule.is_from_pretrained
Expand Down
6 changes: 5 additions & 1 deletion tests/taskmodules/test_transformer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
@pytest.fixture(scope="module")
def taskmodule():
transformer_model = "Babelscape/rebel-large"
taskmodule = TransformerSeq2SeqTaskModule(tokenizer_name_or_path=transformer_model)
taskmodule = TransformerSeq2SeqTaskModule(
tokenizer_name_or_path=transformer_model,
entity_annotation="entities",
relation_annotation="relations",
)
assert not taskmodule.is_from_pretrained

return taskmodule
Expand Down
3 changes: 2 additions & 1 deletion tests/taskmodules/test_transformer_span_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
def taskmodule():
tokenizer_name_or_path = "bert-base-cased"
taskmodule = TransformerSpanClassificationTaskModule(
tokenizer_name_or_path=tokenizer_name_or_path
tokenizer_name_or_path=tokenizer_name_or_path,
entity_annotation="entities",
)
assert not taskmodule.is_from_pretrained

Expand Down
2 changes: 1 addition & 1 deletion tests/taskmodules/test_transformer_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def unprepared_taskmodule(config):
"""
return TransformerTokenClassificationTaskModule(
tokenizer_name_or_path="bert-base-uncased", **config
tokenizer_name_or_path="bert-base-uncased", entity_annotation="entities", **config
)


Expand Down
3 changes: 2 additions & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
def taskmodule():
tokenizer_name_or_path = "bert-base-cased"
taskmodule = TransformerSpanClassificationTaskModule(
tokenizer_name_or_path=tokenizer_name_or_path
tokenizer_name_or_path=tokenizer_name_or_path,
entity_annotation="entities",
)
return taskmodule

Expand Down
1 change: 1 addition & 0 deletions tests/train/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
def prepared_taskmodule(documents):
taskmodule = TransformerTokenClassificationTaskModule(
tokenizer_name_or_path=MODEL_NAME,
entity_annotation="entities",
max_length=128,
)
taskmodule.prepare(documents)
Expand Down

0 comments on commit 421f618

Please sign in to comment.