diff --git a/nlp/desci_sense/shared_functions/parsers/multi_reference_tagger.py b/nlp/desci_sense/shared_functions/parsers/multi_reference_tagger.py index a9b2de60..8d38f3b4 100644 --- a/nlp/desci_sense/shared_functions/parsers/multi_reference_tagger.py +++ b/nlp/desci_sense/shared_functions/parsers/multi_reference_tagger.py @@ -37,7 +37,7 @@ class PromptCase(EnumDictKey): def return_fallback(input): return ParserChainOutput( - answer=[], + answer=Answer(sub_answers=list(), debug={"errors": "fallback"}), pparser_type=ParserChainType.MULTI_REF_TAGGER, extra={"errors": "fallback"}, ) diff --git a/nlp/desci_sense/shared_functions/postprocessing/__init__.py b/nlp/desci_sense/shared_functions/postprocessing/__init__.py index c3ad7302..a332f810 100644 --- a/nlp/desci_sense/shared_functions/postprocessing/__init__.py +++ b/nlp/desci_sense/shared_functions/postprocessing/__init__.py @@ -298,24 +298,30 @@ def convert_ref_tags_to_rdf_triplets( ] else: # non zero refs - add labels as predicates to corresponding urls - assert len(all_reference_tags) == len(reference_urls) - for ref_tags, ref_url in zip(all_reference_tags, reference_urls): - if len(ref_tags) == 0: - # if no ref tags provided, add default mention label - # add warning since this should be handled prior - logger.warning("No ref tags provided, adding default label!") - updated_ref_tags = [ontology.default_mention_label()] - else: - updated_ref_tags = ref_tags - for label in updated_ref_tags: - concept = ontology.get_concept_by_label(label) - assert concept.can_be_predicate() - triplets += [ - RDFTriplet( - predicate=URIRef(concept.uri), - object=URIRef(ref_url), - ) - ] + if len(all_reference_tags) == len(reference_urls): + for ref_tags, ref_url in zip(all_reference_tags, reference_urls): + if len(ref_tags) == 0: + # if no ref tags provided, add default mention label + # add warning since this should be handled prior + logger.warning("No ref tags provided, adding default label!") + updated_ref_tags = [ontology.default_mention_label()] + else: + updated_ref_tags = ref_tags + for label in updated_ref_tags: + concept = ontology.get_concept_by_label(label) + assert concept.can_be_predicate() + triplets += [ + RDFTriplet( + predicate=URIRef(concept.uri), + object=URIRef(ref_url), + ) + ] + else: + logger.warning( + "No predictions in input corresponding to references!\n" + f"{all_reference_tags} != {reference_urls}" + ) + # logger.warning(f"{all_reference_tags} != {reference_urls}") return triplets # # for each tag decide if it's the object or predicate @@ -492,13 +498,17 @@ def get_support_data( ontology_interface: OntologyInterface, metadata_list: List[RefMetadata], ) -> ParserSupport: - md_dict = {} # Initialize an empty dictionary + # Initialize an empty dictionary + md_dict = {} for m in metadata_list: if hasattr(m, "url"): md_dict[m.url] = m - return ParserSupport(ontology=ontology_interface, refs_meta=md_dict) + return ParserSupport( + ontology=ontology_interface, + refs_meta=md_dict, + ) def post_process_firebase(