Skip to content

Commit

Permalink
Support phi-2, add custom conv templates
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Wang <[email protected]>
  • Loading branch information
xiaohk committed Jan 23, 2024
1 parent d61778a commit 9eb8ec2
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 50 deletions.
48 changes: 24 additions & 24 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,29 @@
"preview": "vite preview"
},
"dependencies": {
"lit": "^3.0.0"
"lit": "^3.1.1"
},
"devDependencies": {
"@floating-ui/dom": "^1.5.3",
"@mlc-ai/web-llm": "^0.2.16",
"@tiptap/core": "^2.1.12",
"@tiptap/extension-bubble-menu": "^2.1.12",
"@tiptap/extension-document": "^2.1.12",
"@tiptap/extension-highlight": "^2.1.12",
"@tiptap/extension-paragraph": "^2.1.12",
"@tiptap/extension-placeholder": "^2.1.13",
"@tiptap/extension-text": "^2.1.12",
"@tiptap/pm": "^2.1.12",
"@tiptap/starter-kit": "^2.1.12",
"@floating-ui/dom": "^1.5.4",
"@mlc-ai/web-llm": "^0.2.17",
"@tiptap/core": "^2.1.16",
"@tiptap/extension-bubble-menu": "^2.1.16",
"@tiptap/extension-document": "^2.1.16",
"@tiptap/extension-highlight": "^2.1.16",
"@tiptap/extension-paragraph": "^2.1.16",
"@tiptap/extension-placeholder": "^2.1.16",
"@tiptap/extension-text": "^2.1.16",
"@tiptap/pm": "^2.1.16",
"@tiptap/starter-kit": "^2.1.16",
"@types/d3-array": "^3.2.1",
"@types/d3-format": "^3.0.4",
"@types/d3-random": "^3.0.3",
"@types/d3-time-format": "^4.0.3",
"@types/diff": "^5.0.7",
"@types/diff-match-patch": "^1.0.35",
"@types/diff": "^5.0.9",
"@types/diff-match-patch": "^1.0.36",
"@types/emoji-regex": "^9.2.0",
"@types/uuid": "^9.0.7",
"@typescript-eslint/eslint-plugin": "^6.9.0",
"@typescript-eslint/eslint-plugin": "^6.19.1",
"@webgpu/types": "^0.1.40",
"@xiaohk/utils": "^0.0.6",
"d3-array": "^3.2.4",
Expand All @@ -44,19 +44,19 @@
"diff": "^5.1.0",
"diff-match-patch": "^1.0.5",
"emoji-regex": "^10.3.0",
"eslint": "^8.52.0",
"eslint-config-prettier": "^9.0.0",
"eslint-plugin-lit": "^1.10.1",
"eslint-plugin-prettier": "^5.0.1",
"eslint": "^8.56.0",
"eslint-config-prettier": "^9.1.0",
"eslint-plugin-lit": "^1.11.0",
"eslint-plugin-prettier": "^5.1.3",
"eslint-plugin-wc": "^2.0.4",
"gh-pages": "^6.0.0",
"gh-pages": "^6.1.1",
"idb-keyval": "^6.2.1",
"prettier": "^3.0.3",
"prettier": "^3.2.4",
"tippy.js": "^6.3.7",
"typescript": "^5.2.2",
"typescript": "^5.3.3",
"uuid": "^9.0.1",
"vite": "^4.5.0",
"vite-plugin-dts": "^3.6.1",
"vite": "^4.5.2",
"vite-plugin-dts": "^3.7.1",
"vite-plugin-web-components-hmr": "^0.1.3"
}
}
4 changes: 2 additions & 2 deletions src/components/panel-setting/panel-setting.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ const apiKeyDescriptionMap: Record<ModelFamily, TemplateResult> = {

const localModelSizeMap: Record<SupportedLocalModel, string> = {
[SupportedLocalModel['tinyllama-1.1b']]: '630 MB',
[SupportedLocalModel['llama-2-7b']]: '3.6 GB'
[SupportedLocalModel['llama-2-7b']]: '3.6 GB',
[SupportedLocalModel['phi-2']]: '1.5 GB'
// [SupportedLocalModel['gpt-2']]: '311 MB'
// [SupportedLocalModel['mistral-7b-v0.2']]: '3.5 GB'
// [SupportedLocalModel['phi-2']]: '1.5 GB'
};

const LOCAL_MODEL_MESSAGES = {
Expand Down
2 changes: 1 addition & 1 deletion src/components/text-editor/text-editor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -994,9 +994,9 @@ export class WordflowTextEditor extends LitElement {
break;
}

// case SupportedLocalModel['phi-2']:
// case SupportedLocalModel['mistral-7b-v0.2']:
// case SupportedLocalModel['gpt-2']:
case SupportedLocalModel['phi-2']:
case SupportedLocalModel['llama-2-7b']:
case SupportedLocalModel['tinyllama-1.1b']: {
runRequest = new Promise<TextGenMessage>(resolve => {
Expand Down
2 changes: 1 addition & 1 deletion src/components/wordflow/user-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ export const supportedModelReverseLookup: Record<
[SupportedRemoteModel['gemini-pro']]: 'gemini-pro',
[SupportedLocalModel['tinyllama-1.1b']]: 'tinyllama-1.1b',
[SupportedLocalModel['llama-2-7b']]: 'llama-2-7b',
[SupportedLocalModel['phi-2']]: 'phi-2'
// [SupportedLocalModel['gpt-2']]: 'gpt-2'
// [SupportedLocalModel['mistral-7b-v0.2']]: 'mistral-7b-v0.2'
[SupportedLocalModel['phi-2']]: 'phi-2'
};

export enum ModelFamily {
Expand Down
4 changes: 2 additions & 2 deletions src/config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ const customColors = {

const urls = {
wordflowEndpoint:
// 'https://62uqq9jku8.execute-api.us-east-1.amazonaws.com/prod/records'
'https://p6gnc71v0w.execute-api.localhost.localstack.cloud:4566/prod/records'
'https://62uqq9jku8.execute-api.us-east-1.amazonaws.com/prod/records'
// 'https://p6gnc71v0w.execute-api.localhost.localstack.cloud:4566/prod/records'
};

const colors = {
Expand Down
85 changes: 65 additions & 20 deletions src/llms/web-llm.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import * as webllm from '@mlc-ai/web-llm';
import { SupportedLocalModel } from '../components/wordflow/user-config';
import type { TextGenWorkerMessage } from '../types/common-types';
import type { ConvTemplateConfig } from '@mlc-ai/web-llm/lib/config';

export type TextGenLocalWorkerMessage =
| TextGenWorkerMessage
Expand Down Expand Up @@ -29,7 +30,7 @@ export type TextGenLocalWorkerMessage =
//==========================================================================||
// Worker Initialization ||
//==========================================================================||
const appConfig: webllm.AppConfig = {
const APP_CONFIGS: webllm.AppConfig = {
model_list: [
{
model_url:
Expand All @@ -45,35 +46,74 @@ const appConfig: webllm.AppConfig = {
model_lib_url:
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-ctx1k-webgpu.wasm'
},
// {
// model_url: 'https://huggingface.co/mlc-ai/gpt2-q0f16-MLC/resolve/main/',
// local_id: 'gpt2-q0f16',
// model_lib_url:
// 'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/gpt2/gpt2-q0f16-ctx1k-webgpu.wasm'
// }
// {
// model_url:
// 'https://huggingface.co/mlc-ai/Mistral-7B-Instruct-v0.2-q3f16_1-MLC/resolve/main/',
// local_id: 'Mistral-7B-Instruct-v0.2-q3f16_1',
// model_lib_url:
// 'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm'
// }
{
model_url: 'https://huggingface.co/mlc-ai/gpt2-q0f16-MLC/resolve/main/',
local_id: 'gpt2-q0f16',
model_lib_url:
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/gpt2/gpt2-q0f16-ctx1k-webgpu.wasm'
},
{
model_url:
'https://huggingface.co/mlc-ai/Mistral-7B-Instruct-v0.2-q3f16_1-MLC/resolve/main/',
local_id: 'Mistral-7B-Instruct-v0.2-q3f16_1',
model_lib_url:
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm'
},
{
model_url:
'https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC/resolve/main/',
local_id: 'phi-2-q4f16_1',
local_id: 'Phi2-q4f16_1',
model_lib_url:
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/phi-2/phi-2-q0f16-ctx2k-webgpu.wasm'
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/phi-2/phi-2-q4f16_1-ctx2k-webgpu.wasm',
vram_required_MB: 3053.97,
low_resource_required: false,
required_features: ['shader-f16']
}
]
};

const CONV_TEMPLATES: Record<
SupportedLocalModel,
Partial<ConvTemplateConfig>
> = {
[SupportedLocalModel['tinyllama-1.1b']]: {
system: '<|im_start|><|im_end|> ',
roles: ['<|im_start|>user', '<|im_start|>assistant'],
offset: 0,
seps: ['', ''],
separator_style: 'Two',
stop_str: '<|im_end|>',
add_bos: false,
stop_tokens: [2]
},
[SupportedLocalModel['llama-2-7b']]: {
system: '[INST] <<SYS>><</SYS>>\n\n ',
roles: ['[INST]', '[/INST]'],
offset: 0,
seps: [' ', ' '],
separator_style: 'Two',
stop_str: '[INST]',
add_bos: true,
stop_tokens: [2]
},
[SupportedLocalModel['phi-2']]: {
system: '',
roles: ['Instruct', 'Output'],
offset: 0,
seps: ['\n'],
separator_style: 'Two',
stop_str: '<|endoftext|>',
add_bos: false,
stop_tokens: [50256]
}
};

const modelMap: Record<SupportedLocalModel, string> = {
[SupportedLocalModel['tinyllama-1.1b']]: 'TinyLlama-1.1B-Chat-v0.4-q4f16_1',
[SupportedLocalModel['llama-2-7b']]: 'Llama-2-7b-chat-hf-q4f16_1',
[SupportedLocalModel['phi-2']]: 'Phi2-q4f16_1'
// [SupportedLocalModel['gpt-2']]: 'gpt2-q0f16'
// [SupportedLocalModel['mistral-7b-v0.2']]: 'Mistral-7B-Instruct-v0.2-q3f16_1'
[SupportedLocalModel['phi-2']]: 'phi-2-q4f16_1'
};

const chat = new webllm.ChatModule();
Expand Down Expand Up @@ -136,9 +176,11 @@ const startLoadModel = async (
_temperature = temperature;
const curModel = modelMap[model];
const chatOption: webllm.ChatOptions = {
temperature: temperature
temperature: temperature,
conv_config: CONV_TEMPLATES[model],
conv_template: 'custom'
};
_modelLoadingComplete = chat.reload(curModel, chatOption, appConfig);
_modelLoadingComplete = chat.reload(curModel, chatOption, APP_CONFIGS);
await _modelLoadingComplete;

try {
Expand Down Expand Up @@ -178,6 +220,9 @@ const startTextGen = async (prompt: string, temperature: number) => {

const response = await chat.generate(prompt);

// Reset the chat cache to avoid memorizing previous messages
await chat.resetChat();

// Send back the data to the main thread
const message: TextGenLocalWorkerMessage = {
command: 'finishTextGen',
Expand Down Expand Up @@ -210,7 +255,7 @@ const startTextGen = async (prompt: string, temperature: number) => {

export const hasLocalModelInCache = async (model: SupportedLocalModel) => {
const curModel = modelMap[model];
const inCache = await webllm.hasModelInCache(curModel, appConfig);
const inCache = await webllm.hasModelInCache(curModel, APP_CONFIGS);
return inCache;
};

Expand Down

0 comments on commit 9eb8ec2

Please sign in to comment.