diff --git a/notebooks/Vanilla_RAG.ipynb b/notebooks/Vanilla_RAG.ipynb index 2f85039..ee301b5 100644 --- a/notebooks/Vanilla_RAG.ipynb +++ b/notebooks/Vanilla_RAG.ipynb @@ -679,9 +679,9 @@ " # Process citations, assigning numbers based on unique document_ids\n", " for citation in citations:\n", " citation_numbers = []\n", - " for document_id in sorted(citation[\"document_ids\"]):\n", + " for document_id in citation[\"document_ids\"]:\n", " if document_id not in document_id_to_number:\n", - " citation_number += 1 # Increment for a new document_id\n", + " citation_number = int(document_id.split('_')[1]) # extract the document id\n", " document_id_to_number[document_id] = citation_number\n", " citation_numbers.append(document_id_to_number[document_id])\n", "\n", @@ -697,7 +697,7 @@ "\n", " # Prepare citations for listing at the bottom, ensuring unique document_ids are listed once\n", " unique_citations = {number: doc_id for doc_id, number in document_id_to_number.items()}\n", - " citation_list = '\\n'.join([f'[{doc_id}] source: \"{documents[doc_id - 1][\"snippet\"]}\"' for doc_id, number in sorted(unique_citations.items(), key=lambda item: item[1])])\n", + " citation_list = '\\n'.join([f'[{doc_id}] source: \"{documents[doc_id][\"snippet\"]}\"' for doc_id, number in in dict(sorted(unique_citations.items(), key=lambda item: item[1])).items()])\n", " text_with_citations = f'{text}\\n\\n{citation_list}'\n", "\n", " return text_with_citations\n",