diff --git a/procedural/src/abi_derive/derive_flags.rs b/procedural/src/abi_derive/derive_flags.rs index f1a62f6..ae1fcca 100644 --- a/procedural/src/abi_derive/derive_flags.rs +++ b/procedural/src/abi_derive/derive_flags.rs @@ -2,11 +2,7 @@ use proc_macro2::TokenStream; use quote::quote; use super::extract_docs; -use crate::structs::common::{ - get_left_and_mask, get_right_and_mask, BitMath, FieldInfo, StructInfo, -}; - -const ABI_TYPE_SIZE: usize = 4; +use crate::structs::common::{BitMath, FieldInfo, StructInfo}; pub fn impl_can_be_placed_in_vec(ident: &syn::Ident) -> TokenStream { quote! { @@ -14,70 +10,121 @@ pub fn impl_can_be_placed_in_vec(ident: &syn::Ident) -> TokenStream { } } -pub fn impl_struct_abi_type(name: &syn::Ident) -> TokenStream { - quote! { +fn align_size(name: &syn::Ident, total_bytes: usize) -> syn::Result { + Ok(match total_bytes { + 1 => 1, + 2..=4 => 4, + 5..=8 => 8, + _ => { + return Err(syn::Error::new( + name.span(), + format!("Unsupported struct size: {total_bytes}"), + )) + } + }) +} + +pub fn impl_struct_abi_type(name: &syn::Ident, total_bytes: usize) -> syn::Result { + let sub_type = match total_bytes { + 1 => quote! { u8 }, + 2..=4 => quote! { u32 }, + 5..=8 => quote! { u64 }, + _ => { + return Err(syn::Error::new( + name.span(), + format!("Unsupported struct size: {total_bytes}"), + )) + } + }; + + Ok(quote! { impl ::evm_coder::abi::AbiType for #name { - const SIGNATURE: ::evm_coder::custom_signature::SignatureUnit = <(u32) as ::evm_coder::abi::AbiType>::SIGNATURE; + const SIGNATURE: ::evm_coder::custom_signature::SignatureUnit = <(#sub_type) as ::evm_coder::abi::AbiType>::SIGNATURE; fn is_dynamic() -> bool { - <(u32) as ::evm_coder::abi::AbiType>::is_dynamic() + <(#sub_type) as ::evm_coder::abi::AbiType>::is_dynamic() } fn size() -> usize { - <(u32) as ::evm_coder::abi::AbiType>::size() + <(#sub_type) as ::evm_coder::abi::AbiType>::size() } } - } + }) } -pub fn impl_struct_abi_read(name: &syn::Ident, total_bytes: usize) -> TokenStream { +pub fn impl_struct_abi_read(name: &syn::Ident, total_bytes: usize) -> syn::Result { + let aligned_size = align_size(name, total_bytes)?; let bytes = (0..total_bytes).map(|i| { - let index = ABI_TYPE_SIZE - i - 1; + let index = total_bytes - i - 1; quote! { value[#index] } }); - quote!( + Ok(quote!( impl ::evm_coder::abi::AbiRead for #name { fn abi_read(reader: &mut ::evm_coder::abi::AbiReader) -> ::evm_coder::abi::Result { - let value = reader.uint32()?.to_le_bytes(); + let value = reader.bytes_padleft::<#aligned_size>()?; Ok(#name::from_bytes([#(#bytes),*])) } } - ) + )) } -pub fn impl_struct_abi_write(name: &syn::Ident, total_bytes: usize) -> TokenStream { - let bytes = (0..ABI_TYPE_SIZE).map(|i| { - let index = ABI_TYPE_SIZE - i - 1; - if i < ABI_TYPE_SIZE - total_bytes { +pub fn impl_struct_abi_write(name: &syn::Ident, total_bytes: usize) -> syn::Result { + let aligned_size = align_size(name, total_bytes)?; + let bytes = (0..aligned_size).map(|i| { + if total_bytes < 1 + i { quote! { 0 } } else { - quote! { bytes[#index] } + let index = total_bytes - 1 - i; + quote! { value[#index] } } }); - quote!( + Ok(quote!( impl ::evm_coder::abi::AbiWrite for #name { fn abi_write(&self, writer: &mut ::evm_coder::abi::AbiWriter) { - let bytes = self.clone().into_bytes(); - let value = u32::from_le_bytes([#(#bytes),*]); - <(u32) as ::evm_coder::abi::AbiWrite>::abi_write(&(value), writer) + let value = self.clone().into_bytes(); + writer.bytes_padleft(&[#(#bytes),*]); } } - ) + )) } pub fn impl_struct_solidity_type<'a>( name: &syn::Ident, docs: &[String], + total_bytes: usize, fields: impl Iterator + Clone, ) -> TokenStream { let solidity_name = name.to_string(); let solidity_fields = fields.map(|f| { let name = f.ident.as_ref().to_string(); let docs = f.docs.clone(); - let value = apply_le_math_to_mask(f); - quote! { - SolidityConstant { - docs: &[#(#docs),*], - name: #name, - value: #value, + let (amount_of_bits, zeros_on_left, _, _) = BitMath::from_field(f) + .map(|math| math.into_tuple()) + .unwrap_or((0, 0, 0, 0)); + let zeros_on_right = 8 - (zeros_on_left + amount_of_bits); + if amount_of_bits == 0 { + quote! { + SolidityFlagsField::Bool(SolidityFlagsBool { + docs: &[#(#docs),*], + name: #name, + value: 0, + }) + } + } else if amount_of_bits == 1 { + let value: u8 = 1 << zeros_on_right; + quote! { + SolidityFlagsField::Bool(SolidityFlagsBool { + docs: &[#(#docs),*], + name: #name, + value: #value, + }) + } + } else { + quote! { + SolidityFlagsField::Number(SolidityFlagsNumber { + docs: &[#(#docs),*], + name: #name, + start_bit: #zeros_on_right, + amount_of_bits: #amount_of_bits, + }) } } }); @@ -91,6 +138,7 @@ pub fn impl_struct_solidity_type<'a>( let interface = SolidityLibrary { docs: &[#(#docs),*], name: #solidity_name, + total_bytes: #total_bytes, fields: Vec::from([#( #solidity_fields, )*]), @@ -104,26 +152,6 @@ pub fn impl_struct_solidity_type<'a>( } } -fn apply_le_math_to_mask(field: &FieldInfo) -> TokenStream { - let (amount_of_bits, zeros_on_left, available_bits_in_first_byte, ..) = - if let Ok(math) = BitMath::from_field(field) { - math.into_tuple() - } else { - return quote! { 0 }; - }; - if 8 < (zeros_on_left + amount_of_bits) { - return quote! { 0 }; - } - let zeros_on_right = 8 - (zeros_on_left + amount_of_bits); - // combining the left and right masks will give us a mask that keeps the amount og bytes we - // have in the position we need them to be in for this byte. we use available_bytes for - // right mask because param is amount of 1's on the side specified (right), and - // available_bytes is (8 - zeros_on_left) which is equal to ones_on_right. - let mask = - get_right_and_mask(available_bits_in_first_byte) & get_left_and_mask(8 - zeros_on_right); - quote! { #mask } -} - pub fn impl_struct_solidity_type_name(name: &syn::Ident) -> TokenStream { quote! { #[cfg(feature = "stubgen")] @@ -143,7 +171,7 @@ pub fn impl_struct_solidity_type_name(name: &syn::Ident) -> TokenStream { writer: &mut impl ::core::fmt::Write, tc: &::evm_coder::solidity::TypeCollector, ) -> ::core::fmt::Result { - write!(writer, "{}(0)", tc.collect_struct::()) + write!(writer, "{}.wrap(0)", tc.collect_struct::()) } } } @@ -173,10 +201,11 @@ pub fn expand_flags(ds: &syn::DataStruct, ast: &syn::DeriveInput) -> syn::Result let total_bytes = struct_info.total_bytes(); let can_be_plcaed_in_vec = impl_can_be_placed_in_vec(name); - let abi_type = impl_struct_abi_type(name); - let abi_read = impl_struct_abi_read(name, total_bytes); - let abi_write = impl_struct_abi_write(name, total_bytes); - let solidity_type = impl_struct_solidity_type(name, &docs, struct_info.fields.iter()); + let abi_type = impl_struct_abi_type(name, total_bytes)?; + let abi_read = impl_struct_abi_read(name, total_bytes)?; + let abi_write = impl_struct_abi_write(name, total_bytes)?; + let solidity_type = + impl_struct_solidity_type(name, &docs, total_bytes, struct_info.fields.iter()); let solidity_type_name = impl_struct_solidity_type_name(name); Ok(quote! { #can_be_plcaed_in_vec diff --git a/procedural/src/structs/common.rs b/procedural/src/structs/common.rs index c2fc39d..28f06db 100644 --- a/procedural/src/structs/common.rs +++ b/procedural/src/structs/common.rs @@ -11,36 +11,6 @@ use crate::{ }, }; -/// Returns a u8 mask with provided `num` amount of 1's on the left side (most significant bit) -pub fn get_left_and_mask(num: usize) -> u8 { - match num { - 8 => 0b11111111, - 7 => 0b11111110, - 6 => 0b11111100, - 5 => 0b11111000, - 4 => 0b11110000, - 3 => 0b11100000, - 2 => 0b11000000, - 1 => 0b10000000, - _ => 0b00000000, - } -} - -/// Returns a u8 mask with provided `num` amount of 1's on the right side (least significant bit) -pub fn get_right_and_mask(num: usize) -> u8 { - match num { - 8 => 0b11111111, - 7 => 0b01111111, - 6 => 0b00111111, - 5 => 0b00011111, - 4 => 0b00001111, - 3 => 0b00000111, - 2 => 0b00000011, - 1 => 0b00000001, - _ => 0b00000000, - } -} - pub struct BitMath { pub amount_of_bits: usize, pub zeros_on_left: usize, diff --git a/src/abi/mod.rs b/src/abi/mod.rs index 0e02137..5e7e762 100644 --- a/src/abi/mod.rs +++ b/src/abi/mod.rs @@ -155,6 +155,20 @@ impl<'i> AbiReader<'i> { self.read_padright() } + /// Read [`[u8; S]`] padded left at current position, then advance + pub fn bytes_padleft(&mut self) -> Result<[u8; S]> { + let offset = self.offset; + self.offset += ABI_ALIGNMENT; + Self::read_pad( + self.buf, + offset, + offset, + offset + ABI_ALIGNMENT - S, + offset + ABI_ALIGNMENT - S, + offset + ABI_ALIGNMENT, + ) + } + /// Read [`Vec`] at current position, then advance pub fn bytes(&mut self) -> Result> { let mut subresult = self.subresult(None)?; @@ -357,6 +371,11 @@ impl AbiWriter { self.memory(value); } + /// Write [`bytes`] to end of buffer + pub fn bytes_padleft(&mut self, block: &[u8]) { + self.write_padleft(block); + } + /// Finish writer, concatenating all internal buffers #[must_use] pub fn finish(mut self) -> Vec { diff --git a/src/solidity/mod.rs b/src/solidity/mod.rs index 3eb39c5..efc1bd9 100644 --- a/src/solidity/mod.rs +++ b/src/solidity/mod.rs @@ -50,19 +50,19 @@ impl TypeCollector { pub fn collect_tuple(&self) -> String { let names = T::fields(self); if let Some(id) = self.anonymous.borrow().get(&names).cloned() { - return format!("Tuple{}", id); + return format!("Tuple{id}"); } let id = self.next_id(); let mut str = String::new(); writeln!(str, "/// @dev anonymous struct").unwrap(); - writeln!(str, "struct Tuple{} {{", id).unwrap(); + writeln!(str, "struct Tuple{id} {{").unwrap(); for (i, name) in names.iter().enumerate() { - writeln!(str, "\t{} field_{};", name, i).unwrap(); + writeln!(str, "\t{name} field_{i};").unwrap(); } writeln!(str, "}}").unwrap(); self.collect(str); self.anonymous.borrow_mut().insert(names, id); - format!("Tuple{}", id) + format!("Tuple{id}") } pub fn collect_struct(&self) -> String { T::generate_solidity_interface(self) @@ -269,9 +269,9 @@ impl SolidityFunctions for SolidityF writer: &mut impl fmt::Write, tc: &TypeCollector, ) -> fmt::Result { - let hide_comment = self.hide.then_some("// ").unwrap_or(""); + let hide_comment = if self.hide { "// " } else { "" }; for doc in self.docs { - writeln!(writer, "\t{hide_comment}///{}", doc)?; + writeln!(writer, "\t{hide_comment}///{doc}")?; } writeln!( writer, @@ -364,7 +364,7 @@ impl SolidityInterface { ) -> fmt::Result { const ZERO_BYTES: [u8; 4] = [0; 4]; for doc in self.docs { - writeln!(out, "///{}", doc)?; + writeln!(out, "///{doc}")?; } if self.selector != ZERO_BYTES { writeln!( @@ -385,7 +385,7 @@ impl SolidityInterface { if i != 0 { write!(out, ",")?; } - write!(out, " {}", n)?; + write!(out, " {n}")?; } } writeln!(out, " {{")?; @@ -437,11 +437,12 @@ where { fn solidity_name(&self, out: &mut impl fmt::Write, tc: &TypeCollector) -> fmt::Result { for doc in self.docs { - writeln!(out, "///{}", doc)?; + writeln!(out, "///{doc}")?; } write!(out, "\t")?; T::solidity_name(out, tc)?; - writeln!(out, " {};", self.name)?; + let field_name = self.name; + writeln!(out, " {field_name};",)?; Ok(()) } } @@ -457,7 +458,7 @@ where { pub fn format(&self, out: &mut impl fmt::Write, tc: &TypeCollector) -> fmt::Result { for doc in self.docs { - writeln!(out, "///{}", doc)?; + writeln!(out, "///{doc}")?; } writeln!(out, "struct {} {{", self.name)?; self.fields.solidity_name(out, tc)?; @@ -473,7 +474,7 @@ pub struct SolidityEnumVariant { impl SolidityItems for SolidityEnumVariant { fn solidity_name(&self, out: &mut impl fmt::Write, _tc: &TypeCollector) -> fmt::Result { for doc in self.docs { - writeln!(out, "///{}", doc)?; + writeln!(out, "///{doc}")?; } write!(out, "\t{}", self.name)?; Ok(()) @@ -487,9 +488,10 @@ pub struct SolidityEnum { impl SolidityEnum { pub fn format(&self, out: &mut impl fmt::Write, tc: &TypeCollector) -> fmt::Result { for doc in self.docs { - writeln!(out, "///{}", doc)?; + writeln!(out, "///{doc}")?; } - write!(out, "enum {} {{", self.name)?; + let name = self.name; + write!(out, "enum {name} {{")?; for (i, field) in self.fields.iter().enumerate() { if i != 0 { write!(out, ",")?; @@ -503,35 +505,79 @@ impl SolidityEnum { } } -pub struct SolidityConstant { +pub enum SolidityFlagsField { + Bool(SolidityFlagsBool), + Number(SolidityFlagsNumber), +} + +impl SolidityFlagsField { + pub fn docs(&self) -> &'static [&'static str] { + match self { + Self::Bool(field) => field.docs, + Self::Number(field) => field.docs, + } + } +} + +pub struct SolidityFlagsBool { pub docs: &'static [&'static str], pub name: &'static str, pub value: u8, } +pub struct SolidityFlagsNumber { + pub docs: &'static [&'static str], + pub name: &'static str, + pub start_bit: usize, + pub amount_of_bits: usize, +} + pub struct SolidityLibrary { pub docs: &'static [&'static str], pub name: &'static str, - pub fields: Vec, + pub total_bytes: usize, + pub fields: Vec, } impl SolidityLibrary { pub fn format(&self, out: &mut impl fmt::Write) -> fmt::Result { for doc in self.docs { - writeln!(out, "///{}", doc)?; + writeln!(out, "///{doc}")?; } - writeln!(out, "type {} is uint32;", self.name)?; - write!(out, "library {}Lib {{", self.name)?; - for (i, field) in self.fields.iter().enumerate() { + let total_bytes = self.total_bytes; + let abi_type = match total_bytes { + 1 => "uint8", + 2..=4 => "uint32", + 5..=8 => "uint64", + _ => return Err(fmt::Error), + }; + let lib_name = self.name; + writeln!(out, "type {lib_name} is {abi_type};")?; + write!(out, "library {lib_name}Lib {{")?; + for field in self.fields.iter() { writeln!(out)?; - for doc in field.docs { - writeln!(out, "///{}", doc)?; + for doc in field.docs() { + writeln!(out, "///{doc}")?; + } + match field { + SolidityFlagsField::Bool(field) => { + let field_name = field.name; + let field_value = field.value; + write!( + out, + "\t{lib_name} constant {field_name}Field = {lib_name}.wrap({field_value});" + )?; + } + SolidityFlagsField::Number(field) => { + let field_name = field.name; + let amount_of_bits = field.amount_of_bits; + let start_bit = field.start_bit; + write!( + out, + "\tfunction {field_name}Field({abi_type} value) public pure returns ({lib_name}) {{\n\t\trequire(value < 1 << {amount_of_bits}, \"out of bound value\");\n\t\treturn {lib_name}.wrap(value << {start_bit});\n\t}}" + )?; + } } - write!( - out, - "\t{} constant {}Field = {}.wrap({});", - self.name, field.name, self.name, field.value - )?; } writeln!(out)?; writeln!(out, "}}")?; diff --git a/tests/abi_derive_generation.rs b/tests/abi_derive_generation.rs index 1c2e8b1..1c27bd2 100644 --- a/tests/abi_derive_generation.rs +++ b/tests/abi_derive_generation.rs @@ -758,23 +758,6 @@ mod test_flags { blue: u8, } - #[derive(AbiCoderFlags, Bitfields, Debug, PartialEq, Default, Clone, Copy)] - #[bondrewd(enforce_bytes = 2)] - struct MultipleBytes { - #[bondrewd(bits = "0..1")] - a: bool, - #[bondrewd(bits = "1..2")] - b: bool, - #[bondrewd(bits = "2..8")] - c: u8, - #[bondrewd(bits = "8..14")] - d: u8, - #[bondrewd(bits = "14..15")] - e: bool, - #[bondrewd(bits = "15..16")] - f: bool, - } - #[test] fn empty() {} @@ -790,9 +773,7 @@ mod test_flags { ::SIGNATURE .as_str() .unwrap(), - ::SIGNATURE - .as_str() - .unwrap() + ::SIGNATURE.as_str().unwrap() ); } @@ -800,7 +781,7 @@ mod test_flags { fn impl_abi_type_is_dynamic_same_for_structs() { assert_eq!( ::is_dynamic(), - ::is_dynamic() + ::is_dynamic() ); } @@ -808,7 +789,7 @@ mod test_flags { fn impl_abi_type_size_same_for_structs() { assert_eq!( ::size(), - ::size() + ::size() ); } @@ -831,9 +812,9 @@ mod test_flags { let encoded_u32 = { let mut writer = evm_coder::abi::AbiWriter::new_call(FUNCTION_IDENTIFIER); - let color_int = u32::from_le_bytes([0, 0, 0, color.into_bytes()[0]]); + let color_int = u8::from_le_bytes(color.into_bytes()); - ::abi_write(&color_int, &mut writer); + ::abi_write(&color_int, &mut writer); writer.finish() }; @@ -847,6 +828,23 @@ mod test_flags { } } + #[derive(AbiCoderFlags, Bitfields, Debug, PartialEq, Default, Clone, Copy)] + #[bondrewd(enforce_bytes = 2)] + struct MultipleBytes { + #[bondrewd(bits = "0..1")] + a: bool, + #[bondrewd(bits = "1..2")] + b: bool, + #[bondrewd(bits = "2..8")] + c: u8, + #[bondrewd(bits = "8..14")] + d: u8, + #[bondrewd(bits = "14..15")] + e: bool, + #[bondrewd(bits = "15..16")] + f: bool, + } + #[test] fn test_coder_two_bytes() { const FUNCTION_IDENTIFIER: u32 = 0xdeadbeef; @@ -869,7 +867,7 @@ mod test_flags { let encoded_u32 = { let mut writer = evm_coder::abi::AbiWriter::new_call(FUNCTION_IDENTIFIER); let bytes = data.into_bytes(); - let data_int = u32::from_le_bytes([0, 0, bytes[1], bytes[0]]); + let data_int = u32::from_be_bytes([bytes[1], bytes[0], 0, 0]); ::abi_write(&data_int, &mut writer); writer.finish() @@ -884,4 +882,53 @@ mod test_flags { assert_eq!(restored_flags_data, data); } } + + /// Cross account struct + #[derive(AbiCoderFlags, Bitfields, Clone, Copy, PartialEq, Eq, Debug, Default)] + #[bondrewd(enforce_bytes = 1)] + pub struct Flags { + #[bondrewd(bits = "0..1")] + pub a: bool, + #[bondrewd(bits = "1..2")] + pub b: bool, + #[bondrewd(bits = "2..7", reverse)] + pub c: u8, + #[bondrewd(bits = "7..8")] + pub d: bool, + } + + #[test] + fn test_creation_from_flags() { + const FUNCTION_IDENTIFIER: u32 = 0xdeadbeef; + + let data = Flags { + a: true, + b: true, + c: 3, + d: false, + }; + + let data_int = (1u8 << 7) + (1u8 << 6) + (3u8 << 1); + + let encoded_flags = { + let mut writer = evm_coder::abi::AbiWriter::new_call(FUNCTION_IDENTIFIER); + ::abi_write(&data, &mut writer); + writer.finish() + }; + + let encoded_u8 = { + let mut writer = evm_coder::abi::AbiWriter::new_call(FUNCTION_IDENTIFIER); + ::abi_write(&data_int, &mut writer); + writer.finish() + }; + + similar_asserts::assert_eq!(encoded_flags, encoded_u8); + + { + let (_, mut decoder) = evm_coder::abi::AbiReader::new_call(&encoded_u8).unwrap(); + let restored_flags_data = + ::abi_read(&mut decoder).unwrap(); + assert_eq!(restored_flags_data, data); + } + } }