Skip to content

Commit

Permalink
remove features not declared in the target document type
Browse files Browse the repository at this point in the history
  • Loading branch information
RainbowRivey committed Sep 30, 2024
1 parent 0f6ed10 commit f09940c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
7 changes: 7 additions & 0 deletions src/pie_datasets/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ def dataset_to_document_type(
# remove the document converters because they are not valid anymore
result.document_converters = {}

# remove features not declared in the target document type
if result.features is not None:
original_field_names = set(result.features)
target_field_names = {field.name for field in document_type.fields()}
remove_field_names = original_field_names - target_field_names
result = result.remove_columns(list(remove_field_names))

return result


Expand Down
26 changes: 17 additions & 9 deletions tests/unit/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy
import pytest
import torch
from pyexpat import features
from pytorch_ie import Document
from pytorch_ie.annotations import BinaryRelation, Label, LabeledSpan, Span
from pytorch_ie.core import AnnotationList, annotation_field
Expand Down Expand Up @@ -203,18 +204,25 @@ def test_register_document_converter_mapping(dataset_with_converter_mapping):


def test_to_document_type_function(dataset_with_converter_functions):
assert set(dataset_with_converter_functions.features) == {
"entities",
"relations",
"metadata",
"sentences",
"id",
"text",
}
# Features are only available for Dataset type (not for IterableDataset)
if isinstance(dataset_with_converter_functions, Dataset):
assert set(dataset_with_converter_functions.features) == {
"entities",
"relations",
"metadata",
"sentences",
"id",
"text",
}
else:
assert dataset_with_converter_functions.features is None
assert dataset_with_converter_functions.document_type == TestDocument
converted_dataset = dataset_with_converter_functions.to_document_type(TestDocumentWithLabel)
assert converted_dataset.document_type == TestDocumentWithLabel
assert set(converted_dataset.features) == {"id", "label", "metadata", "text"}
if isinstance(converted_dataset, Dataset):
assert set(converted_dataset.features) == {"id", "label", "metadata", "text"}
else:
assert converted_dataset.features is None

assert len(converted_dataset.document_converters) == 0
for doc in converted_dataset:
Expand Down

0 comments on commit f09940c

Please sign in to comment.