diff --git a/wasm/bindings/service_bindings.cpp b/wasm/bindings/service_bindings.cpp index 8e4fe4d14..d56615dc6 100644 --- a/wasm/bindings/service_bindings.cpp +++ b/wasm/bindings/service_bindings.cpp @@ -45,11 +45,15 @@ std::vector> prepareVocabsSmartMemories(std::vect } MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, AlignedMemory* shortlistMemory, - std::vector uniqueVocabsMemories) { + std::vector uniqueVocabsMemories, + AlignedMemory* qualityEstimatorMemory) { MemoryBundle memoryBundle; memoryBundle.model = std::move(*modelMemory); memoryBundle.shortlist = std::move(*shortlistMemory); memoryBundle.vocabs = std::move(prepareVocabsSmartMemories(uniqueVocabsMemories)); + if (qualityEstimatorMemory != nullptr) { + memoryBundle.qualityEstimatorMemory = std::move(*qualityEstimatorMemory); + } return memoryBundle; } @@ -57,9 +61,9 @@ MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, AlignedMemory* shor // This allows only shared_ptrs to be operational in JavaScript, according to emscripten. // https://emscripten.org/docs/porting/connecting_cpp_and_javascript/embind.html#smart-pointers std::shared_ptr TranslationModelFactory(const std::string& config, AlignedMemory* model, - AlignedMemory* shortlist, - std::vector vocabs) { - MemoryBundle memoryBundle = prepareMemoryBundle(model, shortlist, vocabs); + AlignedMemory* shortlist, std::vector vocabs, + AlignedMemory* qualityEstimator) { + MemoryBundle memoryBundle = prepareMemoryBundle(model, shortlist, vocabs, qualityEstimator); return std::make_shared(config, std::move(memoryBundle)); } diff --git a/wasm/test_page/js/worker.js b/wasm/test_page/js/worker.js index fcbb37aa2..3327d8a3a 100644 --- a/wasm/test_page/js/worker.js +++ b/wasm/test_page/js/worker.js @@ -12,6 +12,14 @@ const MODEL_REGISTRY = "../models/registry.json"; const MODEL_ROOT_URL = "../models/"; const PIVOT_LANGUAGE = 'en'; +// Information corresponding to each file type +const fileInfo = [ + {"type": "model", "alignment": 256}, + {"type": "lex", "alignment": 64}, + {"type": "vocab", "alignment": 64}, + {"type": "qualityModel", "alignment": 64} +]; + const encoder = new TextEncoder(); // string to utf-8 converter const decoder = new TextDecoder(); // utf-8 to string converter @@ -169,12 +177,17 @@ const _downloadAsArrayBuffer = async(url) => { // Constructs and initializes the AlignedMemory from the array buffer and alignment size const _prepareAlignedMemoryFromBuffer = async (buffer, alignmentSize) => { var byteArray = new Int8Array(buffer); - log(`Constructing Aligned memory. Size: ${byteArray.byteLength} bytes, Alignment: ${alignmentSize}`); var alignedMemory = new Module.AlignedMemory(byteArray.byteLength, alignmentSize); - log(`Aligned memory construction done`); const alignedByteArrayView = alignedMemory.getByteArrayView(); alignedByteArrayView.set(byteArray); - log(`Aligned memory initialized`); + return alignedMemory; +} + +async function prepareAlignedMemory(file, languagePair) { + const fileName = `${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair][file.type].name}`; + const buffer = await _downloadAsArrayBuffer(fileName); + const alignedMemory = await _prepareAlignedMemoryFromBuffer(buffer, file.alignment); + log(`"${file.type}" aligned memory prepared. Size:${alignedMemory.size()} bytes, alignment:${file.alignment}`); return alignedMemory; } @@ -201,45 +214,26 @@ gemm-precision: int8shiftAlphaAll alignment: soft `; - const modelFile = `${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["model"].name}`; - const shortlistFile = `${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["lex"].name}`; - const vocabFiles = [`${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["vocab"].name}`, - `${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["vocab"].name}`]; - - const uniqueVocabFiles = new Set(vocabFiles); - log(`modelFile: ${modelFile}\nshortlistFile: ${shortlistFile}\nNo. of unique vocabs: ${uniqueVocabFiles.size}`); - uniqueVocabFiles.forEach(item => log(`unique vocabFile: ${item}`)); + const promises = []; + fileInfo.filter(file => modelRegistry[languagePair].hasOwnProperty(file.type)) + .map((file) => { + promises.push(prepareAlignedMemory(file, languagePair)); + }); - // Download the files as buffers from the given urls - let start = Date.now(); - const downloadedBuffers = await Promise.all([_downloadAsArrayBuffer(modelFile), _downloadAsArrayBuffer(shortlistFile)]); - const modelBuffer = downloadedBuffers[0]; - const shortListBuffer = downloadedBuffers[1]; + const alignedMemories = await Promise.all(promises); - const downloadedVocabBuffers = []; - for (let item of uniqueVocabFiles.values()) { - downloadedVocabBuffers.push(await _downloadAsArrayBuffer(item)); - } - log(`Total Download time for all files of '${languagePair}': ${(Date.now() - start) / 1000} secs`); - - // Construct AlignedMemory objects with downloaded buffers - let constructedAlignedMemories = await Promise.all([_prepareAlignedMemoryFromBuffer(modelBuffer, 256), - _prepareAlignedMemoryFromBuffer(shortListBuffer, 64)]); - let alignedModelMemory = constructedAlignedMemories[0]; - let alignedShortlistMemory = constructedAlignedMemories[1]; - let alignedVocabsMemoryList = new Module.AlignedMemoryList; - for(let item of downloadedVocabBuffers) { - let alignedMemory = await _prepareAlignedMemoryFromBuffer(item, 64); - alignedVocabsMemoryList.push_back(alignedMemory); + log(`Translation Model config: ${modelConfig}`); + log(`Aligned memory sizes: Model:${alignedMemories[0].size()} Shortlist:${alignedMemories[1].size()} Vocab:${alignedMemories[2].size()}`); + const alignedVocabMemoryList = new Module.AlignedMemoryList(); + alignedVocabMemoryList.push_back(alignedMemories[2]); + let translationModel; + if (alignedMemories.length === fileInfo.length) { + log(`QE:${alignedMemories[3].size()}`); + translationModel = new Module.TranslationModel(modelConfig, alignedMemories[0], alignedMemories[1], alignedVocabMemoryList, alignedMemories[3]); } - for (let vocabs=0; vocabs < alignedVocabsMemoryList.size(); vocabs++) { - log(`Aligned vocab memory${vocabs+1} size: ${alignedVocabsMemoryList.get(vocabs).size()}`); + else { + translationModel = new Module.TranslationModel(modelConfig, alignedMemories[0], alignedMemories[1], alignedVocabMemoryList, null); } - log(`Aligned model memory size: ${alignedModelMemory.size()}`); - log(`Aligned shortlist memory size: ${alignedShortlistMemory.size()}`); - - log(`Translation Model config: ${modelConfig}`); - var translationModel = new Module.TranslationModel(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList); languagePairToTranslationModels.set(languagePair, translationModel); }