Skip to content

Commit

Permalink
Fix load file order (#139)
Browse files Browse the repository at this point in the history
* fix load file order

* v2.0.1

* sync
  • Loading branch information
ngxson authored Dec 3, 2024
1 parent 3b9276a commit 48e36d1
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 32 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@wllama/wllama",
"version": "2.0.0",
"version": "2.0.1",
"description": "Low-level WASM binding for llama.cpp",
"main": "index.js",
"type": "module",
Expand Down
54 changes: 32 additions & 22 deletions src/model-manager.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ const SPLIT_MODEL =

test.sequential('parseModelUrl handles single model URL', () => {
const urls = ModelManager.parseModelUrl(TINY_MODEL);
expect(urls).toEqual([TINY_MODEL]);
expect(urls.length).toBe(1);
expect(urls[0]).toBe(TINY_MODEL);
});

test.sequential('parseModelUrl handles array of URLs', () => {
Expand All @@ -33,27 +34,36 @@ test.sequential('download split model', async () => {
expect(model.files[2].size).toBe(5773312);
});

test.sequential(
'interrupt download split model (partial files downloaded)',
async () => {
return; // skip on CI, only run locally with a slow connection
const manager = new ModelManager();
await manager.clear();
const controller = new AbortController();
const downloadPromise = manager.downloadModel(SPLIT_MODEL, {
signal: controller.signal,
progressCallback: ({ loaded, total }) => {
const progress = loaded / total;
if (progress > 0.8) {
controller.abort();
}
},
});
await expect(downloadPromise).rejects.toThrow('aborted');
expect((await manager.getModels()).length).toBe(0);
expect((await manager.getModels({ includeInvalid: true })).length).toBe(1);
}
);
test.sequential('get downloaded split model', async () => {
const manager = new ModelManager();
const models = await manager.getModels();
const model = models.find((m) => m.url === SPLIT_MODEL);
expect(model).toBeDefined();
if (!model) throw new Error();
// check names
expect(model.files[0].metadata.originalURL).toMatch(/-00001-of-00003\.gguf$/);
expect(model.files[1].metadata.originalURL).toMatch(/-00002-of-00003\.gguf$/);
expect(model.files[2].metadata.originalURL).toMatch(/-00003-of-00003\.gguf$/);
});

// skip on CI, only run locally with a slow connection
test.skip('interrupt download split model (partial files downloaded)', async () => {
const manager = new ModelManager();
await manager.clear();
const controller = new AbortController();
const downloadPromise = manager.downloadModel(SPLIT_MODEL, {
signal: controller.signal,
progressCallback: ({ loaded, total }) => {
const progress = loaded / total;
if (progress > 0.8) {
controller.abort();
}
},
});
await expect(downloadPromise).rejects.toThrow('aborted');
expect((await manager.getModels()).length).toBe(0);
expect((await manager.getModels({ includeInvalid: true })).length).toBe(1);
});

test.sequential('download invalid model URL', async () => {
const manager = new ModelManager();
Expand Down
2 changes: 1 addition & 1 deletion src/model-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ export class ModelManager {
);
}
const model = new Model(this, url, undefined);
const validity = await model.validate();
const validity = model.validate();
if (validity !== ModelValidationStatus.VALID) {
await model.refresh(options);
}
Expand Down
78 changes: 78 additions & 0 deletions src/utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import {
isString,
delay,
absoluteUrl,
parseShardNumber,
parseModelUrl,
sortFileByShard,
} from './utils';

describe('joinBuffers', () => {
Expand Down Expand Up @@ -74,3 +77,78 @@ describe('absoluteUrl', () => {
expect(absoluteUrl('/test.html')).toBe('http://example.com/test.html');
});
});

describe('shard processing', () => {
test('parseShardNumber extracts correct info', () => {
expect(parseShardNumber('abcdef-123456-00001-of-00005.gguf')).toEqual({
baseURL: 'abcdef-123456',
current: 1,
total: 5,
});

expect(parseShardNumber('abcdef-123456.9090-q8_0.gguf')).toEqual({
baseURL: 'abcdef-123456.9090-q8_0.gguf',
current: 1,
total: 1,
});
});

test('parseModelUrl generates correct shard URLs', () => {
const singleFile = 'model.gguf';
expect(parseModelUrl(singleFile)).toEqual(['model.gguf']);

const shardedFile = 'model-00001-of-00003.gguf';
expect(parseModelUrl(shardedFile)).toEqual([
'model-00001-of-00003.gguf',
'model-00002-of-00003.gguf',
'model-00003-of-00003.gguf',
]);

const complexPath = 'https://example.com/models/llama-00001-of-00002.gguf';
expect(parseModelUrl(complexPath)).toEqual([
'https://example.com/models/llama-00001-of-00002.gguf',
'https://example.com/models/llama-00002-of-00002.gguf',
]);
});

test('sortFileByShard sorts files by shard number', () => {
const files = [
new File(
[],
'e2fc714c4727ee9395f324cd2e7f331f-model-00003-of-00005.gguf'
),
new File(
[],
'187ef4436122d1cc2f40dc2b92f0eba0-model-00001-of-00005.gguf'
),
new File(
[],
'c4357687ea2b461cb07cf0a0a3de939f-model-00002-of-00005.gguf'
),
new File(
[],
'6a4d40512eabd63221cbdf3df4636cd7-model-00005-of-00005.gguf'
),
new File(
[],
'0952e4c6ba320f5278605eb5333eec0f-model-00004-of-00005.gguf'
),
];

sortFileByShard(files);

expect(files.map((f) => parseShardNumber(f.name).current)).toEqual([
1, 2, 3, 4, 5,
]);

// Single file should not be affected
const singleFile = [new File([], 'model.gguf')];
sortFileByShard(singleFile);
expect(singleFile[0].name).toBe('model.gguf');

// Regular blobs should not be affected
const blobs = [new Blob(), new Blob()];
sortFileByShard(blobs);
expect(blobs.length).toBe(2);
});
});
62 changes: 58 additions & 4 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,68 @@ export const getWModuleConfig = (pathConfig: {
};
};

export interface ShardInfo {
baseURL: string;
current: number;
total: number;
}

const URL_PARTS_REGEX = /-(\d{5})-of-(\d{5})\.gguf$/;

/**
* Parse shard number and total from a file name or URL
*/
export const parseShardNumber = (fnameOrUrl: string): ShardInfo => {
const matches = fnameOrUrl.match(URL_PARTS_REGEX);
if (!matches) {
return {
baseURL: fnameOrUrl,
current: 1,
total: 1,
};
} else {
return {
baseURL: fnameOrUrl.replace(URL_PARTS_REGEX, ''),
current: parseInt(matches[1]),
total: parseInt(matches[2]),
};
}
};

/**
* Parses a model URL and returns an array of URLs based on the following patterns:
* - If the input URL is an array, it returns the array itself.
* - If the input URL is a string in the `gguf-split` format, it returns an array containing the URL of each shard in ascending order.
* - Otherwise, it returns an array containing the input URL as a single element array.
* @param modelUrl URL or list of URLs
*/
export const parseModelUrl = (modelUrl: string): string[] => {
const { baseURL, current, total } = parseShardNumber(modelUrl);
if (current == total && total == 1) {
return [modelUrl];
} else {
const paddedShardIds = Array.from({ length: total }, (_, index) =>
(index + 1).toString().padStart(5, '0')
);
return paddedShardIds.map(
(current) =>
`${baseURL}-${current}-of-${total.toString().padStart(5, '0')}.gguf`
);
}
};

/**
* Check if the given blobs are files or not, then sort them by name
* Check if the given blobs are files or not, then sort them by shard number
*/
export const maybeSortFileByName = (blobs: Blob[]): void => {
export const sortFileByShard = (blobs: Blob[]): void => {
const isFiles = blobs.every((b) => !!(b as File).name);
if (isFiles) {
if (isFiles && blobs.length > 1) {
const files = blobs as File[];
files.sort((a, b) => a.name.localeCompare(b.name));
files.sort((a, b) => {
const infoA = parseShardNumber(a.name);
const infoB = parseShardNumber(b.name);
return infoA.current - infoB.current;
});
}
};

Expand Down
4 changes: 2 additions & 2 deletions src/wasm-from-cdn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// Do not edit this file directly

const WasmFromCDN = {
'single-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/[email protected].0/src/single-thread/wllama.wasm',
'multi-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/[email protected].0/src/multi-thread/wllama.wasm',
'single-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/[email protected].1/src/single-thread/wllama.wasm',
'multi-thread/wllama.wasm': 'https://cdn.jsdelivr.net/npm/@wllama/[email protected].1/src/multi-thread/wllama.wasm',
};

export default WasmFromCDN;
4 changes: 2 additions & 2 deletions src/wllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {
isString,
isSupportMultiThread,
joinBuffers,
maybeSortFileByName,
sortFileByShard,
padDigits,
} from './utils';
import CacheManager, { DownloadOptions } from './cache-manager';
Expand Down Expand Up @@ -438,7 +438,7 @@ export class Wllama {
'load_error'
);
}
maybeSortFileByName(blobs);
sortFileByShard(blobs);
const hasMultipleBuffers = blobs.length > 1;
if (this.proxy) {
throw new WllamaError('Module is already initialized', 'load_error');
Expand Down

0 comments on commit 48e36d1

Please sign in to comment.