From 7b5eff00b044e5915373dafb6ddcbd12380a534e Mon Sep 17 00:00:00 2001 From: Ronen Tamari Date: Tue, 27 Aug 2024 20:46:59 -0400 Subject: [PATCH] #161 add reference ordering information to RefMetadata --- app/firebase-py/functions/main.test.py | 8 ++ nlp/desci_sense/shared_functions/interface.py | 7 ++ .../web_extractors/metadata_extractors.py | 10 ++- nlp/tests/test_multi_chain_app_interface.py | 15 +++- nlp/tests/test_multi_chain_post_processing.py | 81 ++++++++++++++++--- 5 files changed, 108 insertions(+), 13 deletions(-) diff --git a/app/firebase-py/functions/main.test.py b/app/firebase-py/functions/main.test.py index 1d7945c2..324877c8 100644 --- a/app/firebase-py/functions/main.test.py +++ b/app/firebase-py/functions/main.test.py @@ -141,5 +141,13 @@ json_obj = json.loads(serialized) print(f"semantics: {json_obj['semantics']}") +# Sorting the refs metadata dictionary by the 'order' attribute +sorted_refs = sorted( + json_obj["support"]["refs_meta"].items(), + key=lambda x: x[1]["order"], +) + +print(f"ordered references: {[url for url, _ in sorted_refs]}") + with open("last_output.json", "wb") as file: file.write(serialized.encode("utf-8")) diff --git a/nlp/desci_sense/shared_functions/interface.py b/nlp/desci_sense/shared_functions/interface.py index 376c3bea..b3fabd48 100644 --- a/nlp/desci_sense/shared_functions/interface.py +++ b/nlp/desci_sense/shared_functions/interface.py @@ -201,6 +201,13 @@ class RefMetadata(BaseModel): mentioned in a post. """ + ref_id: int = Field( + description="Unique ID of reference (1 indexed)", + ) + order: int = Field( + default=0, + description="1 indexed ordering of reference, sorted by ascending appearance order in the post. 1 - first, 2 - 2nd,. 0 - unassigned", + ) citoid_url: Union[str, None] = Field( description="URL used by citoid (might have different subdomain or final slashes).", ) diff --git a/nlp/desci_sense/shared_functions/web_extractors/metadata_extractors.py b/nlp/desci_sense/shared_functions/web_extractors/metadata_extractors.py index f5a00aa9..2b856cf2 100644 --- a/nlp/desci_sense/shared_functions/web_extractors/metadata_extractors.py +++ b/nlp/desci_sense/shared_functions/web_extractors/metadata_extractors.py @@ -37,12 +37,13 @@ def normalize_citoid_metadata( ): assert len(target_urls) == len(metadata_list) results = [] - for url, metadata in zip(target_urls, metadata_list): + for i, (url, metadata) in enumerate(zip(target_urls, metadata_list)): metadata["original_url"] = url summary = metadata.get("abstractNote", "") results.append( RefMetadata( **{ + "ref_id": i + 1, "citoid_url": metadata.get("url", None), "url": metadata.get("original_url", None), "item_type": metadata.get("itemType", None), @@ -198,12 +199,14 @@ def get_ref_post_metadata_list( post: RefPost, md_dict: Dict[str, RefMetadata], extra_urls: List[str] = None, + add_ordering: bool = True, ) -> List[RefMetadata]: """ Returns list of the post's reference metadata. If extra urls are provided, they are also counted as part of the post ref urls (for example extra_urls could include unprocessed urls due to max length limits). + If `add_ordering`, add reference ordering info to each metadata item. """ all_ref_urls = post.md_ref_urls() @@ -212,9 +215,12 @@ def get_ref_post_metadata_list( all_ref_urls += remove_dups_ordered(extra_urls) md_list = [] - for ref in all_ref_urls: + for i, ref in enumerate(all_ref_urls): if ref in md_dict: md = md_dict.get(ref) if md: + if add_ordering: + # add ordering info (1-indexed) + md.order = i + 1 md_list.append(md) return md_list diff --git a/nlp/tests/test_multi_chain_app_interface.py b/nlp/tests/test_multi_chain_app_interface.py index 61bbd213..c4e87b70 100644 --- a/nlp/tests/test_multi_chain_app_interface.py +++ b/nlp/tests/test_multi_chain_app_interface.py @@ -183,6 +183,10 @@ def test_app_config(): "https://x.com/FDAadcomms/status/1798107142219796794", ] check_uris_in_graph(res.semantics, expected_uris) + # check ordering + for i, url in enumerate(expected_uris): + assert res.support.refs_meta[url].order == i + 1 + assert res.support.refs_meta[url].url == url # "mistralai/mixtral-8x7b-instruct" @@ -190,9 +194,16 @@ def test_app_config(): # "google/gemma-7b-it" if __name__ == "__main__": multi_config = init_multi_chain_parser_config( - llm_type="mistralai/mistral-7b-instruct:free", post_process_type="firebase" + # ref_tagger_llm_type="mistralai/mistral-7b-instruct", + # kw_llm_type="mistralai/mistral-7b-instruct", + # topic_llm_type="mistralai/mistral-7b-instruct", + ref_tagger_llm_type="mistralai/mistral-7b-instruct:free", + kw_llm_type="mistralai/mistral-7b-instruct:free", + topic_llm_type="mistralai/mistral-7b-instruct:free", + post_process_type="firebase", ) mcp = MultiChainParser(multi_config) - thread = get_short_thread() + thread = get_thread_1() pi = ParserInput(thread_post=thread, max_posts=30) res = mcp.process_parser_input(pi) + refs = list(res.support.refs_meta.values()) diff --git a/nlp/tests/test_multi_chain_post_processing.py b/nlp/tests/test_multi_chain_post_processing.py index 8e48e331..1cb98f86 100644 --- a/nlp/tests/test_multi_chain_post_processing.py +++ b/nlp/tests/test_multi_chain_post_processing.py @@ -44,6 +44,13 @@ https://arxiv.org/abs/2402.04607 """ +TEST_POST_TEXT_W_2_REF = """ +I really liked these two papers! +https://arxiv.org/abs/2402.04607 + +https://arxiv.org/abs/2401.14000 +""" + def test_combined_pp(): multi_config = create_multi_config_for_tests() @@ -54,7 +61,22 @@ def test_combined_pp(): assert len(res.keywords) > 0 assert len(res.metadata_list) == 1 assert res.filter_classification == SciFilterClassfication.CITOID_DETECTED_RESEARCH + assert res.metadata_list[0].order == 1 + +def test_combined_2_pp(): + multi_config = create_multi_config_for_tests() + multi_config.post_process_type = PostProcessType.COMBINED + mcp = MultiChainParser(multi_config) + res = mcp.process_text(TEST_POST_TEXT_W_2_REF) + assert res.item_types == ["preprint", "conferencePaper"] + assert len(res.keywords) > 0 + assert len(res.metadata_list) == 2 + assert res.filter_classification == SciFilterClassfication.CITOID_DETECTED_RESEARCH + assert res.metadata_list[0].order == 1 + assert res.metadata_list[0].url == "https://arxiv.org/abs/2402.04607" + assert res.metadata_list[1].order == 2 + assert res.metadata_list[1].url == "https://arxiv.org/abs/2401.14000" def test_firebase_pp(): multi_config = create_multi_config_for_tests() @@ -109,6 +131,7 @@ def test_multi_chain_batch_pp_combined(): == "https://royalsocietypublishing.org/doi/10.1098/rstb.2022.0267" ) assert len(out_0.metadata_list) == 1 + assert out_0.metadata_list[0].order == 1 out_1 = res[1] assert len(out_1.metadata_list) == 1 @@ -116,15 +139,20 @@ def test_multi_chain_batch_pp_combined(): out_1.metadata_list[0].url == "https://write.as/ulrikehahn/some-thoughts-on-social-media-for-science" ) + assert out_1.metadata_list[0].order == 1 out_2 = res[2] assert len(out_2.metadata_list) == 2 + + # ordering not preserved yet for masto so don't test that yet assert set(out_2.reference_urls) == set( [ "https://paragraph.xyz/@sense-nets/sense-nets-intro", "https://paragraph.xyz/@sense-nets/2-project-plan", ] ) + + def test_convert_item_types_to_rdf_triplets_single_entry(): @@ -195,12 +223,47 @@ def test_short_post_no_ref_i146(): if __name__ == "__main__": - multi_config = create_multi_config_for_tests() - multi_config.post_process_type = PostProcessType.FIREBASE - mcp = MultiChainParser(multi_config) - res = mcp.process_text("yup") - print(res.semantics.serialize()) - - # len(res.support.refs_meta) == 1 - # assert "test" in mcp.pparsers - # assert "Google Scholar is manipulatable" in prompt + # get a few posts for input + urls = [ + "https://mastodon.social/@psmaldino@qoto.org/111405098400404613", + "https://mastodon.social/@UlrikeHahn@fediscience.org/111732713776994953", + "https://mastodon.social/@ronent/111687038322549430", + ] + post = scrape_post(urls[2]) +# posts = [scrape_post(url) for url in urls] +# multi_config = create_multi_config_for_tests(llm_type="google/gemma-7b-it:free") +# multi_chain_parser = MultiChainParser(multi_config) +# multi_chain_parser.config.post_process_type = PostProcessType.COMBINED +# res = multi_chain_parser.batch_process_ref_posts(posts) + +# out_0 = res[0] +# assert ( +# out_0.metadata_list[0].url +# == "https://royalsocietypublishing.org/doi/10.1098/rstb.2022.0267" +# ) +# assert len(out_0.metadata_list) == 1 +# assert out_0.metadata_list[0].order == 1 + +# out_1 = res[1] +# assert len(out_1.metadata_list) == 1 +# assert ( +# out_1.metadata_list[0].url +# == "https://write.as/ulrikehahn/some-thoughts-on-social-media-for-science" +# ) +# assert out_1.metadata_list[0].order == 1 + +# out_2 = res[2] +# assert len(out_2.metadata_list) == 2 +# assert set(out_2.reference_urls) == set( +# [ +# "https://paragraph.xyz/@sense-nets/sense-nets-intro", +# "https://paragraph.xyz/@sense-nets/2-project-plan", +# ] +# ) + +# sorted_refs = sorted( +# out_2.metadata_list, +# key=lambda x: x.order, +# ) +# assert sorted_refs[0].url == "https://paragraph.xyz/@sense-nets/sense-nets-intro" +# assert sorted_refs[1].url == "https://paragraph.xyz/@sense-nets/2-project-plan"