Skip to content

Commit

Permalink
reoarganize pytorch_ie.core (#166)
Browse files Browse the repository at this point in the history
* move taskmodule to pytorch_ie.core

* move document to pytorch_ie.core

* move Annotation to pytorch_ie.core.document

* move Metadata to pytorch_ie.core.taskmodule

* move TextDocument to documents

* rename registerable.py to registrable.py

* rename pytorch_ie.py to model.py

* export TaskEncoding and Auto classes

* do not export any entry from annotations and documents on package level



Co-authored-by: Arne Binder <[email protected]>
  • Loading branch information
ArneBinder and ArneBinder authored May 5, 2022
1 parent 138046f commit 5e452b0
Show file tree
Hide file tree
Showing 52 changed files with 175 additions and 172 deletions.
4 changes: 3 additions & 1 deletion datasets/conll2002/conll2002.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import datasets
import pytorch_ie.data.builder
from pytorch_ie import AnnotationList, LabeledSpan, TextDocument, annotation_field
from pytorch_ie import AnnotationList, annotation_field
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocument
from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans


Expand Down
4 changes: 3 additions & 1 deletion datasets/conll2003/conll2003.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import datasets
import pytorch_ie.data.builder
from pytorch_ie import AnnotationList, LabeledSpan, TextDocument, annotation_field
from pytorch_ie import AnnotationList, annotation_field
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocument
from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans


Expand Down
4 changes: 3 additions & 1 deletion datasets/conllpp/conllpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import datasets
import pytorch_ie.data.builder
from pytorch_ie import AnnotationList, LabeledSpan, TextDocument, annotation_field
from pytorch_ie import AnnotationList, annotation_field
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocument
from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import datasets
import pytorch_ie.data.builder
from pytorch_ie import AnnotationList, LabeledSpan, TextDocument, annotation_field
from pytorch_ie import AnnotationList, annotation_field
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocument
from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans

_VERSION = "1.0.0"
Expand Down
4 changes: 3 additions & 1 deletion datasets/germaner/germaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import datasets
import pytorch_ie.data.builder
from pytorch_ie import AnnotationList, LabeledSpan, TextDocument, annotation_field
from pytorch_ie import AnnotationList, annotation_field
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocument
from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans


Expand Down
4 changes: 3 additions & 1 deletion datasets/germeval_14/germeval_14.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import datasets
import pytorch_ie.data.builder
from pytorch_ie import AnnotationList, LabeledSpan, TextDocument, annotation_field
from pytorch_ie import AnnotationList, annotation_field
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocument
from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans


Expand Down
4 changes: 3 additions & 1 deletion datasets/ncbi_disease/ncbi_disease.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import datasets
import pytorch_ie.data.builder
from pytorch_ie import AnnotationList, LabeledSpan, TextDocument, annotation_field
from pytorch_ie import AnnotationList, annotation_field
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocument
from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans


Expand Down
4 changes: 3 additions & 1 deletion datasets/wikiann/wikiann.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import datasets
import pytorch_ie.data.builder
from pytorch_ie import AnnotationList, LabeledSpan, TextDocument, annotation_field
from pytorch_ie import AnnotationList, annotation_field
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocument
from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans

_VERSION = "1.1.0"
Expand Down
4 changes: 3 additions & 1 deletion datasets/wnut_17/wnut_17.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import datasets
import pytorch_ie.data.builder
from pytorch_ie import AnnotationList, LabeledSpan, TextDocument, annotation_field
from pytorch_ie import AnnotationList, annotation_field
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocument
from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans


Expand Down
4 changes: 3 additions & 1 deletion examples/predict/ner_span_classification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass

from pytorch_ie import AnnotationList, LabeledSpan, Pipeline, TextDocument, annotation_field
from pytorch_ie import AnnotationList, Pipeline, annotation_field
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocument
from pytorch_ie.models import TransformerSpanClassificationModel
from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule

Expand Down
11 changes: 3 additions & 8 deletions examples/predict/re_generative.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from dataclasses import dataclass

from pytorch_ie import (
AnnotationList,
BinaryRelation,
LabeledSpan,
Pipeline,
TextDocument,
annotation_field,
)
from pytorch_ie import AnnotationList, Pipeline, annotation_field
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.documents import TextDocument
from pytorch_ie.models import TransformerSeq2SeqModel
from pytorch_ie.taskmodules import TransformerSeq2SeqTaskModule

Expand Down
11 changes: 3 additions & 8 deletions examples/predict/re_text_classification.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from dataclasses import dataclass

from pytorch_ie import (
AnnotationList,
BinaryRelation,
LabeledSpan,
Pipeline,
TextDocument,
annotation_field,
)
from pytorch_ie import AnnotationList, Pipeline, annotation_field
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.documents import TextDocument
from pytorch_ie.models import TransformerTextClassificationModel
from pytorch_ie.taskmodules import TransformerRETextClassificationTaskModule

Expand Down
18 changes: 4 additions & 14 deletions src/pytorch_ie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
# flake8: noqa

from .annotations import (
BinaryRelation,
Label,
LabeledMultiSpan,
LabeledSpan,
MultiLabel,
MultiLabeledBinaryRelation,
MultiLabeledMultiSpan,
MultiLabeledSpan,
Span,
)
from .data import *
from .document import AnnotationList, Document, TextDocument, annotation_field
from .pipeline import Pipeline
from pytorch_ie.auto import AutoModel, AutoPipeline, AutoTaskModule
from pytorch_ie.core import *
from pytorch_ie.data import *
from pytorch_ie.pipeline import Pipeline
36 changes: 3 additions & 33 deletions src/pytorch_ie/annotations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

from pytorch_ie.document import AnnotationList
from pytorch_ie.core.document import Annotation


def _validate_single_label(self):
Expand Down Expand Up @@ -29,36 +29,6 @@ def _validate_multi_label(self):
)


