Skip to content

Commit

Permalink
using existing function for file type check, better wording
Browse files Browse the repository at this point in the history
  • Loading branch information
overmode committed Jan 8, 2025
1 parent 7937634 commit 972d877
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 146 deletions.
252 changes: 108 additions & 144 deletions front/lib/api/files/upsert.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { isSupportedPlainTextContentType } from "@dust-tt/client";
import type {
FileUseCase,
Result,
Expand Down Expand Up @@ -77,142 +78,124 @@ async function generateSnippet(
return new Ok(snippet);
}

switch (file.contentType) {
case "application/msword":
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
case "application/pdf":
case "text/plain":
case "text/markdown":
case "text/html":
case "text/xml":
case "text/calendar":
case "text/css":
case "text/javascript":
case "application/json":
case "application/xml":
case "application/x-sh":
case "text/vnd.dust.attachment.slack.thread":
if (!ENABLE_LLM_SNIPPETS) {
// Take the first 256 characters
if (content.length > 256) {
return new Ok(content.slice(0, 242) + "... (truncated)");
} else {
return new Ok(content);
}
}

const model = getSmallWhitelistedModel(owner);
if (!model) {
return new Err(
new Error(`Failed to find a whitelisted model to generate title`)
);
if (isSupportedPlainTextContentType(file.contentType)) {
if (!ENABLE_LLM_SNIPPETS) {
// Take the first 256 characters
if (content.length > 256) {
return new Ok(content.slice(0, 242) + "... (truncated)");
} else {
return new Ok(content);
}
}

const appConfig = cloneBaseConfig(
DustProdActionRegistry["conversation-file-summarizer"].config
const model = getSmallWhitelistedModel(owner);
if (!model) {
return new Err(
new Error(`Failed to find a whitelisted model to generate title`)
);
appConfig.MODEL.provider_id = model.providerId;
appConfig.MODEL.model_id = model.modelId;

const coreAPI = new CoreAPI(config.getCoreAPIConfig(), logger);
const resTokenize = await coreAPI.tokenize({
text: content,
providerId: model.providerId,
modelId: model.modelId,
});

if (resTokenize.isErr()) {
return new Err(
new Error(
`Error tokenizing content: ${resTokenize.error.code} ${resTokenize.error.message}`
)
);
}
}

const tokensCount = resTokenize.value.tokens.length;
const allowedTokens = model.contextSize * 0.9;
if (tokensCount > allowedTokens) {
// Truncate the content to the context size * 0.9 using cross product
const truncateLength = Math.floor(
(allowedTokens * content.length) / tokensCount
);
const appConfig = cloneBaseConfig(
DustProdActionRegistry["conversation-file-summarizer"].config
);
appConfig.MODEL.provider_id = model.providerId;
appConfig.MODEL.model_id = model.modelId;

const coreAPI = new CoreAPI(config.getCoreAPIConfig(), logger);
const resTokenize = await coreAPI.tokenize({
text: content,
providerId: model.providerId,
modelId: model.modelId,
});

logger.warn(
{
tokensCount,
contentLength: content.length,
contextSize: model.contextSize,
},
`Truncating content to ${truncateLength} characters`
);
if (resTokenize.isErr()) {
return new Err(
new Error(
`Error tokenizing content: ${resTokenize.error.code} ${resTokenize.error.message}`
)
);
}

content = content.slice(0, truncateLength);
}
const tokensCount = resTokenize.value.tokens.length;
const allowedTokens = model.contextSize * 0.9;
if (tokensCount > allowedTokens) {
// Truncate the content to the context size * 0.9 using cross product
const truncateLength = Math.floor(
(allowedTokens * content.length) / tokensCount
);

const res = await runAction(
auth,
"conversation-file-summarizer",
appConfig,
[
{
content: content,
},
]
logger.warn(
{
tokensCount,
contentLength: content.length,
contextSize: model.contextSize,
},
`Truncating content to ${truncateLength} characters`
);

if (res.isErr()) {
return new Err(
new Error(
`Error generating snippet: ${res.error.type} ${res.error.message}`
)
);
}
content = content.slice(0, truncateLength);
}

const {
status: { run },
traces,
results,
} = res.value;
const res = await runAction(
auth,
"conversation-file-summarizer",
appConfig,
[
{
content: content,
},
]
);

switch (run) {
case "errored":
const error = removeNulls(traces.map((t) => t[1][0][0].error)).join(
", "
);
return new Err(new Error(`Error generating snippet: ${error}`));
case "succeeded":
if (!results || results.length === 0) {
return new Err(
new Error(
`Error generating snippet: no results returned while run was successful`
)
);
}
const snippet = results[0][0].value as string;
const endTime = Date.now();
logger.info(
{
workspaceId: owner.sId,
fileId: file.sId,
},
`Snippet generation took ${endTime - startTime}ms`
);
if (res.isErr()) {
return new Err(
new Error(
`Error generating snippet: ${res.error.type} ${res.error.message}`
)
);
}

return new Ok(snippet);
case "running":
const {
status: { run },
traces,
results,
} = res.value;

switch (run) {
case "errored":
const error = removeNulls(traces.map((t) => t[1][0][0].error)).join(
", "
);
return new Err(new Error(`Error generating snippet: ${error}`));
case "succeeded":
if (!results || results.length === 0) {
return new Err(
new Error(
`Snippet generation is still running, should never happen.`
`Error generating snippet: no results returned while run was successful`
)
);
default:
assertNever(run);
}
break;
}
const snippet = results[0][0].value as string;
const endTime = Date.now();
logger.info(
{
workspaceId: owner.sId,
fileId: file.sId,
},
`Snippet generation took ${endTime - startTime}ms`
);

default:
assertNever(file.contentType);
return new Ok(snippet);
case "running":
return new Err(
new Error(`Snippet generation is still running, should never happen.`)
);
default:
assertNever(run);
}
}

return new Err(new Error("Unsupported file type"));
}

// Upload to dataSource
Expand Down Expand Up @@ -334,32 +317,13 @@ const getProcessingFunction = ({
}
}

switch (contentType) {
case "application/msword":
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
case "application/pdf":
case "text/markdown":
case "text/plain":
case "text/vnd.dust.attachment.slack.thread":
case "text/html":
case "text/xml":
case "text/calendar":
case "text/css":
case "text/javascript":
case "application/json":
case "application/xml":
case "application/x-sh":
if (
useCase === "conversation" ||
useCase === "tool_output" ||
useCase === "folder_document"
) {
return upsertDocumentToDatasource;
}
break;

default:
assertNever(contentType);
if (
isSupportedPlainTextContentType(contentType) &&
(useCase === "conversation" ||
useCase === "tool_output" ||
useCase === "folder_document")
) {
return upsertDocumentToDatasource;
}

return undefined;
Expand Down
3 changes: 1 addition & 2 deletions front/lib/swr/file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ export function useUpsertFileAsDatasourceEntry(
sendNotification({
type: "success",
title: "File successfully uploaded",
description:
"The file has been successfully uploaded to the data source.",
description: "The file has been successfully uploaded.",
});

const response: UpsertFileToDataSourceResponseBody = await res.json();
Expand Down

0 comments on commit 972d877

Please sign in to comment.