Skip to content

Commit

Permalink
refactor threadStorage to be cleaner and more efficient
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrank-summit committed Oct 26, 2024
1 parent 8a46fcb commit e1b4e97
Showing 1 changed file with 94 additions and 150 deletions.
244 changes: 94 additions & 150 deletions auto-content-creator/agents/src/services/threadStorage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,54 @@ import sqlite3 from 'sqlite3';
import { open, Database } from 'sqlite';
import { ThreadState } from '../types';
import logger from '../logger';
import { AIMessage, HumanMessage } from '@langchain/core/messages';
import { AIMessage, HumanMessage, BaseMessage } from '@langchain/core/messages';
import path from 'path';

// Database initialization
// Types
type MessageType = {
_type?: string;
content: string;
kwargs?: { content: string };
id?: string[];
};

type DbRow = {
state: string;
last_output: string | null;
};

// Pure functions for message handling
const createMessage = (type: 'human' | 'ai', content: string): BaseMessage =>
type === 'human' ? new HumanMessage({ content }) : new AIMessage({ content });

const deserializeMessage = (msg: MessageType): BaseMessage => {
if (!msg) return createMessage('ai', 'Invalid message');

// Handle LangChain format
if (msg.kwargs?.content) {
return createMessage(
msg.id?.includes('HumanMessage') ? 'human' : 'ai',
msg.kwargs.content
);
}

// Handle simplified format
return createMessage(
msg._type === 'human' ? 'human' : 'ai',
msg.content
);
};

const serializeMessage = (msg: BaseMessage) => ({
_type: msg._getType(),
content: msg.content,
additional_kwargs: msg.additional_kwargs
});