@dataclass(eq=True, frozen=True)
class Annotation:
_target: Optional[Union[AnnotationList, str]] = field(
default=None, init=False, repr=False, hash=False
)

def set_target(self, value: Union[AnnotationList, str, None]):
object.__setattr__(self, "_target", value)

@property
def target(self) -> Optional[Union[AnnotationList, str]]:
return self._target

def asdict(self) -> Dict[str, Any]:
dct = asdict(self)
dct["_id"] = hash(self)
del dct["_target"]
return dct

@classmethod
def fromdict(
cls,
dct: Dict[str, Any],
annotation_store: Optional[Dict[int, Tuple[str, "Annotation"]]] = None,
):
tmp_dct = dict(dct)
tmp_dct.pop("_id", None)
return cls(**tmp_dct)


@dataclass(eq=True, frozen=True)
class Label(Annotation):
label: str
Expand Down
5 changes: 2 additions & 3 deletions src/pytorch_ie/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from huggingface_hub.constants import PYTORCH_WEIGHTS_NAME
from huggingface_hub.file_download import hf_hub_download

from pytorch_ie import Pipeline
from pytorch_ie.core import PyTorchIEModel
from pytorch_ie.core import PyTorchIEModel, TaskModule
from pytorch_ie.core.hf_hub_mixin import PyTorchIEModelHubMixin, PyTorchIETaskmoduleModelHubMixin
from pytorch_ie.taskmodules import TaskModule
from pytorch_ie.pipeline import Pipeline


class AutoTaskModule(PyTorchIETaskmoduleModelHubMixin):
Expand Down
4 changes: 3 additions & 1 deletion src/pytorch_ie/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .pytorch_ie import PyTorchIEModel
from .document import Annotation, AnnotationList, Document, annotation_field
from .model import PyTorchIEModel
from .taskmodule import TaskEncoding, TaskModule
43 changes: 31 additions & 12 deletions src/pytorch_ie/document.py → src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import dataclasses
import typing
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union, overload

if TYPE_CHECKING:
from pytorch_ie.annotations import Annotation
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, overload


def _depth_first_search(lst: List[str], visited: Set[str], graph: Dict[str, List[str]], node: str):
Expand All @@ -24,6 +21,36 @@ def annotation_field(target: Optional[str] = None):
return dataclasses.field(metadata=dict(target=target), init=False, repr=False)


@dataclasses.dataclass(eq=True, frozen=True)
class Annotation:
_target: Optional[Union["AnnotationList", str]] = dataclasses.field(
default=None, init=False, repr=False, hash=False
)

def set_target(self, value: Union["AnnotationList", str, None]):
object.__setattr__(self, "_target", value)

@property
def target(self) -> Optional[Union["AnnotationList", str]]:
return self._target

