diff --git a/libs/community/langchain_community/chat_models/kinetica.py b/libs/community/langchain_community/chat_models/kinetica.py index a8c52e6a6533d..26003e2ecff5a 100644 --- a/libs/community/langchain_community/chat_models/kinetica.py +++ b/libs/community/langchain_community/chat_models/kinetica.py @@ -8,10 +8,11 @@ import re from importlib.metadata import version from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast, Dict if TYPE_CHECKING: import gpudb + import pandas as pd from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.chat_models import BaseChatModel @@ -24,7 +25,6 @@ from langchain_core.output_parsers.transform import BaseOutputParser from langchain_core.outputs import ChatGeneration, ChatResult, Generation from langchain_core.pydantic_v1 import BaseModel, Field, root_validator -from pandas import DataFrame LOG = logging.getLogger(__name__) @@ -36,11 +36,11 @@ class _KdtSuggestContext(BaseModel): table: Optional[str] = Field(default=None, title="Name of table") description: Optional[str] = Field(default=None, title="Table description") - columns: list[str] = Field(default=None, title="Table columns list") - rules: Optional[list[str]] = Field( + columns: List[str] = Field(default=None, title="Table columns list") + rules: Optional[List[str]] = Field( default=None, title="Rules that apply to the table." ) - samples: Optional[dict] = Field( + samples: Optional[Dict] = Field( default=None, title="Samples that apply to the entire context." ) @@ -77,7 +77,7 @@ class _KdtSuggestPayload(BaseModel): """pydantic API request type""" question: Optional[str] - context: list[_KdtSuggestContext] + context: List[_KdtSuggestContext] def get_system_str(self) -> str: lines = [] @@ -88,7 +88,7 @@ def get_system_str(self) -> str: lines.append(context_str) return "\n\n".join(lines) - def get_messages(self) -> list[dict]: + def get_messages(self) -> List[Dict]: messages = [] for context in self.context: if context.samples is None: @@ -101,7 +101,7 @@ def get_messages(self) -> list[dict]: messages.append(dict(role="assistant", content=answer)) return messages - def to_completion(self) -> dict: + def to_completion(self) -> Dict: messages = [] messages.append(dict(role="system", content=self.get_system_str())) messages.extend(self.get_messages()) @@ -146,7 +146,7 @@ class _KdtSqlResponse(BaseModel): object: str created: int model: str - choices: list[_KdtChoice] + choices: List[_KdtChoice] usage: _KdtUsage prompt: str = Field(default=None, title="The input question") @@ -165,14 +165,14 @@ class _KineticaLlmFileContextParser: PARSER = re.compile(r"^<\|(?P\w+)\|>\W*(?P.*)$", re.DOTALL) @classmethod - def parse_dialogue_file(cls, input_file: os.PathLike) -> dict: + def parse_dialogue_file(cls, input_file: os.PathLike) -> Dict: path = Path(input_file) schema = path.name.removesuffix(".txt") lines = open(input_file).read() return cls.parse_dialogue(lines, schema) @classmethod - def parse_dialogue(cls, text: str, schema: str) -> dict: + def parse_dialogue(cls, text: str, schema: str) -> Dict: messages = [] system = None @@ -286,7 +286,7 @@ def _llm_type(self) -> str: return "kinetica-sqlassist" @property - def _identifying_params(self) -> dict[str, Any]: + def _identifying_params(self) -> Dict[str, Any]: return dict( kinetica_version=str(self.kdbc.server_version), api_version=version("gpudb") ) @@ -320,7 +320,7 @@ def _generate( llm_output=llm_output, ) - def load_messages_from_context(self, context_name: str) -> list: + def load_messages_from_context(self, context_name: str) -> List: """Load a lanchain prompt from a Kinetica context.""" # query kinetica for the prompt @@ -340,7 +340,7 @@ def load_messages_from_context(self, context_name: str) -> list: messages = [self._convert_message_from_dict(m) for m in dict_messages] return messages - def _submit_completion(self, messages: list[dict]) -> _KdtSqlResponse: + def _submit_completion(self, messages: List[Dict]) -> _KdtSqlResponse: """Submit a /chat/completions request to Kinetica.""" request = dict(messages=messages) @@ -438,7 +438,7 @@ def _convert_dict_to_messages(cls, sa_data: Dict) -> List[BaseMessage]: messages = sa_data["messages"] LOG.info(f"Importing prompt for schema: {schema}") - result_list: list[BaseMessage] = [] + result_list: List[BaseMessage] = [] result_list.append(SystemMessage(content=system)) result_list.extend([cls._convert_message_from_dict(m) for m in messages]) return result_list @@ -448,7 +448,7 @@ class KineticaSqlResponse(BaseModel): """Response containing SQL and the fetched data.""" sql: str = Field(description="Result SQL") - dataframe: DataFrame = Field(description="Result Data") + dataframe: "pd.DataFrame" = Field(description="Result Data") class Config: """Configuration for this pydantic object."""