diff --git a/poetry.lock b/poetry.lock index 25febe387..e63a68fd9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2158,27 +2158,23 @@ files = [ [[package]] name = "pytorch-ie" -version = "0.31.3" +version = "0.31.2" description = "State-of-the-art Information Extraction in PyTorch" optional = false -python-versions = "^3.9" -files = [] -develop = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "pytorch_ie-0.31.2-py3-none-any.whl", hash = "sha256:fef91fb3d4dff84b0b6fd973d5bf7e5be51e0e01393195a01d7d824688c8cb3e"}, + {file = "pytorch_ie-0.31.2.tar.gz", hash = "sha256:cd9683ef4ba0191854ff1843f22f431f4c38c8745962ee55ba5b5c52f27afd7c"}, +] [package.dependencies] -absl-py = "^1.0.0" +absl-py = ">=1.0.0,<2.0.0" fsspec = "<2023.9.0" -pandas = "^2.0.0" -pytorch-lightning = "^2" +pandas = ">=2.0.0,<3.0.0" +pytorch-lightning = ">=2,<3" torch = ">=1.10" -torchmetrics = "^1" -transformers = "^4.18" - -[package.source] -type = "git" -url = "https://github.com/ArneBinder/pytorch-ie" -reference = "document/deduplicate_annotations" -resolved_reference = "57bb34386a2ec9922ea7c8e7f36a0a199b02848e" +torchmetrics = ">=1,<2" +transformers = ">=4.18,<5.0" [[package]] name = "pytorch-lightning" @@ -3443,4 +3439,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "ab8de51bf6d389468d923b5d391f339c6e7f383396565cf5db925039de957132" +content-hash = "9edc2e1c448159e3f55c1d7fb1c6fa1d2baa9a11b2ef6e8aa80d5f551789ecac" diff --git a/pyproject.toml b/pyproject.toml index d784a1584..69593d7ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.9" -#pytorch-ie = ">=0.31.4,<0.32.0" -# install from branch from https://github.com/ArneBinder/pytorch-ie/pull/436 -pytorch-ie = { git = "https://github.com/ArneBinder/pytorch-ie", branch = "document/deduplicate_annotations" } +pytorch-ie = ">=0.31.2,<0.32.0" pytorch-lightning = "^2.1.0" torchmetrics = "^1" # >=4.35 because of BartModelWithDecoderPositionIds, <4.37 because of generation config diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index e49213109..30c6cd6e7 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -568,9 +568,9 @@ def test_decode_with_add_reversed_relations(): task_outputs = [task_encoding.targets for task_encoding in task_encodings] docs_with_predictions = taskmodule.decode(task_encodings, task_outputs) assert len(docs_with_predictions) == 1 - doc_with_predictions: ExampleDocument = docs_with_predictions[0].deduplicate_annotations() - assert list(doc_with_predictions.entities.predictions) == list(doc_with_predictions.entities) - assert list(doc_with_predictions.relations.predictions) == list(doc_with_predictions.relations) + doc_with_predictions: ExampleDocument = docs_with_predictions[0] + assert set(doc_with_predictions.entities.predictions) == set(doc_with_predictions.entities) + assert set(doc_with_predictions.relations.predictions) == set(doc_with_predictions.relations) @pytest.fixture()