Skip to content

Commit

Permalink
add graceful failure for cases with failure to generate valid prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
ronentk committed Aug 28, 2024
1 parent ac9ae65 commit 4acb233
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class PromptCase(EnumDictKey):

def return_fallback(input):
return ParserChainOutput(
answer=[],
answer=Answer(sub_answers=list(), debug={"errors": "fallback"}),
pparser_type=ParserChainType.MULTI_REF_TAGGER,
extra={"errors": "fallback"},
)
Expand Down
50 changes: 30 additions & 20 deletions nlp/desci_sense/shared_functions/postprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,24 +298,30 @@ def convert_ref_tags_to_rdf_triplets(
]
else:
# non zero refs - add labels as predicates to corresponding urls
assert len(all_reference_tags) == len(reference_urls)
for ref_tags, ref_url in zip(all_reference_tags, reference_urls):
if len(ref_tags) == 0:
# if no ref tags provided, add default mention label
# add warning since this should be handled prior
logger.warning("No ref tags provided, adding default label!")
updated_ref_tags = [ontology.default_mention_label()]
else:
updated_ref_tags = ref_tags
for label in updated_ref_tags:
concept = ontology.get_concept_by_label(label)
assert concept.can_be_predicate()
triplets += [
RDFTriplet(
predicate=URIRef(concept.uri),
object=URIRef(ref_url),
)
]
if len(all_reference_tags) == len(reference_urls):
for ref_tags, ref_url in zip(all_reference_tags, reference_urls):
if len(ref_tags) == 0:
# if no ref tags provided, add default mention label
# add warning since this should be handled prior
logger.warning("No ref tags provided, adding default label!")
updated_ref_tags = [ontology.default_mention_label()]
else:
updated_ref_tags = ref_tags
for label in updated_ref_tags:
concept = ontology.get_concept_by_label(label)
assert concept.can_be_predicate()
triplets += [
RDFTriplet(
predicate=URIRef(concept.uri),
object=URIRef(ref_url),
)
]
else:
logger.warning(
"No predictions in input corresponding to references!\n"
f"{all_reference_tags} != {reference_urls}"
)
# logger.warning(f"{all_reference_tags} != {reference_urls}")
return triplets

# # for each tag decide if it's the object or predicate
Expand Down Expand Up @@ -492,13 +498,17 @@ def get_support_data(
ontology_interface: OntologyInterface,
metadata_list: List[RefMetadata],
) -> ParserSupport:
md_dict = {} # Initialize an empty dictionary
# Initialize an empty dictionary
md_dict = {}

for m in metadata_list:
if hasattr(m, "url"):
md_dict[m.url] = m

return ParserSupport(ontology=ontology_interface, refs_meta=md_dict)
return ParserSupport(
ontology=ontology_interface,
refs_meta=md_dict,
)


def post_process_firebase(
Expand Down

0 comments on commit 4acb233

Please sign in to comment.