From 8a46fcb7c0b84ae23ec94dec1e953c06d497d5c9 Mon Sep 17 00:00:00 2001 From: Jeremy Frank Date: Sat, 26 Oct 2024 16:33:08 -0600 Subject: [PATCH 1/2] add thread persistence to agents --- .gitignore | 6 +- auto-content-creator/agents/package.json | 4 +- auto-content-creator/agents/src/index.ts | 19 +- .../agents/src/routes/writer.ts | 25 ++ .../agents/src/services/threadStorage.ts | 231 ++++++++++++++++++ .../agents/src/services/writerAgent.ts | 83 +++++-- auto-content-creator/agents/src/types.ts | 15 ++ 7 files changed, 355 insertions(+), 28 deletions(-) create mode 100644 auto-content-creator/agents/src/services/threadStorage.ts diff --git a/.gitignore b/.gitignore index 1e593cf..9d35922 100644 --- a/.gitignore +++ b/.gitignore @@ -37,4 +37,8 @@ node_modules/ # SQLite database files *.sqlite *.sqlite3 -*.db \ No newline at end of file +*.db + +# SQLite database +*.sqlite +*.sqlite-journal diff --git a/auto-content-creator/agents/package.json b/auto-content-creator/agents/package.json index 2ee3b61..51d9185 100644 --- a/auto-content-creator/agents/package.json +++ b/auto-content-creator/agents/package.json @@ -20,7 +20,9 @@ "dotenv": "^16.4.1", "express": "5.0.0", "winston": "^3.14.2", - "zod": "^3.23.8" + "zod": "^3.23.8", + "sqlite3": "^5.1.7", + "sqlite": "^5.1.1" }, "devDependencies": { "@types/express": "5.0.0", diff --git a/auto-content-creator/agents/src/index.ts b/auto-content-creator/agents/src/index.ts index b554e7a..1db43c1 100644 --- a/auto-content-creator/agents/src/index.ts +++ b/auto-content-creator/agents/src/index.ts @@ -2,6 +2,7 @@ import express, { Request, Response, NextFunction } from 'express'; import { writerRouter } from './routes/writer'; import logger from './logger'; import { config } from './config'; +import { initializeStorage } from './services/writerAgent'; // Create an Express application const app = express(); @@ -28,7 +29,17 @@ app.use((err: Error, req: Request, res: Response, next: NextFunction) => { }); }); -// Start server -app.listen(port, () => { - logger.info(`Agents service is running on http://localhost:${port}`); -}); +// Initialize storage before starting the server +const startServer = async () => { + try { + await initializeStorage(); + app.listen(port, () => { + logger.info(`Agents service is running on http://localhost:${port}`); + }); + } catch (error) { + logger.error('Failed to start server:', error); + process.exit(1); + } +}; + +startServer(); diff --git a/auto-content-creator/agents/src/routes/writer.ts b/auto-content-creator/agents/src/routes/writer.ts index 0b449ff..5881248 100644 --- a/auto-content-creator/agents/src/routes/writer.ts +++ b/auto-content-creator/agents/src/routes/writer.ts @@ -65,4 +65,29 @@ router.post('/:threadId/feedback', (req, res, next) => { })(); }); +// Add new endpoint to check thread state +router.get('/:threadId/state', (req, res, next) => { + logger.info('Received request to get thread state:', req.params.threadId); + (async () => { + try { + const threadId = req.params.threadId; + const threadState = await writerAgent.getThreadState(threadId); + + if (!threadState) { + return res.status(404).json({ + error: 'Thread not found' + }); + } + + res.json({ + threadId, + lastOutput: threadState.lastOutput, + }); + } catch (error) { + logger.error('Error getting thread state:', error); + next(error); + } + })(); +}); + export const writerRouter = router; diff --git a/auto-content-creator/agents/src/services/threadStorage.ts b/auto-content-creator/agents/src/services/threadStorage.ts new file mode 100644 index 0000000..4b46774 --- /dev/null +++ b/auto-content-creator/agents/src/services/threadStorage.ts @@ -0,0 +1,231 @@ +import sqlite3 from 'sqlite3'; +import { open, Database } from 'sqlite'; +import { ThreadState } from '../types'; +import logger from '../logger'; +import { AIMessage, HumanMessage } from '@langchain/core/messages'; +import path from 'path'; + +// Database initialization +const initializeDb = async (dbPath: string): Promise => { + logger.info('Initializing SQLite database at:', dbPath); + const db = await open({ + filename: dbPath, + driver: sqlite3.Database + }); + + await db.exec(` + CREATE TABLE IF NOT EXISTS threads ( + thread_id TEXT PRIMARY KEY, + state TEXT NOT NULL, + last_output TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + `); + + logger.info('Thread storage initialized with SQLite'); + return db; +}; + +// Message deserialization +const deserializeMessage = (msg: any) => { + try { + // Handle LangChain serialized format + if (msg.kwargs) { + const content = msg.kwargs.content; + return msg.id.includes('HumanMessage') + ? new HumanMessage({ content }) + : new AIMessage({ content }); + } + + // Handle our simplified format + if (msg._type === 'human') { + return new HumanMessage({ content: msg.content }); + } + return new AIMessage({ content: msg.content }); + } catch (error) { + logger.error('Error deserializing message:', error); + return new AIMessage({ content: 'Error deserializing message' }); + } +}; + +// Thread storage operations +export const createThreadStorage = () => { + const dbPath = path.join(process.cwd(), 'thread-storage.sqlite'); + let dbPromise = initializeDb(dbPath); + + const ensureConnection = async (): Promise => { + const db = await dbPromise; + if (!db) { + throw new Error('Database connection not established'); + } + return db; + }; + + const saveThread = async (threadId: string, state: ThreadState): Promise => { + try { + const db = await ensureConnection(); + logger.info(`Saving thread state for ${threadId}`, { + messageCount: state.state.messages?.length ?? 0, + hasLastOutput: !!state.lastOutput + }); + + // Ensure state has all required properties + const stateToSave = { + messages: state.state.messages ?? [], + reflectionScore: state.state.reflectionScore ?? 0, + researchPerformed: state.state.researchPerformed ?? false, + research: state.state.research ?? '', + reflections: state.state.reflections ?? [], + drafts: state.state.drafts ?? [], + feedbackHistory: state.state.feedbackHistory ?? [], + }; + + const serializedState = JSON.stringify(stateToSave, (key, value) => { + if (value instanceof HumanMessage || value instanceof AIMessage) { + return { + _type: value._getType(), + content: value.content, + additional_kwargs: value.additional_kwargs + }; + } + return value; + }); + + const serializedLastOutput = state.lastOutput ? JSON.stringify(state.lastOutput) : null; + + await db.run( + `INSERT OR REPLACE INTO threads (thread_id, state, last_output, updated_at) + VALUES (?, ?, ?, CURRENT_TIMESTAMP)`, + [threadId, serializedState, serializedLastOutput] + ); + + logger.info(`Thread state saved successfully: ${threadId}`); + } catch (error) { + logger.error(`Error saving thread ${threadId}:`, error); + throw error; + } + }; + + const loadThread = async (threadId: string): Promise => { + try { + const db = await ensureConnection(); + logger.info(`Loading thread state for ${threadId}`); + + const row = await db.get( + 'SELECT state, last_output FROM threads WHERE thread_id = ?', + threadId + ); + + if (!row) { + logger.warn(`Thread ${threadId} not found`); + return null; + } + + const parsedState = JSON.parse(row.state); + const lastOutput = row.last_output ? JSON.parse(row.last_output) : undefined; + + // Ensure messages array exists + if (!parsedState.messages) { + logger.warn(`Invalid state structure for thread ${threadId}: missing messages array`); + return null; + } + + // Force messages into array if not already + const messageArray = Array.isArray(parsedState.messages) + ? parsedState.messages + : [parsedState.messages]; + + // Reconstruct message instances with improved error handling + const messages = messageArray.map((msg: any) => { + try { + if (!msg) { + logger.warn('Null or undefined message found'); + return new AIMessage({ content: 'Invalid message' }); + } + return deserializeMessage(msg); + } catch (error) { + logger.error('Error deserializing message:', { error, msg }); + return new AIMessage({ content: 'Error deserializing message' }); + } + }); + + // Ensure all required state properties exist + const state = { + messages, + reflectionScore: parsedState.reflectionScore ?? 0, + researchPerformed: parsedState.researchPerformed ?? false, + research: parsedState.research ?? '', + reflections: parsedState.reflections ?? [], + drafts: parsedState.drafts ?? [], + feedbackHistory: parsedState.feedbackHistory ?? [], + }; + + logger.info(`Thread state loaded successfully: ${threadId}`, { + messageCount: messages.length, + hasLastOutput: !!lastOutput + }); + + return { state, lastOutput }; + } catch (error) { + logger.error(`Error loading thread ${threadId}:`, error); + throw error; + } + }; + + const getAllThreads = async (): Promise> => { + try { + const db = await ensureConnection(); + const rows = await db.all( + 'SELECT thread_id, created_at, updated_at FROM threads ORDER BY updated_at DESC' + ); + + return rows.map(row => ({ + threadId: row.thread_id, + createdAt: row.created_at, + updatedAt: row.updated_at + })); + } catch (error) { + logger.error('Error listing threads:', error); + throw error; + } + }; + + const deleteThread = async (threadId: string): Promise => { + try { + const db = await ensureConnection(); + await db.run('DELETE FROM threads WHERE thread_id = ?', threadId); + logger.info(`Thread deleted: ${threadId}`); + } catch (error) { + logger.error(`Error deleting thread ${threadId}:`, error); + throw error; + } + }; + + const cleanup = async (olderThanDays: number = 30): Promise => { + try { + const db = await ensureConnection(); + const result = await db.run( + 'DELETE FROM threads WHERE updated_at < datetime("now", ?)', + [`-${olderThanDays} days`] + ); + + const deletedCount = result.changes || 0; + logger.info(`Cleaned up ${deletedCount} old threads`); + return deletedCount; + } catch (error) { + logger.error('Error during cleanup:', error); + throw error; + } + }; + + return { + saveThread, + loadThread, + getAllThreads, + deleteThread, + cleanup + }; +}; + +export type ThreadStorage = ReturnType; diff --git a/auto-content-creator/agents/src/services/writerAgent.ts b/auto-content-creator/agents/src/services/writerAgent.ts index 7316dca..4c9f0ca 100644 --- a/auto-content-creator/agents/src/services/writerAgent.ts +++ b/auto-content-creator/agents/src/services/writerAgent.ts @@ -9,6 +9,7 @@ import { generationSchema, reflectionSchema, researchDecisionSchema, humanFeedba import { generationPrompt, reflectionPrompt, researchDecisionPrompt, humanFeedbackPrompt } from '../prompts'; import logger from '../logger'; import { ChatPromptTemplate } from '@langchain/core/prompts'; +import { createThreadStorage } from './threadStorage'; // Create a list of tools and ToolNode instance const tools = [webSearchTool]; @@ -223,16 +224,34 @@ const workflow = new StateGraph(State) .addEdge('reflect', 'generate') .addEdge('processFeedback', 'generate'); -const app = workflow.compile({ checkpointer: new MemorySaver() }); +// Create a persistent memory store +const memoryStore = new MemorySaver(); -// Add these interfaces +// Update the workflow compilation to use the memory store +const app = workflow.compile({ + checkpointer: memoryStore +}); + +// Add a type for thread state management interface ThreadState { state: typeof State.State; - stream: any; + lastOutput?: WriterAgentOutput; } -// Add a map to store thread states -const threadStates = new Map(); +// Initialize thread storage as a singleton +const threadStorage = createThreadStorage(); + +// Ensure thread storage is initialized before starting the server +export const initializeStorage = async () => { + try { + // Just check if we can connect to the database + const threads = await threadStorage.getAllThreads(); + logger.info('Thread storage initialized successfully', { threadCount: threads.length }); + } catch (error) { + logger.error('Failed to initialize thread storage:', error); + throw error; + } +}; // Split the writerAgent into two exported functions export const writerAgent = { @@ -256,15 +275,21 @@ export const writerAgent = { feedbackHistory: [], }; - const stream = await app.stream(initialState, { configurable: { thread_id: threadId } }); + const config = { + configurable: { + thread_id: threadId + } + }; + + const stream = await app.stream(initialState, config); + const result = await processStream(stream); - // Store the state and stream for later use - threadStates.set(threadId, { + // Store thread state in persistent storage + await threadStorage.saveThread(threadId, { state: initialState, - stream, + lastOutput: result }); - const result = await processStream(stream); return { ...result, threadId }; }, @@ -275,32 +300,46 @@ export const writerAgent = { threadId: string; feedback: string; }): Promise { - const threadState = threadStates.get(threadId); + // Load thread state from storage + const threadState = await threadStorage.loadThread(threadId); if (!threadState) { - throw new Error('Thread not found'); + logger.error('Thread not found', { threadId }); + throw new Error(`Thread ${threadId} not found`); } logger.info('WriterAgent - Continuing draft with feedback', { threadId }); - // Add feedback to the existing state const newMessage = new HumanMessage({ content: `FEEDBACK: ${feedback}` }); - threadState.state.messages.push(newMessage); + const updatedState = { + ...threadState.state, + messages: [...threadState.state.messages, newMessage] + }; + + const config = { + configurable: { + thread_id: threadId + } + }; - // Continue the stream with the updated state - const stream = await app.stream(threadState.state, { configurable: { thread_id: threadId } }); + const stream = await app.stream(updatedState, config); + const result = await processStream(stream); - // Update the stored state - threadStates.set(threadId, { - state: threadState.state, - stream, + // Update thread state in storage + await threadStorage.saveThread(threadId, { + state: updatedState, + lastOutput: result }); - return await processStream(stream); + return result; + }, + + async getThreadState(threadId: string): Promise { + return await threadStorage.loadThread(threadId); } }; // Helper function to process the stream -async function processStream(stream: any): Promise { +const processStream = async (stream: any): Promise => { let finalContent = ''; let research = ''; let reflections: { critique: string; score: number }[] = []; diff --git a/auto-content-creator/agents/src/types.ts b/auto-content-creator/agents/src/types.ts index a0d6dbd..6dc2c63 100644 --- a/auto-content-creator/agents/src/types.ts +++ b/auto-content-creator/agents/src/types.ts @@ -1,3 +1,5 @@ +import { BaseMessage } from '@langchain/core/messages'; + export interface WriterAgentParams { category: string; topic: string; @@ -16,3 +18,16 @@ export interface WriterAgentOutput { drafts: string[]; feedbackHistory?: string[]; } + +export interface ThreadState { + state: { + messages: BaseMessage[]; + reflectionScore: number; + researchPerformed: boolean; + research: string; + reflections: Array<{ critique: string; score: number }>; + drafts: string[]; + feedbackHistory: string[]; + }; + lastOutput?: WriterAgentOutput; +} From e1b4e97d0630ee56b5e51006ac8ca056f4deaa68 Mon Sep 17 00:00:00 2001 From: Jeremy Frank Date: Sat, 26 Oct 2024 16:37:26 -0600 Subject: [PATCH 2/2] refactor threadStorage to be cleaner and more efficient --- .../agents/src/services/threadStorage.ts | 244 +++++++----------- 1 file changed, 94 insertions(+), 150 deletions(-) diff --git a/auto-content-creator/agents/src/services/threadStorage.ts b/auto-content-creator/agents/src/services/threadStorage.ts index 4b46774..c60a2fc 100644 --- a/auto-content-creator/agents/src/services/threadStorage.ts +++ b/auto-content-creator/agents/src/services/threadStorage.ts @@ -2,12 +2,54 @@ import sqlite3 from 'sqlite3'; import { open, Database } from 'sqlite'; import { ThreadState } from '../types'; import logger from '../logger'; -import { AIMessage, HumanMessage } from '@langchain/core/messages'; +import { AIMessage, HumanMessage, BaseMessage } from '@langchain/core/messages'; import path from 'path'; -// Database initialization +// Types +type MessageType = { + _type?: string; + content: string; + kwargs?: { content: string }; + id?: string[]; +}; + +type DbRow = { + state: string; + last_output: string | null; +}; + +// Pure functions for message handling +const createMessage = (type: 'human' | 'ai', content: string): BaseMessage => + type === 'human' ? new HumanMessage({ content }) : new AIMessage({ content }); + +const deserializeMessage = (msg: MessageType): BaseMessage => { + if (!msg) return createMessage('ai', 'Invalid message'); + + // Handle LangChain format + if (msg.kwargs?.content) { + return createMessage( + msg.id?.includes('HumanMessage') ? 'human' : 'ai', + msg.kwargs.content + ); + } + + // Handle simplified format + return createMessage( + msg._type === 'human' ? 'human' : 'ai', + msg.content + ); +}; + +const serializeMessage = (msg: BaseMessage) => ({ + _type: msg._getType(), + content: msg.content, + additional_kwargs: msg.additional_kwargs +}); + +// Database operations const initializeDb = async (dbPath: string): Promise => { logger.info('Initializing SQLite database at:', dbPath); + const db = await open({ filename: dbPath, driver: sqlite3.Database @@ -23,188 +65,101 @@ const initializeDb = async (dbPath: string): Promise => { ) `); - logger.info('Thread storage initialized with SQLite'); return db; }; -// Message deserialization -const deserializeMessage = (msg: any) => { - try { - // Handle LangChain serialized format - if (msg.kwargs) { - const content = msg.kwargs.content; - return msg.id.includes('HumanMessage') - ? new HumanMessage({ content }) - : new AIMessage({ content }); - } +// Thread state operations +const parseThreadState = (row: DbRow): ThreadState | null => { + const parsedState = JSON.parse(row.state); + const lastOutput = row.last_output ? JSON.parse(row.last_output) : undefined; - // Handle our simplified format - if (msg._type === 'human') { - return new HumanMessage({ content: msg.content }); - } - return new AIMessage({ content: msg.content }); - } catch (error) { - logger.error('Error deserializing message:', error); - return new AIMessage({ content: 'Error deserializing message' }); + if (!Array.isArray(parsedState.messages)) { + logger.warn('Invalid state: messages is not an array'); + return null; } + + return { + state: { + messages: parsedState.messages.map(deserializeMessage), + reflectionScore: parsedState.reflectionScore ?? 0, + researchPerformed: parsedState.researchPerformed ?? false, + research: parsedState.research ?? '', + reflections: parsedState.reflections ?? [], + drafts: parsedState.drafts ?? [], + feedbackHistory: parsedState.feedbackHistory ?? [], + }, + lastOutput + }; }; -// Thread storage operations +// Main storage factory export const createThreadStorage = () => { const dbPath = path.join(process.cwd(), 'thread-storage.sqlite'); - let dbPromise = initializeDb(dbPath); + const dbPromise = initializeDb(dbPath); - const ensureConnection = async (): Promise => { + const ensureConnection = async () => { const db = await dbPromise; - if (!db) { - throw new Error('Database connection not established'); - } + if (!db) throw new Error('Database connection not established'); return db; }; - const saveThread = async (threadId: string, state: ThreadState): Promise => { - try { + return { + async saveThread(threadId: string, state: ThreadState): Promise { const db = await ensureConnection(); - logger.info(`Saving thread state for ${threadId}`, { - messageCount: state.state.messages?.length ?? 0, - hasLastOutput: !!state.lastOutput - }); - // Ensure state has all required properties const stateToSave = { - messages: state.state.messages ?? [], - reflectionScore: state.state.reflectionScore ?? 0, - researchPerformed: state.state.researchPerformed ?? false, - research: state.state.research ?? '', - reflections: state.state.reflections ?? [], - drafts: state.state.drafts ?? [], - feedbackHistory: state.state.feedbackHistory ?? [], + ...state.state, + messages: state.state.messages.map(serializeMessage) }; - const serializedState = JSON.stringify(stateToSave, (key, value) => { - if (value instanceof HumanMessage || value instanceof AIMessage) { - return { - _type: value._getType(), - content: value.content, - additional_kwargs: value.additional_kwargs - }; - } - return value; - }); - - const serializedLastOutput = state.lastOutput ? JSON.stringify(state.lastOutput) : null; - await db.run( `INSERT OR REPLACE INTO threads (thread_id, state, last_output, updated_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP)`, - [threadId, serializedState, serializedLastOutput] + [ + threadId, + JSON.stringify(stateToSave), + state.lastOutput ? JSON.stringify(state.lastOutput) : null + ] ); - logger.info(`Thread state saved successfully: ${threadId}`); - } catch (error) { - logger.error(`Error saving thread ${threadId}:`, error); - throw error; - } - }; + logger.info(`Thread saved: ${threadId}`, { + messageCount: state.state.messages.length + }); + }, - const loadThread = async (threadId: string): Promise => { - try { + async loadThread(threadId: string): Promise { const db = await ensureConnection(); - logger.info(`Loading thread state for ${threadId}`); - const row = await db.get( + const row = await db.get( 'SELECT state, last_output FROM threads WHERE thread_id = ?', threadId ); if (!row) { - logger.warn(`Thread ${threadId} not found`); - return null; - } - - const parsedState = JSON.parse(row.state); - const lastOutput = row.last_output ? JSON.parse(row.last_output) : undefined; - - // Ensure messages array exists - if (!parsedState.messages) { - logger.warn(`Invalid state structure for thread ${threadId}: missing messages array`); + logger.warn(`Thread not found: ${threadId}`); return null; } - // Force messages into array if not already - const messageArray = Array.isArray(parsedState.messages) - ? parsedState.messages - : [parsedState.messages]; - - // Reconstruct message instances with improved error handling - const messages = messageArray.map((msg: any) => { - try { - if (!msg) { - logger.warn('Null or undefined message found'); - return new AIMessage({ content: 'Invalid message' }); - } - return deserializeMessage(msg); - } catch (error) { - logger.error('Error deserializing message:', { error, msg }); - return new AIMessage({ content: 'Error deserializing message' }); - } - }); - - // Ensure all required state properties exist - const state = { - messages, - reflectionScore: parsedState.reflectionScore ?? 0, - researchPerformed: parsedState.researchPerformed ?? false, - research: parsedState.research ?? '', - reflections: parsedState.reflections ?? [], - drafts: parsedState.drafts ?? [], - feedbackHistory: parsedState.feedbackHistory ?? [], - }; - - logger.info(`Thread state loaded successfully: ${threadId}`, { - messageCount: messages.length, - hasLastOutput: !!lastOutput - }); - - return { state, lastOutput }; - } catch (error) { - logger.error(`Error loading thread ${threadId}:`, error); - throw error; - } - }; + return parseThreadState(row); + }, - const getAllThreads = async (): Promise> => { - try { + async getAllThreads(): Promise> { const db = await ensureConnection(); - const rows = await db.all( + + return db.all( 'SELECT thread_id, created_at, updated_at FROM threads ORDER BY updated_at DESC' ); + }, - return rows.map(row => ({ - threadId: row.thread_id, - createdAt: row.created_at, - updatedAt: row.updated_at - })); - } catch (error) { - logger.error('Error listing threads:', error); - throw error; - } - }; - - const deleteThread = async (threadId: string): Promise => { - try { + async deleteThread(threadId: string): Promise { const db = await ensureConnection(); await db.run('DELETE FROM threads WHERE thread_id = ?', threadId); logger.info(`Thread deleted: ${threadId}`); - } catch (error) { - logger.error(`Error deleting thread ${threadId}:`, error); - throw error; - } - }; + }, - const cleanup = async (olderThanDays: number = 30): Promise => { - try { + async cleanup(olderThanDays: number = 30): Promise { const db = await ensureConnection(); + const result = await db.run( 'DELETE FROM threads WHERE updated_at < datetime("now", ?)', [`-${olderThanDays} days`] @@ -213,19 +168,8 @@ export const createThreadStorage = () => { const deletedCount = result.changes || 0; logger.info(`Cleaned up ${deletedCount} old threads`); return deletedCount; - } catch (error) { - logger.error('Error during cleanup:', error); - throw error; } }; - - return { - saveThread, - loadThread, - getAllThreads, - deleteThread, - cleanup - }; }; export type ThreadStorage = ReturnType;