Skip to content

Commit

Permalink
JS: Using supervised QE models for available language pairs (#378)
Browse files Browse the repository at this point in the history
* JS: Refactored model loading
 - Passing single vocab memory via JS
* JS: Use supervised QE models when available
* Ran clang format
  • Loading branch information
abhi-agg authored Mar 15, 2022
1 parent 2c0e65c commit 0a52a6d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 42 deletions.
12 changes: 8 additions & 4 deletions wasm/bindings/service_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,25 @@ std::vector<std::shared_ptr<AlignedMemory>> prepareVocabsSmartMemories(std::vect
}

MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, AlignedMemory* shortlistMemory,
std::vector<AlignedMemory*> uniqueVocabsMemories) {
std::vector<AlignedMemory*> uniqueVocabsMemories,
AlignedMemory* qualityEstimatorMemory) {
MemoryBundle memoryBundle;
memoryBundle.model = std::move(*modelMemory);
memoryBundle.shortlist = std::move(*shortlistMemory);
memoryBundle.vocabs = std::move(prepareVocabsSmartMemories(uniqueVocabsMemories));
if (qualityEstimatorMemory != nullptr) {
memoryBundle.qualityEstimatorMemory = std::move(*qualityEstimatorMemory);
}

return memoryBundle;
}

// This allows only shared_ptrs to be operational in JavaScript, according to emscripten.
// https://emscripten.org/docs/porting/connecting_cpp_and_javascript/embind.html#smart-pointers
std::shared_ptr<TranslationModel> TranslationModelFactory(const std::string& config, AlignedMemory* model,
AlignedMemory* shortlist,
std::vector<AlignedMemory*> vocabs) {
MemoryBundle memoryBundle = prepareMemoryBundle(model, shortlist, vocabs);
AlignedMemory* shortlist, std::vector<AlignedMemory*> vocabs,
AlignedMemory* qualityEstimator) {
MemoryBundle memoryBundle = prepareMemoryBundle(model, shortlist, vocabs, qualityEstimator);
return std::make_shared<TranslationModel>(config, std::move(memoryBundle));
}

Expand Down
70 changes: 32 additions & 38 deletions wasm/test_page/js/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ const MODEL_REGISTRY = "../models/registry.json";
const MODEL_ROOT_URL = "../models/";
const PIVOT_LANGUAGE = 'en';

// Information corresponding to each file type
const fileInfo = [
{"type": "model", "alignment": 256},
{"type": "lex", "alignment": 64},
{"type": "vocab", "alignment": 64},
{"type": "qualityModel", "alignment": 64}
];

const encoder = new TextEncoder(); // string to utf-8 converter
const decoder = new TextDecoder(); // utf-8 to string converter

Expand Down Expand Up @@ -169,12 +177,17 @@ const _downloadAsArrayBuffer = async(url) => {
// Constructs and initializes the AlignedMemory from the array buffer and alignment size
const _prepareAlignedMemoryFromBuffer = async (buffer, alignmentSize) => {
var byteArray = new Int8Array(buffer);
log(`Constructing Aligned memory. Size: ${byteArray.byteLength} bytes, Alignment: ${alignmentSize}`);
var alignedMemory = new Module.AlignedMemory(byteArray.byteLength, alignmentSize);
log(`Aligned memory construction done`);
const alignedByteArrayView = alignedMemory.getByteArrayView();
alignedByteArrayView.set(byteArray);
log(`Aligned memory initialized`);
return alignedMemory;
}

async function prepareAlignedMemory(file, languagePair) {
const fileName = `${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair][file.type].name}`;
const buffer = await _downloadAsArrayBuffer(fileName);
const alignedMemory = await _prepareAlignedMemoryFromBuffer(buffer, file.alignment);
log(`"${file.type}" aligned memory prepared. Size:${alignedMemory.size()} bytes, alignment:${file.alignment}`);
return alignedMemory;
}

Expand All @@ -201,45 +214,26 @@ gemm-precision: int8shiftAlphaAll
alignment: soft
`;

const modelFile = `${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["model"].name}`;
const shortlistFile = `${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["lex"].name}`;
const vocabFiles = [`${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["vocab"].name}`,
`${MODEL_ROOT_URL}/${languagePair}/${modelRegistry[languagePair]["vocab"].name}`];

const uniqueVocabFiles = new Set(vocabFiles);
log(`modelFile: ${modelFile}\nshortlistFile: ${shortlistFile}\nNo. of unique vocabs: ${uniqueVocabFiles.size}`);
uniqueVocabFiles.forEach(item => log(`unique vocabFile: ${item}`));
const promises = [];
fileInfo.filter(file => modelRegistry[languagePair].hasOwnProperty(file.type))
.map((file) => {
promises.push(prepareAlignedMemory(file, languagePair));
});

// Download the files as buffers from the given urls
let start = Date.now();
const downloadedBuffers = await Promise.all([_downloadAsArrayBuffer(modelFile), _downloadAsArrayBuffer(shortlistFile)]);
const modelBuffer = downloadedBuffers[0];
const shortListBuffer = downloadedBuffers[1];
const alignedMemories = await Promise.all(promises);

const downloadedVocabBuffers = [];
for (let item of uniqueVocabFiles.values()) {
downloadedVocabBuffers.push(await _downloadAsArrayBuffer(item));
}
log(`Total Download time for all files of '${languagePair}': ${(Date.now() - start) / 1000} secs`);

// Construct AlignedMemory objects with downloaded buffers
let constructedAlignedMemories = await Promise.all([_prepareAlignedMemoryFromBuffer(modelBuffer, 256),
_prepareAlignedMemoryFromBuffer(shortListBuffer, 64)]);
let alignedModelMemory = constructedAlignedMemories[0];
let alignedShortlistMemory = constructedAlignedMemories[1];
let alignedVocabsMemoryList = new Module.AlignedMemoryList;
for(let item of downloadedVocabBuffers) {
let alignedMemory = await _prepareAlignedMemoryFromBuffer(item, 64);
alignedVocabsMemoryList.push_back(alignedMemory);
log(`Translation Model config: ${modelConfig}`);
log(`Aligned memory sizes: Model:${alignedMemories[0].size()} Shortlist:${alignedMemories[1].size()} Vocab:${alignedMemories[2].size()}`);
const alignedVocabMemoryList = new Module.AlignedMemoryList();
alignedVocabMemoryList.push_back(alignedMemories[2]);
let translationModel;
if (alignedMemories.length === fileInfo.length) {
log(`QE:${alignedMemories[3].size()}`);
translationModel = new Module.TranslationModel(modelConfig, alignedMemories[0], alignedMemories[1], alignedVocabMemoryList, alignedMemories[3]);
}
for (let vocabs=0; vocabs < alignedVocabsMemoryList.size(); vocabs++) {
log(`Aligned vocab memory${vocabs+1} size: ${alignedVocabsMemoryList.get(vocabs).size()}`);
else {
translationModel = new Module.TranslationModel(modelConfig, alignedMemories[0], alignedMemories[1], alignedVocabMemoryList, null);
}
log(`Aligned model memory size: ${alignedModelMemory.size()}`);
log(`Aligned shortlist memory size: ${alignedShortlistMemory.size()}`);

log(`Translation Model config: ${modelConfig}`);
var translationModel = new Module.TranslationModel(modelConfig, alignedModelMemory, alignedShortlistMemory, alignedVocabsMemoryList);
languagePairToTranslationModels.set(languagePair, translationModel);
}

Expand Down

0 comments on commit 0a52a6d

Please sign in to comment.