// Database operations
const initializeDb = async (dbPath: string): Promise<Database> => {
logger.info('Initializing SQLite database at:', dbPath);

const db = await open({
filename: dbPath,
driver: sqlite3.Database
Expand All @@ -23,188 +65,101 @@ const initializeDb = async (dbPath: string): Promise<Database> => {
)
`);

logger.info('Thread storage initialized with SQLite');
return db;
};

// Message deserialization
const deserializeMessage = (msg: any) => {
try {
// Handle LangChain serialized format
if (msg.kwargs) {
const content = msg.kwargs.content;
return msg.id.includes('HumanMessage')
? new HumanMessage({ content })
: new AIMessage({ content });
}
// Thread state operations
const parseThreadState = (row: DbRow): ThreadState | null => {
const parsedState = JSON.parse(row.state);
const lastOutput = row.last_output ? JSON.parse(row.last_output) : undefined;

// Handle our simplified format
if (msg._type === 'human') {
return new HumanMessage({ content: msg.content });
}
return new AIMessage({ content: msg.content });
} catch (error) {
logger.error('Error deserializing message:', error);
return new AIMessage({ content: 'Error deserializing message' });
if (!Array.isArray(parsedState.messages)) {
logger.warn('Invalid state: messages is not an array');
return null;
}

return {
state: {
messages: parsedState.messages.map(deserializeMessage),
reflectionScore: parsedState.reflectionScore ?? 0,
researchPerformed: parsedState.researchPerformed ?? false,
research: parsedState.research ?? '',
reflections: parsedState.reflections ?? [],
drafts: parsedState.drafts ?? [],
feedbackHistory: parsedState.feedbackHistory ?? [],
},
lastOutput
};
};

// Thread storage operations
// Main storage factory
export const createThreadStorage = () => {
const dbPath = path.join(process.cwd(), 'thread-storage.sqlite');
let dbPromise = initializeDb(dbPath);
const dbPromise = initializeDb(dbPath);

const ensureConnection = async (): Promise<Database> => {
const ensureConnection = async () => {
const db = await dbPromise;
if (!db) {
throw new Error('Database connection not established');
}
if (!db) throw new Error('Database connection not established');
return db;
};

const saveThread = async (threadId: string, state: ThreadState): Promise<void> => {
try {
return {
async saveThread(threadId: string, state: ThreadState): Promise<void> {
const db = await ensureConnection();
logger.info(`Saving thread state for ${threadId}`, {
messageCount: state.state.messages?.length ?? 0,
hasLastOutput: !!state.lastOutput
});

// Ensure state has all required properties
const stateToSave = {
messages: state.state.messages ?? [],
reflectionScore: state.state.reflectionScore ?? 0,
researchPerformed: state.state.researchPerformed ?? false,
research: state.state.research ?? '',
reflections: state.state.reflections ?? [],
drafts: state.state.drafts ?? [],
feedbackHistory: state.state.feedbackHistory ?? [],
...state.state,
messages: state.state.messages.map(serializeMessage)
};

const serializedState = JSON.stringify(stateToSave, (key, value) => {
if (value instanceof HumanMessage || value instanceof AIMessage) {
return {
_type: value._getType(),
content: value.content,
additional_kwargs: value.additional_kwargs
};
}
return value;
});

const serializedLastOutput = state.lastOutput ? JSON.stringify(state.lastOutput) : null;

await db.run(
`INSERT OR REPLACE INTO threads (thread_id, state, last_output, updated_at)
VALUES (?, ?, ?, CURRENT_TIMESTAMP)`,
[threadId, serializedState, serializedLastOutput]
[
threadId,
JSON.stringify(stateToSave),
state.lastOutput ? JSON.stringify(state.lastOutput) : null
]
);

logger.info(`Thread state saved successfully: ${threadId}`);
} catch (error) {
logger.error(`Error saving thread ${threadId}:`, error);
throw error;
}
};
logger.info(`Thread saved: ${threadId}`, {
messageCount: state.state.messages.length
});
},

const loadThread = async (threadId: string): Promise<ThreadState | null> => {
try {
async loadThread(threadId: string): Promise<ThreadState | null> {
const db = await ensureConnection();
logger.info(`Loading thread state for ${threadId}`);

const row = await db.get(
const row = await db.get<DbRow>(
'SELECT state, last_output FROM threads WHERE thread_id = ?',
threadId
);

if (!row) {
logger.warn(`Thread ${threadId} not found`);
return null;
}

const parsedState = JSON.parse(row.state);
const lastOutput = row.last_output ? JSON.parse(row.last_output) : undefined;

// Ensure messages array exists
if (!parsedState.messages) {
logger.warn(`Invalid state structure for thread ${threadId}: missing messages array`);
logger.warn(`Thread not found: ${threadId}`);
return null;
}

// Force messages into array if not already
const messageArray = Array.isArray(parsedState.messages)
? parsedState.messages
: [parsedState.messages];

// Reconstruct message instances with improved error handling
const messages = messageArray.map((msg: any) => {
try {
if (!msg) {
logger.warn('Null or undefined message found');
return new AIMessage({ content: 'Invalid message' });
}
return deserializeMessage(msg);
} catch (error) {
logger.error('Error deserializing message:', { error, msg });
return new AIMessage({ content: 'Error deserializing message' });
}
});

// Ensure all required state properties exist
const state = {
messages,
reflectionScore: parsedState.reflectionScore ?? 0,
researchPerformed: parsedState.researchPerformed ?? false,
research: parsedState.research ?? '',
reflections: parsedState.reflections ?? [],
drafts: parsedState.drafts ?? [],
feedbackHistory: parsedState.feedbackHistory ?? [],
};

logger.info(`Thread state loaded successfully: ${threadId}`, {
messageCount: messages.length,
hasLastOutput: !!lastOutput
});

return { state, lastOutput };
} catch (error) {
logger.error(`Error loading thread ${threadId}:`, error);
throw error;
}
};
return parseThreadState(row);
},

const getAllThreads = async (): Promise<Array<{ threadId: string; createdAt: string; updatedAt: string }>> => {
try {
async getAllThreads(): Promise<Array<{ threadId: string; createdAt: string; updatedAt: string }>> {
const db = await ensureConnection();
const rows = await db.all(

return db.all(
'SELECT thread_id, created_at, updated_at FROM threads ORDER BY updated_at DESC'
);
},

return rows.map(row => ({
threadId: row.thread_id,
createdAt: row.created_at,
updatedAt: row.updated_at
}));
} catch (error) {
logger.error('Error listing threads:', error);
throw error;
}
};

const deleteThread = async (threadId: string): Promise<void> => {
try {
async deleteThread(threadId: string): Promise<void> {
const db = await ensureConnection();
await db.run('DELETE FROM threads WHERE thread_id = ?', threadId);
logger.info(`Thread deleted: ${threadId}`);
} catch (error) {
logger.error(`Error deleting thread ${threadId}:`, error);
throw error;
}
};
},

const cleanup = async (olderThanDays: number = 30): Promise<number> => {
try {
async cleanup(olderThanDays: number = 30): Promise<number> {
const db = await ensureConnection();

const result = await db.run(
'DELETE FROM threads WHERE updated_at < datetime("now", ?)',
[`-${olderThanDays} days`]
Expand All @@ -213,19 +168,8 @@ export const createThreadStorage = () => {
const deletedCount = result.changes || 0;
logger.info(`Cleaned up ${deletedCount} old threads`);
return deletedCount;
} catch (error) {
logger.error('Error during cleanup:', error);
throw error;
}
};

return {
saveThread,
loadThread,
getAllThreads,
deleteThread,
cleanup
};
};

export type ThreadStorage = ReturnType<typeof createThreadStorage>;

0 comments on commit e1b4e97

Please sign in to comment.