Skip to content

Commit

Permalink
Merge pull request #69 from DocShow-AI/add_mypy
Browse files Browse the repository at this point in the history
Add mypy
  • Loading branch information
liberty-rising authored Dec 3, 2023
2 parents 159e6ae + 5f64c0c commit 9837559
Show file tree
Hide file tree
Showing 15 changed files with 91 additions and 39 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/python-quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- 'backend/**'

jobs:
black-check:
python-check:
runs-on: ubuntu-latest
steps:
- name: Check out code
Expand All @@ -17,11 +17,14 @@ jobs:
with:
python-version: '3.8'

- name: Install Black
run: pip install black flake8
- name: Install dependencies
run: pip install black flake8 mypy

- name: Check Python code formatting with Black
run: black --check backend/

- name: Check Python code style with Flake8
run: flake8 backend/

- name: Run Mypy Type Checking
run: mypy backend/
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@ repos:
hooks:
- id: flake8
files: ^backend/ # Run flake8 only on files in the backend/ directory

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.7.1' # Use the latest version
hooks:
- id: mypy
files: ^backend/

12 changes: 8 additions & 4 deletions backend/databases/chat_history_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from sqlalchemy.orm import Session
from models.chat_history import ChatHistory # Replace with your actual import
from typing import Optional

import json

from sqlalchemy.orm import Session
from models.chat_history import ChatHistory # Replace with your actual import


class ChatHistoryManager:
"""
Expand Down Expand Up @@ -41,14 +43,16 @@ def get_new_chat_id(self) -> int:
"""
Get a new chat id. Used for storing a new chat.
"""
latest_chat = (
latest_chat: Optional[ChatHistory] = (
self.session.query(ChatHistory).order_by(ChatHistory.chat_id.desc()).first()
)

if not latest_chat: # If there are no chats in database
return 0

return latest_chat.chat_id + 1
new_chat_id: int = latest_chat.chat_id + 1

return new_chat_id

def get_llm_chat_history_for_user(self, user_id: int, llm_type: str):
"""
Expand Down
4 changes: 3 additions & 1 deletion backend/databases/dashboard_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ def get_dashboards(self) -> List[Dashboard]:
List[Dashboard]: List of Dashboard objects.
"""
try:
return self.db_session.query(Dashboard).all()
dashboards: List[Dashboard] = self.db_session.query(Dashboard).all()
return dashboards
except Exception as e:
# Handle exception
print(f"Database error: {str(e)}")
return []

def save_dashboard(self, dashboard: Dashboard):
self.db_session.add(dashboard)
Expand Down
7 changes: 4 additions & 3 deletions backend/databases/sql_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sqlalchemy import inspect, text
from sqlalchemy.orm import Session
from typing import List

import pandas as pd

Expand Down Expand Up @@ -71,16 +72,16 @@ def get_all_table_names_as_str(self) -> str:
print(f"An error occurred: {e}")
raise

def get_all_table_names_as_list(self) -> list:
def get_all_table_names_as_list(self) -> List:
try:
# Using the engine from the session to get table names
table_names = self.session.bind.table_names()
table_names: List = self.session.bind.table_names()
return table_names
except Exception as e:
print(f"An error occurred: {e}")
raise

def get_table_columns(self, table_name: str) -> list:
def get_table_columns(self, table_name: str) -> List:
try:
engine = self.session.bind
inspector = inspect(engine)
Expand Down
2 changes: 1 addition & 1 deletion backend/databases/table_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def determine_table(self, sample_content: str, extra_desc: str) -> str:
table_metadata
)

table_name = self.llm.fetch_table_name_from_sample(
table_name: str = self.llm.fetch_table_name_from_sample(
sample_content, extra_desc, formatted_table_metadata
)
return table_name
Expand Down
6 changes: 5 additions & 1 deletion backend/databases/table_metadata_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@ def get_all_metadata(self) -> List[TableMetadata]:
List[TableMetadata]: List of TableMetadata objects.
"""
try:
return self.db_session.query(TableMetadata).all()
table_metadatas: List[TableMetadata] = self.db_session.query(
TableMetadata
).all()
return table_metadatas
except Exception as e:
# Handle exception
print(f"Database error: {str(e)}")
return []

def get_metadata(self, table_name: str) -> TableMetadata:
"""Retrieve metadata for a single table"""
Expand Down
10 changes: 6 additions & 4 deletions backend/databases/user_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
This module provides a UserManager class that handles database operations related to the User model.
It uses SQLAlchemy for ORM operations and facilitates CRUD operations on user data.
"""
from typing import Optional

from sqlalchemy.orm import Session
from models.user import User

Expand Down Expand Up @@ -95,10 +97,10 @@ def get_users_without_password(self, skip: int = 0, limit: int = 10):
def update_user(
self,
user_id: int,
email: str = None,
organization_id: int = None,
role: str = None,
refresh_token: str = None,
email: Optional[str],
organization_id: Optional[int],
role: Optional[str],
refresh_token: Optional[str],
):
"""
Update a user's details in the database.
Expand Down
12 changes: 8 additions & 4 deletions backend/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,21 @@ def generate_text(self, prompt: str) -> str:
"""
raise NotImplementedError("This method should be overridden by subclass")

def generate_create_statement(
self, sample_content: str, existing_table_names: str, extra_desc: str
async def generate_create_statement(
self,
sample_content: str,
header: str,
existing_table_names: str,
extra_desc: str,
) -> str:
raise NotImplementedError

def generate_table_desc(
async def generate_table_desc(
self, create_query: str, sample_content: str, extra_desc: str
) -> str:
raise NotImplementedError

def fetch_table_name_from_sample(
self, sample_content: str, extra_desc: str, table_metadata: str
):
) -> str:
raise NotImplementedError
25 changes: 15 additions & 10 deletions backend/llms/gpt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from openai import ChatCompletion
from typing import Optional

import json
import openai
import tiktoken

from .base import BaseLLM
from databases.chat_history_manager import ChatHistoryManager
from databases.database_manager import DatabaseManager
from models.user import User
from settings import OPENAI_API_KEY
from utils.nivo_assistant import NivoAssistant

import json
import openai
import tiktoken


class GPTLLM(BaseLLM):
"""
Expand All @@ -21,7 +24,7 @@ class GPTLLM(BaseLLM):

