diff --git a/README.rst b/README.rst index 47b08e4b..bf581e30 100644 --- a/README.rst +++ b/README.rst @@ -60,21 +60,22 @@ Span-classification-based Named Entity Recognition from dataclasses import dataclass - from pytorch_ie import AnnotationList, LabeledSpan, TextDocument, annotation_field + from pytorch_ie.annotations import LabeledSpan from pytorch_ie.auto import AutoPipeline - + from pytorch_ie.core import AnnotationList, annotation_field + from pytorch_ie.documents import TextDocument @dataclass class ExampleDocument(TextDocument): entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - # see below for the long version - ner_pipeline = AutoPipeline.from_pretrained("pie/example-ner-spanclf-conll03", device=-1, num_workers=0) - document = ExampleDocument( "“Making a super tasty alt-chicken wing is only half of it,” said Po Bronson, general partner at SOSV and managing director of IndieBio." ) + # see below for the long version + ner_pipeline = AutoPipeline.from_pretrained("pie/example-ner-spanclf-conll03", device=-1, num_workers=0) + ner_pipeline(document, predict_field="entities") for entity in document.entities.predictions: @@ -89,8 +90,8 @@ To create the same pipeline as above without `AutoPipeline`: .. code:: python - from pytorch_ie import Pipeline from pytorch_ie.auto import AutoTaskModule, AutoModel + from pytorch_ie.pipeline import Pipeline model_name_or_path = "pie/example-ner-spanclf-conll03" ner_taskmodule = AutoTaskModule.from_pretrained(model_name_or_path) @@ -101,7 +102,7 @@ Or, without `Auto` classes at all: .. code:: python - from pytorch_ie import Pipeline + from pytorch_ie.pipeline import Pipeline from pytorch_ie.models import TransformerSpanClassificationModel from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule @@ -118,8 +119,10 @@ Text-classification-based Relation Extraction from dataclasses import dataclass - from pytorch_ie import AnnotationList, BinaryRelation, LabeledSpan, TextDocument, annotation_field + from pytorch_ie.annotations import BinaryRelation, LabeledSpan from pytorch_ie.auto import AutoPipeline + from pytorch_ie.core import AnnotationList, annotation_field + from pytorch_ie.documents import TextDocument @dataclass @@ -127,13 +130,12 @@ Text-classification-based Relation Extraction entities: AnnotationList[LabeledSpan] = annotation_field(target="text") relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") - - re_pipeline = AutoPipeline.from_pretrained("pie/example-re-textclf-tacred", device=-1, num_workers=0) - document = ExampleDocument( "“Making a super tasty alt-chicken wing is only half of it,” said Po Bronson, general partner at SOSV and managing director of IndieBio." ) + re_pipeline = AutoPipeline.from_pretrained("pie/example-re-textclf-tacred", device=-1, num_workers=0) + for start, end, label in [(65, 75, "PER"), (96, 100, "ORG"), (126, 134, "ORG")]: document.entities.append(LabeledSpan(start=start, end=end, label=label)) diff --git a/datasets/conll2002/conll2002.py b/datasets/conll2002/conll2002.py index 511e8452..5e81da14 100644 --- a/datasets/conll2002/conll2002.py +++ b/datasets/conll2002/conll2002.py @@ -2,8 +2,8 @@ import datasets import pytorch_ie.data.builder -from pytorch_ie import AnnotationList, annotation_field from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans diff --git a/datasets/conll2003/conll2003.py b/datasets/conll2003/conll2003.py index 00b637d3..0c2354a9 100644 --- a/datasets/conll2003/conll2003.py +++ b/datasets/conll2003/conll2003.py @@ -2,8 +2,8 @@ import datasets import pytorch_ie.data.builder -from pytorch_ie import AnnotationList, annotation_field from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans diff --git a/datasets/conllpp/conllpp.py b/datasets/conllpp/conllpp.py index 3b33d7cb..a67995f1 100644 --- a/datasets/conllpp/conllpp.py +++ b/datasets/conllpp/conllpp.py @@ -2,8 +2,8 @@ import datasets import pytorch_ie.data.builder -from pytorch_ie import AnnotationList, annotation_field from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans diff --git a/datasets/german_legal_entity_recognition/german_legal_entity_recognition.py b/datasets/german_legal_entity_recognition/german_legal_entity_recognition.py index f138bdd0..7ba5c2d3 100644 --- a/datasets/german_legal_entity_recognition/german_legal_entity_recognition.py +++ b/datasets/german_legal_entity_recognition/german_legal_entity_recognition.py @@ -2,8 +2,8 @@ import datasets import pytorch_ie.data.builder -from pytorch_ie import AnnotationList, annotation_field from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans diff --git a/datasets/germaner/germaner.py b/datasets/germaner/germaner.py index 01cacc31..e1cade6e 100644 --- a/datasets/germaner/germaner.py +++ b/datasets/germaner/germaner.py @@ -2,8 +2,8 @@ import datasets import pytorch_ie.data.builder -from pytorch_ie import AnnotationList, annotation_field from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans diff --git a/datasets/germeval_14/germeval_14.py b/datasets/germeval_14/germeval_14.py index 548c5884..259725ef 100644 --- a/datasets/germeval_14/germeval_14.py +++ b/datasets/germeval_14/germeval_14.py @@ -2,8 +2,8 @@ import datasets import pytorch_ie.data.builder -from pytorch_ie import AnnotationList, annotation_field from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans diff --git a/datasets/ncbi_disease/ncbi_disease.py b/datasets/ncbi_disease/ncbi_disease.py index 84446deb..b5a7bc03 100644 --- a/datasets/ncbi_disease/ncbi_disease.py +++ b/datasets/ncbi_disease/ncbi_disease.py @@ -2,8 +2,8 @@ import datasets import pytorch_ie.data.builder -from pytorch_ie import AnnotationList, annotation_field from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans diff --git a/datasets/wikiann/wikiann.py b/datasets/wikiann/wikiann.py index dce5d348..b9c437fd 100644 --- a/datasets/wikiann/wikiann.py +++ b/datasets/wikiann/wikiann.py @@ -2,8 +2,8 @@ import datasets import pytorch_ie.data.builder -from pytorch_ie import AnnotationList, annotation_field from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans diff --git a/datasets/wnut_17/wnut_17.py b/datasets/wnut_17/wnut_17.py index cb42a563..6a499959 100644 --- a/datasets/wnut_17/wnut_17.py +++ b/datasets/wnut_17/wnut_17.py @@ -2,8 +2,8 @@ import datasets import pytorch_ie.data.builder -from pytorch_ie import AnnotationList, annotation_field from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans diff --git a/examples/predict/ner_span_classification.py b/examples/predict/ner_span_classification.py index 7c8965fb..9437d969 100644 --- a/examples/predict/ner_span_classification.py +++ b/examples/predict/ner_span_classification.py @@ -1,9 +1,10 @@ from dataclasses import dataclass -from pytorch_ie import AnnotationList, Pipeline, annotation_field from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerSpanClassificationModel +from pytorch_ie.pipeline import Pipeline from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule diff --git a/examples/predict/re_generative.py b/examples/predict/re_generative.py index 9361df84..1f5d76a5 100644 --- a/examples/predict/re_generative.py +++ b/examples/predict/re_generative.py @@ -1,9 +1,10 @@ from dataclasses import dataclass -from pytorch_ie import AnnotationList, Pipeline, annotation_field from pytorch_ie.annotations import BinaryRelation, LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerSeq2SeqModel +from pytorch_ie.pipeline import Pipeline from pytorch_ie.taskmodules import TransformerSeq2SeqTaskModule diff --git a/examples/predict/re_text_classification.py b/examples/predict/re_text_classification.py index e398293d..218e70de 100644 --- a/examples/predict/re_text_classification.py +++ b/examples/predict/re_text_classification.py @@ -1,9 +1,10 @@ from dataclasses import dataclass -from pytorch_ie import AnnotationList, Pipeline, annotation_field from pytorch_ie.annotations import BinaryRelation, LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerTextClassificationModel +from pytorch_ie.pipeline import Pipeline from pytorch_ie.taskmodules import TransformerRETextClassificationTaskModule diff --git a/src/pytorch_ie/__init__.py b/src/pytorch_ie/__init__.py index 494fd451..29badd60 100644 --- a/src/pytorch_ie/__init__.py +++ b/src/pytorch_ie/__init__.py @@ -1,6 +1,7 @@ # flake8: noqa from pytorch_ie.auto import AutoModel, AutoPipeline, AutoTaskModule -from pytorch_ie.core import * from pytorch_ie.data import * +from pytorch_ie.models import * from pytorch_ie.pipeline import Pipeline +from pytorch_ie.taskmodules import * diff --git a/src/pytorch_ie/data/__init__.py b/src/pytorch_ie/data/__init__.py index cc1545e8..269be2db 100644 --- a/src/pytorch_ie/data/__init__.py +++ b/src/pytorch_ie/data/__init__.py @@ -1,9 +1,16 @@ +from typing import Dict, Union + +from datasets import Split + from .builder import GeneratorBasedBuilder from .dataset import Dataset from .dataset_formatter import DocumentFormatter +DatasetDict = Dict[Union[str, Split], Dataset] + __all__ = [ "GeneratorBasedBuilder", "Dataset", + "DatasetDict", "DocumentFormatter", ] diff --git a/src/pytorch_ie/data/datasets/__init__.py b/src/pytorch_ie/data/datasets/__init__.py index 702174f0..ffbe5d00 100644 --- a/src/pytorch_ie/data/datasets/__init__.py +++ b/src/pytorch_ie/data/datasets/__init__.py @@ -1,9 +1,3 @@ import pathlib -from typing import Dict, List, Union - -from datasets import Split -from pytorch_ie import Document HF_DATASETS_ROOT = pathlib.Path(__file__).parent / "hf_datasets" - -PIEDatasetDict = Dict[Union[str, Split], List[Document]] diff --git a/src/pytorch_ie/documents.py b/src/pytorch_ie/documents.py index dd41c36b..1e6b2ea8 100644 --- a/src/pytorch_ie/documents.py +++ b/src/pytorch_ie/documents.py @@ -1,7 +1,7 @@ import dataclasses from typing import Any, Dict, Optional -from pytorch_ie import Document +from pytorch_ie.core import Document @dataclasses.dataclass diff --git a/src/pytorch_ie/models/transformer_seq2seq.py b/src/pytorch_ie/models/transformer_seq2seq.py index 9ea40db9..ed54f7cd 100644 --- a/src/pytorch_ie/models/transformer_seq2seq.py +++ b/src/pytorch_ie/models/transformer_seq2seq.py @@ -4,7 +4,7 @@ from transformers import AutoModelForSeq2SeqLM, BatchEncoding from transformers.modeling_outputs import Seq2SeqLMOutput -from pytorch_ie import PyTorchIEModel +from pytorch_ie.core import PyTorchIEModel from pytorch_ie.core.taskmodule import Metadata from pytorch_ie.documents import TextDocument diff --git a/src/pytorch_ie/models/transformer_span_classification.py b/src/pytorch_ie/models/transformer_span_classification.py index 1808991b..3d5b3d84 100644 --- a/src/pytorch_ie/models/transformer_span_classification.py +++ b/src/pytorch_ie/models/transformer_span_classification.py @@ -11,7 +11,7 @@ get_linear_schedule_with_warmup, ) -from pytorch_ie import PyTorchIEModel +from pytorch_ie.core import PyTorchIEModel from pytorch_ie.models.modules.mlp import MLP TransformerSpanClassificationModelBatchEncoding = BatchEncoding diff --git a/src/pytorch_ie/models/transformer_text_classification.py b/src/pytorch_ie/models/transformer_text_classification.py index 1f1dcf3f..1352c356 100644 --- a/src/pytorch_ie/models/transformer_text_classification.py +++ b/src/pytorch_ie/models/transformer_text_classification.py @@ -4,7 +4,7 @@ from torch import Tensor, nn from transformers import AdamW, AutoConfig, AutoModel, get_linear_schedule_with_warmup -from pytorch_ie import PyTorchIEModel +from pytorch_ie.core import PyTorchIEModel TransformerTextClassificationModelBatchEncoding = MutableMapping[str, Any] TransformerTextClassificationModelBatchOutput = Dict[str, Any] diff --git a/src/pytorch_ie/models/transformer_token_classification.py b/src/pytorch_ie/models/transformer_token_classification.py index 73978da0..7a54db90 100644 --- a/src/pytorch_ie/models/transformer_token_classification.py +++ b/src/pytorch_ie/models/transformer_token_classification.py @@ -5,7 +5,7 @@ from torch import Tensor, nn from transformers import AutoConfig, AutoModelForTokenClassification, BatchEncoding -from pytorch_ie import PyTorchIEModel +from pytorch_ie.core import PyTorchIEModel TransformerTokenClassificationModelBatchEncoding = BatchEncoding TransformerTokenClassificationModelBatchOutput = Dict[str, Any] diff --git a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py index aeeed259..3d9e8fbf 100644 --- a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py @@ -7,8 +7,8 @@ from transformers.file_utils import PaddingStrategy from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy -from pytorch_ie import TaskEncoding, TaskModule from pytorch_ie.annotations import BinaryRelation, LabeledSpan, MultiLabeledBinaryRelation, Span +from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.documents import TextDocument from pytorch_ie.models import ( TransformerTextClassificationModelBatchOutput, diff --git a/src/pytorch_ie/taskmodules/transformer_seq2seq.py b/src/pytorch_ie/taskmodules/transformer_seq2seq.py index ff00232c..e8bda111 100644 --- a/src/pytorch_ie/taskmodules/transformer_seq2seq.py +++ b/src/pytorch_ie/taskmodules/transformer_seq2seq.py @@ -6,8 +6,8 @@ from transformers.file_utils import PaddingStrategy from transformers.tokenization_utils_base import TruncationStrategy -from pytorch_ie import Annotation, TaskEncoding, TaskModule from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span +from pytorch_ie.core import Annotation, TaskEncoding, TaskModule from pytorch_ie.documents import TextDocument from pytorch_ie.models import ( TransformerSeq2SeqModelBatchOutput, diff --git a/src/pytorch_ie/taskmodules/transformer_span_classification.py b/src/pytorch_ie/taskmodules/transformer_span_classification.py index 5774589d..b4ca546e 100644 --- a/src/pytorch_ie/taskmodules/transformer_span_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_span_classification.py @@ -8,8 +8,8 @@ from transformers.file_utils import PaddingStrategy from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy -from pytorch_ie import TaskEncoding, TaskModule from pytorch_ie.annotations import LabeledSpan, MultiLabeledSpan, Span +from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.documents import TextDocument from pytorch_ie.models.transformer_span_classification import ( TransformerSpanClassificationModelBatchOutput, diff --git a/src/pytorch_ie/taskmodules/transformer_text_classification.py b/src/pytorch_ie/taskmodules/transformer_text_classification.py index 8f848a7e..6b23940b 100644 --- a/src/pytorch_ie/taskmodules/transformer_text_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_text_classification.py @@ -17,8 +17,8 @@ from transformers.file_utils import PaddingStrategy from transformers.tokenization_utils_base import TruncationStrategy -from pytorch_ie import TaskEncoding, TaskModule from pytorch_ie.annotations import Label, MultiLabel +from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.documents import TextDocument from pytorch_ie.models.transformer_text_classification import ( TransformerTextClassificationModelBatchOutput, diff --git a/src/pytorch_ie/taskmodules/transformer_token_classification.py b/src/pytorch_ie/taskmodules/transformer_token_classification.py index bf0add0d..93c26993 100644 --- a/src/pytorch_ie/taskmodules/transformer_token_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_token_classification.py @@ -8,8 +8,8 @@ from transformers.file_utils import PaddingStrategy from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy -from pytorch_ie import TaskEncoding, TaskModule from pytorch_ie.annotations import LabeledSpan, Span +from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.documents import TextDocument from pytorch_ie.models.transformer_token_classification import ( TransformerTokenClassificationModelBatchOutput, diff --git a/tests/conftest.py b/tests/conftest.py index 6a5a79fe..9b41afc9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,9 @@ import pytest import datasets -from pytorch_ie import AnnotationList, Dataset, annotation_field from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.data import Dataset from pytorch_ie.documents import TextDocument from tests import FIXTURES_ROOT diff --git a/tests/data/datasets/test_brat.py b/tests/data/datasets/test_brat.py index 882dbe96..af75836d 100644 --- a/tests/data/datasets/test_brat.py +++ b/tests/data/datasets/test_brat.py @@ -1,5 +1,5 @@ # type: ignore - +""" import os import pytest @@ -29,6 +29,7 @@ "T2\tperson 25 37\tJenny Durkan\n", "R1\tmayor_of head:T2 tail:T1\n", ] +""" # def get_doc1(with_ids: bool = False, **kwargs) -> TextDocument: diff --git a/tests/pipeline/test_ner_span_classification.py b/tests/pipeline/test_ner_span_classification.py index b68f08f2..74e44bf7 100644 --- a/tests/pipeline/test_ner_span_classification.py +++ b/tests/pipeline/test_ner_span_classification.py @@ -2,10 +2,11 @@ import pytest -from pytorch_ie import AnnotationList, Pipeline, annotation_field from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerSpanClassificationModel +from pytorch_ie.pipeline import Pipeline from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule diff --git a/tests/pipeline/test_re_generative.py b/tests/pipeline/test_re_generative.py index a0492161..b57d8088 100644 --- a/tests/pipeline/test_re_generative.py +++ b/tests/pipeline/test_re_generative.py @@ -2,10 +2,11 @@ import pytest -from pytorch_ie import AnnotationList, Pipeline, annotation_field from pytorch_ie.annotations import BinaryRelation, LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerSeq2SeqModel +from pytorch_ie.pipeline import Pipeline from pytorch_ie.taskmodules import TransformerSeq2SeqTaskModule diff --git a/tests/pipeline/test_re_text_classification.py b/tests/pipeline/test_re_text_classification.py index 994da4c8..50c8ce6c 100644 --- a/tests/pipeline/test_re_text_classification.py +++ b/tests/pipeline/test_re_text_classification.py @@ -3,10 +3,11 @@ import pytest -from pytorch_ie import AnnotationList, Pipeline, annotation_field from pytorch_ie.annotations import BinaryRelation, LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerTextClassificationModel +from pytorch_ie.pipeline import Pipeline from pytorch_ie.taskmodules import TransformerRETextClassificationTaskModule diff --git a/tests/taskmodules/test_transformer_span_classification.py b/tests/taskmodules/test_transformer_span_classification.py index 6952118e..00e04c46 100644 --- a/tests/taskmodules/test_transformer_span_classification.py +++ b/tests/taskmodules/test_transformer_span_classification.py @@ -2,7 +2,7 @@ import pytest import torch -from pytorch_ie import TaskModule +from pytorch_ie.core import TaskModule from pytorch_ie.taskmodules.transformer_span_classification import ( TransformerSpanClassificationTaskModule, ) diff --git a/tests/taskmodules/test_transformer_token_classification.py b/tests/taskmodules/test_transformer_token_classification.py index 0fcae901..62486e0b 100644 --- a/tests/taskmodules/test_transformer_token_classification.py +++ b/tests/taskmodules/test_transformer_token_classification.py @@ -5,7 +5,7 @@ # import torch # from numpy.testing import assert_almost_equal -# from pytorch_ie import LabeledSpan +# from pytorch_ie.annotations import LabeledSpan # from pytorch_ie.taskmodules import TransformerTokenClassificationTaskModule # from tests.fixtures.document import ( # DOC1_ENTITY_BERLIN, diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 527fb5f9..082ad5d3 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -3,7 +3,6 @@ import pytest -from pytorch_ie import AnnotationList, annotation_field from pytorch_ie.annotations import ( BinaryRelation, Label, @@ -15,6 +14,7 @@ MultiLabeledSpan, Span, ) +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument diff --git a/tests/test_auto.py b/tests/test_auto.py index 69c8cca1..8a038bb1 100644 --- a/tests/test_auto.py +++ b/tests/test_auto.py @@ -2,15 +2,9 @@ import pytest -from pytorch_ie import ( - AnnotationList, - AutoModel, - AutoPipeline, - AutoTaskModule, - TaskModule, - annotation_field, -) from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.auto import AutoModel, AutoPipeline, AutoTaskModule +from pytorch_ie.core import AnnotationList, TaskModule, annotation_field from pytorch_ie.documents import TextDocument from pytorch_ie.models import TransformerSpanClassificationModel from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule diff --git a/tests/test_document.py b/tests/test_document.py index 2c9eab50..011163fb 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -3,8 +3,8 @@ import pytest -from pytorch_ie import AnnotationList, annotation_field from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index d3ed0a07..8ca845fc 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -6,9 +6,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPooling import pytorch_ie.models.modules.mlp -from pytorch_ie import Pipeline from pytorch_ie.core.taskmodule import InplaceNotSupportedException from pytorch_ie.models.transformer_span_classification import TransformerSpanClassificationModel +from pytorch_ie.pipeline import Pipeline from pytorch_ie.taskmodules.transformer_span_classification import ( TransformerSpanClassificationTaskModule, ) diff --git a/tests/train/test_training.py b/tests/train/test_training.py index 616eb50d..df34e39a 100644 --- a/tests/train/test_training.py +++ b/tests/train/test_training.py @@ -4,7 +4,7 @@ import pytorch_lightning as pl from torch.utils.data import DataLoader -from pytorch_ie import Document, PyTorchIEModel, TaskModule +from pytorch_ie.core import Document, PyTorchIEModel, TaskModule from pytorch_ie.models import TransformerTokenClassificationModel from pytorch_ie.taskmodules import TransformerTokenClassificationTaskModule