Skip to content

Commit

Permalink
fix _generate_document_kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
idalr committed Nov 7, 2023
1 parent ae0f032 commit f7b3753
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 34 deletions.
2 changes: 1 addition & 1 deletion dataset_builders/pie/cdcp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The document type for this dataset is `CDCPDocument` which defines the following

- `text` (str)
- `id` (str, optional)
- `metadata` (dictionary, dataclasses)
- `metadata` (dictionary, optional)

and the following annotation layers:

Expand Down
24 changes: 12 additions & 12 deletions dataset_builders/pie/cdcp/cdcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ class CDCPDocument(TextBasedDocument):

def example_to_document(
example: Dict[str, Any],
relation_int2str: Callable[[int], str],
proposition_int2str: Callable[[int], str],
relation_label: Callable[[int], str],
proposition_label: Callable[[int], str],
):
document = CDCPDocument(id=example["id"], text=example["text"])
for proposition_dict in dl2ld(example["propositions"]):
proposition = LabeledSpan(
start=proposition_dict["start"],
end=proposition_dict["end"],
label=proposition_int2str(proposition_dict["label"]),
label=proposition_label.int2str(proposition_dict["label"]),
)
document.propositions.append(proposition)
if proposition_dict.get("url", "") != "":
Expand All @@ -58,7 +58,7 @@ def example_to_document(
relation = BinaryRelation(
head=document.propositions[relation_dict["head"]],
tail=document.propositions[relation_dict["tail"]],
label=relation_int2str(relation_dict["label"]),
label=relation_label.int2str(relation_dict["label"]),
)
document.relations.append(relation)

Expand All @@ -67,8 +67,8 @@ def example_to_document(

def document_to_example(
document: CDCPDocument,
relation_str2int: Callable[[str], int],
proposition_str2int: Callable[[str], int],
relation_label: Callable[[int], str],
proposition_label: Callable[[int], str],
) -> Dict[str, Any]:
result = {"id": document.id, "text": document.text}
proposition2dict = {}
Expand All @@ -77,7 +77,7 @@ def document_to_example(
proposition2dict[proposition] = {
"start": proposition.start,
"end": proposition.end,
"label": proposition_str2int(proposition.label),
"label": proposition_label.str2int(proposition.label),
"url": "",
}
proposition2idx[proposition] = idx
Expand All @@ -92,7 +92,7 @@ def document_to_example(
{
"head": proposition2idx[relation.head],
"tail": proposition2idx[relation.tail],
"label": relation_str2int(relation.label),
"label": relation_label.str2int(relation.label),
}
for relation in document.relations
]
Expand Down Expand Up @@ -132,11 +132,11 @@ class CDCP(pytorch_ie.data.builder.GeneratorBasedBuilder):

def _generate_document_kwargs(self, dataset):
return {
"relation_int2str": dataset.features["relations"].feature["label"].int2str,
"proposition_int2str": dataset.features["propositions"].feature["label"].int2str,
"relation_label": dataset.features["relations"].feature["label"],
"proposition_label": dataset.features["propositions"].feature["label"],
}

def _generate_document(self, example, relation_int2str, proposition_int2str):
def _generate_document(self, example, relation_label, proposition_label):
return example_to_document(

Check warning on line 140 in dataset_builders/pie/cdcp/cdcp.py

View check run for this annotation

Codecov / codecov/patch

dataset_builders/pie/cdcp/cdcp.py#L140

Added line #L140 was not covered by tests
example, relation_int2str=relation_int2str, proposition_int2str=proposition_int2str
example, relation_label=relation_label, proposition_label=proposition_label
)
3 changes: 0 additions & 3 deletions src/pie_datasets/document/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ class Attribute(Annotation):
value: Optional[str] = None
score: float = 1.0

def __post_init__(self) -> None:
_post_init_single_label(self)


# ========================= Document Types ========================= #

Expand Down
23 changes: 5 additions & 18 deletions tests/dataset_builders/pie/test_cdcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,28 +97,17 @@ def test_hf_example(hf_example, split):

@pytest.fixture(scope="module")
def generate_document_kwargs(hf_dataset, split):
return {
"relation_int2str": hf_dataset[split].features["relations"].feature["label"].int2str,
"proposition_int2str": hf_dataset[split].features["propositions"].feature["label"].int2str,
}


@pytest.fixture(scope="module")
def convert_back_kwargs(hf_dataset, split):
return {
"relation_str2int": hf_dataset[split].features["relations"].feature["label"].str2int,
"proposition_str2int": hf_dataset[split].features["propositions"].feature["label"].str2int,
}
return CDCP()._generate_document_kwargs(hf_dataset[split])


def test_example_to_document(hf_example, generate_document_kwargs):
doc = example_to_document(hf_example, **generate_document_kwargs)
assert doc is not None


def test_example_to_document_and_back(hf_example, generate_document_kwargs, convert_back_kwargs):
def test_example_to_document_and_back(hf_example, generate_document_kwargs):
doc = example_to_document(hf_example, **generate_document_kwargs)
hf_example_back = document_to_example(doc, **convert_back_kwargs)
hf_example_back = document_to_example(doc, **generate_document_kwargs)
_deep_compare(
obj=hf_example_back,
obj_expected=hf_example,
Expand Down Expand Up @@ -155,13 +144,11 @@ class TextDocumentWithEntities(TextBasedDocument):
_assert_no_span_overlap(document=doc1, text_field="text", span_layer="entities")


def test_example_to_document_and_back_all(
hf_dataset, generate_document_kwargs, convert_back_kwargs, split
):
def test_example_to_document_and_back_all(hf_dataset, generate_document_kwargs, split):
for hf_ex in hf_dataset[split]:
doc = example_to_document(hf_ex, **generate_document_kwargs)
_assert_no_span_overlap(document=doc, text_field="text", span_layer="propositions")
hf_example_back = document_to_example(doc, **convert_back_kwargs)
hf_example_back = document_to_example(doc, **generate_document_kwargs)
_deep_compare(
obj=hf_example_back,
obj_expected=hf_ex,
Expand Down

0 comments on commit f7b3753

Please sign in to comment.