From 93f8119e182f8dd2dbbec36691492e7bd9f1a5bd Mon Sep 17 00:00:00 2001 From: Jonah Rangnow <104314597+minglu7@users.noreply.github.com> Date: Wed, 29 May 2024 09:25:30 +0800 Subject: [PATCH] Update node.py (#13771) --- .../llama_index/core/postprocessor/node.py | 12 +++++++++--- .../tests/test_vespavectorstore.py | 7 ++++++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/llama-index-core/llama_index/core/postprocessor/node.py b/llama-index-core/llama_index/core/postprocessor/node.py index beb5d7d135822..8ea55d51fdbd8 100644 --- a/llama-index-core/llama_index/core/postprocessor/node.py +++ b/llama-index-core/llama_index/core/postprocessor/node.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional, cast from llama_index.core.bridge.pydantic import Field, validator +from llama_index.core.llms import LLM from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.core.prompts.base import PromptTemplate from llama_index.core.response_synthesizers import ( @@ -12,6 +13,7 @@ ) from llama_index.core.schema import NodeRelationship, NodeWithScore, QueryBundle from llama_index.core.service_context import ServiceContext +from llama_index.core.settings import Settings, llm_from_settings_or_context from llama_index.core.storage.docstore import BaseDocumentStore logger = logging.getLogger(__name__) @@ -278,11 +280,13 @@ class AutoPrevNextNodePostprocessor(BaseNodePostprocessor): """ docstore: BaseDocumentStore - service_context: ServiceContext + service_context: Optional[ServiceContext] = None + llm: Optional[LLM] = None num_nodes: int = Field(default=1) infer_prev_next_tmpl: str = Field(default=DEFAULT_INFER_PREV_NEXT_TMPL) refine_prev_next_tmpl: str = Field(default=DEFAULT_REFINE_INFER_PREV_NEXT_TMPL) verbose: bool = Field(default=False) + response_mode: ResponseMode = Field(default=ResponseMode.TREE_SUMMARIZE) class Config: """Configuration for this pydantic object.""" @@ -310,6 +314,8 @@ def _postprocess_nodes( query_bundle: Optional[QueryBundle] = None, ) -> List[NodeWithScore]: """Postprocess nodes.""" + llm = self.llm or llm_from_settings_or_context(Settings, self.service_context) + if query_bundle is None: raise ValueError("Missing query bundle.") @@ -324,10 +330,10 @@ def _postprocess_nodes( # use response builder instead of llm directly # to be more robust to handling long context response_builder = get_response_synthesizer( - service_context=self.service_context, + llm=llm, text_qa_template=infer_prev_next_prompt, refine_template=refine_infer_prev_next_prompt, - response_mode=ResponseMode.TREE_SUMMARIZE, + response_mode=self.response_mode, ) raw_pred = response_builder.get_response( text_chunks=[node.node.get_content()], diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-vespa/tests/test_vespavectorstore.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-vespa/tests/test_vespavectorstore.py index 44bbc6b90d5b4..7555e31dbd4ae 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-vespa/tests/test_vespavectorstore.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-vespa/tests/test_vespavectorstore.py @@ -23,7 +23,12 @@ @pytest.fixture(scope="session") def vespa_app(): app_package: ApplicationPackage = hybrid_template - return VespaVectorStore(application_package=app_package, deployment_target="local") + try: + return VespaVectorStore( + application_package=app_package, deployment_target="local" + ) + except RuntimeError as e: + pytest.skip(f"Could not create VespaVectorStore: {e}") @pytest.fixture(scope="session")