def asdict(self) -> Dict[str, Any]:
dct = dataclasses.asdict(self)
dct["_id"] = hash(self)
del dct["_target"]
return dct

@classmethod
def fromdict(
cls,
dct: Dict[str, Any],
annotation_store: Optional[Dict[int, Tuple[str, "Annotation"]]] = None,
):
tmp_dct = dict(dct)
tmp_dct.pop("_id", None)
return cls(**tmp_dct)


T = TypeVar("T", covariant=False, bound="Annotation")


Expand Down Expand Up @@ -200,11 +227,3 @@ def fromdict(cls, dct):
getattr(doc, field_name).append(annotation)

return doc


@dataclasses.dataclass
class TextDocument(Document):
text: str
id: Optional[str] = None
metadata: Dict[str, Any] = dataclasses.field(default_factory=dict)
_root_annotation: str = dataclasses.field(default="text", init=False, repr=False)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pytorch_lightning import LightningModule

from pytorch_ie.core.hf_hub_mixin import PyTorchIEModelHubMixin
from pytorch_ie.core.registerable import Registrable
from pytorch_ie.core.registrable import Registrable


class PyTorchIEModel(LightningModule, Registrable, PyTorchIEModelHubMixin):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
overload,
)

from pytorch_ie import Dataset, Document
from pytorch_ie.annotations import Annotation
from pytorch_ie.core.document import Annotation, Document
from pytorch_ie.core.hf_hub_mixin import PyTorchIETaskmoduleModelHubMixin
from pytorch_ie.core.registerable import Registrable
from pytorch_ie.data import Metadata
from pytorch_ie.core.registrable import Registrable
from pytorch_ie.data import Dataset

"""
workflow:
Expand Down Expand Up @@ -48,6 +47,9 @@ class InplaceNotSupportedException(Exception):
pass


Metadata = Dict[str, Any]


class TaskEncoding(Generic[DocumentType, InputEncoding, TargetEncoding]):
def __init__(
self,
Expand Down
5 changes: 0 additions & 5 deletions src/pytorch_ie/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from typing import Any, Dict

from .builder import GeneratorBasedBuilder
from .dataset import Dataset
from .dataset_formatter import DocumentFormatter

Metadata = Dict[str, Any]

__all__ = [
"Metadata",
"GeneratorBasedBuilder",
"Dataset",
"DocumentFormatter",
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_ie/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from datasets.load import load_dataset_builder

import datasets
from pytorch_ie.core.document import Document
from pytorch_ie.data.dataset import Dataset, decorate_convert_to_dict_of_lists
from pytorch_ie.document import Document


class GeneratorBasedBuilder(datasets.builder.GeneratorBasedBuilder):
Expand Down
5 changes: 2 additions & 3 deletions src/pytorch_ie/data/datamodules/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any, Dict, Generic, List, Optional, Sequence
from typing import Any, Dict, Generic, Optional, Sequence

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

from pytorch_ie.data.datasets import PIEDatasetDict
from pytorch_ie.taskmodules.taskmodule import (
from pytorch_ie.core.taskmodule import (
DocumentType,
InputEncoding,
TargetEncoding,
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_ie/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from functools import wraps
from typing import TYPE_CHECKING, Callable, List, Optional, Type, Union
from typing import Callable, List, Optional, Type, Union

import pandas as pd
from datasets.formatting import _register_formatter

import datasets
from pytorch_ie.core.document import Document
from pytorch_ie.data.dataset_formatter import DocumentFormatter
from pytorch_ie.document import Document

_register_formatter(DocumentFormatter, "document")

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_ie/data/dataset_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pyarrow as pa
from datasets.formatting.formatting import Formatter

from pytorch_ie.document import Document
from pytorch_ie.core.document import Document


class DocumentFormatter(Formatter[Document, list, List[Document]]):
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_ie/data/datasets/brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple

from datasets import DatasetDict, load_dataset
from pytorch_ie import BinaryRelation, Document, LabeledMultiSpan, LabeledSpan
from pytorch_ie.annotations import Annotation
from pytorch_ie import Document
from pytorch_ie.annotations import Annotation, BinaryRelation, LabeledMultiSpan, LabeledSpan
from pytorch_ie.data.datasets import HF_DATASETS_ROOT

DEFAULT_HEAD_ARGUMENT_NAME: str = "Arg1"
Expand Down
Loading

0 comments on commit 5e452b0

Please sign in to comment.