diff --git a/brownie-config.yaml b/brownie-config.yaml index 17e0433..1666623 100644 --- a/brownie-config.yaml +++ b/brownie-config.yaml @@ -13,7 +13,7 @@ compiler: version: 0.8.11 remappings: - "@openzeppelin=OpenZeppelin/openzeppelin-contracts@4.1.0" - + viaIR: true optimizer: details: yul: true diff --git a/contracts/VM.sol b/contracts/VM.sol index f4558a0..d73de20 100644 --- a/contracts/VM.sol +++ b/contracts/VM.sol @@ -29,6 +29,8 @@ abstract contract VM { self = address(this); } + function dispatch(bytes memory inputs) internal virtual returns (bool success, bytes memory ret) {} + function _execute(bytes32[] calldata commands, bytes[] memory state) internal returns (bytes[] memory) { @@ -61,14 +63,28 @@ abstract contract VM { ) ); } else if (flags & FLAG_CT_MASK == FLAG_CT_CALL) { - (success, outdata) = address(uint160(uint256(command))).call( // target - // inputs - state.buildInputs( - //selector - bytes4(command), - indices - ) - ); + address _target = address(uint160(uint256(command))); + bytes memory inputs = state.buildInputs( + //selector + bytes4(command), + indices + ); + success = false; + + if (_target == address(this)) { + (success, outdata) = dispatch(inputs); + } + + if (!success) { + (success, outdata) = _target.call( // target + // inputs + state.buildInputs( + //selector + bytes4(command), + indices + ) + ); + } } else if (flags & FLAG_CT_MASK == FLAG_CT_STATICCALL) { (success, outdata) = address(uint160(uint256(command))).staticcall( // target // inputs diff --git a/contracts/test/TestableVMWithMath.sol b/contracts/test/TestableVMWithMath.sol new file mode 100644 index 0000000..2928968 --- /dev/null +++ b/contracts/test/TestableVMWithMath.sol @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.11; + +import "../VM.sol"; + +contract TestableVMWithMath is VM { + function execute(bytes32[] calldata commands, bytes[] memory state) + public + payable + returns (bytes[] memory) + { + return _execute(commands, state); + } + + function sum(uint256 a, uint256 b) external pure returns (uint256) { + return a + b; + } + + + function dispatch(bytes memory inputs) + internal + override + returns (bool _success, bytes memory _ret) + { + bytes4 _selector = bytes4(bytes32(inputs)); + if (this.sum.selector == _selector) { + uint256 a; + uint256 b; + assembly { + a := mload(add(inputs, 36)) + b := mload(add(inputs, 68)) + } + uint256 res = this.sum(a, b); + _ret = new bytes(32); + assembly { + mstore(add(_ret, 32), res) + } + return (true, _ret); + } + } +} diff --git a/tests/conftest.py b/tests/conftest.py index df09379..8623212 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,10 @@ def math(alice, Math): math_brownie = alice.deploy(Math) yield WeirollContract.createLibrary(math_brownie) +@pytest.fixture(scope="module") +def weiroll_vm_with_math(alice, TestableVMWithMath): + vm = alice.deploy(TestableVMWithMath) + yield vm @pytest.fixture(scope="module") def testContract(alice, TestContract): diff --git a/tests/test_weiroll_local.py b/tests/test_weiroll_local.py new file mode 100644 index 0000000..bdd780b --- /dev/null +++ b/tests/test_weiroll_local.py @@ -0,0 +1,20 @@ +from brownie import Contract, accounts, Wei, chain, TestableVM +from weiroll import WeirollContract, WeirollPlanner + + +def test_vm_with_math(weiroll_vm_with_math): + weiroll_vm = weiroll_vm_with_math + whale = accounts.at("0x57757E3D981446D585Af0D9Ae4d7DF6D64647806", force=True) + + planner = WeirollPlanner(weiroll_vm) + sum = planner.call(weiroll_vm, "sum", 1, 2) + sum_2 = planner.call(weiroll_vm, "sum", sum, 3) + sum_2 = planner.call(weiroll_vm, "sum", 3, sum_2) + + cmds, state = planner.plan() + weiroll_tx = weiroll_vm.execute( + cmds, state, {"from": whale, "gas_limit": 8_000_000, "gas_price": 0} + ) + + print(weiroll_tx.return_value) + #assert False