diff --git a/src/parser/WASMParser.cpp b/src/parser/WASMParser.cpp index a5394cf70..09616d5fb 100644 --- a/src/parser/WASMParser.cpp +++ b/src/parser/WASMParser.cpp @@ -244,6 +244,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { Type m_returnValueType; size_t m_position; std::vector m_vmStack; + std::vector m_parameterPositions; uint32_t m_functionStackSizeSoFar; bool m_shouldRestoreVMStackAtEnd; bool m_byteCodeGenerationStopped; @@ -276,6 +277,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { auto endIter = binaryReader.m_vmStack.rbegin() + param.size(); auto iter = binaryReader.m_vmStack.rbegin(); while (iter != endIter) { + m_parameterPositions.push_back(iter->nonOptimizedPosition()); if (iter->hasValidLocalIndex()) { binaryReader.generateMoveCodeIfNeeds(iter->position(), iter->nonOptimizedPosition(), iter->valueType()); iter->setPosition(iter->nonOptimizedPosition()); @@ -512,10 +514,10 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { return std::make_pair(Walrus::Optional(), 0); } - std::pair, size_t> readAheadLocalGetIfExists() // return localIndex and code length if exists + std::pair, size_t> readAheadUint32OpcodeIfExists(uint8_t opcode) // return localIndex and code length if exists { Walrus::Optional mayLoadGetCode = lookaheadUnsigned8(); - if (mayLoadGetCode.hasValue() && mayLoadGetCode.value() == 0x21) { + if (mayLoadGetCode.hasValue() && mayLoadGetCode.value() == opcode) { auto r = lookaheadUnsigned32(1); if (r.first) { return std::make_pair(r.first, r.second + 1); @@ -995,13 +997,23 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { if (!m_inPreprocess) { // if there is local.set code ahead, // we can use local variable position as expr target position - auto localSetInfo = readAheadLocalGetIfExists(); + auto localSetInfo = readAheadUint32OpcodeIfExists(0x21); if (localSetInfo.first) { auto pos = resolveLocalOffsetAndSize(localSetInfo.first.value()).first; // skip local.set opcode *m_readerOffsetPointer += localSetInfo.second; return pos; } + + auto localTeeInfo = readAheadUint32OpcodeIfExists(0x22); + if (localTeeInfo.first + && canUseDirectReference(*m_readerOffsetPointer, m_functionStackSizeSoFar, localTeeInfo.first.value())) { + auto pos = resolveLocalOffsetAndSize(localTeeInfo.first.value()).first; + // skip local.tee opcode + *m_readerOffsetPointer += localTeeInfo.second; + pushVMStack(type, pos, localTeeInfo.first.value()); + return pos; + } } return pushVMStack(type); @@ -1033,23 +1045,33 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { } } - virtual void OnLocalGetExpr(Index localIndex) override + bool canUseDirectReference(size_t codePosition, size_t stackPosition, Index localIndex) { - auto r = resolveLocalOffsetAndSize(localIndex); - auto localValueType = m_localInfo[localIndex].m_valueType; - - bool canUseDirectReference = true; - size_t pos = *m_readerOffsetPointer; + for (const auto& bi : m_blockInfo) { + for (uint32_t p : bi.m_parameterPositions) { + if (stackPosition == p) { + return false; + } + } + } + bool ret = true; for (const auto& r : m_localVariableUsage) { - if (r.m_localIndex == localIndex && r.m_startPosition <= pos && pos <= r.m_endPosition) { + if (r.m_localIndex == localIndex && r.m_startPosition <= codePosition && codePosition <= r.m_endPosition) { if (r.m_hasWriteUsage) { - canUseDirectReference = false; + ret = false; break; } } } + return ret; + } - if (canUseDirectReference) { + virtual void OnLocalGetExpr(Index localIndex) override + { + auto r = resolveLocalOffsetAndSize(localIndex); + auto localValueType = m_localInfo[localIndex].m_valueType; + bool direct = canUseDirectReference(*m_readerOffsetPointer, r.first, localIndex); + if (direct) { pushVMStack(localValueType, r.first, localIndex); } else { auto pos = m_functionStackSizeSoFar; @@ -1311,7 +1333,6 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate { for (size_t i = start; i < m_vmStack.size(); i++) { dropValueSize += m_vmStack[i].stackAllocatedSize(); } - if (iter->m_blockType == BlockInfo::Loop) { if (iter->m_returnValueType.IsIndex()) { auto ft = m_result.m_functionTypes[iter->m_returnValueType];