Skip to content
This repository has been archived by the owner on Oct 2, 2024. It is now read-only.

Fix eip1153_tstore from EF #428

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/codegen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ impl<'c> OperationCtx<'c> {
key: Value<'c, 'c>,
value: Value<'c, 'c>,
location: Location<'c>,
) {
) -> Result<Value, CodegenError> {
syscall::mlir::transient_storage_write_syscall(
self.mlir_context,
self.syscall_ctx,
Expand Down
79 changes: 22 additions & 57 deletions src/codegen/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5544,7 +5544,6 @@ fn codegen_tload<'c, 'r>(

// Allocate a pointer for the key
let key_ptr = allocate_and_store_value(op_ctx, &ok_block, key, location)?;

// Allocate a pointer for the value
let read_value_ptr = ok_block
.append_operation(llvm::alloca(
Expand Down Expand Up @@ -5583,30 +5582,19 @@ fn codegen_tstore<'c, 'r>(
let start_block = region.append_block(Block::new(&[]));
let context = &op_ctx.mlir_context;
let location = Location::unknown(context);
let uint256 = IntegerType::new(context, 256);
let ptr_type = pointer(context, 0);
let pointer_size = start_block
.append_operation(arith::constant(
context,
IntegerAttribute::new(uint256.into(), 1_i64).into(),
location,
))
.result(0)?
.into();

let flag = check_stack_has_at_least(context, &start_block, 2)?;
let gas_flag = consume_gas(context, &start_block, gas_cost::TSTORE)?;
let ok_context_flag = check_context_is_not_static(op_ctx, &start_block)?;
//Check there are enough arguments in stack
let ok_stack_flag = check_stack_has_at_least(context, &start_block, 2)?;

let condition = start_block
.append_operation(arith::andi(gas_flag, flag, location))
let ok_block = region.append_block(Block::new(&[]));
let ok_flag = start_block
.append_operation(arith::andi(ok_context_flag, ok_stack_flag, location))
.result(0)?
.into();

let ok_block = region.append_block(Block::new(&[]));

start_block.append_operation(cf::cond_br(
context,
condition,
ok_flag,
&ok_block,
&op_ctx.revert_block,
&[],
Expand All @@ -5617,47 +5605,24 @@ fn codegen_tstore<'c, 'r>(
let key = stack_pop(context, &ok_block)?;
let value = stack_pop(context, &ok_block)?;

// Allocate a pointer for the key
let key_ptr = ok_block
.append_operation(llvm::alloca(
context,
pointer_size,
ptr_type,
location,
AllocaOptions::new().elem_type(Some(TypeAttribute::new(uint256.into()))),
))
.result(0)?
.into();
let res = ok_block.append_operation(llvm::store(
context,
key,
key_ptr,
location,
LoadStoreOptions::default(),
));
assert!(res.verify());
let key_ptr = allocate_and_store_value(op_ctx, &ok_block, key, location)?;
let value_ptr = allocate_and_store_value(op_ctx, &ok_block, value, location)?;

// Allocate a pointer for the value
let value_ptr = ok_block
.append_operation(llvm::alloca(
context,
pointer_size,
ptr_type,
location,
AllocaOptions::new().elem_type(Some(TypeAttribute::new(uint256.into()))),
))
.result(0)?
.into();
let res = ok_block.append_operation(llvm::store(
let gas_cost =
op_ctx.transient_storage_write_syscall(&ok_block, key_ptr, value_ptr, location)?;
let gas_flag = consume_gas_as_value(context, &ok_block, gas_cost)?;

let end_block = region.append_block(Block::new(&[]));

ok_block.append_operation(cf::cond_br(
context,
value,
value_ptr,
gas_flag,
&end_block,
&op_ctx.revert_block,
&[],
&[],
location,
LoadStoreOptions::default(),
));
assert!(res.verify());

op_ctx.transient_storage_write_syscall(&ok_block, key_ptr, value_ptr, location);

Ok((start_block, ok_block))
Ok((start_block, end_block))
}
19 changes: 19 additions & 0 deletions src/journal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,18 @@ impl From<&JournalAccount> for AccountInfo {

type AccountState = HashMap<Address, JournalAccount>;
type ContractState = HashMap<B256, Bytecode>;
type TransientStorage = HashMap<(Address, U256), JournalStorageSlot>;

#[derive(Default, Debug)]
pub struct Journal<'a> {
accounts: AccountState,
contracts: ContractState,
block_hashes: HashMap<U256, B256>,
db: Option<&'a mut Db>,
// The storage of a transaction
//
// Incorporated in [EIP-1153]: https://eips.ethereum.org/EIPS/eip-1153
pub transient_storage: TransientStorage, // ((addr, key), (current, original))
}

// TODO: Handle unwraps and panics
Expand Down Expand Up @@ -329,6 +334,7 @@ impl<'a> Journal<'a> {
contracts: self.contracts.clone(),
block_hashes: self.block_hashes.clone(),
db: self.db.take(),
transient_storage: self.transient_storage.clone(),
}
}

Expand Down Expand Up @@ -382,4 +388,17 @@ impl<'a> Journal<'a> {
.unwrap_or_default();
JournalStorageSlot::from(value)
}

pub fn read_tx_storage(&self, address: Address, key: U256) -> JournalStorageSlot {
self.transient_storage
.get(&(address, key))
.cloned()
.unwrap_or(JournalStorageSlot::default())
}

pub fn write_tx_storage(&mut self, address: Address, key: U256, value: U256) {
let mut slot = self.read_tx_storage(address, key);
slot.present_value = value;
self.transient_storage.insert((address, key), slot);
}
}
49 changes: 28 additions & 21 deletions src/syscall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ use crate::{
};
use melior::ExecutionEngine;
use sha3::{Digest, Keccak256};
use std::collections::HashMap;

/// Function type for the main entrypoint of the generated code
pub type MainFunc = extern "C" fn(&mut SyscallContext, initial_gas: u64) -> u8;
Expand Down Expand Up @@ -187,7 +186,6 @@ pub struct SyscallContext<'c> {
pub inner_context: InnerContext,
pub halt_reason: Option<HaltReason>,
initial_gas: u64,
pub transient_storage: HashMap<(Address, EU256), EU256>, // TODO: Move this to Journal
}

#[derive(Clone, Debug, Default, Eq, PartialEq, Hash)]
Expand All @@ -212,7 +210,6 @@ impl<'c> SyscallContext<'c> {
call_frame,
halt_reason: None,
inner_context: Default::default(),
transient_storage: Default::default(),
}
}

Expand Down Expand Up @@ -1112,25 +1109,31 @@ impl<'c> SyscallContext<'c> {
}

pub extern "C" fn read_transient_storage(&mut self, stg_key: &U256, stg_value: &mut U256) {
let sender_address = self.env.tx.get_address();
let key = stg_key.to_primitive_u256();
let address = self.env.tx.get_address();

let result = self
.transient_storage
.get(&(address, key))
.cloned()
.unwrap_or(EU256::zero());
.journal
.read_tx_storage(sender_address, key)
.present_value;

stg_value.hi = (result >> 128).low_u128();
stg_value.lo = result.low_u128();
}

pub extern "C" fn write_transient_storage(&mut self, stg_key: &U256, stg_value: &mut U256) {
let address = self.env.tx.get_address();

pub extern "C" fn write_transient_storage(
&mut self,
stg_key: &U256,
stg_value: &mut U256,
) -> i64 {
let sender_address = self.env.tx.get_address();
let key = stg_key.to_primitive_u256();
let value = stg_value.to_primitive_u256();
self.transient_storage.insert((address, key), value);

let _slot = self.journal.read_tx_storage(sender_address, key);
self.journal.write_tx_storage(sender_address, key, value);

gas_cost::TSTORE
}
}

Expand Down Expand Up @@ -1901,7 +1904,7 @@ pub(crate) mod mlir {
context,
StringAttribute::new(context, symbols::TRANSIENT_STORAGE_WRITE),
r#TypeAttribute::new(
FunctionType::new(context, &[ptr_type, ptr_type, ptr_type], &[]).into(),
FunctionType::new(context, &[ptr_type, ptr_type, ptr_type], &[uint64]).into(),
),
Region::new(),
attributes,
Expand Down Expand Up @@ -2195,14 +2198,18 @@ pub(crate) mod mlir {
key: Value<'c, 'c>,
value: Value<'c, 'c>,
location: Location<'c>,
) {
block.append_operation(func::call(
mlir_ctx,
FlatSymbolRefAttribute::new(mlir_ctx, symbols::TRANSIENT_STORAGE_WRITE),
&[syscall_ctx, key, value],
&[],
location,
));
) -> Result<Value<'c, 'c>, CodegenError> {
let uint64 = IntegerType::new(mlir_ctx, 64);
let value = block
.append_operation(func::call(
mlir_ctx,
FlatSymbolRefAttribute::new(mlir_ctx, symbols::TRANSIENT_STORAGE_WRITE),
&[syscall_ctx, key, value],
&[uint64.into()],
location,
))
.result(0)?;
Ok(value.into())
}

/// Receives log data and appends a log to the logs vector
Expand Down
2 changes: 1 addition & 1 deletion tests/ef_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ fn get_ignored_groups() -> HashSet<String> {
HashSet::from([
"stEIP4844-blobtransactions".into(),
"stEIP5656-MCOPY".into(),
"eip1153_tstore".into(),
"stEIP3651-warmcoinbase".into(),
"stTimeConsuming".into(), // this will be tested with the time_consuming_test binary
"stRevertTest".into(),
Expand Down Expand Up @@ -64,7 +65,6 @@ fn get_ignored_groups() -> HashSet<String> {
"stMemoryTest".into(),
"stInitCodeTest".into(),
"stBadOpcode".into(),
"eip1153_tstore".into(),
"stSolidityTest".into(),
"yul".into(),
"stEIP3607".into(),
Expand Down