diff --git a/.github/workflows/python-quality.yml b/.github/workflows/python-quality.yml index e1bbbba..8db5c5a 100644 --- a/.github/workflows/python-quality.yml +++ b/.github/workflows/python-quality.yml @@ -6,7 +6,7 @@ on: - 'backend/**' jobs: - black-check: + python-check: runs-on: ubuntu-latest steps: - name: Check out code @@ -17,11 +17,14 @@ jobs: with: python-version: '3.8' - - name: Install Black - run: pip install black flake8 + - name: Install dependencies + run: pip install black flake8 mypy - name: Check Python code formatting with Black run: black --check backend/ - name: Check Python code style with Flake8 run: flake8 backend/ + + - name: Run Mypy Type Checking + run: mypy backend/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3d8c455..e28c0ea 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,3 +10,10 @@ repos: hooks: - id: flake8 files: ^backend/ # Run flake8 only on files in the backend/ directory + +- repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v1.7.1' # Use the latest version + hooks: + - id: mypy + files: ^backend/ + diff --git a/backend/databases/chat_history_manager.py b/backend/databases/chat_history_manager.py index 5ce3935..7ce9ac9 100644 --- a/backend/databases/chat_history_manager.py +++ b/backend/databases/chat_history_manager.py @@ -1,8 +1,10 @@ -from sqlalchemy.orm import Session -from models.chat_history import ChatHistory # Replace with your actual import +from typing import Optional import json +from sqlalchemy.orm import Session +from models.chat_history import ChatHistory # Replace with your actual import + class ChatHistoryManager: """ @@ -41,14 +43,16 @@ def get_new_chat_id(self) -> int: """ Get a new chat id. Used for storing a new chat. """ - latest_chat = ( + latest_chat: Optional[ChatHistory] = ( self.session.query(ChatHistory).order_by(ChatHistory.chat_id.desc()).first() ) if not latest_chat: # If there are no chats in database return 0 - return latest_chat.chat_id + 1 + new_chat_id: int = latest_chat.chat_id + 1 + + return new_chat_id def get_llm_chat_history_for_user(self, user_id: int, llm_type: str): """ diff --git a/backend/databases/dashboard_manager.py b/backend/databases/dashboard_manager.py index 7476508..d1139d4 100644 --- a/backend/databases/dashboard_manager.py +++ b/backend/databases/dashboard_manager.py @@ -43,10 +43,12 @@ def get_dashboards(self) -> List[Dashboard]: List[Dashboard]: List of Dashboard objects. """ try: - return self.db_session.query(Dashboard).all() + dashboards: List[Dashboard] = self.db_session.query(Dashboard).all() + return dashboards except Exception as e: # Handle exception print(f"Database error: {str(e)}") + return [] def save_dashboard(self, dashboard: Dashboard): self.db_session.add(dashboard) diff --git a/backend/databases/sql_executor.py b/backend/databases/sql_executor.py index 29bd937..198ba3b 100644 --- a/backend/databases/sql_executor.py +++ b/backend/databases/sql_executor.py @@ -1,5 +1,6 @@ from sqlalchemy import inspect, text from sqlalchemy.orm import Session +from typing import List import pandas as pd @@ -71,16 +72,16 @@ def get_all_table_names_as_str(self) -> str: print(f"An error occurred: {e}") raise - def get_all_table_names_as_list(self) -> list: + def get_all_table_names_as_list(self) -> List: try: # Using the engine from the session to get table names - table_names = self.session.bind.table_names() + table_names: List = self.session.bind.table_names() return table_names except Exception as e: print(f"An error occurred: {e}") raise - def get_table_columns(self, table_name: str) -> list: + def get_table_columns(self, table_name: str) -> List: try: engine = self.session.bind inspector = inspect(engine) diff --git a/backend/databases/table_manager.py b/backend/databases/table_manager.py index 2f17cce..20715e9 100644 --- a/backend/databases/table_manager.py +++ b/backend/databases/table_manager.py @@ -106,7 +106,7 @@ def determine_table(self, sample_content: str, extra_desc: str) -> str: table_metadata ) - table_name = self.llm.fetch_table_name_from_sample( + table_name: str = self.llm.fetch_table_name_from_sample( sample_content, extra_desc, formatted_table_metadata ) return table_name diff --git a/backend/databases/table_metadata_manager.py b/backend/databases/table_metadata_manager.py index c19ed98..4a025cc 100644 --- a/backend/databases/table_metadata_manager.py +++ b/backend/databases/table_metadata_manager.py @@ -32,10 +32,14 @@ def get_all_metadata(self) -> List[TableMetadata]: List[TableMetadata]: List of TableMetadata objects. """ try: - return self.db_session.query(TableMetadata).all() + table_metadatas: List[TableMetadata] = self.db_session.query( + TableMetadata + ).all() + return table_metadatas except Exception as e: # Handle exception print(f"Database error: {str(e)}") + return [] def get_metadata(self, table_name: str) -> TableMetadata: """Retrieve metadata for a single table""" diff --git a/backend/databases/user_manager.py b/backend/databases/user_manager.py index e0797a2..4731005 100644 --- a/backend/databases/user_manager.py +++ b/backend/databases/user_manager.py @@ -2,6 +2,8 @@ This module provides a UserManager class that handles database operations related to the User model. It uses SQLAlchemy for ORM operations and facilitates CRUD operations on user data. """ +from typing import Optional + from sqlalchemy.orm import Session from models.user import User @@ -95,10 +97,10 @@ def get_users_without_password(self, skip: int = 0, limit: int = 10): def update_user( self, user_id: int, - email: str = None, - organization_id: int = None, - role: str = None, - refresh_token: str = None, + email: Optional[str], + organization_id: Optional[int], + role: Optional[str], + refresh_token: Optional[str], ): """ Update a user's details in the database. diff --git a/backend/llms/base.py b/backend/llms/base.py index 1f48297..a7d3433 100644 --- a/backend/llms/base.py +++ b/backend/llms/base.py @@ -11,17 +11,21 @@ def generate_text(self, prompt: str) -> str: """ raise NotImplementedError("This method should be overridden by subclass") - def generate_create_statement( - self, sample_content: str, existing_table_names: str, extra_desc: str + async def generate_create_statement( + self, + sample_content: str, + header: str, + existing_table_names: str, + extra_desc: str, ) -> str: raise NotImplementedError - def generate_table_desc( + async def generate_table_desc( self, create_query: str, sample_content: str, extra_desc: str ) -> str: raise NotImplementedError def fetch_table_name_from_sample( self, sample_content: str, extra_desc: str, table_metadata: str - ): + ) -> str: raise NotImplementedError diff --git a/backend/llms/gpt.py b/backend/llms/gpt.py index 61965da..12da75e 100644 --- a/backend/llms/gpt.py +++ b/backend/llms/gpt.py @@ -1,3 +1,10 @@ +from openai import ChatCompletion +from typing import Optional + +import json +import openai +import tiktoken + from .base import BaseLLM from databases.chat_history_manager import ChatHistoryManager from databases.database_manager import DatabaseManager @@ -5,10 +12,6 @@ from settings import OPENAI_API_KEY from utils.nivo_assistant import NivoAssistant -import json -import openai -import tiktoken - class GPTLLM(BaseLLM): """ @@ -21,7 +24,7 @@ class GPTLLM(BaseLLM): def __init__( self, - chat_id: int = None, + chat_id: Optional[int], user: User = None, store_history: bool = False, llm_type: str = "generic", @@ -77,12 +80,14 @@ def _set_response_format(self, is_json: bool): async def _api_call(self, payload: dict) -> str: """Make an API call to get a response based on the conversation history.""" - completion = await openai.ChatCompletion.acreate( + completion: ChatCompletion = await openai.ChatCompletion.acreate( model=self.model, messages=payload["messages"], response_format=self.response_format, ) - return completion.choices[0].message.content + return str( + completion.choices[0].message.content + ) # TODO: Typecasting is not recommended for mypy def _count_tokens(self, text: str) -> int: """Count the number of tokens in the given text.""" @@ -181,7 +186,7 @@ def _add_system_message(self, assistant_type: str) -> None: self.llm_type = assistant_type - async def _send_and_receive_message(self, prompt: str): + async def _send_and_receive_message(self, prompt: str) -> str: user_message = self._create_message("user", prompt) self.history.append(user_message) @@ -196,7 +201,7 @@ async def _send_and_receive_message(self, prompt: str): self.history.append(assistant_message) if self.store_history: - self._save_messages() + self._save_messages(user_message, assistant_message) return assistant_message_content @@ -276,7 +281,7 @@ async def generate_create_statement( f"\n\nAdditional information about the sample data: \n{extra_desc}" ) - gpt_response = await self._send_and_receive_message(prompt) + gpt_response: str = await self._send_and_receive_message(prompt) return gpt_response diff --git a/backend/routes/chart_routes.py b/backend/routes/chart_routes.py index c4f665c..3b42ace 100644 --- a/backend/routes/chart_routes.py +++ b/backend/routes/chart_routes.py @@ -8,6 +8,7 @@ from llms.gpt import GPTLLM from models.user import User from models.chart import Chart, ChartCreate +from models.table_metadata import TableMetadata from security import get_current_user from utils.nivo_assistant import NivoAssistant from utils.utils import execute_select_query @@ -58,7 +59,7 @@ async def create_chart_config( chat_id = request.chat_id msg = request.msg chart_config = request.chart_config - table_name = chart_config.get("table") + table_name = chart_config.get("table", "") chart_type = chart_config.get("type") nivo_config = chart_config.get("nivoConfig") table_metadata = get_table_metadata(table_name) @@ -85,7 +86,9 @@ async def create_chart_config( return updated_chart_config, gpt.chat_id -def get_table_metadata(table_name: str, current_user: User = Depends(get_current_user)): +def get_table_metadata( + table_name: str, current_user: User = Depends(get_current_user) +) -> TableMetadata: """Get table metadata""" with DatabaseManager() as session: metadata_manager = TableMetadataManager(session) diff --git a/backend/security.py b/backend/security.py index 0f6de36..b41f602 100644 --- a/backend/security.py +++ b/backend/security.py @@ -109,7 +109,7 @@ def create_token(data: dict, expires_delta: timedelta) -> str: to_encode = data.copy() expire = datetime.utcnow() + expires_delta to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM) + encoded_jwt: str = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt @@ -191,7 +191,7 @@ def set_tokens_in_cookies(response: Response, access_token: str, refresh_token: httponly=True, max_age=1800, secure=True, - samesite="Lax", + samesite="lax", ) response.set_cookie( key="refresh_token", @@ -199,7 +199,7 @@ def set_tokens_in_cookies(response: Response, access_token: str, refresh_token: httponly=True, max_age=60 * 60 * 24 * 7, secure=True, - samesite="Lax", + samesite="lax", ) diff --git a/backend/utils/sql_string_manipulator.py b/backend/utils/sql_string_manipulator.py index e57fdb5..b528183 100644 --- a/backend/utils/sql_string_manipulator.py +++ b/backend/utils/sql_string_manipulator.py @@ -70,5 +70,6 @@ def extract_sql_query_from_text(self) -> Optional[str]: match = re.findall(r"CREATE TABLE [^;]+;", self.sql_string) if match: # Extract only the last "CREATE TABLE" statement and add a space after "CREATE TABLE" - last_statement = match[-1] + last_statement: str = match[-1] return "CREATE TABLE " + last_statement.split("CREATE TABLE")[-1].strip() + return None diff --git a/backend/utils/utils.py b/backend/utils/utils.py index b54d601..565d053 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -1,7 +1,7 @@ from csv import Sniffer from fastapi import File, UploadFile from io import StringIO -from typing import Any +from typing import Any, Dict, Optional import logging import os @@ -30,10 +30,18 @@ def process_file(file: UploadFile, encoding: str) -> Any: Returns: Any: Processed content of the file. """ + if file.filename is None: + raise ValueError("File must have a filename") + # Find file type by file extension file_type = file.filename.split(".")[-1].lower() - files = {"processed_df": None, "sample_file_content": None} + # Define the dictionary with types for values + files: Dict[str, Optional[str]] = { + "processed_df": None, + "sample_file_content_str": None, + "header_str": None, + } if file_type == "csv": # Sniff the first 1024 bytes to check for a header diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..1efd6c8 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,8 @@ +[mypy] +python_version = 3.8 +check_untyped_defs = True +ignore_missing_imports = True +warn_redundant_casts = True +warn_unused_ignores = True +warn_return_any = True +no_implicit_optional = True