diff --git a/extension/controller/backgroundScript.js b/extension/controller/backgroundScript.js index dd54ec1b..bff1e8d0 100644 --- a/extension/controller/backgroundScript.js +++ b/extension/controller/backgroundScript.js @@ -72,7 +72,7 @@ let translationRequestsByTab = new Map(); let outboundRequestsByTab = new Map(); const translateAsBrowseMap = new Map(); let isMochitest = false; -const languageModelFileTypes = ["model", "lex", "vocab", "qualityModel"]; +const languageModelFileTypes = ["model", "lex", "vocab", "qualityModel", "srcvocab", "trgvocab"]; const CACHE_NAME = "fxtranslations"; const init = () => { @@ -509,6 +509,9 @@ const getLanguageModels = async (tabId, languagePairs) => { languageModels.forEach((languageModel, index) => { let clonedLanguagePair = { ...languagePairs[index] }; clonedLanguagePair.languageModelBlobs = languageModel; + clonedLanguagePair.precision = modelRegistry[clonedLanguagePair.name].model.name.endsWith("intgemm8.bin") + ? "int8shiftAll" + : "int8shiftAlphaAll"; result.push(clonedLanguagePair); }); return result; @@ -516,17 +519,17 @@ const getLanguageModels = async (tabId, languagePairs) => { const getLanguageModel = async (tabId, languagePair) => { let languageModelPromise = []; - languageModelFileTypes + let filesToLoad = languageModelFileTypes .filter(fileType => fileType !== "qualityModel" || languagePair.withQualityEstimation) - .filter(fileType => Reflect.apply(Object.prototype.hasOwnProperty, modelRegistry[languagePair.name], [fileType])) - .forEach(fileType => languageModelPromise.push(downloadFile(tabId, fileType, languagePair.name))); // eslint-disable-line no-use-before-define + .filter(fileType => fileType in modelRegistry[languagePair.name]); + filesToLoad.forEach(fileType => languageModelPromise.push(downloadFile(tabId, fileType, languagePair.name))); // eslint-disable-line no-use-before-define let buffers = await Promise.all(languageModelPromise); // create Blobs from buffers and return let files = {}; buffers.forEach((buffer, index) => { - files[languageModelFileTypes[index]] = new Blob([buffer]); + files[filesToLoad[index]] = new Blob([buffer]); }); return files; }; diff --git a/extension/controller/translation/translationWorker.js b/extension/controller/translation/translationWorker.js index b00346d8..d8248429 100644 --- a/extension/controller/translation/translationWorker.js +++ b/extension/controller/translation/translationWorker.js @@ -35,6 +35,8 @@ class TranslationHelper { "lex": 64, "vocab": 64, "qualityModel": 64, + "srcvocab": 64, + "trgvocab": 64, } } @@ -381,17 +383,11 @@ class TranslationHelper { } } - getLanguageModelForPair(languageModels, languagePair) { - let languageModel = languageModels.find(languageModel => { - return languageModel.name === languagePair; - }); - return languageModel.languageModelBlobs; - } - // eslint-disable-next-line max-lines-per-function async constructTranslationModelHelper(languageModels, languagePair, withQualityEstimation) { console.log(`Constructing translation model ${languagePair}`); const modelConfigQualityEstimation = !withQualityEstimation; + let languageModel = languageModels.find(lm => lm.name === languagePair); /* * for available configuration options, @@ -410,33 +406,46 @@ class TranslationHelper { cpu-threads: 0 quiet: true quiet-translation: true - gemm-precision: int8shiftAlphaAll + gemm-precision: ${languageModel.precision} alignment: soft `; // download files into buffers - let languageModelBlobs = this.getLanguageModelForPair(languageModels, languagePair); + let languageModelBlobs = languageModel.languageModelBlobs; let downloadedBuffersPromises = []; - Object.entries(this.modelFileAlignments) + + let filesToLoad = Object.entries(this.modelFileAlignments) .filter(([fileType]) => fileType !== "qualityModel" || withQualityEstimation) - .filter(([fileType]) => Reflect.apply(Object.prototype.hasOwnProperty, languageModelBlobs, [fileType])) - .map(([fileType, fileAlignment]) => downloadedBuffersPromises.push(this.fetchFile(fileType, fileAlignment, languageModelBlobs))); + .filter(([fileType]) => fileType in languageModelBlobs); + filesToLoad.map(([fileType, fileAlignment]) => downloadedBuffersPromises.push(this.fetchFile(fileType, fileAlignment, languageModelBlobs))); let downloadedBuffers = await Promise.all(downloadedBuffersPromises); // prepare aligned memories from buffers - let alignedMemories = []; - downloadedBuffers.forEach(entry => alignedMemories.push(this.prepareAlignedMemoryFromBuffer(entry.buffer, entry.fileAlignment))); + let alignedMemories = Object.assign({}, ...filesToLoad.map(([name, alignment], index) => ( + { [name]: this.prepareAlignedMemoryFromBuffer(downloadedBuffers[index].buffer, alignment) }))); + + const alignedModelMemory = alignedMemories.model; + const alignedShortlistMemory = alignedMemories.lex; + let alignedMemoryLogMessage = `Aligned memory sizes: Model:${alignedModelMemory.size()}, Shortlist:${alignedShortlistMemory.size()}, `; - const alignedModelMemory = alignedMemories[0]; - const alignedShortlistMemory = alignedMemories[1]; const alignedVocabMemoryList = new this.WasmEngineModule.AlignedMemoryList(); - alignedVocabMemoryList.push_back(alignedMemories[2]); + if ("vocab" in alignedMemories) { + alignedVocabMemoryList.push_back(alignedMemories.vocab); + alignedMemoryLogMessage += ` Vocab: ${alignedMemories.vocab.size()}`; + } else if (("srcvocab" in alignedMemories) && ("trgvocab" in alignedMemories)) { + alignedVocabMemoryList.push_back(alignedMemories.srcvocab); + alignedVocabMemoryList.push_back(alignedMemories.trgvocab); + alignedMemoryLogMessage += ` Src Vocab: ${alignedMemories.srcvocab.size()}`; + alignedMemoryLogMessage += ` Trg Vocab: ${alignedMemories.trgvocab.size()}`; + } else { + throw new Error("vocabulary key is not found"); + } + let alignedQEMemory = null; - let alignedMemoryLogMessage = `Aligned memory sizes: Model:${alignedModelMemory.size()}, Shortlist:${alignedShortlistMemory.size()}, Vocab:${alignedMemories[2].size()}, `; - if (alignedMemories.length === Object.entries(this.modelFileAlignments).length) { - alignedQEMemory = alignedMemories[3]; - alignedMemoryLogMessage += `QualityModel: ${alignedQEMemory.size()}`; + if ("qualityModel" in alignedMemories) { + alignedQEMemory = alignedMemories.qualityModel; + alignedMemoryLogMessage += ` QualityModel: ${alignedQEMemory.size()}`; } console.log(`Translation Model config: ${modelConfig}`); console.log(alignedMemoryLogMessage); diff --git a/extension/model/modelRegistry.js b/extension/model/modelRegistry.js index 28a9626f..6f386950 100644 --- a/extension/model/modelRegistry.js +++ b/extension/model/modelRegistry.js @@ -1,7 +1,7 @@ /* eslint-disable no-unused-vars */ /* eslint-disable max-lines */ -const modelRegistryVersion = "0.2.18"; +const modelRegistryVersion = "0.3.0"; let modelRegistryRootURL = `https://storage.googleapis.com/bergamot-models-sandbox/${modelRegistryVersion}`; const modelRegistryRootURLTest = "https://example.com/browser/browser/extensions/translations/test/browser"; @@ -113,19 +113,19 @@ const modelRegistry = { "expectedSha256Hash": "e19c77231bf977988e31ff8db15fe79966b5170564bd3e10613f239e7f461d97", "modelType": "prod" }, - "vocab": { - "name": "vocab.csen.spm", - "size": 769763, - "estimatedCompressedSize": 366392, - "expectedSha256Hash": "f71cc5d045e479607078e079884f44032f5a0b82547fb96eefa29cd1eb47c6f3", - "modelType": "prod" - }, "qualityModel": { "name": "qualityModel.encs.bin", "size": 68, "estimatedCompressedSize": 108, "expectedSha256Hash": "d7eba90036a065e4a1e93e889befe09f93a7d9a3417f3edffdb09a0db88fe83a", "modelType": "prod" + }, + "vocab": { + "name": "vocab.csen.spm", + "size": 769763, + "estimatedCompressedSize": 366392, + "expectedSha256Hash": "f71cc5d045e479607078e079884f44032f5a0b82547fb96eefa29cd1eb47c6f3", + "modelType": "prod" } }, "ende": { @@ -464,6 +464,36 @@ const modelRegistry = { "modelType": "dev" } }, + "enuk": { + "model": { + "name": "model.enuk.intgemm8.bin", + "size": 25315747, + "estimatedCompressedSize": 18227138, + "expectedSha256Hash": "326aa67032b19dfd979267ea88f066c8ca394b01bedece00e0bf6a722a42a099", + "modelType": "dev" + }, + "lex": { + "name": "lex.enuk.s2t.bin", + "size": 10294724, + "estimatedCompressedSize": 5706473, + "expectedSha256Hash": "2b07001be2cad9eca0a26dfb8cc8a9cc8f4f8a8359b53cc5c77474e54cb1f94a", + "modelType": "dev" + }, + "trgvocab": { + "name": "trgvocab.enuk.spm", + "size": 1003426, + "estimatedCompressedSize": 436542, + "expectedSha256Hash": "04f3110c139f80a4e72aeb2b6802a0be50b94e36aa89647cab53318a0917e442", + "modelType": "dev" + }, + "srcvocab": { + "name": "srcvocab.enuk.spm", + "size": 789110, + "estimatedCompressedSize": 394528, + "expectedSha256Hash": "dd44ee771e3be2fce4986beb4f4386fa0a5b233dfb5602d3cb78461053a6a50e", + "modelType": "dev" + } + }, "faen": { "model": { "name": "model.faen.intgemm.alphas.bin", @@ -555,5 +585,35 @@ const modelRegistry = { "expectedSha256Hash": "aaf9a325c0a988c507d0312cb6ba1a02bac7a370bcd879aedee626a40bfbda78", "modelType": "dev" } + }, + "uken": { + "model": { + "name": "model.uken.intgemm8.bin", + "size": 25315747, + "estimatedCompressedSize": 18520747, + "expectedSha256Hash": "90b6e21644af5bf5ce26442c724f55848a005d75e8bf688a51d2e64d6bc6b249", + "modelType": "dev" + }, + "lex": { + "name": "lex.uken.s2t.bin", + "size": 9761460, + "estimatedCompressedSize": 5402306, + "expectedSha256Hash": "763b9e0add9fd712305bc031ab86a58fb15f719dcad296046742176937b86841", + "modelType": "dev" + }, + "srcvocab": { + "name": "srcvocab.uken.spm", + "size": 984214, + "estimatedCompressedSize": 426936, + "expectedSha256Hash": "797de9759ff722c124c64663f3b75538516a059cfce3e6cf9446f39d1063cb6d", + "modelType": "dev" + }, + "trgvocab": { + "name": "trgvocab.uken.spm", + "size": 803064, + "estimatedCompressedSize": 402483, + "expectedSha256Hash": "d933cbf156c925ef42c064cbd6f85f18516f3ccac49bee7025b19a4a5c0ef711", + "modelType": "dev" + } } } \ No newline at end of file