Skip to content

Commit

Permalink
Update node.py (run-llama#13771)
Browse files Browse the repository at this point in the history
  • Loading branch information
minglu7 authored May 29, 2024
1 parent fd434ac commit 93f8119
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
12 changes: 9 additions & 3 deletions llama-index-core/llama_index/core/postprocessor/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, List, Optional, cast

from llama_index.core.bridge.pydantic import Field, validator
from llama_index.core.llms import LLM
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.prompts.base import PromptTemplate
from llama_index.core.response_synthesizers import (
Expand All @@ -12,6 +13,7 @@
)
from llama_index.core.schema import NodeRelationship, NodeWithScore, QueryBundle
from llama_index.core.service_context import ServiceContext
from llama_index.core.settings import Settings, llm_from_settings_or_context
from llama_index.core.storage.docstore import BaseDocumentStore

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -278,11 +280,13 @@ class AutoPrevNextNodePostprocessor(BaseNodePostprocessor):
"""

docstore: BaseDocumentStore
service_context: ServiceContext
service_context: Optional[ServiceContext] = None
llm: Optional[LLM] = None
num_nodes: int = Field(default=1)
infer_prev_next_tmpl: str = Field(default=DEFAULT_INFER_PREV_NEXT_TMPL)
refine_prev_next_tmpl: str = Field(default=DEFAULT_REFINE_INFER_PREV_NEXT_TMPL)
verbose: bool = Field(default=False)
response_mode: ResponseMode = Field(default=ResponseMode.TREE_SUMMARIZE)

class Config:
"""Configuration for this pydantic object."""
Expand Down Expand Up @@ -310,6 +314,8 @@ def _postprocess_nodes(
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
"""Postprocess nodes."""
llm = self.llm or llm_from_settings_or_context(Settings, self.service_context)

if query_bundle is None:
raise ValueError("Missing query bundle.")

Expand All @@ -324,10 +330,10 @@ def _postprocess_nodes(
# use response builder instead of llm directly
# to be more robust to handling long context
response_builder = get_response_synthesizer(
service_context=self.service_context,
llm=llm,
text_qa_template=infer_prev_next_prompt,
refine_template=refine_infer_prev_next_prompt,
response_mode=ResponseMode.TREE_SUMMARIZE,
response_mode=self.response_mode,
)
raw_pred = response_builder.get_response(
text_chunks=[node.node.get_content()],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
@pytest.fixture(scope="session")
def vespa_app():
app_package: ApplicationPackage = hybrid_template
return VespaVectorStore(application_package=app_package, deployment_target="local")
try:
return VespaVectorStore(
application_package=app_package, deployment_target="local"
)
except RuntimeError as e:
pytest.skip(f"Could not create VespaVectorStore: {e}")


@pytest.fixture(scope="session")
Expand Down

0 comments on commit 93f8119

Please sign in to comment.