def __init__(
self,
chat_id: int = None,
chat_id: Optional[int],
user: User = None,
store_history: bool = False,
llm_type: str = "generic",
Expand Down Expand Up @@ -77,12 +80,14 @@ def _set_response_format(self, is_json: bool):

async def _api_call(self, payload: dict) -> str:
"""Make an API call to get a response based on the conversation history."""
completion = await openai.ChatCompletion.acreate(
completion: ChatCompletion = await openai.ChatCompletion.acreate(
model=self.model,
messages=payload["messages"],
response_format=self.response_format,
)
return completion.choices[0].message.content
return str(
completion.choices[0].message.content
) # TODO: Typecasting is not recommended for mypy

def _count_tokens(self, text: str) -> int:
"""Count the number of tokens in the given text."""
Expand Down Expand Up @@ -181,7 +186,7 @@ def _add_system_message(self, assistant_type: str) -> None:

self.llm_type = assistant_type

async def _send_and_receive_message(self, prompt: str):
async def _send_and_receive_message(self, prompt: str) -> str:
user_message = self._create_message("user", prompt)
self.history.append(user_message)

Expand All @@ -196,7 +201,7 @@ async def _send_and_receive_message(self, prompt: str):
self.history.append(assistant_message)

if self.store_history:
self._save_messages()
self._save_messages(user_message, assistant_message)

return assistant_message_content

Expand Down Expand Up @@ -276,7 +281,7 @@ async def generate_create_statement(
f"\n\nAdditional information about the sample data: \n{extra_desc}"
)

gpt_response = await self._send_and_receive_message(prompt)
gpt_response: str = await self._send_and_receive_message(prompt)

return gpt_response

Expand Down
7 changes: 5 additions & 2 deletions backend/routes/chart_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from llms.gpt import GPTLLM
from models.user import User
from models.chart import Chart, ChartCreate
from models.table_metadata import TableMetadata
from security import get_current_user
from utils.nivo_assistant import NivoAssistant
from utils.utils import execute_select_query
Expand Down Expand Up @@ -58,7 +59,7 @@ async def create_chart_config(
chat_id = request.chat_id
msg = request.msg
chart_config = request.chart_config
table_name = chart_config.get("table")
table_name = chart_config.get("table", "")
chart_type = chart_config.get("type")
nivo_config = chart_config.get("nivoConfig")
table_metadata = get_table_metadata(table_name)
Expand All @@ -85,7 +86,9 @@ async def create_chart_config(
return updated_chart_config, gpt.chat_id


def get_table_metadata(table_name: str, current_user: User = Depends(get_current_user)):
def get_table_metadata(
table_name: str, current_user: User = Depends(get_current_user)
) -> TableMetadata:
"""Get table metadata"""
with DatabaseManager() as session:
metadata_manager = TableMetadataManager(session)
Expand Down
6 changes: 3 additions & 3 deletions backend/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def create_token(data: dict, expires_delta: timedelta) -> str:
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
encoded_jwt: str = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt


Expand Down Expand Up @@ -191,15 +191,15 @@ def set_tokens_in_cookies(response: Response, access_token: str, refresh_token:
httponly=True,
max_age=1800,
secure=True,
samesite="Lax",
samesite="lax",
)
response.set_cookie(
key="refresh_token",
value=f"Bearer {refresh_token}",
httponly=True,
max_age=60 * 60 * 24 * 7,
secure=True,
samesite="Lax",
samesite="lax",
)


Expand Down
3 changes: 2 additions & 1 deletion backend/utils/sql_string_manipulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,6 @@ def extract_sql_query_from_text(self) -> Optional[str]:
match = re.findall(r"CREATE TABLE [^;]+;", self.sql_string)
if match:
# Extract only the last "CREATE TABLE" statement and add a space after "CREATE TABLE"
last_statement = match[-1]
last_statement: str = match[-1]
return "CREATE TABLE " + last_statement.split("CREATE TABLE")[-1].strip()
return None
12 changes: 10 additions & 2 deletions backend/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from csv import Sniffer
from fastapi import File, UploadFile
from io import StringIO
from typing import Any
from typing import Any, Dict, Optional

import logging
import os
Expand Down Expand Up @@ -30,10 +30,18 @@ def process_file(file: UploadFile, encoding: str) -> Any:
Returns:
Any: Processed content of the file.
"""
if file.filename is None:
raise ValueError("File must have a filename")

# Find file type by file extension
file_type = file.filename.split(".")[-1].lower()

files = {"processed_df": None, "sample_file_content": None}
# Define the dictionary with types for values
files: Dict[str, Optional[str]] = {
"processed_df": None,
"sample_file_content_str": None,
"header_str": None,
}

if file_type == "csv":
# Sniff the first 1024 bytes to check for a header
Expand Down
8 changes: 8 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[mypy]
python_version = 3.8
check_untyped_defs = True
ignore_missing_imports = True
warn_redundant_casts = True
warn_unused_ignores = True
warn_return_any = True
no_implicit_optional = True

0 comments on commit 9837559

Please sign in to comment.