From e1b4e97d0630ee56b5e51006ac8ca056f4deaa68 Mon Sep 17 00:00:00 2001 From: Jeremy Frank Date: Sat, 26 Oct 2024 16:37:26 -0600 Subject: [PATCH] refactor threadStorage to be cleaner and more efficient --- .../agents/src/services/threadStorage.ts | 244 +++++++----------- 1 file changed, 94 insertions(+), 150 deletions(-) diff --git a/auto-content-creator/agents/src/services/threadStorage.ts b/auto-content-creator/agents/src/services/threadStorage.ts index 4b46774..c60a2fc 100644 --- a/auto-content-creator/agents/src/services/threadStorage.ts +++ b/auto-content-creator/agents/src/services/threadStorage.ts @@ -2,12 +2,54 @@ import sqlite3 from 'sqlite3'; import { open, Database } from 'sqlite'; import { ThreadState } from '../types'; import logger from '../logger'; -import { AIMessage, HumanMessage } from '@langchain/core/messages'; +import { AIMessage, HumanMessage, BaseMessage } from '@langchain/core/messages'; import path from 'path'; -// Database initialization +// Types +type MessageType = { + _type?: string; + content: string; + kwargs?: { content: string }; + id?: string[]; +}; + +type DbRow = { + state: string; + last_output: string | null; +}; + +// Pure functions for message handling +const createMessage = (type: 'human' | 'ai', content: string): BaseMessage => + type === 'human' ? new HumanMessage({ content }) : new AIMessage({ content }); + +const deserializeMessage = (msg: MessageType): BaseMessage => { + if (!msg) return createMessage('ai', 'Invalid message'); + + // Handle LangChain format + if (msg.kwargs?.content) { + return createMessage( + msg.id?.includes('HumanMessage') ? 'human' : 'ai', + msg.kwargs.content + ); + } + + // Handle simplified format + return createMessage( + msg._type === 'human' ? 'human' : 'ai', + msg.content + ); +}; + +const serializeMessage = (msg: BaseMessage) => ({ + _type: msg._getType(), + content: msg.content, + additional_kwargs: msg.additional_kwargs +}); + +// Database operations const initializeDb = async (dbPath: string): Promise => { logger.info('Initializing SQLite database at:', dbPath); + const db = await open({ filename: dbPath, driver: sqlite3.Database @@ -23,188 +65,101 @@ const initializeDb = async (dbPath: string): Promise => { ) `); - logger.info('Thread storage initialized with SQLite'); return db; }; -// Message deserialization -const deserializeMessage = (msg: any) => { - try { - // Handle LangChain serialized format - if (msg.kwargs) { - const content = msg.kwargs.content; - return msg.id.includes('HumanMessage') - ? new HumanMessage({ content }) - : new AIMessage({ content }); - } +// Thread state operations +const parseThreadState = (row: DbRow): ThreadState | null => { + const parsedState = JSON.parse(row.state); + const lastOutput = row.last_output ? JSON.parse(row.last_output) : undefined; - // Handle our simplified format - if (msg._type === 'human') { - return new HumanMessage({ content: msg.content }); - } - return new AIMessage({ content: msg.content }); - } catch (error) { - logger.error('Error deserializing message:', error); - return new AIMessage({ content: 'Error deserializing message' }); + if (!Array.isArray(parsedState.messages)) { + logger.warn('Invalid state: messages is not an array'); + return null; } + + return { + state: { + messages: parsedState.messages.map(deserializeMessage), + reflectionScore: parsedState.reflectionScore ?? 0, + researchPerformed: parsedState.researchPerformed ?? false, + research: parsedState.research ?? '', + reflections: parsedState.reflections ?? [], + drafts: parsedState.drafts ?? [], + feedbackHistory: parsedState.feedbackHistory ?? [], + }, + lastOutput + }; }; -// Thread storage operations +// Main storage factory export const createThreadStorage = () => { const dbPath = path.join(process.cwd(), 'thread-storage.sqlite'); - let dbPromise = initializeDb(dbPath); + const dbPromise = initializeDb(dbPath); - const ensureConnection = async (): Promise => { + const ensureConnection = async () => { const db = await dbPromise; - if (!db) { - throw new Error('Database connection not established'); - } + if (!db) throw new Error('Database connection not established'); return db; }; - const saveThread = async (threadId: string, state: ThreadState): Promise => { - try { + return { + async saveThread(threadId: string, state: ThreadState): Promise { const db = await ensureConnection(); - logger.info(`Saving thread state for ${threadId}`, { - messageCount: state.state.messages?.length ?? 0, - hasLastOutput: !!state.lastOutput - }); - // Ensure state has all required properties const stateToSave = { - messages: state.state.messages ?? [], - reflectionScore: state.state.reflectionScore ?? 0, - researchPerformed: state.state.researchPerformed ?? false, - research: state.state.research ?? '', - reflections: state.state.reflections ?? [], - drafts: state.state.drafts ?? [], - feedbackHistory: state.state.feedbackHistory ?? [], + ...state.state, + messages: state.state.messages.map(serializeMessage) }; - const serializedState = JSON.stringify(stateToSave, (key, value) => { - if (value instanceof HumanMessage || value instanceof AIMessage) { - return { - _type: value._getType(), - content: value.content, - additional_kwargs: value.additional_kwargs - }; - } - return value; - }); - - const serializedLastOutput = state.lastOutput ? JSON.stringify(state.lastOutput) : null; - await db.run( `INSERT OR REPLACE INTO threads (thread_id, state, last_output, updated_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP)`, - [threadId, serializedState, serializedLastOutput] + [ + threadId, + JSON.stringify(stateToSave), + state.lastOutput ? JSON.stringify(state.lastOutput) : null + ] ); - logger.info(`Thread state saved successfully: ${threadId}`); - } catch (error) { - logger.error(`Error saving thread ${threadId}:`, error); - throw error; - } - }; + logger.info(`Thread saved: ${threadId}`, { + messageCount: state.state.messages.length + }); + }, - const loadThread = async (threadId: string): Promise => { - try { + async loadThread(threadId: string): Promise { const db = await ensureConnection(); - logger.info(`Loading thread state for ${threadId}`); - const row = await db.get( + const row = await db.get( 'SELECT state, last_output FROM threads WHERE thread_id = ?', threadId ); if (!row) { - logger.warn(`Thread ${threadId} not found`); - return null; - } - - const parsedState = JSON.parse(row.state); - const lastOutput = row.last_output ? JSON.parse(row.last_output) : undefined; - - // Ensure messages array exists - if (!parsedState.messages) { - logger.warn(`Invalid state structure for thread ${threadId}: missing messages array`); + logger.warn(`Thread not found: ${threadId}`); return null; } - // Force messages into array if not already - const messageArray = Array.isArray(parsedState.messages) - ? parsedState.messages - : [parsedState.messages]; - - // Reconstruct message instances with improved error handling - const messages = messageArray.map((msg: any) => { - try { - if (!msg) { - logger.warn('Null or undefined message found'); - return new AIMessage({ content: 'Invalid message' }); - } - return deserializeMessage(msg); - } catch (error) { - logger.error('Error deserializing message:', { error, msg }); - return new AIMessage({ content: 'Error deserializing message' }); - } - }); - - // Ensure all required state properties exist - const state = { - messages, - reflectionScore: parsedState.reflectionScore ?? 0, - researchPerformed: parsedState.researchPerformed ?? false, - research: parsedState.research ?? '', - reflections: parsedState.reflections ?? [], - drafts: parsedState.drafts ?? [], - feedbackHistory: parsedState.feedbackHistory ?? [], - }; - - logger.info(`Thread state loaded successfully: ${threadId}`, { - messageCount: messages.length, - hasLastOutput: !!lastOutput - }); - - return { state, lastOutput }; - } catch (error) { - logger.error(`Error loading thread ${threadId}:`, error); - throw error; - } - }; + return parseThreadState(row); + }, - const getAllThreads = async (): Promise> => { - try { + async getAllThreads(): Promise> { const db = await ensureConnection(); - const rows = await db.all( + + return db.all( 'SELECT thread_id, created_at, updated_at FROM threads ORDER BY updated_at DESC' ); + }, - return rows.map(row => ({ - threadId: row.thread_id, - createdAt: row.created_at, - updatedAt: row.updated_at - })); - } catch (error) { - logger.error('Error listing threads:', error); - throw error; - } - }; - - const deleteThread = async (threadId: string): Promise => { - try { + async deleteThread(threadId: string): Promise { const db = await ensureConnection(); await db.run('DELETE FROM threads WHERE thread_id = ?', threadId); logger.info(`Thread deleted: ${threadId}`); - } catch (error) { - logger.error(`Error deleting thread ${threadId}:`, error); - throw error; - } - }; + }, - const cleanup = async (olderThanDays: number = 30): Promise => { - try { + async cleanup(olderThanDays: number = 30): Promise { const db = await ensureConnection(); + const result = await db.run( 'DELETE FROM threads WHERE updated_at < datetime("now", ?)', [`-${olderThanDays} days`] @@ -213,19 +168,8 @@ export const createThreadStorage = () => { const deletedCount = result.changes || 0; logger.info(`Cleaned up ${deletedCount} old threads`); return deletedCount; - } catch (error) { - logger.error('Error during cleanup:', error); - throw error; } }; - - return { - saveThread, - loadThread, - getAllThreads, - deleteThread, - cleanup - }; }; export type ThreadStorage = ReturnType;