diff --git a/src/trapi_predict_kit/trapi_parser.py b/src/trapi_predict_kit/trapi_parser.py index c54af85..d49e667 100644 --- a/src/trapi_predict_kit/trapi_parser.py +++ b/src/trapi_predict_kit/trapi_parser.py @@ -126,7 +126,6 @@ def resolve_trapi_query(reasoner_query, endpoints_list, infores: str = ""): if subjs_to_predict and pred_to_predict and objs_to_predict: subject_ids = subjs_to_predict.get("ids", []) object_ids = objs_to_predict.get("ids", []) - labels_dict = get_entities_labels(subject_ids + object_ids) try: log.info(f"🔮⏳️ Getting predictions for: {subject_ids} | {object_ids}") @@ -150,6 +149,12 @@ def resolve_trapi_query(reasoner_query, endpoints_list, infores: str = ""): log.error(f"Error getting the predictions: {e}") prediction_json = [] + # Get the labels of all entities returned by the prediction function + all_ids = [pred["subject"] for pred in prediction_json] + [ + pred["subject"] for pred in prediction_json + ] + labels_dict = get_entities_labels(list(set(all_ids))) + for association in prediction_json: # id/type of nodes are registered in a dict to avoid duplicate in knowledge_graph.nodes # Build dict of node ID : label @@ -186,8 +191,12 @@ def resolve_trapi_query(reasoner_query, endpoints_list, infores: str = ""): ), } - if subject_id in labels_dict and labels_dict[subject_id]: - node_dict[subject_id]["label"] = labels_dict[subject_id]["id"]["label"] + if "subject_label" in association: + node_dict[subject_id]["label"] = association["subject_label"] + else: + if subject_id in labels_dict and labels_dict[subject_id]: + node_dict[subject_id]["label"] = labels_dict[subject_id]["id"]["label"] + if "object_label" in association: node_dict[object_id]["label"] = association["object_label"] else: