Skip to content

Commit

Permalink
work in progress tg agent for game info
Browse files Browse the repository at this point in the history
  • Loading branch information
moshemalawach committed Jan 10, 2025
1 parent 79a3f65 commit eebece3
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 21 deletions.
44 changes: 32 additions & 12 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ package-mode = false

[tool.poetry.dependencies]
python = "^3.10"
libertai-agents = "0.1.0"
libertai-agents = {path = "../libertai-agents/libertai_agents", develop = true}
sqlalchemy = "^2.0.36"
aiosqlite = "^0.20.0"
greenlet = "^3.1.1"
pytelegrambotapi = "^4.25.0"
yfinance = "^0.2.51"
pypdf2 = "^3.0.1"

[tool.poetry.group.dev.dependencies]
mypy = "^1.11.1"
Expand Down
17 changes: 15 additions & 2 deletions src/commands/message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import telebot.types as telebot_types
from libertai_agents.interfaces.messages import Message as LibertaiMessage
from libertai_agents.interfaces.messages import MessageRoleEnum
import aiohttp

from src.config import config
from src.utils.telegram import (
Expand All @@ -12,6 +13,8 @@
# Max number of messages we will pass

Check failure on line 13 in src/commands/message.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

src/commands/message.py:1:1: I001 Import block is un-sorted or un-formatted
MESSAGES_NUMBER = 50

CHAT_SESSIONS = {}


async def text_message_handler(message: telebot_types.Message):
"""
Expand All @@ -33,12 +36,15 @@ async def text_message_handler(message: telebot_types.Message):
should_reply = should_reply_to_message(message)
if should_reply is False:
span.info("Message not intended for the bot")
return None

# Send an initial response
# TODO: select a phrase randomly from a list to get a more dynamic result
result = "I'm thinking..."
reply = await config.BOT.reply_to(message, result)

context = await config.KNOWLEDGEBASE.query(message.text, min=0.25, top_k=2)

messages: list[LibertaiMessage] = []

chat_history = await config.DATABASE.get_chat_last_messages(
Expand All @@ -47,7 +53,7 @@ async def text_message_handler(message: telebot_types.Message):
# Iterate over the messages we've pulled
for chat_msg in reversed(chat_history):
message_username = get_formatted_username(chat_msg.from_user)
message_content = get_formatted_message_content(chat_msg)
message_role, message_content = get_formatted_message_content(chat_msg)
# TODO: support multiple users with names
role = (
MessageRoleEnum.assistant
Expand All @@ -56,9 +62,16 @@ async def text_message_handler(message: telebot_types.Message):
)
messages.append(LibertaiMessage(role=role, content=message_content))

for title, content, similarity in context:
messages.append(LibertaiMessage(role=MessageRoleEnum.system, content=f"Relevant knowledgebase entry: {title}\n{content}"))

Check failure on line 67 in src/commands/message.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W293)

src/commands/message.py:67:1: W293 Blank line contains whitespace

if chat_id not in CHAT_SESSIONS:
CHAT_SESSIONS[chat_id] = aiohttp.ClientSession()

# TODO: pass system prompt with chat details when libertai-agents new version released
# and also tell it to answer directly (to avoid including "username in reply to username" in its answer)
async for response_msg in config.AGENT.generate_answer(messages):
async for response_msg in config.AGENT.generate_answer(messages, session=CHAT_SESSIONS[chat_id]):
if response_msg.content != result:
result = response_msg.content
# Update the message
Expand Down
9 changes: 6 additions & 3 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from src.tools import tools
from src.utils.database import AsyncDatabase
from src.utils.logger import Logger

from src.tools.knowledgebase import KnowledgeBase

class _Config:

Check failure on line 15 in src/config.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

src/config.py:1:1: I001 Import block is un-sorted or un-formatted
BOT_COMMANDS: list[tuple[str, str]]
Expand All @@ -31,8 +31,9 @@ def __init__(self):
]

# Logger
log_path = os.getenv("LOG_PATH")
log_path = os.getenv("LOG_PATH", None)
debug = os.getenv("DEBUG", "False") == "True"

self.LOGGER = Logger(log_path, debug)

try:
Expand All @@ -46,11 +47,13 @@ def __init__(self):
database_path = os.getenv("DATABASE_PATH", ":memory:")
self.DATABASE = AsyncDatabase(database_path)

self.KNOWLEDGEBASE = KnowledgeBase(os.getenv("KNOWLEDGEBASE_PATH", "knowledgebase.json"))

# LibertAI Agent
self.LOGGER.info("Setting up agent...")
self.AGENT = ChatAgent(
model=get_model("NousResearch/Hermes-3-Llama-3.1-8B"),
system_prompt="You are a helpful assistant",
system_prompt="CAPTAIN LASERHAWK world chatbot. Assistant has access to a knowledgebase of documents and can answer questions about the captain laserhawk world.",
tools=tools,
expose_api=False,
)
Expand Down
4 changes: 2 additions & 2 deletions src/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
)

tools: list[Tool] = [
Tool.from_function(get_current_stock_price),
Tool.from_function(get_current_cryptocurrency_price_usd),
# Tool.from_function(get_current_stock_price),
# Tool.from_function(get_current_cryptocurrency_price_usd),
]
17 changes: 17 additions & 0 deletions src/tools/citizens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

import aiohttp

metadata_base_url = "https://aleph.sh/vm/bd8291cca9a3de79937c452a606d81efc912ab7d223cd88031da5ca5e2a868dd/QmVpCmKPnD3dAFK61WE5czz7ucV3GHuqHqNsj2wNfWVjXf"


async def get_citizen_information(number: int) -> dict:
"""
Get the information of a citizen
Args:
number: The citizen number
Returns:
The information of the citizen as a JSON object
"""
async with aiohttp.ClientSession() as session:
async with session.get(f"{metadata_base_url}/{number}") as response:
return await response.json()

Check failure on line 17 in src/tools/citizens.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W292)

src/tools/citizens.py:17:41: W292 No newline at end of file
137 changes: 137 additions & 0 deletions src/tools/knowledgebase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
import json
import numpy as np
import aiohttp
import asyncio
from PyPDF2 import PdfReader
from .. import config

# Define a global variable for the default API URL

Check failure on line 9 in src/tools/knowledgebase.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

src/tools/knowledgebase.py:1:1: I001 Import block is un-sorted or un-formatted
DEFAULT_API_URL = "https://curated.aleph.cloud/vm/ee1b2a8e5bd645447739d8b234ef495c9a2b4d0b98317d510a3ccf822808ebe5/embedding"

# Async function to embed text using the external embedding tool
async def embed_text(text, session, api_url, tries=3):
backoff = 1 # Start with 1 second backoff
params = {
'content': text,
}

response = None
errors = []
for _ in range(tries):
try:
async with session.post(api_url, json=params) as resp:
resp.raise_for_status()
response = await resp.json()
break
except aiohttp.ClientError as error:
errors.append(str(error))
print(f"Error embedding text: {error}")
await asyncio.sleep(backoff)
backoff *= 2

if response is None:
raise Exception('Failed to generate embedding: ' + '; '.join(errors))

return response.get('embedding', [])

class KnowledgeBase:
def __init__(self, db_path, api_url=DEFAULT_API_URL):
self.db_path = db_path
self.api_url = api_url
self.session = None
self.load_db()

async def close(self):
if self.session is not None:
await self.session.close()

def load_db(self):
if os.path.exists(self.db_path):
with open(self.db_path, 'r') as f:
self.db = json.load(f)
else:
self.db = {}

def save_db(self):
with open(self.db_path, 'w') as f:
json.dump(self.db, f)

async def add_entry(self, title, content):
if self.session == None:

Check failure on line 61 in src/tools/knowledgebase.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E711)

src/tools/knowledgebase.py:61:28: E711 Comparison to `None` should be `cond is None`
self.session = aiohttp.ClientSession()

embedding = await embed_text(content, self.session, self.api_url)
self.db[title] = {
'content': content,
'embedding': embedding
}
self.save_db()

async def query(self, query_text, top_k=3, min=0.1):
if self.session is None:
self.session = aiohttp.ClientSession()

query_embedding = await embed_text(query_text, self.session, self.api_url)
results = []

for title, data in self.db.items():
entry_embedding = np.array(data['embedding'])
similarity = np.dot(query_embedding, entry_embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(entry_embedding))
results.append((title, data['content'], similarity))

# Sort results by similarity in descending order and return top_k results
results.sort(key=lambda x: x[2], reverse=True)
return [r for r in results[:top_k] if r[2] > min]

def extract_text_from_pdf(pdf_path):
reader = PdfReader(pdf_path)
text = ""
for page in reader.pages:
text += page.extract_text()
return text

def chunk_text(text, chunk_size=200):
words = text.split()
for i in range(0, len(words), chunk_size):
yield ' '.join(words[i:i + chunk_size])

async def populate_db_from_pdf(db_path, api_url, pdf_path):
kb = KnowledgeBase(db_path, api_url)
try:
text = extract_text_from_pdf(pdf_path)
for i, chunk in enumerate(chunk_text(text)):
title = f"Chunk {i+1}"
await kb.add_entry(title, chunk)
finally:
await kb.close()

async def search_knowledgebase(query_text: str, top_k: int = 3, min: float = 0.2) -> list[tuple[str, str, float]]:
"""
Search the knowledge base for entries that match the query text.
Args:
query_text (str): The text to query the knowledge base with.
top_k (int, optional): The maximum number of top results to return. Defaults to 3.
min (float, optional): The minimum similarity threshold for results to be included. Defaults to 0.2.
Returns:
list[tuple[str, str, float]]: A list of tuples containing the title, content, and similarity score of the top matching entries.
"""
kb = config.KNOWLEDGEBASE
return await kb.query(query_text, top_k, min)

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description='Populate the knowledge database from a PDF.')
parser.add_argument('db_path', type=str, help='Path to the knowledge database file.')
parser.add_argument('pdf_path', type=str, help='Path to the PDF file containing the lorebook.')
parser.add_argument('--api_url', type=str, default=DEFAULT_API_URL, help='API URL for the embedding service (default: %(default)s).')

args = parser.parse_args()
asyncio.run(populate_db_from_pdf(args.db_path, args.api_url, args.pdf_path))




2 changes: 1 addition & 1 deletion src/utils/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_formatted_message_content(message: Message) -> str:
reply_to_username = get_formatted_username(message.reply_to_message.from_user)
sender = f"{sender} (in reply to {reply_to_username})"

return f"{sender}\n{message.text}"
return sender, message.text


def should_reply_to_message(message: Message) -> bool:
Expand Down

0 comments on commit eebece3

Please sign in to comment.