Skip to content

Commit

Permalink
save entries of token_offset_mapping as lists in metadata to fix dese…
Browse files Browse the repository at this point in the history
…rialization
  • Loading branch information
ArneBinder committed Oct 5, 2023
1 parent 3f07b1c commit 96407e2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
25 changes: 18 additions & 7 deletions src/pytorch_ie/data/document_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,24 @@ def text_based_document_to_token_based(

# save text, token_offset_mapping and char_to_token (if available) in metadata
result.metadata["text"] = doc.text
token_offset_mapping_lists: Optional[List[List[int]]]
if token_offset_mapping is not None:
# convert offset tuples to lists because serialization and deserialization again
# will produce lists in any way (json does not know tuples)
token_offset_mapping_lists = [list(offsets) for offsets in token_offset_mapping]
if (
"token_offset_mapping" in doc.metadata
and doc.metadata["token_offset_mapping"] != token_offset_mapping
and doc.metadata["token_offset_mapping"] != token_offset_mapping_lists
):
logger.warning(
"token_offset_mapping in metadata is different from the new token_offset_mapping, "
"overwrite the metadata"
)
result.metadata["token_offset_mapping"] = token_offset_mapping
result.metadata["token_offset_mapping"] = token_offset_mapping_lists
else:
token_offset_mapping = doc.metadata.get("token_offset_mapping")
token_offset_mapping_lists = doc.metadata.get("token_offset_mapping")
if token_offset_mapping_lists is not None:
token_offset_mapping = [tuple(offsets) for offsets in token_offset_mapping_lists] # type: ignore
if char_to_token is not None:
if "char_to_token" in doc.metadata and doc.metadata["char_to_token"] != char_to_token:
logger.warning(
Expand Down Expand Up @@ -167,29 +173,34 @@ def token_based_document_to_text_based(
raise ValueError(
"if join_tokens_with is None, text must be provided, but got None as well"
)
token_offset_mapping = (
token_offset_mapping_lists = (
doc.metadata.get("token_offset_mapping")
if token_offset_mapping is None
else token_offset_mapping
)
if token_offset_mapping is None:
if token_offset_mapping_lists is None:
raise ValueError(
"if join_tokens_with is None, token_offsets must be provided, but got None as well"
)
else:
token_offset_mapping = [tuple(offsets) for offsets in token_offset_mapping_lists] # type: ignore

result = document_type(text=text, id=doc.id, metadata=deepcopy(doc.metadata))
if "tokens" in doc.metadata and doc.metadata["tokens"] != list(doc.tokens):
logger.warning("tokens in metadata are different from new tokens, overwrite the metadata")
result.metadata["tokens"] = list(doc.tokens)
# convert offset tuples to lists because serialization and deserialization again
# will produce lists in any way (json does not know tuples)
token_offset_mapping_lists = [list(offsets) for offsets in token_offset_mapping]
if (
"token_offset_mapping" in doc.metadata
and doc.metadata["token_offset_mapping"] != token_offset_mapping
and doc.metadata["token_offset_mapping"] != token_offset_mapping_lists
):
logger.warning(
"token_offset_mapping in metadata is different from the new token_offset_mapping, "
"overwrite the metadata"
)
result.metadata["token_offset_mapping"] = token_offset_mapping
result.metadata["token_offset_mapping"] = token_offset_mapping_lists

token_targeting_layers = [
annotation_field.name
Expand Down
5 changes: 3 additions & 2 deletions tests/data/test_document_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ def test_text_based_document_to_token_based(documents, tokenizer):
# check (de-)serialization
tokenized_doc.copy()

offset_mapping_lists = [list(offsets) for offsets in tokenized_text.offset_mapping]
if i == 0:
assert doc.id == "train_doc1"
assert tokenized_doc.metadata["text"] == doc.text == "A single sentence."
assert tokenized_doc.metadata["token_offset_mapping"] == tokenized_text.offset_mapping
assert tokenized_doc.metadata["token_offset_mapping"] == offset_mapping_lists
assert tokenized_doc.metadata.get("char_to_token") is None
assert tokenized_doc.tokens == ("[CLS]", "A", "single", "sentence", ".", "[SEP]")
assert len(tokenized_doc.sentences) == len(doc.sentences) == 1
Expand Down Expand Up @@ -91,7 +92,7 @@ def test_text_based_document_to_token_based(documents, tokenizer):
elif i == 2:
assert doc.id == "train_doc3"
assert tokenized_doc.metadata["text"] == doc.text == "Entity C and D."
assert tokenized_doc.metadata["token_offset_mapping"] == tokenized_text.offset_mapping
assert tokenized_doc.metadata["token_offset_mapping"] == offset_mapping_lists
assert tokenized_doc.metadata["char_to_token"] == tokenized_text.char_to_token
assert tokenized_doc.tokens == (
"[CLS]",
Expand Down

0 comments on commit 96407e2

Please sign in to comment.