Skip to content

Commit

Permalink
fix python 3.8 errors
Browse files Browse the repository at this point in the history
Fixes for python 3.8.
  • Loading branch information
chadj2 committed Feb 20, 2024
1 parent c1dd200 commit 1c580b4
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions libs/community/langchain_community/chat_models/kinetica.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import re
from importlib.metadata import version
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast, Dict

Check failure on line 11 in libs/community/langchain_community/chat_models/kinetica.py

View workflow job for this annotation

GitHub Actions / cd libs/community / - / make lint #3.8

Ruff (F811)

langchain_community/chat_models/kinetica.py:11:68: F811 Redefinition of unused `Dict` from line 11

if TYPE_CHECKING:

Check failure on line 13 in libs/community/langchain_community/chat_models/kinetica.py

View workflow job for this annotation

GitHub Actions / cd libs/community / - / make lint #3.8

Ruff (I001)

langchain_community/chat_models/kinetica.py:5:1: I001 Import block is un-sorted or un-formatted
import gpudb
import pandas as pd

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
Expand All @@ -24,7 +25,6 @@
from langchain_core.output_parsers.transform import BaseOutputParser
from langchain_core.outputs import ChatGeneration, ChatResult, Generation
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from pandas import DataFrame

LOG = logging.getLogger(__name__)

Expand All @@ -36,11 +36,11 @@ class _KdtSuggestContext(BaseModel):

table: Optional[str] = Field(default=None, title="Name of table")
description: Optional[str] = Field(default=None, title="Table description")
columns: list[str] = Field(default=None, title="Table columns list")
rules: Optional[list[str]] = Field(
columns: List[str] = Field(default=None, title="Table columns list")
rules: Optional[List[str]] = Field(
default=None, title="Rules that apply to the table."
)
samples: Optional[dict] = Field(
samples: Optional[Dict] = Field(
default=None, title="Samples that apply to the entire context."
)

Expand Down Expand Up @@ -77,7 +77,7 @@ class _KdtSuggestPayload(BaseModel):
"""pydantic API request type"""

question: Optional[str]
context: list[_KdtSuggestContext]
context: List[_KdtSuggestContext]

def get_system_str(self) -> str:
lines = []
Expand All @@ -88,7 +88,7 @@ def get_system_str(self) -> str:
lines.append(context_str)
return "\n\n".join(lines)

def get_messages(self) -> list[dict]:
def get_messages(self) -> List[Dict]:
messages = []
for context in self.context:
if context.samples is None:
Expand All @@ -101,7 +101,7 @@ def get_messages(self) -> list[dict]:
messages.append(dict(role="assistant", content=answer))
return messages

def to_completion(self) -> dict:
def to_completion(self) -> Dict:
messages = []
messages.append(dict(role="system", content=self.get_system_str()))
messages.extend(self.get_messages())
Expand Down Expand Up @@ -146,7 +146,7 @@ class _KdtSqlResponse(BaseModel):
object: str
created: int
model: str
choices: list[_KdtChoice]
choices: List[_KdtChoice]
usage: _KdtUsage
prompt: str = Field(default=None, title="The input question")

Expand All @@ -165,14 +165,14 @@ class _KineticaLlmFileContextParser:
PARSER = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)

@classmethod
def parse_dialogue_file(cls, input_file: os.PathLike) -> dict:
def parse_dialogue_file(cls, input_file: os.PathLike) -> Dict:
path = Path(input_file)
schema = path.name.removesuffix(".txt")
lines = open(input_file).read()
return cls.parse_dialogue(lines, schema)

@classmethod
def parse_dialogue(cls, text: str, schema: str) -> dict:
def parse_dialogue(cls, text: str, schema: str) -> Dict:
messages = []
system = None

Expand Down Expand Up @@ -286,7 +286,7 @@ def _llm_type(self) -> str:
return "kinetica-sqlassist"

@property
def _identifying_params(self) -> dict[str, Any]:
def _identifying_params(self) -> Dict[str, Any]:
return dict(
kinetica_version=str(self.kdbc.server_version), api_version=version("gpudb")
)
Expand Down Expand Up @@ -320,7 +320,7 @@ def _generate(
llm_output=llm_output,
)

def load_messages_from_context(self, context_name: str) -> list:
def load_messages_from_context(self, context_name: str) -> List:
"""Load a lanchain prompt from a Kinetica context."""

# query kinetica for the prompt
Expand All @@ -340,7 +340,7 @@ def load_messages_from_context(self, context_name: str) -> list:
messages = [self._convert_message_from_dict(m) for m in dict_messages]
return messages

def _submit_completion(self, messages: list[dict]) -> _KdtSqlResponse:
def _submit_completion(self, messages: List[Dict]) -> _KdtSqlResponse:
"""Submit a /chat/completions request to Kinetica."""

request = dict(messages=messages)
Expand Down Expand Up @@ -438,7 +438,7 @@ def _convert_dict_to_messages(cls, sa_data: Dict) -> List[BaseMessage]:
messages = sa_data["messages"]
LOG.info(f"Importing prompt for schema: {schema}")

result_list: list[BaseMessage] = []
result_list: List[BaseMessage] = []
result_list.append(SystemMessage(content=system))
result_list.extend([cls._convert_message_from_dict(m) for m in messages])
return result_list
Expand All @@ -448,7 +448,7 @@ class KineticaSqlResponse(BaseModel):
"""Response containing SQL and the fetched data."""

sql: str = Field(description="Result SQL")
dataframe: DataFrame = Field(description="Result Data")
dataframe: "pd.DataFrame" = Field(description="Result Data")

class Config:
"""Configuration for this pydantic object."""
Expand Down

0 comments on commit 1c580b4

Please sign in to comment.