Skip to content

Commit

Permalink
Merge pull request #14 from autonomys/thread-persistence
Browse files Browse the repository at this point in the history
Add thread persistence to agents
  • Loading branch information
jfrank-summit authored Oct 26, 2024
2 parents d278d13 + e1b4e97 commit a4d2487
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 28 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,8 @@ node_modules/
# SQLite database files
*.sqlite
*.sqlite3
*.db
*.db

# SQLite database
*.sqlite
*.sqlite-journal
4 changes: 3 additions & 1 deletion auto-content-creator/agents/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
"dotenv": "^16.4.1",
"express": "5.0.0",
"winston": "^3.14.2",
"zod": "^3.23.8"
"zod": "^3.23.8",
"sqlite3": "^5.1.7",
"sqlite": "^5.1.1"
},
"devDependencies": {
"@types/express": "5.0.0",
Expand Down
19 changes: 15 additions & 4 deletions auto-content-creator/agents/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import express, { Request, Response, NextFunction } from 'express';
import { writerRouter } from './routes/writer';
import logger from './logger';
import { config } from './config';
import { initializeStorage } from './services/writerAgent';

// Create an Express application
const app = express();
Expand All @@ -28,7 +29,17 @@ app.use((err: Error, req: Request, res: Response, next: NextFunction) => {
});
});

// Start server
app.listen(port, () => {
logger.info(`Agents service is running on http://localhost:${port}`);
});
// Initialize storage before starting the server
const startServer = async () => {
try {
await initializeStorage();
app.listen(port, () => {
logger.info(`Agents service is running on http://localhost:${port}`);
});
} catch (error) {
logger.error('Failed to start server:', error);
process.exit(1);
}
};

startServer();
25 changes: 25 additions & 0 deletions auto-content-creator/agents/src/routes/writer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,29 @@ router.post('/:threadId/feedback', (req, res, next) => {
})();
});

// Add new endpoint to check thread state
router.get('/:threadId/state', (req, res, next) => {
logger.info('Received request to get thread state:', req.params.threadId);
(async () => {
try {
const threadId = req.params.threadId;
const threadState = await writerAgent.getThreadState(threadId);

if (!threadState) {
return res.status(404).json({
error: 'Thread not found'
});
}

res.json({
threadId,
lastOutput: threadState.lastOutput,
});
} catch (error) {
logger.error('Error getting thread state:', error);
next(error);
}
})();
});

export const writerRouter = router;
175 changes: 175 additions & 0 deletions auto-content-creator/agents/src/services/threadStorage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import sqlite3 from 'sqlite3';
import { open, Database } from 'sqlite';
import { ThreadState } from '../types';
import logger from '../logger';
import { AIMessage, HumanMessage, BaseMessage } from '@langchain/core/messages';
import path from 'path';

// Types
type MessageType = {
_type?: string;
content: string;
kwargs?: { content: string };
id?: string[];
};

type DbRow = {
state: string;
last_output: string | null;
};

// Pure functions for message handling
const createMessage = (type: 'human' | 'ai', content: string): BaseMessage =>
type === 'human' ? new HumanMessage({ content }) : new AIMessage({ content });

const deserializeMessage = (msg: MessageType): BaseMessage => {
if (!msg) return createMessage('ai', 'Invalid message');

// Handle LangChain format
if (msg.kwargs?.content) {
return createMessage(
msg.id?.includes('HumanMessage') ? 'human' : 'ai',
msg.kwargs.content
);
}

// Handle simplified format
return createMessage(
msg._type === 'human' ? 'human' : 'ai',
msg.content
);
};

const serializeMessage = (msg: BaseMessage) => ({
_type: msg._getType(),
content: msg.content,
additional_kwargs: msg.additional_kwargs
});

// Database operations
const initializeDb = async (dbPath: string): Promise<Database> => {
logger.info('Initializing SQLite database at:', dbPath);

const db = await open({
filename: dbPath,
driver: sqlite3.Database
});

await db.exec(`
CREATE TABLE IF NOT EXISTS threads (
thread_id TEXT PRIMARY KEY,
state TEXT NOT NULL,
last_output TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`);

return db;
};

// Thread state operations
const parseThreadState = (row: DbRow): ThreadState | null => {
const parsedState = JSON.parse(row.state);
const lastOutput = row.last_output ? JSON.parse(row.last_output) : undefined;

if (!Array.isArray(parsedState.messages)) {
logger.warn('Invalid state: messages is not an array');
return null;
}

return {
state: {
messages: parsedState.messages.map(deserializeMessage),
reflectionScore: parsedState.reflectionScore ?? 0,
researchPerformed: parsedState.researchPerformed ?? false,
research: parsedState.research ?? '',
reflections: parsedState.reflections ?? [],
drafts: parsedState.drafts ?? [],
feedbackHistory: parsedState.feedbackHistory ?? [],
},
lastOutput
};
};

