diff --git a/.gitignore b/.gitignore index ba0430d..9069a60 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ -__pycache__/ \ No newline at end of file +__pycache__/ + +src/credentials.json +src/token.json \ No newline at end of file diff --git a/chroma_storage/chroma.sqlite3 b/chroma_storage/chroma.sqlite3 new file mode 100644 index 0000000..7345008 Binary files /dev/null and b/chroma_storage/chroma.sqlite3 differ diff --git a/requirements.txt b/requirements.txt index 83314d7..2c65839 100755 --- a/requirements.txt +++ b/requirements.txt @@ -9,5 +9,10 @@ pypandoc_binary nltk langchain flask -sentence-transformers flask-cors +sentence-transformers +google-auth +google-auth-oauthlib +google-api-python-client +tqdm +markdown \ No newline at end of file diff --git a/src/chroma_storage/chroma.sqlite3 b/src/chroma_storage/chroma.sqlite3 index 6f242d4..f81cf64 100644 Binary files a/src/chroma_storage/chroma.sqlite3 and b/src/chroma_storage/chroma.sqlite3 differ diff --git a/src/chroma_storage/d99d339e-cf80-4412-ac9f-80a2a2a28aa5/data_level0.bin b/src/chroma_storage/d99d339e-cf80-4412-ac9f-80a2a2a28aa5/data_level0.bin index 79f8b0a..ae2f009 100644 Binary files a/src/chroma_storage/d99d339e-cf80-4412-ac9f-80a2a2a28aa5/data_level0.bin and b/src/chroma_storage/d99d339e-cf80-4412-ac9f-80a2a2a28aa5/data_level0.bin differ diff --git a/src/chroma_storage/d99d339e-cf80-4412-ac9f-80a2a2a28aa5/length.bin b/src/chroma_storage/d99d339e-cf80-4412-ac9f-80a2a2a28aa5/length.bin index 5eef315..3286330 100644 Binary files a/src/chroma_storage/d99d339e-cf80-4412-ac9f-80a2a2a28aa5/length.bin and b/src/chroma_storage/d99d339e-cf80-4412-ac9f-80a2a2a28aa5/length.bin differ diff --git a/src/documents/newDoc.txt b/src/documents/newDoc.txt index d5d2e92..b0866a7 100644 --- a/src/documents/newDoc.txt +++ b/src/documents/newDoc.txt @@ -1,2 +1,4 @@ Niels Houben älskar äppelpaj, äpplen, och äppeljuice. -Niels vill alltid äta äpplen! \ No newline at end of file +Niels vill alltid äta äpplen! + +Han gillar även ostkaka \ No newline at end of file diff --git a/src/main.py b/src/main.py index cb3cda0..8131f85 100644 --- a/src/main.py +++ b/src/main.py @@ -1,63 +1,3 @@ -""" Main module to start the application """ -import sys -import json -from prompt_request import prompt_request -# from model_wrapper import prompt_model -# from query_wrapper import preform_query -from test_prompts.test_prompts import test_prompts, standard_prompts - -from rich.console import Console -from rich.markdown import Markdown -console = Console() - - - -def run_standard_prompts(): - output = [] - for prompt in standard_prompts: - query_string = prompt - result = preform_query(query_string) - answer = prompt_model(str(result), prompt) - output.append({ - "usr_prompt": prompt, - "answer": answer - }) - with open("test_results/test_output.json", "w+") as f: - json.dump(output, f, indent=4, ensure_ascii=False) - - -def question_loop(): - while True: - usr_prompt = input("Ställ din fråga: ") - if usr_prompt == "e": - break - answer = prompt_request(usr_prompt).json()['answer'] - - console.print(Markdown(answer)) - # print(prompt_request(usr_prompt)) - # query_string = usr_prompt - # result = preform_query(query_string) - #print("dokument som hittades: ",[res["filepath"] for res in result]) - #print("results:") - # print(str(result)) - # prompt_model(str(result), usr_prompt) - - -def main() -> None: - """ For now just a dummy """ - - # if sys.argv[1] == 'test': - # run_standard_prompts() - # return 0 - - # sys_prompt = test_prompts[0]["sys_prompt"] - # usr_prompt = test_prompts[0]["usr_prompt"] - question_loop() - - - -#could create user interface in json file that can be edited... - -if __name__ == '__main__': - sys.exit(main()) +""" Runs servers for model, mail, file-watcher """ +# TODO, FIX THE THING \ No newline at end of file diff --git a/src/model_server/test.py b/src/model_server/test.py new file mode 100644 index 0000000..fb52d99 --- /dev/null +++ b/src/model_server/test.py @@ -0,0 +1,11 @@ +import requests +import json + +url = 'http://127.0.0.1:5000/prompt' +headers = {'Content-Type': 'application/json'} +data = {'prompt': 'Hej, vad är budgeten?'} +json_data = json.dumps(data) +response = requests.post(url, headers=headers, data=json_data) + +print(response.status_code) +print(response.json()) diff --git a/src/prompt_loop.py b/src/prompt_loop.py new file mode 100644 index 0000000..2e3efa5 --- /dev/null +++ b/src/prompt_loop.py @@ -0,0 +1,63 @@ +""" Main module to start the application """ +import sys +import json +from utils.api.prompt_request import prompt_request +# from model_wrapper import prompt_model +# from query import preform_query +from test_prompts.test_prompts import test_prompts, standard_prompts + +from rich.console import Console +from rich.markdown import Markdown +console = Console() + + + +# def run_standard_prompts(): +# output = [] +# for prompt in standard_prompts: +# query_string = prompt +# result = preform_query(query_string) +# answer = prompt_model(str(result), prompt) +# output.append({ +# "usr_prompt": prompt, +# "answer": answer +# }) +# with open("test_results/test_output.json", "w+") as f: +# json.dump(output, f, indent=4, ensure_ascii=False) + + +def question_loop(): + while True: + usr_prompt = input("Ställ din fråga: ") + if usr_prompt == "e": + break + answer = prompt_request(usr_prompt)['answer'] + + console.print(Markdown(answer)) + # print(prompt_request(usr_prompt)) + # query_string = usr_prompt + # result = preform_query(query_string) + #print("dokument som hittades: ",[res["filepath"] for res in result]) + #print("results:") + # print(str(result)) + # prompt_model(str(result), usr_prompt) + + +def main() -> None: + """ For now just a dummy """ + + # if sys.argv[1] == 'test': + # run_standard_prompts() + # return 0 + + # sys_prompt = test_prompts[0]["sys_prompt"] + # usr_prompt = test_prompts[0]["usr_prompt"] + question_loop() + + + +#could create user interface in json file that can be edited... + +if __name__ == '__main__': + sys.exit(main()) + diff --git a/src/server_mail.py b/src/server_mail.py new file mode 100644 index 0000000..6bfff49 --- /dev/null +++ b/src/server_mail.py @@ -0,0 +1,27 @@ +from time import sleep + +from utils.mail.mail_utils import get_service, respond_to_mails + +# If modifying these SCOPES, delete the file token.json. +SCOPES = ["https://www.googleapis.com/auth/gmail.modify"] +SENDER_ADRESS = "informationhantering@gmail.com" +USER_ID = "me" + + +def main(): + service = get_service(SCOPES) + + """ + Checks the inbox for mails every 5 seconds for 30 minutes or until keyboard interupt. + Implement google pub sub later to trigger when new mail is received + """ + + print("Started mailbox watcher") + for _ in range(360): + print("Checking...") + respond_to_mails(service, SENDER_ADRESS, USER_ID) + sleep(5) + + +if __name__ == "__main__": + main() diff --git a/src/server.py b/src/server_model.py similarity index 95% rename from src/server.py rename to src/server_model.py index 8475e92..0b29199 100644 --- a/src/server.py +++ b/src/server_model.py @@ -2,7 +2,7 @@ from flask_cors import CORS from model_wrapper import prompt_model -from query_wrapper import preform_query +from utils.query.query import preform_query app = Flask(__name__) CORS(app) diff --git a/src/server_watcher.py b/src/server_watcher.py new file mode 100644 index 0000000..93d90d0 --- /dev/null +++ b/src/server_watcher.py @@ -0,0 +1,28 @@ +from utils.watcher.update_db import check_db, cleanse_documents, init_db +from utils.watcher.watcher_utils import start_watcher + +choicetext = """ +start - Start the watcher +reset - Remove all documents from DB (not directory) and then re-add them +check - Print all documents that are currently in the DB +quit - Quit the CLI + +> """ + +if __name__ == "__main__": + DIRECTORY = "./documents/" # Change to path of network folder + while True: + choice = input(choicetext) + + if choice.lower() == "start": + print("Press ctrl+c to stop watcher") + start_watcher(DIRECTORY) + elif choice.lower() == "reset": + cleanse_documents() + init_db(DIRECTORY) + print("Reset done") + elif choice.lower() == "check": + check_db() + elif choice.lower() == "quit": + print("closing CLI") + break diff --git a/src/test.ipynb b/src/test.ipynb deleted file mode 100644 index e69de29..0000000 diff --git a/src/query_wrapper.py b/src/test_query.py similarity index 73% rename from src/query_wrapper.py rename to src/test_query.py index b89ecad..d1d59d6 100644 --- a/src/query_wrapper.py +++ b/src/test_query.py @@ -1,4 +1,4 @@ -from query_utils.query import preform_query +from utils.query.query import preform_query if __name__ == "__main__": diff --git a/src/test_similarity_checks.py b/src/test_similarity_checks.py new file mode 100644 index 0000000..c53c921 --- /dev/null +++ b/src/test_similarity_checks.py @@ -0,0 +1,57 @@ +import chromadb + # client = chromadb.PersistentClient(path=persist_directory) +import torch +from langchain.document_loaders import TextLoader +from langchain.vectorstores import Chroma +from transformers import AutoTokenizer, AutoModel + +def preprocess_text(text, chunk_size=2000): + chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] + return chunks + +def perform_query(query: str): + collection_name = "documents_collection" + persist_directory = "chroma_storage" + + # Instantiate the Chroma client or use the appropriate class/method + client = chromadb.PersistentClient(path=persist_directory) + + # Get the collection. + collection = client.get_collection(name=collection_name) + + print("Querying...\n") + + # Instantiate a Hugging Face tokenizer and model + model_name = "bert-base-uncased" # Replace with your desired Hugging Face model + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModel.from_pretrained(model_name) + + # Preprocess the query into chunks + query_chunks = preprocess_text(query) + + # Perform similarity checks on each chunk + results = [] + for chunk in query_chunks: + # Tokenize and obtain embeddings using the Hugging Face model + tokens = tokenizer(chunk, return_tensors="pt") + with torch.no_grad(): + embeddings = model(**tokens).last_hidden_state.mean(dim=1).squeeze().numpy() + + # Convert the embeddings to a list of strings for the query_texts argument + embeddings_str = [str(val) for val in embeddings.tolist()] + + result_chunk = collection.query( + query_texts=embeddings_str, n_results=4, include=["documents", "metadatas"] + ) + + filepaths = [result["filepath"] for result in result_chunk["metadatas"][0]] + documents = result_chunk["documents"][0] + result_chunk = [{"filepath": filepath, "text": document} for filepath, document in zip(filepaths, documents)] + results.extend(result_chunk) + + return results + +# Example usage: +query = "Your query text goes here." +results = perform_query(query) +print(results) diff --git a/src/prompt_request.py b/src/utils/api/prompt_request.py similarity index 69% rename from src/prompt_request.py rename to src/utils/api/prompt_request.py index a4dc194..205b094 100644 --- a/src/prompt_request.py +++ b/src/utils/api/prompt_request.py @@ -1,10 +1,12 @@ import json import requests -def prompt_request(question): + +def prompt_request(question: str) -> dict: + """ Sends a request to server_model """ url = 'http://127.0.0.1:5000/prompt' headers = {'Content-Type': 'application/json'} data = {'prompt': question} json_data = json.dumps(data) response = requests.post(url, headers=headers, data=json_data) # timeout=30 - return response + return response.json() diff --git a/src/utils/mail/mail_utils.py b/src/utils/mail/mail_utils.py new file mode 100644 index 0000000..4f930ec --- /dev/null +++ b/src/utils/mail/mail_utils.py @@ -0,0 +1,251 @@ +import base64 +import os +import re +from email.mime.text import MIMEText + +from google.auth.transport.requests import Request +from google.oauth2.credentials import Credentials +from google_auth_oauthlib.flow import InstalledAppFlow +from googleapiclient.discovery import build + +from utils.api.prompt_request import prompt_request +from rich.markdown import Markdown +import markdown +from markdown_it import MarkdownIt +import markdown2 + + + +def get_unread_emails(service, user_id="me"): + """ + Checks the mailbox for unread messages + + args: + service: see get_service() + user_id: should ALWAYS be "me" + + return: + list of dictionaries with the following structure: + [ + { + 'id': (str), + 'name': (str), + 'address': (str), + 'subject': (str), + 'body': (str) + } + ] + """ + try: + response = ( + service.users().messages().list(userId=user_id, q="is:unread").execute() + ) + messages = response.get("messages", []) + emails = [] + + for message in messages: + msg_id = message["id"] + msg = ( + service.users() + .messages() + .get(userId=user_id, id=msg_id, format="full") + .execute() + ) + + # Extract the sender's email address and name using regex + sender_info = next( + header["value"] + for header in msg["payload"]["headers"] + if header["name"] == "From" + ) + email_match = re.search(r"<(.+?)>", sender_info) + sender_email = email_match.group(1) if email_match else sender_info + sender_name = re.search(r"(.+?) <", sender_info) + sender_name = sender_name.group(1) if sender_name else "" + + # Extract the subject + subject = next( + header["value"] + for header in msg["payload"]["headers"] + if header["name"] == "Subject" + ) + + # Extract the message body + parts = msg["payload"].get("parts", []) + body = "" + for part in parts: + if part["mimeType"] == "text/plain": + body_data = part["body"]["data"] + body = base64.urlsafe_b64decode(body_data).decode("utf-8") + break + + email_info = { + "id": msg_id, + "name": sender_name, + "address": sender_email, + "subject": subject, + "body": body, + } + emails.append(email_info) + + return emails + except Exception as error: + print(f"An error occurred: {error}") + return [] + + +def mark_as_read(service, user_id, msg_id): + """ + Marks a specified email as read. + + Args: + service: The Gmail API service instance. + user_id (str): The user's email ID. Usually 'me'. + msg_id (str): The unique ID of the email to be marked as read. + return: + True if successfully marked as read else false + """ + + try: + service.users().messages().modify( + userId=user_id, id=msg_id, body={"removeLabelIds": ["UNREAD"]} + ).execute() + print(f"Marked message {msg_id} as read.") + return True + except Exception as error: + print(f"An error occurred: {error}") + return False + + +def get_service(scopes): + """ + Creates a Gmail API service client. + + Args: + scopes (list): A list of strings representing the API scopes. + + Returns: + An authorized Gmail API service client. + """ + creds = None + # The file token.json stores the user's access and refresh tokens, and is + # created automatically when the authorization flow completes for the first time. + if os.path.exists("token.json"): + creds = Credentials.from_authorized_user_file("token.json", scopes) + # If there are no (valid) credentials available, let the user log in. + if not creds or not creds.valid: + if creds and creds.expired and creds.refresh_token: + creds.refresh(Request()) + else: + flow = InstalledAppFlow.from_client_secrets_file("credentials.json", scopes) + creds = flow.run_local_server(port=5000) + # Save the credentials for the next run + with open("token.json", "w") as token: + token.write(creds.to_json()) + + service = build("gmail", "v1", credentials=creds) + return service + + +def create_message(sender, to, subject, message_text): + """ + Creates a message + + params: + + sender (str): The e-mail adress of the sender + to (str): The e-mail adress of the receiver + subject (str): The subject of the mail + message_text (str): The message text + + return: A raw encoded message ready to be sent using send_message() + """ + message = MIMEText(message_text, "html") + message["to"] = to + message["from"] = sender + message["subject"] = subject + print( + f"Created message from {sender} to {to}\nSubject: {subject}\nContent: {message_text}" + ) + return {"raw": base64.urlsafe_b64encode(message.as_bytes()).decode()} + + +def send_message(service: build, user_id, message): + """ + Sends an email message. + + Args: + service: The Gmail API service instance. + user_id (str): The user's email ID. Usually 'me'. + message (dict): The message to be sent. Should be created using create_message. + + Returns: + bool: True if the message was sent successfully, False otherwise. + """ + try: + message = ( + service.users().messages().send(userId=user_id, body=message).execute() + ) + print(f"Sent {message['id']} successfully") + return True + except Exception as error: + print(f"An error occurred sending the message: {error}") + return False + + +def respond_to_mails(service, sender_adress, user_id): + """ + Responds to unread emails with a specific subject. + + Args: + service: The Gmail API service instance. + sender_address (str): The email address of the sender (response sender). + user_id (str): The user's email ID. Usually 'me'. + + This function scans for unread emails with the subject 'info', + creates a response, sends it, and marks the original email as read. + """ + unread_emails = get_unread_emails(service) + + for email in unread_emails: + if email["subject"].lower() == "info": + print("\nMAIL FOUND\n==================================") + print( + f'Sender: {email["name"]}\nAddress: {email["address"]}\nContent: {email["body"]}' + ) + + + prompt = email["body"].rstrip() + prompt_res = prompt_request(prompt) + + md = MarkdownIt() + # formatted_answer = markdown.markdown(prompt_res['answer']) + # formatted_answer = md.render(prompt_res['answer']) + formatted_answer = markdown2.markdown(prompt_res['answer'], extras=['tables']) + + + + response_text = f""" +Hej {email["name"]}! + +Här kommer svaret på frågan "{prompt}": +{formatted_answer} + +De här dokumenten har använts: +{", ".join(prompt_res['doc_names'])} + + +Hälsningar, +Effektiv Administration + """ + + + message = create_message( + sender=sender_adress, + to=email["address"], + subject=f'Svar till frågan {email["body"][:50]}', + message_text=response_text, + ) + + send_message(service, user_id, message) + mark_as_read(service, user_id, email["id"]) diff --git a/src/query_utils/query.py b/src/utils/query/query.py similarity index 76% rename from src/query_utils/query.py rename to src/utils/query/query.py index 97b6878..8928a07 100644 --- a/src/query_utils/query.py +++ b/src/utils/query/query.py @@ -1,12 +1,10 @@ import chromadb - - def preform_query(query: str): - """ - Perform query on Chroma DB collenction - - :param query: The long string of text to be chunked. + """ + Preforms a query in the ChromaDB to find relevant documents + Args: + query (str): The string you want to query by in the vector DB """ collection_name = "documents_collection" persist_directory = "chroma_storage" @@ -26,3 +24,5 @@ def preform_query(query: str): return result +if __name__ == "__main__": + preform_query("Budget") diff --git a/src/watcher_utils/file_reader.py b/src/utils/watcher/file_reader.py similarity index 54% rename from src/watcher_utils/file_reader.py rename to src/utils/watcher/file_reader.py index bf4669f..a8d2d89 100644 --- a/src/watcher_utils/file_reader.py +++ b/src/utils/watcher/file_reader.py @@ -1,9 +1,10 @@ -import pypandoc import nltk +import pypandoc from nltk.tokenize import sent_tokenize # Download the necessary NLTK models (if not already downloaded) -nltk.download('punkt') +nltk.download("punkt") + def chunk_text(text, max_length=2000): """ @@ -18,29 +19,66 @@ def chunk_text(text, max_length=2000): current_chunk = [] for sentence in sentences: - if len(' '.join(current_chunk) + ' ' + sentence) <= max_length: + if len(" ".join(current_chunk) + " " + sentence) <= max_length: current_chunk.append(sentence) else: - chunks.append(' '.join(current_chunk)) + chunks.append(" ".join(current_chunk)) current_chunk = [sentence] if current_chunk: - chunks.append(' '.join(current_chunk)) + chunks.append(" ".join(current_chunk)) return chunks pypandoc_compatible = [ - "biblatex", "bibtex", "commonmark", "commonmark_x", "creole", "csljson", - "csv", "docbook", "docx", "dokuwiki", "endnotexml", "epub", "fb2", "gfm", - "haddock", "html", "ipynb", "jats", "jira", "json", "latex", "man", - "markdown", "markdown_github", "markdown_mmd", "markdown_phpextra", - "markdown_strict", "mediawiki", "muse", "native", "odt", "opml", "org", - "ris", "rst", "rtf", "t2t", "textile", "tikiwiki", "tsv", "twiki", "vimwiki" + "biblatex", + "bibtex", + "commonmark", + "commonmark_x", + "creole", + "csljson", + "csv", + "docbook", + "docx", + "dokuwiki", + "endnotexml", + "epub", + "fb2", + "gfm", + "haddock", + "html", + "ipynb", + "jats", + "jira", + "json", + "latex", + "man", + "markdown", + "markdown_github", + "markdown_mmd", + "markdown_phpextra", + "markdown_strict", + "mediawiki", + "muse", + "native", + "odt", + "opml", + "org", + "ris", + "rst", + "rtf", + "t2t", + "textile", + "tikiwiki", + "tsv", + "twiki", + "vimwiki", ] + def read_normal_file(file_path): try: - with open(file_path, 'r', encoding='utf-8') as file: + with open(file_path, "r", encoding="utf-8") as file: return file.read() except FileNotFoundError: return "File not found." @@ -50,11 +88,11 @@ def read_normal_file(file_path): def get_chunks(file_path: str): file_end = file_path.split(".")[-1] - if (file_end in pypandoc_compatible): - text: str = pypandoc.convert_file(file_path, 'rst') + if file_end in pypandoc_compatible: + text: str = pypandoc.convert_file(file_path, "rst") else: text: str = read_normal_file(file_path) - text = text.replace("\r","") + text = text.replace("\r", "") chunks = chunk_text(text) return chunks diff --git a/src/watcher_utils/update_db.py b/src/utils/watcher/update_db.py similarity index 77% rename from src/watcher_utils/update_db.py rename to src/utils/watcher/update_db.py index 97d56ef..82ff408 100644 --- a/src/watcher_utils/update_db.py +++ b/src/utils/watcher/update_db.py @@ -1,14 +1,14 @@ import os -import argparse -import pypandoc import shutil -from watcher_utils.file_reader import get_chunks + +import chromadb from tqdm import tqdm +from utils.watcher.file_reader import get_chunks -import chromadb COLLECTION_NAME = "documents_collection" -PERSIST_DIR :str= "chroma_storage" +PERSIST_DIR: str = "chroma_storage" + def load_file(filepath): documents = [] @@ -19,7 +19,6 @@ def load_file(filepath): documents.append(chunk) metadatas.append({"filepath": filepath, "chunk": i}) - client = chromadb.PersistentClient(path=PERSIST_DIR) collection = client.get_or_create_collection(name=COLLECTION_NAME) count = collection.count() @@ -35,8 +34,6 @@ def load_file(filepath): metadatas=metadatas[i : i + 100], ) - results = collection.get(where= {"filepath": filepath}) - def remove_file(filepath): client = chromadb.PersistentClient(path=PERSIST_DIR) @@ -44,8 +41,8 @@ def remove_file(filepath): if collection is None: print(f"No collection found with name {COLLECTION_NAME}") return - - collection.delete(where= {"filepath": filepath}) + + collection.delete(where={"filepath": filepath}) def remove_chroma_storage_files(): @@ -68,64 +65,61 @@ def remove_chroma_storage_files(): shutil.rmtree(item_path) print(f"Removed folder: {item_path}") + def cleanse_documents(): + """Remove all documents from ChromaDB (irreversible)""" + input("You are deleting everything in the db, do you want to continue?") client = chromadb.PersistentClient(path=PERSIST_DIR) client.delete_collection(name=COLLECTION_NAME) remove_chroma_storage_files() - + print("DB deleted") + def update_file(filepath): + """Update the information of all the documents related to a file. + TODO: Investigate if there is a better way to do this than remove and create again. + """ remove_file(filepath) load_file(filepath) + def check_db(): + """Returns all documents currently in the DB""" client = chromadb.PersistentClient(path="chroma_storage") collection = client.get_or_create_collection(name="documents_collection") result = collection.get() print(result) + def init_db(directory): - # Read all files in the data directory + """Create documents for every file in """ documents = [] metadatas = [] - #läser documentes directory files = os.listdir(directory) - #läser alla rader i dokumenten och sparar det i document och metadata list for filepath in files: - chunks = get_chunks(directory+filepath) + chunks = get_chunks(directory + filepath) for i, chunk in enumerate(chunks): documents.append(chunk) - metadatas.append({"filepath": directory+filepath, "chunk": i}) + metadatas.append({"filepath": directory + filepath, "chunk": i}) - # Instantiate a persistent chroma client in the persist_directory. client = chromadb.PersistentClient(path="chroma_storage") - # If the collection already exists, we just return it. This allows us to add more - # data to an existing collection. collection = client.get_or_create_collection(name="documents_collection") - # Create ids from the current count count = collection.count() print(f"Collection already contains {count} documents") ids = [str(i) for i in range(count, count + len(documents))] - # Load the documents in batches of 100 for i in tqdm( range(0, len(documents), 100), desc="Adding documents", unit_scale=100 ): collection.add( ids=ids[i : i + 100], documents=documents[i : i + 100], - metadatas=metadatas[i : i + 100], # type: ignore + metadatas=metadatas[i : i + 100], ) new_count = collection.count() print(f"Added {new_count - count} documents") - - - -if __name__ == "__main__": - cleanse_documents() - init_db("./documents2/") diff --git a/src/utils/watcher/watcher_utils.py b/src/utils/watcher/watcher_utils.py new file mode 100644 index 0000000..74f3470 --- /dev/null +++ b/src/utils/watcher/watcher_utils.py @@ -0,0 +1,48 @@ +import time + +from watchdog import events +from watchdog.events import FileSystemEventHandler +from watchdog.observers.polling import PollingObserver + +from utils.watcher.update_db import load_file, remove_file, update_file + + +class WatcherHandler(FileSystemEventHandler): + def on_modified(self, event: events.FileModifiedEvent): + if event.is_directory: + return + update_file(event.src_path) + print("\n" * 50) + print(f"Modified file: {event.src_path}") + + def on_created(self, event: events.FileCreatedEvent): + if event.is_directory: + return + load_file(event.src_path) + print("\n" * 50) + print(f"Created file: {event.src_path}") + + def on_deleted(self, event: events.FileDeletedEvent): + if event.is_directory: + return + remove_file(event.src_path) + print("\n" * 50) + print(f"Deleted file: {event.src_path}") + + +def start_watcher(directory): + """ + Starts the watcher on a specified directory that notifies and handles creation/changes/deletion of files within + that directory so that the ChromaDB reflects the information contained in the directory""" + event_handler = WatcherHandler() + observer = PollingObserver() + observer.schedule(event_handler, directory, recursive=True) + + observer.start() + print(f"Started watcher on {directory}") + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + observer.stop() + observer.join() diff --git a/src/watcher_utils/watcher.py b/src/watcher_utils/watcher.py deleted file mode 100644 index 061ac23..0000000 --- a/src/watcher_utils/watcher.py +++ /dev/null @@ -1,56 +0,0 @@ -import time -from watchdog.events import FileSystemEventHandler -from watchdog import events -from watchdog.observers.polling import PollingObserver -from watchdog.events import FileSystemEventHandler -from watchdog import events - -from watcher_utils.update_db import load_file, remove_file, update_file - -class WatcherHandler(FileSystemEventHandler): - - def on_modified(self, event: events.FileModifiedEvent): - try: - if event.is_directory: - return - update_file(event.src_path) - print("\n" * 50) - print(f'Modified file: {event.src_path}') - except: - pass - - def on_created(self, event:events.FileCreatedEvent): - try: - if event.is_directory: - return - load_file(event.src_path) - print("\n" * 50) - print(f'Created file: {event.src_path}') - except: - pass - - def on_deleted(self, event:events.FileDeletedEvent): - try: - if event.is_directory: - return - remove_file(event.src_path) - print("\n" * 50) - print(f'Deleted file: {event.src_path}') - except: - pass - - - -def start_watcher(directory): - event_handler = WatcherHandler() - observer = PollingObserver() - observer.schedule(event_handler, directory, recursive=True) - - observer.start() - print(f"Started watcher on {directory}") - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - observer.stop() - observer.join() \ No newline at end of file diff --git a/src/watcher_wrapper.py b/src/watcher_wrapper.py deleted file mode 100644 index ea38fd7..0000000 --- a/src/watcher_wrapper.py +++ /dev/null @@ -1,17 +0,0 @@ -from watcher_utils.watcher import start_watcher -from watcher_utils.update_db import cleanse_documents, init_db, check_db - -if __name__ == "__main__": - DIRECTORY = "./documents/" - - """Steg 1, ta bort dokument (Om chroma storage är tom behövs inte detta""" - cleanse_documents() - """Steg 2 Stoppa in dokument""" - init_db(DIRECTORY) - """Steg 3 kolla vad som finns i db (optional)""" - check_db() - - #start_watcher(DIRECTORY) - - - \ No newline at end of file