diff --git a/README.md b/README.md index 19813ec..7d17754 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,11 @@ contents of the file. $ bulloak scaffold -wf ./**/*.tree ``` +Note all tests are showing as passing when their body is empty. To prevent this, +you can use the `-S` (or `--vm-skip`) option to add a `vm.skip(true);` at the +beginning of each test function. This option will also add an import for +forge-std's `Test.sol` and all test contracts will inherit from it. + ### Check That Your Code And Spec Match You can use `bulloak check` to make sure that your Solidity files match your diff --git a/benches/bench.rs b/benches/bench.rs index c0bcf2f..b4f7399 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -4,7 +4,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; fn big_tree(c: &mut Criterion) { let tree = std::fs::read_to_string("benches/bench_data/cancel.tree").unwrap(); - let scaffolder = bulloak::scaffold::Scaffolder::new("some_version"); + let scaffolder = bulloak::scaffold::Scaffolder::new("some_version", true); let mut group = c.benchmark_group("sample-size-10"); // group.sample_size(10); group.bench_function("big-tree", |b| { diff --git a/src/check/context.rs b/src/check/context.rs index e5c0dd4..ee30e0b 100644 --- a/src/check/context.rs +++ b/src/check/context.rs @@ -38,7 +38,7 @@ impl Context { pub(crate) fn new(tree: PathBuf) -> Result { let tree_path_cow = tree.to_string_lossy(); let tree_contents = try_read_to_string(&tree)?; - let hir = crate::hir::translate(&tree_contents).map_err(|e| { + let hir = crate::hir::translate(&tree_contents, false).map_err(|e| { Violation::new( ViolationKind::ParsingFailed(e), Location::File(tree_path_cow.into_owned()), diff --git a/src/check/violation.rs b/src/check/violation.rs index 5c75bec..3144290 100644 --- a/src/check/violation.rs +++ b/src/check/violation.rs @@ -178,7 +178,8 @@ impl ViolationKind { pub(crate) fn fix(&self, mut ctx: Context) -> Context { match self { ViolationKind::ContractMissing(_) => { - let pt = sol::Translator::new(INTERNAL_DEFAULT_SOL_VERSION).translate(&ctx.hir); + let pt = + sol::Translator::new(INTERNAL_DEFAULT_SOL_VERSION, false).translate(&ctx.hir); let source = sol::Formatter::new().emit(pt.clone()); let parsed = parse(&source).expect("should parse solidity string"); ctx.from_parsed(parsed) diff --git a/src/hir/combiner.rs b/src/hir/combiner.rs index 2e84e57..fabe9ce 100644 --- a/src/hir/combiner.rs +++ b/src/hir/combiner.rs @@ -267,13 +267,13 @@ mod tests { use crate::syntax::parser::Parser; use crate::syntax::tokenizer::Tokenizer; - fn translate(text: &str) -> Result { + fn translate(text: &str, with_vm_skip: bool) -> Result { let tokens = Tokenizer::new().tokenize(&text)?; let ast = Parser::new().parse(&text, &tokens)?; let mut discoverer = modifiers::ModifierDiscoverer::new(); let modifiers = discoverer.discover(&ast); - Ok(hir::translator::Translator::new().translate(&ast, modifiers)) + Ok(hir::translator::Translator::new().translate(&ast, modifiers, with_vm_skip)) } fn combine(text: &str, hirs: Vec) -> Result { @@ -307,6 +307,10 @@ mod tests { }) } + fn statement(ty: hir::StatementType) -> Hir { + Hir::Statement(hir::Statement { ty }) + } + fn comment(lexeme: String) -> Hir { Hir::Comment(hir::Comment { lexeme }) } @@ -317,7 +321,10 @@ mod tests { "::orphanedFunction\n└── when something bad happens\n └── it should revert", "Contract::function\n└── when something bad happens\n └── it should revert", ]; - let hirs = trees.iter().map(|tree| translate(tree).unwrap()).collect(); + let hirs = trees + .iter() + .map(|tree| translate(tree, true).unwrap()) + .collect(); let text = trees.join("\n\n"); let result = combine(&text, hirs); @@ -330,7 +337,10 @@ mod tests { "Contract::function\n└── when something bad happens\n └── it should revert", "::orphanedFunction\n└── when something bad happens\n └── it should revert", ]; - let hirs = trees.iter().map(|tree| translate(tree).unwrap()).collect(); + let hirs = trees + .iter() + .map(|tree| translate(tree, true).unwrap()) + .collect(); let expected = r"••••••••••••••••••••••••••••••••••••••••••••••••••••••••••••••••••••••••••••••• bulloak error: contract name missing at tree root #2"; @@ -348,7 +358,10 @@ bulloak error: contract name missing at tree root #2"; "Contract::function1\n└── when something bad happens\n └── it should revert", "Contract::function2\n└── when something shit happens\n └── it should revert", ]; - let mut hirs: Vec<_> = trees.iter().map(|tree| translate(tree).unwrap()).collect(); + let mut hirs: Vec<_> = trees + .iter() + .map(|tree| translate(tree, true).unwrap()) + .collect(); // Append a comment HIR to the hirs. hirs.push(root(vec![comment("this is a random comment".to_owned())])); @@ -369,14 +382,20 @@ bulloak error: contract name missing at tree root #2"; hir::FunctionTy::Function, Span::new(Position::new(20, 2, 1), Position::new(86, 3, 24)), None, - Some(vec![comment("it should revert".to_owned())]) + Some(vec![ + comment("it should revert".to_owned()), + statement(hir::StatementType::VmSkip) + ]) ), function( "test_Function2RevertWhen_SomethingShitHappens".to_owned(), hir::FunctionTy::Function, Span::new(Position::new(20, 2, 1), Position::new(87, 3, 24)), None, - Some(vec![comment("it should revert".to_owned())]) + Some(vec![ + comment("it should revert".to_owned()), + statement(hir::StatementType::VmSkip) + ]) ), ] )] @@ -389,7 +408,10 @@ bulloak error: contract name missing at tree root #2"; "Contract::function1\n└── when something bad happens\n └── given something else happens\n └── it should revert", "Contract::function2\n└── when something bad happens\n └── given the caller is 0x1337\n └── it should revert", ]; - let mut hirs: Vec<_> = trees.iter().map(|tree| translate(tree).unwrap()).collect(); + let mut hirs: Vec<_> = trees + .iter() + .map(|tree| translate(tree, true).unwrap()) + .collect(); // Append a comment HIR to the hirs. hirs.push(root(vec![comment("this is a random comment".to_owned())])); @@ -418,14 +440,20 @@ bulloak error: contract name missing at tree root #2"; hir::FunctionTy::Function, Span::new(Position::new(61, 3, 5), Position::new(133, 4, 28)), Some(vec!["whenSomethingBadHappens".to_owned()]), - Some(vec![comment("it should revert".to_owned())]) + Some(vec![ + comment("it should revert".to_owned()), + statement(hir::StatementType::VmSkip) + ]) ), function( "test_Function2RevertGiven_TheCallerIs0x1337".to_owned(), hir::FunctionTy::Function, Span::new(Position::new(61, 3, 5), Position::new(131, 4, 28)), Some(vec!["whenSomethingBadHappens".to_owned()]), - Some(vec![comment("it should revert".to_owned())]) + Some(vec![ + comment("it should revert".to_owned()), + statement(hir::StatementType::VmSkip) + ]) ), ] )] diff --git a/src/hir/hir.rs b/src/hir/hir.rs index 31cc213..576d002 100644 --- a/src/hir/hir.rs +++ b/src/hir/hir.rs @@ -22,6 +22,8 @@ pub enum Hir { FunctionDefinition(FunctionDefinition), /// A comment. Comment(Comment), + /// A Statement. + Statement(Statement), } impl Default for Hir { @@ -129,3 +131,17 @@ pub struct Comment { /// The contract name. pub lexeme: String, } + +/// The statements which are currently supported. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum StatementType { + /// The `vm.skip(true);` statement. + VmSkip, +} + +/// A statement node. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Statement { + /// The statement. + pub ty: StatementType, +} diff --git a/src/hir/mod.rs b/src/hir/mod.rs index 709de6d..d9bc6c0 100644 --- a/src/hir/mod.rs +++ b/src/hir/mod.rs @@ -12,29 +12,35 @@ pub use hir::*; /// /// This function leverages `translate_tree_to_hir` to generate the HIR for each tree, /// and `crate::hir::combiner::Combiner::combine` to combine the HIRs into a single HIR. -pub fn translate(text: &str) -> anyhow::Result { - Ok(translate_and_combine_trees(text)?) +pub fn translate(text: &str, add_vm_skip: bool) -> anyhow::Result { + Ok(translate_and_combine_trees(text, add_vm_skip)?) } /// Generates the HIR for a single tree. /// /// This function leverages `crate::syntax::parse` and `crate::hir::translator::Translator::translate` /// to hide away most of the complexity of `bulloak`'s internal compiler. -pub fn translate_tree_to_hir(tree: &str) -> crate::error::Result { +pub fn translate_tree_to_hir( + tree: &str, + add_vm_skip: bool, +) -> crate::error::Result { let ast = crate::syntax::parse(tree)?; let mut discoverer = crate::scaffold::modifiers::ModifierDiscoverer::new(); let modifiers = discoverer.discover(&ast); - Ok(crate::hir::translator::Translator::new().translate(&ast, modifiers)) + Ok(crate::hir::translator::Translator::new().translate(&ast, modifiers, add_vm_skip)) } /// High-level function that returns a HIR given the contents of a `.tree` file. /// /// This function leverages `translate_tree_to_hir` to generate the HIR for each tree, /// and `crate::hir::combiner::Combiner::combine` to combine the HIRs into a single HIR. -pub(crate) fn translate_and_combine_trees(text: &str) -> crate::error::Result { +pub(crate) fn translate_and_combine_trees( + text: &str, + add_vm_skip: bool, +) -> crate::error::Result { let trees = crate::utils::split_trees(text); let hirs = trees - .map(translate_tree_to_hir) + .map(|tree| translate_tree_to_hir(tree, add_vm_skip)) .collect::>>()?; Ok(crate::hir::combiner::Combiner::new().combine(text, hirs)?) } diff --git a/src/hir/translator.rs b/src/hir/translator.rs index 3a1b106..f43ef8b 100644 --- a/src/hir/translator.rs +++ b/src/hir/translator.rs @@ -1,6 +1,5 @@ //! The implementation of a translator between a bulloak tree AST and a //! high-level intermediate representation (HIR) -- AST -> HIR. - use indexmap::IndexMap; use crate::hir::{self, Hir}; @@ -27,8 +26,13 @@ impl Translator { /// /// This function is the entry point of the translator. #[must_use] - pub fn translate(&self, ast: &ast::Ast, modifiers: &IndexMap) -> Hir { - TranslatorI::new(modifiers).translate(ast) + pub fn translate( + &self, + ast: &ast::Ast, + modifiers: &IndexMap, + with_vm_skip: bool, + ) -> Hir { + TranslatorI::new(modifiers, with_vm_skip).translate(ast) } } @@ -50,14 +54,17 @@ struct TranslatorI<'a> { /// to improve performance. Otherwise each title would be converted /// to a modifier every time it is used. modifiers: &'a IndexMap, + /// Whether to add `vm.skip(true)` at the beginning of each test. + with_vm_skip: bool, } impl<'a> TranslatorI<'a> { /// Creates a new internal translator. - fn new(modifiers: &'a IndexMap) -> Self { + fn new(modifiers: &'a IndexMap, with_vm_skip: bool) -> Self { Self { modifier_stack: Vec::new(), modifiers, + with_vm_skip, } } @@ -107,7 +114,15 @@ impl<'a> Visitor for TranslatorI<'a> { let test_name = sanitize(&test_name); let test_name = format!("test_{test_name}"); - let hirs = self.visit_action(action)?; + let mut hirs = self.visit_action(action)?; + + // Include any optional statement for the first function node. + if self.with_vm_skip { + hirs.push(Hir::Statement(hir::Statement { + ty: hir::StatementType::VmSkip, + })); + } + let hir = Hir::FunctionDefinition(hir::FunctionDefinition { identifier: test_name, ty: hir::FunctionTy::Function, @@ -226,6 +241,13 @@ impl<'a> Visitor for TranslatorI<'a> { Some(self.modifier_stack.iter().map(|&m| m.to_owned()).collect()) }; + // Add a `vm.skip(true);` at the start of the function. + if self.with_vm_skip { + actions.push(Hir::Statement(hir::Statement { + ty: hir::StatementType::VmSkip, + })); + } + let hir = Hir::FunctionDefinition(hir::FunctionDefinition { identifier: function_name, ty: hir::FunctionTy::Function, @@ -289,13 +311,13 @@ mod tests { use crate::syntax::parser::Parser; use crate::syntax::tokenizer::Tokenizer; - fn translate(text: &str) -> Result { + fn translate(text: &str, with_vm_skip: bool) -> Result { let tokens = Tokenizer::new().tokenize(&text)?; let ast = Parser::new().parse(&text, &tokens)?; let mut discoverer = modifiers::ModifierDiscoverer::new(); let modifiers = discoverer.discover(&ast); - Ok(hir::translator::Translator::new().translate(&ast, modifiers)) + Ok(hir::translator::Translator::new().translate(&ast, modifiers, with_vm_skip)) } fn root(children: Vec) -> Hir { @@ -325,6 +347,10 @@ mod tests { }) } + fn statement(ty: hir::StatementType) -> Hir { + Hir::Statement(hir::Statement { ty }) + } + fn comment(lexeme: String) -> Hir { Hir::Comment(hir::Comment { lexeme }) } @@ -332,7 +358,11 @@ mod tests { #[test] fn one_child() { assert_eq!( - translate("Foo_Test\n└── when something bad happens\n └── it should revert").unwrap(), + translate( + "Foo_Test\n└── when something bad happens\n └── it should revert", + true + ) + .unwrap(), root(vec![contract( "Foo_Test".to_owned(), vec![function( @@ -340,7 +370,10 @@ mod tests { hir::FunctionTy::Function, Span::new(Position::new(9, 2, 1), Position::new(74, 3, 23)), None, - Some(vec![comment("it should revert".to_owned())]) + Some(vec![ + comment("it should revert".to_owned()), + statement(hir::StatementType::VmSkip) + ]) ),] )]) ); @@ -354,7 +387,8 @@ mod tests { ├── when stuff called │ └── it should revert └── given not stuff called - └── it should revert" + └── it should revert", + true ) .unwrap(), root(vec![contract( @@ -365,14 +399,20 @@ mod tests { hir::FunctionTy::Function, Span::new(Position::new(19, 2, 1), Position::new(77, 3, 23)), None, - Some(vec![comment("it should revert".to_owned())]) + Some(vec![ + comment("it should revert".to_owned()), + statement(hir::StatementType::VmSkip) + ]) ), function( "test_RevertGiven_NotStuffCalled".to_owned(), hir::FunctionTy::Function, Span::new(Position::new(79, 4, 1), Position::new(140, 5, 23)), None, - Some(vec![comment("it should revert".to_owned())]) + Some(vec![ + comment("it should revert".to_owned()), + statement(hir::StatementType::VmSkip) + ]) ), ] )]) @@ -394,7 +434,7 @@ Foo_Test ); assert_eq!( - translate(&file_contents)?, + translate(&file_contents, true)?, root(vec![contract( "Foo_Test".to_owned(), vec![ @@ -412,7 +452,8 @@ Foo_Test Some(vec!["whenStuffCalled".to_owned()]), Some(vec![ comment("It should do stuff.".to_owned()), - comment("It should do more.".to_owned()) + comment("It should do more.".to_owned()), + statement(hir::StatementType::VmSkip) ]) ), function( @@ -420,14 +461,20 @@ Foo_Test hir::FunctionTy::Function, Span::new(Position::new(76, 5, 5), Position::new(135, 6, 28)), Some(vec!["whenStuffCalled".to_owned()]), - Some(vec![comment("it should revert".to_owned())]) + Some(vec![ + comment("it should revert".to_owned()), + statement(hir::StatementType::VmSkip) + ]) ), function( "test_WhenBCalled".to_owned(), hir::FunctionTy::Function, Span::new(Position::new(174, 8, 5), Position::new(235, 9, 32)), Some(vec!["whenStuffCalled".to_owned()]), - Some(vec![comment("it should not revert".to_owned())]) + Some(vec![ + comment("it should not revert".to_owned()), + statement(hir::StatementType::VmSkip) + ]) ), ] )]) diff --git a/src/hir/visitor.rs b/src/hir/visitor.rs index 3a74d36..dcbb764 100644 --- a/src/hir/visitor.rs +++ b/src/hir/visitor.rs @@ -15,6 +15,8 @@ pub trait Visitor { type FunctionDefinitionOutput; /// The result of visiting a `Comment`. type CommentOutput; + /// The result of visiting a `Statement`. + type StatementOutput; /// An error that might occur when visiting the HIR. type Error; @@ -59,4 +61,16 @@ pub trait Visitor { /// A `Result` containing either the output of visiting the comment node or an error. fn visit_comment(&mut self, comment: &hir::Comment) -> Result; + + /// Visits a statement node within the HIR. + /// + /// # Arguments + /// * `statement` - A reference to the statement node in the HIR. + /// + /// # Returns + /// A `Result` containing either the output of visiting the statement node or an error. + fn visit_statement( + &mut self, + statement: &hir::Statement, + ) -> Result; } diff --git a/src/scaffold/emitter.rs b/src/scaffold/emitter.rs index d99b55e..c9ca315 100644 --- a/src/scaffold/emitter.rs +++ b/src/scaffold/emitter.rs @@ -66,6 +66,9 @@ impl<'s> EmitterI<'s> { Hir::ContractDefinition(ref inner) => self.visit_contract(inner).unwrap(), Hir::FunctionDefinition(ref inner) => self.visit_function(inner).unwrap(), Hir::Comment(ref inner) => self.visit_comment(inner).unwrap(), + Hir::Statement(_) => { + unreachable!("a statement can't be a top-level source unit in Solidity") + } } } @@ -154,6 +157,7 @@ impl<'s> Visitor for EmitterI<'s> { type ContractDefinitionOutput = String; type FunctionDefinitionOutput = String; type CommentOutput = String; + type StatementOutput = String; type Error = (); fn visit_root(&mut self, root: &hir::Root) -> result::Result { @@ -215,6 +219,8 @@ impl<'s> Visitor for EmitterI<'s> { for child in children { if let Hir::Comment(comment) = child { emitted.push_str(&self.visit_comment(comment)?); + } else if let Hir::Statement(statement) = child { + emitted.push_str(&self.visit_statement(statement)?); } } } @@ -236,6 +242,23 @@ impl<'s> Visitor for EmitterI<'s> { Ok(emitted) } + + fn visit_statement( + &mut self, + statement: &hir::Statement, + ) -> result::Result { + let mut emitted = String::new(); + let indentation = self.emitter.indent().repeat(2); + + // Match any supported statement to its string representation + match statement.ty { + hir::StatementType::VmSkip => { + emitted.push_str(format!("{}vm.skip(true);\n", indentation).as_str()); + } + } + + Ok(emitted) + } } #[cfg(test)] @@ -244,16 +267,21 @@ mod tests { use crate::constants::INTERNAL_DEFAULT_SOL_VERSION; use crate::error::Result; - use crate::hir::translate_and_combine_trees; + use crate::hir::{translate_and_combine_trees, Hir, Statement, StatementType}; use crate::scaffold::emitter; - fn scaffold_with_flags(text: &str, indent: usize, version: &str) -> Result { - let hir = translate_and_combine_trees(text)?; + fn scaffold_with_flags( + text: &str, + indent: usize, + version: &str, + with_vm_skip: bool, + ) -> Result { + let hir = translate_and_combine_trees(text, with_vm_skip)?; Ok(emitter::Emitter::new(indent, version).emit(&hir)) } fn scaffold(text: &str) -> Result { - scaffold_with_flags(text, 2, INTERNAL_DEFAULT_SOL_VERSION) + scaffold_with_flags(text, 2, INTERNAL_DEFAULT_SOL_VERSION, false) } #[test] @@ -297,7 +325,7 @@ contract FileTest { let file_contents = String::from("FileTest\n├── it should do st-ff\n└── It never reverts."); assert_eq!( - &scaffold_with_flags(&file_contents, 2, INTERNAL_DEFAULT_SOL_VERSION)?, + &scaffold_with_flags(&file_contents, 2, INTERNAL_DEFAULT_SOL_VERSION, false)?, r"// SPDX-License-Identifier: UNLICENSED pragma solidity 0.8.0; @@ -320,7 +348,7 @@ contract FileTest { ); assert_eq!( - &scaffold_with_flags(&file_contents, 2, INTERNAL_DEFAULT_SOL_VERSION)?, + &scaffold_with_flags(&file_contents, 2, INTERNAL_DEFAULT_SOL_VERSION, false)?, r"// SPDX-License-Identifier: UNLICENSED pragma solidity 0.8.0; @@ -344,7 +372,7 @@ contract FileTest { ); assert_eq!( - &scaffold_with_flags(&file_contents, 2, INTERNAL_DEFAULT_SOL_VERSION)?, + &scaffold_with_flags(&file_contents, 2, INTERNAL_DEFAULT_SOL_VERSION, false)?, r"// SPDX-License-Identifier: UNLICENSED pragma solidity 0.8.0; @@ -372,7 +400,7 @@ contract FileTest { String::from("Fi-eTest\n└── when something bad happens\n └── it should not revert"); assert_eq!( - &scaffold_with_flags(&file_contents, 2, INTERNAL_DEFAULT_SOL_VERSION)?, + &scaffold_with_flags(&file_contents, 2, INTERNAL_DEFAULT_SOL_VERSION, false)?, r"// SPDX-License-Identifier: UNLICENSED pragma solidity 0.8.0; @@ -392,7 +420,7 @@ contract Fi_eTest { String::from("FileTest\n└── when something bad happens\n └── it should not revert"); assert_eq!( - &scaffold_with_flags(&file_contents, 4, INTERNAL_DEFAULT_SOL_VERSION)?, + &scaffold_with_flags(&file_contents, 4, INTERNAL_DEFAULT_SOL_VERSION, false)?, r"// SPDX-License-Identifier: UNLICENSED pragma solidity 0.8.0; @@ -406,6 +434,37 @@ contract FileTest { Ok(()) } + #[test] + fn with_vm_skip() -> Result<()> { + let file_contents = + String::from("FileTest\n└── when something bad happens\n └── it should not revert"); + + assert_eq!( + &scaffold_with_flags(&file_contents, 4, INTERNAL_DEFAULT_SOL_VERSION, true)?, + r"// SPDX-License-Identifier: UNLICENSED +pragma solidity 0.8.0; + +contract FileTest { + function test_WhenSomethingBadHappens() external { + // it should not revert + vm.skip(true); + } +}" + ); + + Ok(()) + } + + #[test] + #[should_panic] + fn with_vm_skip_top_level_statement() { + let hir = Hir::Statement(Statement { + ty: StatementType::VmSkip, + }); + + let _ = emitter::Emitter::new(4, INTERNAL_DEFAULT_SOL_VERSION).emit(&hir); + } + #[test] fn two_children() -> Result<()> { let file_contents = String::from( diff --git a/src/scaffold/mod.rs b/src/scaffold/mod.rs index e1b92e3..917265e 100644 --- a/src/scaffold/mod.rs +++ b/src/scaffold/mod.rs @@ -39,11 +39,14 @@ pub struct Scaffold { /// Sets a Solidity version for the test contracts. #[arg(short = 's', long, default_value = INTERNAL_DEFAULT_SOL_VERSION)] solidity_version: String, + /// Whether to add vm.skip(true) at the begining of each test. + #[arg(short = 'S', long = "vm-skip", default_value = "false")] + with_vm_skip: bool, } impl Scaffold { pub fn run(self) -> anyhow::Result<()> { - let scaffolder = Scaffolder::new(&self.solidity_version); + let scaffolder = Scaffolder::new(&self.solidity_version, self.with_vm_skip); // For each input file, compile it and print it or write it // to the filesystem. @@ -100,19 +103,24 @@ impl Scaffold { pub struct Scaffolder<'s> { /// Sets a Solidity version for the test contracts. solidity_version: &'s str, + /// Whether to add vm.skip(true) at the begining of each test. + with_vm_skip: bool, } impl<'s> Scaffolder<'s> { /// Creates a new scaffolder with the provided configuration. #[must_use] - pub fn new(solidity_version: &'s str) -> Self { - Scaffolder { solidity_version } + pub fn new(solidity_version: &'s str, with_vm_skip: bool) -> Self { + Scaffolder { + solidity_version, + with_vm_skip, + } } /// Generates Solidity code from a `.tree` file. pub fn scaffold(&self, text: &str) -> crate::error::Result { - let hir = translate_and_combine_trees(text)?; - let pt = sol::Translator::new(self.solidity_version).translate(&hir); + let hir = translate_and_combine_trees(text, self.with_vm_skip)?; + let pt = sol::Translator::new(self.solidity_version, self.with_vm_skip).translate(&hir); let source = sol::Formatter::new().emit(pt); let formatted = fmt(&source).expect("should format the emitted solidity code"); diff --git a/src/sol/fmt.rs b/src/sol/fmt.rs index 49ff488..202e8e3 100644 --- a/src/sol/fmt.rs +++ b/src/sol/fmt.rs @@ -10,6 +10,16 @@ use crate::utils::sanitize; use super::visitor::Visitor; +trait Identified { + fn name(&self) -> String; +} + +impl Identified for Base { + fn name(&self) -> String { + self.name.identifiers[0].name.clone() + } +} + pub(crate) struct Formatter; impl Formatter { @@ -73,10 +83,14 @@ impl Visitor for Formatter { result.push(' '); } + // Include any base contract inherited. if !contract.base.is_empty() { + result.push_str("is "); + let mut bases = vec![]; for b in &mut contract.base { - bases.push(format!("{b}")); + let base_name = &b.name(); + bases.push(base_name.to_string()); } result.push_str(&bases.join(", ")); result.push(' '); @@ -208,6 +222,7 @@ impl Visitor for Formatter { Ok(format!("{identifier}")) } } + Expression::FunctionCall(_, _, _) => Ok(format!("{expression};")), expression => Ok(format!("{expression}")), } } diff --git a/src/sol/translator.rs b/src/sol/translator.rs index c982528..41ebd53 100644 --- a/src/sol/translator.rs +++ b/src/sol/translator.rs @@ -18,8 +18,8 @@ use std::cell::Cell; use solang_parser::pt::{ Base, ContractDefinition, ContractPart, ContractTy, Expression, FunctionAttribute, - FunctionDefinition, FunctionTy, Identifier, IdentifierPath, Loc, SourceUnit, SourceUnitPart, - Statement, StringLiteral, Type, VariableDeclaration, Visibility, + FunctionDefinition, FunctionTy, Identifier, IdentifierPath, Import, ImportPath, Loc, + SourceUnit, SourceUnitPart, Statement, StringLiteral, Type, VariableDeclaration, Visibility, }; use crate::hir::visitor::Visitor; @@ -35,14 +35,17 @@ use crate::utils::sanitize; pub(crate) struct Translator { /// The Solidity version to be used in the pragma directive. sol_version: String, + /// A flag indicating if there is a forge-std dependency. + with_forge_std: bool, } impl Translator { /// Create a new translator. #[must_use] - pub(crate) fn new(sol_version: &str) -> Self { + pub(crate) fn new(sol_version: &str, with_forge_std: bool) -> Self { Self { sol_version: sol_version.to_owned(), + with_forge_std: with_forge_std.to_owned(), } } @@ -229,9 +232,12 @@ impl TranslatorI { fn gen_function_statements(&mut self, children: &Vec) -> Result, ()> { let mut stmts = Vec::with_capacity(children.len()); for child in children { + if let Hir::Statement(statement) = child { + stmts.push(self.visit_statement(statement)?); + } if let Hir::Comment(comment) = child { stmts.push(self.visit_comment(comment)?); - }; + } } // If there is at least one child, we add a '\n' @@ -290,6 +296,7 @@ impl Visitor for TranslatorI { type ContractDefinitionOutput = SourceUnitPart; type FunctionDefinitionOutput = ContractPart; type CommentOutput = Statement; + type StatementOutput = Statement; type Error = (); /// Visits the root node of a High-Level Intermediate Representation (HIR) and translates @@ -298,9 +305,9 @@ impl Visitor for TranslatorI { /// of the HIR into a corresponding PT structure. /// /// The translation involves creating a `SourceUnit`, starting with a pragma directive - /// based on the translator's Solidity version, and then iterating over each child node - /// within the root. Each contract definition, is translated and incorporated into the - /// `SourceUnit`. + /// based on the translator's Solidity version as well as optional file imports (e.g. forge-std) + /// if required. It then iterates over each child node within the root. + /// Each contract definition is translated and incorporated into the `SourceUnit`. /// /// # Arguments /// * `root` - A reference to the root of the HIR structure, representing the highest level @@ -337,6 +344,35 @@ impl Visitor for TranslatorI { )); self.bump(";\n"); + // Add the forge-std's Test import, if needed. + if self.translator.with_forge_std { + // Getting the relevant offsets for `import {Test} from "forge-std/Test.sol"`. + let loc_import_start = self.offset.get(); + self.bump("import { "); + let loc_identifier = self.bump("Test"); + self.bump(" } from \""); + let loc_path = self.bump("forge-std/Test.sol"); + + // The import directive `Rename` corresponds to `import {x} from y.sol`. + source_unit.push(SourceUnitPart::ImportDirective(Import::Rename( + ImportPath::Filename(StringLiteral { + loc: loc_path, + unicode: false, + string: "forge-std/Test.sol".to_string(), + }), + vec![( + Identifier { + loc: loc_identifier, + name: "Test".to_string(), + }, + None, + )], + Loc::File(0, loc_import_start, loc_path.end()), + ))); + + self.bump("\";\n"); + } + for child in &root.children { if let Hir::ContractDefinition(contract) = child { source_unit.push(self.visit_contract(contract)?); @@ -375,7 +411,29 @@ impl Visitor for TranslatorI { loc: self.bump(&contract_name), name: contract.identifier.clone(), }); - self.bump(" {"); // `{` after contract identifier. + + let mut contract_base = vec![]; + + // If there is an import, inherit the base contract as well. + if self.translator.with_forge_std { + let base_start = self.offset.get(); + self.bump(" is "); + let base_loc = self.bump("Test"); + let base_identifier_path = IdentifierPath { + loc: base_loc, + identifiers: vec![Identifier { + loc: base_loc, + name: "Test".to_string(), + }], + }; + + contract_base = vec![Base { + loc: Loc::File(0, base_start, base_loc.end()), + name: base_identifier_path, + args: None, + }]; + } + self.bump(" {"); // `{` after contract identifier and base. let mut parts = Vec::with_capacity(contract.children.len()); for child in &contract.children { @@ -388,13 +446,37 @@ impl Visitor for TranslatorI { loc: Loc::File(0, contract_start, self.offset.get()), name: contract_name, ty: contract_ty, - base: vec![], + base: contract_base, parts, }; Ok(SourceUnitPart::ContractDefinition(Box::new(contract_def))) } + /// Visits a `FunctionDefinition` node in the High-Level Intermediate Representation (HIR) + /// and translates it into a `ContractPart` for inclusion in the `solang_parser` parse tree (PT). + /// This function handles the translation of function definitions, converting them into a format + /// suitable for the PT. + /// + /// The translation process involves several steps: + /// 1. Determining the function type and translating it to the corresponding PT representation. + /// 2. Translating the function identifier and storing its location information. + /// 3. Generating function attributes based on the HIR function definition. + /// 4. Translating the function body, including statements and comments, into PT statements. + /// 5. Constructing the final `FunctionDefinition` object with the translated components. + /// + /// # Arguments + /// * `function` - A reference to the `FunctionDefinition` node in the HIR, representing a + /// single function within the HIR structure. + /// + /// # Returns + /// A `Result` containing the `ContractPart::FunctionDefinition` representing the translated + /// function if the translation is successful, or an `Error` otherwise. The `ContractPart` + /// encapsulates the function's PT representation. + /// + /// # Errors + /// This function may return an error if the translation of any component within the function + /// encounters issues, such as failing to translate the function body. fn visit_function( &mut self, function: &hir::FunctionDefinition, @@ -495,4 +577,48 @@ impl Visitor for TranslatorI { Ok(definition) } + + /// Visits a supported statement node and match based on its type. + fn visit_statement( + &mut self, + statement: &hir::Statement, + ) -> Result { + let start_offset = self.offset.get(); + + match statement.ty { + hir::StatementType::VmSkip => { + let loc_vm = self.bump("vm"); + self.bump("."); + let loc_skip = self.bump("skip"); + self.bump("("); + let loc_arg = self.bump("true"); + self.bump(");"); + + let vm_interface = Expression::MemberAccess( + Loc::File(0, start_offset, loc_skip.end()), + Box::new(Expression::Variable(solang_parser::pt::Identifier { + loc: loc_vm, + name: "vm".to_owned(), + })), + solang_parser::pt::Identifier { + loc: loc_skip, + name: "skip".to_owned(), + }, + ); + + let vm_skip_arg = vec![Expression::BoolLiteral(loc_arg, true)]; + + let vm_skip_call = Expression::FunctionCall( + Loc::File(0, loc_skip.start(), loc_arg.end()), + Box::new(vm_interface), + vm_skip_arg, + ); + + Ok(Statement::Expression( + Loc::File(0, start_offset, self.offset.get()), + vm_skip_call, + )) + } + } + } } diff --git a/tests/scaffold.rs b/tests/scaffold.rs index 8531de0..90958f3 100644 --- a/tests/scaffold.rs +++ b/tests/scaffold.rs @@ -34,6 +34,40 @@ fn scaffolds_trees() { } } +#[test] +fn scaffolds_trees_with_vm_skip() { + let cwd = env::current_dir().unwrap(); + let binary_path = get_binary_path(); + let tests_path = cwd.join("tests").join("scaffold"); + let trees = [ + "basic.tree", + "complex.tree", + "multiple_roots.tree", + "removes_invalid_title_chars.tree", + ]; + let args = vec!["--vm-skip"]; + + for tree_name in trees { + let tree_path = tests_path.join(tree_name); + let output = cmd(&binary_path, "scaffold", &tree_path, &args); + let actual = String::from_utf8(output.stdout).unwrap(); + + let mut trimmed_extension = tree_path.clone(); + trimmed_extension.set_extension(""); + + let mut output_file_str = trimmed_extension.into_os_string(); + output_file_str.push("_vm_skip"); + + let mut output_file: std::path::PathBuf = output_file_str.into(); + output_file.set_extension("t.sol"); + + let expected = fs::read_to_string(output_file).unwrap(); + + // We trim here because we don't care about ending newlines. + assert_eq!(expected.trim(), actual.trim()); + } +} + #[test] fn skips_trees_when_file_exists() { let cwd = env::current_dir().unwrap(); diff --git a/tests/scaffold/basic_vm_skip.t.sol b/tests/scaffold/basic_vm_skip.t.sol new file mode 100644 index 0000000..c4880ad --- /dev/null +++ b/tests/scaffold/basic_vm_skip.t.sol @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity 0.8.0; + +import {Test} from "forge-std/Test.sol"; + +contract HashPairTestSanitize is Test { + function test_ShouldNeverRevert() external { + // It should never revert. + vm.skip(true); + } + + modifier whenFirstArgIsSmallerThanSecondArg() { + _; + } + + function test_WhenFirstArgIsSmallerThanSecondArg() external whenFirstArgIsSmallerThanSecondArg { + // It should match the result of `keccak256(abi.encodePacked(a,b))`. + vm.skip(true); + } + + function test_WhenFirstArgIsZero() external whenFirstArgIsSmallerThanSecondArg { + // It should do something. + vm.skip(true); + } + + function test_WhenFirstArgIsBiggerThanSecondArg() external { + // It should match the result of `keccak256(abi.encodePacked(b,a))`. + vm.skip(true); + } +} diff --git a/tests/scaffold/complex_vm_skip.t.sol b/tests/scaffold/complex_vm_skip.t.sol new file mode 100644 index 0000000..f41dfdc --- /dev/null +++ b/tests/scaffold/complex_vm_skip.t.sol @@ -0,0 +1,360 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity 0.8.0; + +import {Test} from "forge-std/Test.sol"; + +contract CancelTest is Test { + function test_RevertWhen_DelegateCalled() external { + // it should revert + vm.skip(true); + } + + modifier whenNotDelegateCalled() { + _; + } + + function test_RevertGiven_TheIdReferencesANullStream() external whenNotDelegateCalled { + // it should revert + vm.skip(true); + } + + modifier givenTheIdDoesNotReferenceANullStream() { + _; + } + + modifier givenTheStreamIsCold() { + _; + } + + function test_RevertGiven_TheStreamsStatusIsDEPLETED() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsCold + { + // it should revert + vm.skip(true); + } + + function test_RevertGiven_TheStreamsStatusIsCANCELED() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsCold + { + // it should revert + vm.skip(true); + } + + function test_RevertGiven_TheStreamsStatusIsSETTLED() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsCold + { + // it should revert + vm.skip(true); + } + + modifier givenTheStreamIsWarm() { + _; + } + + modifier whenTheCallerIsUnauthorized() { + _; + } + + function test_RevertWhen_TheCallerIsAMaliciousThirdParty() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsUnauthorized + { + // it should revert + vm.skip(true); + } + + function test_RevertWhen_TheCallerIsAnApprovedThirdParty() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsUnauthorized + { + // it should revert + vm.skip(true); + } + + function test_RevertWhen_TheCallerIsAFormerRecipient() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsUnauthorized + { + // it should revert + vm.skip(true); + } + + modifier whenTheCallerIsAuthorized() { + _; + } + + function test_RevertGiven_TheStreamIsNotCancelable() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsAuthorized + { + // it should revert + vm.skip(true); + } + + modifier givenTheStreamIsCancelable() { + _; + } + + function test_GivenTheStreamsStatusIsPENDING() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsAuthorized + givenTheStreamIsCancelable + { + // it should cancel the stream + // it should mark the stream as depleted + // it should make the stream not cancelable + vm.skip(true); + } + + modifier givenTheStreamsStatusIsSTREAMING() { + _; + } + + modifier whenTheCallerIsTheSender() { + _; + } + + function test_GivenTheRecipientIsNotAContract() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsAuthorized + givenTheStreamIsCancelable + givenTheStreamsStatusIsSTREAMING + whenTheCallerIsTheSender + { + // it should cancel the stream + // it should mark the stream as canceled + vm.skip(true); + } + + modifier givenTheRecipientIsAContract() { + _; + } + + function test_GivenTheRecipientDoesNotImplementTheHook() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsAuthorized + givenTheStreamIsCancelable + givenTheStreamsStatusIsSTREAMING + whenTheCallerIsTheSender + givenTheRecipientIsAContract + { + // it should cancel the stream + // it should mark the stream as canceled + // it should call the recipient hook + // it should ignore the revert + vm.skip(true); + } + + modifier givenTheRecipientImplementsTheHook() { + _; + } + + function test_WhenTheRecipientReverts() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsAuthorized + givenTheStreamIsCancelable + givenTheStreamsStatusIsSTREAMING + whenTheCallerIsTheSender + givenTheRecipientIsAContract + givenTheRecipientImplementsTheHook + { + // it should cancel the stream + // it should mark the stream as canceled + // it should call the recipient hook + // it should ignore the revert + vm.skip(true); + } + + modifier whenTheRecipientDoesNotRevert() { + _; + } + + function test_WhenThereIsReentrancy1() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsAuthorized + givenTheStreamIsCancelable + givenTheStreamsStatusIsSTREAMING + whenTheCallerIsTheSender + givenTheRecipientIsAContract + givenTheRecipientImplementsTheHook + whenTheRecipientDoesNotRevert + { + // it should cancel the stream + // it should mark the stream as canceled + // it should call the recipient hook + // it should ignore the revert + vm.skip(true); + } + + function test_WhenThereIsNoReentrancy1() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsAuthorized + givenTheStreamIsCancelable + givenTheStreamsStatusIsSTREAMING + whenTheCallerIsTheSender + givenTheRecipientIsAContract + givenTheRecipientImplementsTheHook + whenTheRecipientDoesNotRevert + { + // it should cancel the stream + // it should mark the stream as canceled + // it should make the stream not cancelable + // it should update the refunded amount + // it should refund the sender + // it should call the recipient hook + // it should emit a {CancelLockupStream} event + // it should emit a {MetadataUpdate} event + vm.skip(true); + } + + modifier whenTheCallerIsTheRecipient() { + _; + } + + function test_GivenTheSenderIsNotAContract() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsAuthorized + givenTheStreamIsCancelable + givenTheStreamsStatusIsSTREAMING + whenTheCallerIsTheRecipient + { + // it should cancel the stream + // it should mark the stream as canceled + vm.skip(true); + } + + modifier givenTheSenderIsAContract() { + _; + } + + function test_GivenTheSenderDoesNotImplementTheHook() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsAuthorized + givenTheStreamIsCancelable + givenTheStreamsStatusIsSTREAMING + whenTheCallerIsTheRecipient + givenTheSenderIsAContract + { + // it should cancel the stream + // it should mark the stream as canceled + // it should call the sender hook + // it should ignore the revert + vm.skip(true); + } + + modifier givenTheSenderImplementsTheHook() { + _; + } + + function test_WhenTheSenderReverts() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsAuthorized + givenTheStreamIsCancelable + givenTheStreamsStatusIsSTREAMING + whenTheCallerIsTheRecipient + givenTheSenderIsAContract + givenTheSenderImplementsTheHook + { + // it should cancel the stream + // it should mark the stream as canceled + // it should call the sender hook + // it should ignore the revert + vm.skip(true); + } + + modifier whenTheSenderDoesNotRevert() { + _; + } + + function test_WhenThereIsReentrancy2() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsAuthorized + givenTheStreamIsCancelable + givenTheStreamsStatusIsSTREAMING + whenTheCallerIsTheRecipient + givenTheSenderIsAContract + givenTheSenderImplementsTheHook + whenTheSenderDoesNotRevert + { + // it should cancel the stream + // it should mark the stream as canceled + // it should call the sender hook + // it should ignore the revert + vm.skip(true); + } + + function test_WhenThereIsNoReentrancy2() + external + whenNotDelegateCalled + givenTheIdDoesNotReferenceANullStream + givenTheStreamIsWarm + whenTheCallerIsAuthorized + givenTheStreamIsCancelable + givenTheStreamsStatusIsSTREAMING + whenTheCallerIsTheRecipient + givenTheSenderIsAContract + givenTheSenderImplementsTheHook + whenTheSenderDoesNotRevert + { + // it should cancel the stream + // it should mark the stream as canceled + // it should make the stream not cancelable + // it should update the refunded amount + // it should refund the sender + // it should call the sender hook + // it should emit a {MetadataUpdate} event + // it should emit a {CancelLockupStream} event + vm.skip(true); + } +} diff --git a/tests/scaffold/multiple_roots_vm_skip.t.sol b/tests/scaffold/multiple_roots_vm_skip.t.sol new file mode 100644 index 0000000..7b0b93b --- /dev/null +++ b/tests/scaffold/multiple_roots_vm_skip.t.sol @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity 0.8.0; + +import {Test} from "forge-std/Test.sol"; + +contract MultipleRootsTreeTest is Test { + function test_Function1ShouldNeverRevert() external { + // It should never revert. + vm.skip(true); + } + + function test_Function1WhenFirstArgIsBiggerThanSecondArg() external { + // It is all good + vm.skip(true); + } + + function test_Function2WhenStuffHappens() external { + // It should do something simple + vm.skip(true); + } +} \ No newline at end of file diff --git a/tests/scaffold/removes_invalid_title_chars_vm_skip.t.sol b/tests/scaffold/removes_invalid_title_chars_vm_skip.t.sol new file mode 100644 index 0000000..43d74bf --- /dev/null +++ b/tests/scaffold/removes_invalid_title_chars_vm_skip.t.sol @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity 0.8.0; + +import {Test} from "forge-std/Test.sol"; + +contract Foo is Test { + function test_CantDoX() external { + // It can’t do, X. + vm.skip(true); + } +}