// Main storage factory
export const createThreadStorage = () => {
const dbPath = path.join(process.cwd(), 'thread-storage.sqlite');
const dbPromise = initializeDb(dbPath);

const ensureConnection = async () => {
const db = await dbPromise;
if (!db) throw new Error('Database connection not established');
return db;
};

return {
async saveThread(threadId: string, state: ThreadState): Promise<void> {
const db = await ensureConnection();

const stateToSave = {
...state.state,
messages: state.state.messages.map(serializeMessage)
};

await db.run(
`INSERT OR REPLACE INTO threads (thread_id, state, last_output, updated_at)
VALUES (?, ?, ?, CURRENT_TIMESTAMP)`,
[
threadId,
JSON.stringify(stateToSave),
state.lastOutput ? JSON.stringify(state.lastOutput) : null
]
);

logger.info(`Thread saved: ${threadId}`, {
messageCount: state.state.messages.length
});
},

async loadThread(threadId: string): Promise<ThreadState | null> {
const db = await ensureConnection();

const row = await db.get<DbRow>(
'SELECT state, last_output FROM threads WHERE thread_id = ?',
threadId
);

if (!row) {
logger.warn(`Thread not found: ${threadId}`);
return null;
}

return parseThreadState(row);
},

async getAllThreads(): Promise<Array<{ threadId: string; createdAt: string; updatedAt: string }>> {
const db = await ensureConnection();

return db.all(
'SELECT thread_id, created_at, updated_at FROM threads ORDER BY updated_at DESC'
);
},

async deleteThread(threadId: string): Promise<void> {
const db = await ensureConnection();
await db.run('DELETE FROM threads WHERE thread_id = ?', threadId);
logger.info(`Thread deleted: ${threadId}`);
},

async cleanup(olderThanDays: number = 30): Promise<number> {
const db = await ensureConnection();

const result = await db.run(
'DELETE FROM threads WHERE updated_at < datetime("now", ?)',
[`-${olderThanDays} days`]
);

const deletedCount = result.changes || 0;
logger.info(`Cleaned up ${deletedCount} old threads`);
return deletedCount;
}
};
};

export type ThreadStorage = ReturnType<typeof createThreadStorage>;
83 changes: 61 additions & 22 deletions auto-content-creator/agents/src/services/writerAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { generationSchema, reflectionSchema, researchDecisionSchema, humanFeedba
import { generationPrompt, reflectionPrompt, researchDecisionPrompt, humanFeedbackPrompt } from '../prompts';
import logger from '../logger';
import { ChatPromptTemplate } from '@langchain/core/prompts';
import { createThreadStorage } from './threadStorage';

// Create a list of tools and ToolNode instance
const tools = [webSearchTool];
Expand Down Expand Up @@ -223,16 +224,34 @@ const workflow = new StateGraph(State)
.addEdge('reflect', 'generate')
.addEdge('processFeedback', 'generate');

const app = workflow.compile({ checkpointer: new MemorySaver() });
// Create a persistent memory store
const memoryStore = new MemorySaver();

// Add these interfaces
// Update the workflow compilation to use the memory store
const app = workflow.compile({
checkpointer: memoryStore
});

// Add a type for thread state management
interface ThreadState {
state: typeof State.State;
stream: any;
lastOutput?: WriterAgentOutput;
}

// Add a map to store thread states
const threadStates = new Map<string, ThreadState>();
// Initialize thread storage as a singleton
const threadStorage = createThreadStorage();

// Ensure thread storage is initialized before starting the server
export const initializeStorage = async () => {
try {
// Just check if we can connect to the database
const threads = await threadStorage.getAllThreads();
logger.info('Thread storage initialized successfully', { threadCount: threads.length });
} catch (error) {
logger.error('Failed to initialize thread storage:', error);
throw error;
}
};

// Split the writerAgent into two exported functions
export const writerAgent = {
Expand All @@ -256,15 +275,21 @@ export const writerAgent = {
feedbackHistory: [],
};

const stream = await app.stream(initialState, { configurable: { thread_id: threadId } });
const config = {
configurable: {
thread_id: threadId
}
};

const stream = await app.stream(initialState, config);
const result = await processStream(stream);

// Store the state and stream for later use
threadStates.set(threadId, {
// Store thread state in persistent storage
await threadStorage.saveThread(threadId, {
state: initialState,
stream,
lastOutput: result
});

const result = await processStream(stream);
return { ...result, threadId };
},

Expand All @@ -275,32 +300,46 @@ export const writerAgent = {
threadId: string;
feedback: string;
}): Promise<WriterAgentOutput> {
const threadState = threadStates.get(threadId);
// Load thread state from storage
const threadState = await threadStorage.loadThread(threadId);
if (!threadState) {
throw new Error('Thread not found');
logger.error('Thread not found', { threadId });
throw new Error(`Thread ${threadId} not found`);
}

logger.info('WriterAgent - Continuing draft with feedback', { threadId });

// Add feedback to the existing state
const newMessage = new HumanMessage({ content: `FEEDBACK: ${feedback}` });
threadState.state.messages.push(newMessage);
const updatedState = {
...threadState.state,
messages: [...threadState.state.messages, newMessage]
};

const config = {
configurable: {
thread_id: threadId
}
};

// Continue the stream with the updated state
const stream = await app.stream(threadState.state, { configurable: { thread_id: threadId } });
const stream = await app.stream(updatedState, config);
const result = await processStream(stream);

// Update the stored state
threadStates.set(threadId, {
state: threadState.state,
stream,
// Update thread state in storage
await threadStorage.saveThread(threadId, {
state: updatedState,
lastOutput: result
});

return await processStream(stream);
return result;
},

async getThreadState(threadId: string): Promise<ThreadState | null> {
return await threadStorage.loadThread(threadId);
}
};

// Helper function to process the stream
async function processStream(stream: any): Promise<WriterAgentOutput> {
const processStream = async (stream: any): Promise<WriterAgentOutput> => {
let finalContent = '';
let research = '';
let reflections: { critique: string; score: number }[] = [];
Expand Down
Loading

0 comments on commit a4d2487

Please sign in to comment.