diff --git a/README.md b/README.md index 713be224..c5f4b489 100644 --- a/README.md +++ b/README.md @@ -462,6 +462,7 @@ print("val docs: ", len(val_docs)) # Create a PIE taskmodule. task_module = TransformerSpanClassificationTaskModule( tokenizer_name_or_path=model_name, + entity_annotation="entities", max_length=128, ) diff --git a/examples/predict/re_generative.py b/examples/predict/re_generative.py index b42133d4..de0e656b 100644 --- a/examples/predict/re_generative.py +++ b/examples/predict/re_generative.py @@ -19,6 +19,8 @@ def main(): taskmodule = TransformerSeq2SeqTaskModule( tokenizer_name_or_path=model_name_or_path, + entity_annotation="entities", + relation_annotation="relations", max_input_length=128, max_target_length=128, ) diff --git a/examples/train/span_classification.py b/examples/train/span_classification.py index d687c9fc..7fe805ff 100644 --- a/examples/train/span_classification.py +++ b/examples/train/span_classification.py @@ -27,6 +27,7 @@ def main(): task_module = TransformerSpanClassificationTaskModule( tokenizer_name_or_path=model_name, + entity_annotation="entities", max_length=128, ) diff --git a/src/pytorch_ie/documents.py b/src/pytorch_ie/documents.py index 4d5e0d47..160e5b9b 100644 --- a/src/pytorch_ie/documents.py +++ b/src/pytorch_ie/documents.py @@ -59,7 +59,7 @@ class TextDocumentWithMultiLabel(DocumentWithMultiLabel, TextBasedDocument): @dataclasses.dataclass class TextDocumentWithLabeledPartitions(TextBasedDocument): - partitions: AnnotationList[LabeledSpan] = annotation_field(target="text") + labeled_partitions: AnnotationList[LabeledSpan] = annotation_field(target="text") @dataclasses.dataclass @@ -68,59 +68,59 @@ class TextDocumentWithSentences(TextBasedDocument): @dataclasses.dataclass -class TextDocumentWithEntities(TextBasedDocument): - entities: AnnotationList[Span] = annotation_field(target="text") +class TextDocumentWithSpans(TextBasedDocument): + spans: AnnotationList[Span] = annotation_field(target="text") @dataclasses.dataclass -class TextDocumentWithLabeledEntities(TextBasedDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") +class TextDocumentWithLabeledSpans(TextBasedDocument): + labeled_spans: AnnotationList[LabeledSpan] = annotation_field(target="text") @dataclasses.dataclass -class TextDocumentWithLabeledEntitiesAndLabeledPartitions( - TextDocumentWithLabeledEntities, TextDocumentWithLabeledPartitions +class TextDocumentWithLabeledSpansAndLabeledPartitions( + TextDocumentWithLabeledSpans, TextDocumentWithLabeledPartitions ): pass @dataclasses.dataclass -class TextDocumentWithLabeledEntitiesAndSentences( - TextDocumentWithLabeledEntities, TextDocumentWithSentences +class TextDocumentWithLabeledSpansAndSentences( + TextDocumentWithLabeledSpans, TextDocumentWithSentences ): pass @dataclasses.dataclass -class TextDocumentWithLabeledEntitiesAndRelations(TextDocumentWithLabeledEntities): - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") +class TextDocumentWithLabeledSpansAndBinaryRelations(TextDocumentWithLabeledSpans): + binary_relations: AnnotationList[BinaryRelation] = annotation_field(target="labeled_spans") @dataclasses.dataclass -class TextDocumentWithLabeledEntitiesRelationsAndLabeledPartitions( - TextDocumentWithLabeledEntitiesAndLabeledPartitions, - TextDocumentWithLabeledEntitiesAndRelations, +class TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( + TextDocumentWithLabeledSpansAndLabeledPartitions, + TextDocumentWithLabeledSpansAndBinaryRelations, TextDocumentWithLabeledPartitions, ): pass @dataclasses.dataclass -class TextDocumentWithEntitiesAndRelations(TextDocumentWithEntities): - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") +class TextDocumentWithSpansAndBinaryRelations(TextDocumentWithSpans): + binary_relations: AnnotationList[BinaryRelation] = annotation_field(target="spans") @dataclasses.dataclass -class TextDocumentWithEntitiesAndLabeledPartitions( - TextDocumentWithEntities, TextDocumentWithLabeledPartitions +class TextDocumentWithSpansAndLabeledPartitions( + TextDocumentWithSpans, TextDocumentWithLabeledPartitions ): pass @dataclasses.dataclass -class TextDocumentWithEntitiesRelationsAndLabeledPartitions( - TextDocumentWithEntitiesAndLabeledPartitions, - TextDocumentWithEntitiesAndRelations, +class TextDocumentWithSpansBinaryRelationsAndLabeledPartitions( + TextDocumentWithSpansAndLabeledPartitions, + TextDocumentWithSpansAndBinaryRelations, TextDocumentWithLabeledPartitions, ): pass diff --git a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py index 4c184132..f2ea5edd 100644 --- a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py @@ -39,8 +39,8 @@ from pytorch_ie.core import AnnotationList, Document, TaskEncoding, TaskModule from pytorch_ie.documents import ( TextDocument, - TextDocumentWithLabeledEntitiesAndRelations, - TextDocumentWithLabeledEntitiesRelationsAndLabeledPartitions, + TextDocumentWithLabeledSpansAndBinaryRelations, + TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, ) from pytorch_ie.models.transformer_text_classification import ModelOutputType, ModelStepInputType from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize @@ -154,7 +154,7 @@ def __init__( tokenizer_name_or_path: str, # this is deprecated, the target of the relation layer already specifies the entity layer entity_annotation: Optional[str] = None, - relation_annotation: str = "relations", + relation_annotation: str = "binary_relations", create_relation_candidates: bool = False, partition_annotation: Optional[str] = None, none_label: str = "no_relation", @@ -216,9 +216,9 @@ def __init__( @property def document_type(self) -> Optional[Type[TextDocument]]: if self.partition_annotation is not None: - return TextDocumentWithLabeledEntitiesRelationsAndLabeledPartitions + return TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions else: - return TextDocumentWithLabeledEntitiesAndRelations + return TextDocumentWithLabeledSpansAndBinaryRelations def get_relation_layer(self, document: Document) -> AnnotationList[BinaryRelation]: return document[self.relation_annotation] diff --git a/src/pytorch_ie/taskmodules/transformer_seq2seq.py b/src/pytorch_ie/taskmodules/transformer_seq2seq.py index 9aaa3976..1bffe6c0 100644 --- a/src/pytorch_ie/taskmodules/transformer_seq2seq.py +++ b/src/pytorch_ie/taskmodules/transformer_seq2seq.py @@ -18,7 +18,7 @@ from pytorch_ie.annotations import BinaryRelation, LabeledSpan from pytorch_ie.core import Annotation, TaskEncoding, TaskModule -from pytorch_ie.documents import TextDocument, TextDocumentWithLabeledEntitiesAndRelations +from pytorch_ie.documents import TextDocument, TextDocumentWithLabeledSpansAndBinaryRelations from pytorch_ie.models.transformer_seq2seq import ModelOutputType, ModelStepInputType InputEncodingType: TypeAlias = Dict[str, Sequence[int]] @@ -42,13 +42,13 @@ @TaskModule.register() class TransformerSeq2SeqTaskModule(TaskModuleType): - DOCUMENT_TYPE = TextDocumentWithLabeledEntitiesAndRelations + DOCUMENT_TYPE = TextDocumentWithLabeledSpansAndBinaryRelations def __init__( self, tokenizer_name_or_path: str, - entity_annotation: str = "entities", - relation_annotation: str = "relations", + entity_annotation: str = "labeled_spans", + relation_annotation: str = "binary_relations", padding: Union[bool, str, PaddingStrategy] = True, truncation: Union[bool, str, TruncationStrategy] = True, max_input_length: Optional[int] = None, diff --git a/src/pytorch_ie/taskmodules/transformer_span_classification.py b/src/pytorch_ie/taskmodules/transformer_span_classification.py index 1f585019..23f341c6 100644 --- a/src/pytorch_ie/taskmodules/transformer_span_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_span_classification.py @@ -22,9 +22,9 @@ from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.documents import ( TextDocument, - TextDocumentWithLabeledEntities, - TextDocumentWithLabeledEntitiesAndLabeledPartitions, - TextDocumentWithLabeledEntitiesAndSentences, + TextDocumentWithLabeledSpans, + TextDocumentWithLabeledSpansAndLabeledPartitions, + TextDocumentWithLabeledSpansAndSentences, ) from pytorch_ie.models.transformer_span_classification import ModelOutputType, ModelStepInputType @@ -56,7 +56,7 @@ class TransformerSpanClassificationTaskModule(TaskModuleType): def __init__( self, tokenizer_name_or_path: str, - entity_annotation: str = "entities", + entity_annotation: str = "labeled_spans", single_sentence: bool = False, sentence_annotation: str = "sentences", padding: Union[bool, str, PaddingStrategy] = True, @@ -92,9 +92,9 @@ def __init__( @property def document_type(self) -> TypeAlias: if self.single_sentence: - return TextDocumentWithLabeledEntitiesAndSentences + return TextDocumentWithLabeledSpansAndSentences else: - return TextDocumentWithLabeledEntities + return TextDocumentWithLabeledSpans def _config(self) -> Dict[str, Any]: config = super()._config() diff --git a/src/pytorch_ie/taskmodules/transformer_token_classification.py b/src/pytorch_ie/taskmodules/transformer_token_classification.py index ef9b4f09..9763e55a 100644 --- a/src/pytorch_ie/taskmodules/transformer_token_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_token_classification.py @@ -22,8 +22,8 @@ from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.documents import ( TextDocument, - TextDocumentWithLabeledEntities, - TextDocumentWithLabeledEntitiesAndLabeledPartitions, + TextDocumentWithLabeledSpans, + TextDocumentWithLabeledSpansAndLabeledPartitions, ) from pytorch_ie.models.transformer_token_classification import ModelOutputType, ModelStepInputType from pytorch_ie.utils.span import ( @@ -62,7 +62,7 @@ class TransformerTokenClassificationTaskModule(TaskModuleType): def __init__( self, tokenizer_name_or_path: str, - entity_annotation: str = "entities", + entity_annotation: str = "labeled_spans", partition_annotation: Optional[str] = None, padding: Union[bool, str, PaddingStrategy] = True, truncation: Union[bool, str, TruncationStrategy] = False, @@ -97,9 +97,9 @@ def __init__( @property def document_type(self) -> Type[TextDocument]: if self.partition_annotation is not None: - return TextDocumentWithLabeledEntitiesAndLabeledPartitions + return TextDocumentWithLabeledSpansAndLabeledPartitions else: - return TextDocumentWithLabeledEntities + return TextDocumentWithLabeledSpans def _config(self) -> Dict[str, Any]: config = super()._config() diff --git a/tests/data/test_builder.py b/tests/data/test_builder.py index 405c9548..5112a8e1 100644 --- a/tests/data/test_builder.py +++ b/tests/data/test_builder.py @@ -10,7 +10,7 @@ from pytorch_ie.annotations import LabeledSpan, Span from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.data.builder import PieDatasetBuilder -from pytorch_ie.documents import TextBasedDocument, TextDocumentWithEntities +from pytorch_ie.documents import TextBasedDocument, TextDocumentWithSpans from tests import FIXTURES_ROOT DATASETS_ROOT = FIXTURES_ROOT / "builder" / "datasets" @@ -174,10 +174,10 @@ class ExampleDocumentWithSimpleSpans(TextBasedDocument): def convert_example_document_to_example_document_with_simple_spans( - document: TextDocumentWithEntities, + document: TextDocumentWithSpans, ) -> ExampleDocumentWithSimpleSpans: - result = ExampleDocumentWithSimpleSpans(text=document.text, spans=document.entities) - for entity in document.entities: + result = ExampleDocumentWithSimpleSpans(text=document.text, spans=document.spans) + for entity in document.spans: result.spans.append(Span(start=entity.start, end=entity.end)) return result diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 6e0cc40d..4f0797a2 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -26,7 +26,8 @@ def taskmodule(): tokenizer_name_or_path = "bert-base-cased" taskmodule = TransformerSpanClassificationTaskModule( - tokenizer_name_or_path=tokenizer_name_or_path + tokenizer_name_or_path=tokenizer_name_or_path, + entity_annotation="entities", ) return taskmodule diff --git a/tests/models/test_transformer_span_classification.py b/tests/models/test_transformer_span_classification.py index 2eae1856..63b275ba 100644 --- a/tests/models/test_transformer_span_classification.py +++ b/tests/models/test_transformer_span_classification.py @@ -13,7 +13,8 @@ def taskmodule(): tokenizer_name_or_path = "bert-base-cased" taskmodule = TransformerSpanClassificationTaskModule( - tokenizer_name_or_path=tokenizer_name_or_path + tokenizer_name_or_path=tokenizer_name_or_path, + entity_annotation="entities", ) return taskmodule diff --git a/tests/pipeline/test_re_generative.py b/tests/pipeline/test_re_generative.py index cc84df4c..87899b59 100644 --- a/tests/pipeline/test_re_generative.py +++ b/tests/pipeline/test_re_generative.py @@ -22,6 +22,8 @@ def test_re_generative(): taskmodule = TransformerSeq2SeqTaskModule( tokenizer_name_or_path=model_name_or_path, + entity_annotation="entities", + relation_annotation="relations", max_input_length=128, max_target_length=128, ) diff --git a/tests/taskmodules/test_transformer_re_text_classification.py b/tests/taskmodules/test_transformer_re_text_classification.py index baa2d4b3..a9753961 100644 --- a/tests/taskmodules/test_transformer_re_text_classification.py +++ b/tests/taskmodules/test_transformer_re_text_classification.py @@ -31,7 +31,7 @@ def cfg(request): def taskmodule(cfg): tokenizer_name_or_path = "bert-base-cased" taskmodule = TransformerRETextClassificationTaskModule( - tokenizer_name_or_path=tokenizer_name_or_path, **cfg + tokenizer_name_or_path=tokenizer_name_or_path, relation_annotation="relations", **cfg ) assert not taskmodule.is_from_pretrained @@ -605,6 +605,7 @@ def test_encode_with_partition(documents): tokenizer_name_or_path = "bert-base-cased" taskmodule = TransformerRETextClassificationTaskModule( tokenizer_name_or_path=tokenizer_name_or_path, + relation_annotation="relations", partition_annotation="sentences", ) assert not taskmodule.is_from_pretrained @@ -700,6 +701,7 @@ def test_encode_with_windowing(documents): tokenizer_name_or_path = "bert-base-cased" taskmodule = TransformerRETextClassificationTaskModule( tokenizer_name_or_path=tokenizer_name_or_path, + relation_annotation="relations", max_window=12, ) assert not taskmodule.is_from_pretrained diff --git a/tests/taskmodules/test_transformer_seq2seq.py b/tests/taskmodules/test_transformer_seq2seq.py index 082aa680..39189f4f 100644 --- a/tests/taskmodules/test_transformer_seq2seq.py +++ b/tests/taskmodules/test_transformer_seq2seq.py @@ -14,7 +14,11 @@ @pytest.fixture(scope="module") def taskmodule(): transformer_model = "Babelscape/rebel-large" - taskmodule = TransformerSeq2SeqTaskModule(tokenizer_name_or_path=transformer_model) + taskmodule = TransformerSeq2SeqTaskModule( + tokenizer_name_or_path=transformer_model, + entity_annotation="entities", + relation_annotation="relations", + ) assert not taskmodule.is_from_pretrained return taskmodule diff --git a/tests/taskmodules/test_transformer_span_classification.py b/tests/taskmodules/test_transformer_span_classification.py index 67ac677a..95b1fc0f 100644 --- a/tests/taskmodules/test_transformer_span_classification.py +++ b/tests/taskmodules/test_transformer_span_classification.py @@ -12,7 +12,8 @@ def taskmodule(): tokenizer_name_or_path = "bert-base-cased" taskmodule = TransformerSpanClassificationTaskModule( - tokenizer_name_or_path=tokenizer_name_or_path + tokenizer_name_or_path=tokenizer_name_or_path, + entity_annotation="entities", ) assert not taskmodule.is_from_pretrained diff --git a/tests/taskmodules/test_transformer_token_classification.py b/tests/taskmodules/test_transformer_token_classification.py index 1aab0d53..e2805b9f 100644 --- a/tests/taskmodules/test_transformer_token_classification.py +++ b/tests/taskmodules/test_transformer_token_classification.py @@ -51,7 +51,7 @@ def unprepared_taskmodule(config): """ return TransformerTokenClassificationTaskModule( - tokenizer_name_or_path="bert-base-uncased", **config + tokenizer_name_or_path="bert-base-uncased", entity_annotation="entities", **config ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index dbdc2adb..7994caa2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -18,7 +18,8 @@ def taskmodule(): tokenizer_name_or_path = "bert-base-cased" taskmodule = TransformerSpanClassificationTaskModule( - tokenizer_name_or_path=tokenizer_name_or_path + tokenizer_name_or_path=tokenizer_name_or_path, + entity_annotation="entities", ) return taskmodule diff --git a/tests/train/test_training.py b/tests/train/test_training.py index 58c9abbf..232fa56c 100644 --- a/tests/train/test_training.py +++ b/tests/train/test_training.py @@ -12,6 +12,7 @@ def prepared_taskmodule(documents): taskmodule = TransformerTokenClassificationTaskModule( tokenizer_name_or_path=MODEL_NAME, + entity_annotation="entities", max_length=128, ) taskmodule.prepare(documents)