-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
initial creation of ManipulatePreds class + tests #184
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from typing import List | ||
|
||
|
||
class ManipulatePreds: | ||
""" | ||
Class to help expand prediction text and indexing via OCR data. | ||
""" | ||
|
||
def __init__(self, ocr_tokens: List[dict], preds: List[dict]): | ||
self.ocr_tokens = ocr_tokens | ||
self.predictions = preds | ||
|
||
def expand_predictions(self, pred_start: int, pred_end: int) -> dict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure this is how I'd imagine using this function-- wouldn't you have a specific prediction in mind where you want this to run? So the function could be static and take list of ocr_tokens and pred and then expands then returns the pred (expanded if that's relevant-- updating the start/end indexes as well as the text value)? Let me know what your rationale was for pred start/end instead I might be missing something There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "alks Sco"-- > ["talks", "Scott"] - > "talks Scott" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. def expand_pred(some_pred: dict, tokens: List[dict]) -> expanded_pred (dict) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -> dict, bool (bool = True if updated) |
||
""" | ||
Expand predictions and boundaries to match that of OCR data. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. provide an example here of what you mean |
||
|
||
Args: | ||
pred_start (int): Starting prediction index | ||
pred_end (int): Ending prediction index | ||
Returns: | ||
dict: Returns expanded prediction dictionary | ||
""" | ||
expanded_start, expanded_end = pred_start, pred_end | ||
pred_index = None | ||
|
||
# Find index value of bounded prediction | ||
for index, pred in enumerate(self.predictions): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you could eliminate this? I'm not sure how you'd have pred start/end in the first place without already knowing the pred you want to operate on |
||
if pred["start"] == pred_start or pred["end"] == pred_end: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder here whether you should be looking for any overlap rather than an exact match on start/end-- I believe there is already a helper function for this somewhere |
||
pred_index = index | ||
break | ||
|
||
if pred_index is None: | ||
raise ValueError("No matching prediction found.") | ||
|
||
# Validate current prediction needs expanding | ||
original_text = self.predictions[pred_index]["text"] | ||
ocr_text_initial = self._get_ocr_text(pred_start, pred_end) | ||
if original_text == ocr_text_initial: | ||
print("No expansion needed") | ||
return self.predictions[pred_index] | ||
|
||
# Use overlapping boundaries and expand text / boundaries to match OCR data | ||
for token in self.ocr_tokens: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you might be able to reuse some token matching functionality from other classes-- and then expand from there There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. e.g. see if you can reuse this class (or if it needs some small tweaks, could include that?) https://github.com/IndicoDataSolutions/Indico-Solutions-Toolkit/blob/main/indico_toolkit/association/extracted_tokens.py |
||
if token["doc_offset"]["start"] <= pred_start <= token["doc_offset"]["end"]: | ||
expanded_start = min(expanded_start, token["doc_offset"]["start"]) | ||
if token["doc_offset"]["start"] <= pred_end <= token["doc_offset"]["end"]: | ||
expanded_end = max(expanded_end, token["doc_offset"]["end"]) | ||
|
||
expanded_text = self._get_ocr_text(expanded_start, expanded_end) | ||
if expanded_text != ocr_text_initial: | ||
raise ValueError("Expanded text does not match the OCR text.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how could this be the case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would feel like if you match, then you're good |
||
|
||
if expanded_text == ocr_text_initial: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't need this check given condition above (I would always assume that these have to be equal given that you set them) |
||
# Update prediction | ||
self.predictions[pred_index]["start"] = expanded_start | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we instantiate this class with all of the preds / tokens, then maybe this method woul dmake more sense operating against a particular label? (i.e. expand for all "Insured Name")-- to me, probably makes more sense to have this be a static method as described above- |
||
self.predictions[pred_index]["end"] = expanded_end | ||
self.predictions[pred_index]["text"] = expanded_text | ||
|
||
return self.predictions[pred_index] | ||
|
||
def is_token_nearby( | ||
self, ocr_start: int, ocr_end: int, search_tokens: List[str], distance: int | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I understand what distance means here or when this function would be useful (also don't like the idea of entering in the ocr_start ocr_end (how would you know what those values should be?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the more relevant functionality would be "is SOME_TEXT contained within a token within X distance" |
||
) -> bool: | ||
""" | ||
A function that returns a boolean if specified token(s) are found within a given distance. | ||
|
||
Args: | ||
ocr_start (int): Starting OCR index | ||
ocr_end (int): Ending OCR index | ||
search_tokens (List[str]): A list of strings to be searched for, case senstive. | ||
distance (int): The amount of tokens examined, forward and backwards, in the search. | ||
Returns: | ||
bool: Returns True if a search token is found. | ||
""" | ||
token_index = None | ||
for index, token in enumerate(self.ocr_tokens): | ||
if ( | ||
token["doc_offset"]["start"] == ocr_start | ||
and token["doc_offset"]["end"] == ocr_end | ||
): | ||
token_index = index | ||
break | ||
|
||
if token_index: | ||
for i in range(max(0, token_index - distance), token_index): | ||
if self.ocr_tokens[i]["text"] in search_tokens: | ||
return True | ||
|
||
for i in range( | ||
token_index + 1, min(len(self.ocr_tokens), token_index + distance + 1) | ||
): | ||
if self.ocr_tokens[i]["text"] in search_tokens: | ||
return True | ||
|
||
else: | ||
raise ValueError("No token found with specified bounds.") | ||
|
||
return False | ||
|
||
def _get_ocr_text(self, start: int, end: int) -> str: | ||
""" | ||
Args: | ||
start (int): Starting OCR token index | ||
end (int): Ending OCR token index | ||
Returns: | ||
str: Full token text found within the specified boundaries. | ||
""" | ||
text = "" | ||
for token in self.ocr_tokens: | ||
if token["doc_offset"]["end"] <= start: | ||
continue | ||
if token["doc_offset"]["start"] >= end: | ||
break | ||
text += token["text"] + " " | ||
return text.strip() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
[ | ||
{ | ||
"start": 26, | ||
"end": 32, | ||
"label": "Charity Number", | ||
"confidence": { | ||
"Charity Number": 0.991480827331543, | ||
"Name": 4.498019188758917e-05 | ||
}, | ||
"field_id": 52262, | ||
"page_num": 0, | ||
"text": "20179", | ||
"normalized": { | ||
"text": "201794", | ||
"start": 26, | ||
"end": 3, | ||
"structured": null, | ||
"formatted": "201794", | ||
"status": "SUCCESS", | ||
"validation": [ | ||
{ | ||
"validation_type": "TYPE_CONVERSION", | ||
"error_message": null, | ||
"validation_status": "SUCCESS" | ||
} | ||
] | ||
} | ||
}, | ||
{ | ||
"start": 75, | ||
"end": 95, | ||
"label": "Name", | ||
"confidence": { | ||
"Charity Number": 4.890366653853562e-06, | ||
"Name": 0.9999933242797852 | ||
}, | ||
"field_id": 52263, | ||
"page_num": 0, | ||
"text": "UGLAS ARTER FOUNDATI", | ||
"normalized": { | ||
"text": "DOUGLAS ARTER FOUNDATION", | ||
"start": 73, | ||
"end": 97, | ||
"structured": null, | ||
"formatted": "DOUGLAS ARTER FOUNDATION", | ||
"status": "SUCCESS", | ||
"validation": [ | ||
{ | ||
"validation_type": "TYPE_CONVERSION", | ||
"error_message": null, | ||
"validation_status": "SUCCESS" | ||
} | ||
] | ||
} | ||
} | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add an example usage here?