Skip to content

Commit

Permalink
feat: fuse with statements across if statements
Browse files Browse the repository at this point in the history
  • Loading branch information
vberlier committed Feb 23, 2024
1 parent 3b130ad commit d93140a
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 160 deletions.
66 changes: 36 additions & 30 deletions bolt/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def if_statement(
branch = self.helper("branch", condition)
self.statement("with", branch, "as", "_bolt_condition", lineno=lineno)
with self.block():
self.statement(f"if _bolt_condition")
self.statement("if", "_bolt_condition")
with self.block():
yield
self.condition_inverse = inverse
Expand All @@ -304,7 +304,7 @@ def dup(self, target: str, *, lineno: Any = None) -> str:
dup = self.make_variable()
value = self.helper("get_dup", target)
self.statement(f"{dup} = {value}", lineno=lineno)
self.statement(f"if {dup} is not None")
self.statement("if", f"{dup} is not None")
with self.block():
self.statement(f"{target} = {dup}()")
return dup
Expand All @@ -314,13 +314,13 @@ def rebind(self, target: str, op: str, value: str, *, lineno: Any = None):
rebind = self.helper("get_rebind", target)
self.statement(f"_bolt_rebind = {rebind}", lineno=lineno)
self.statement(f"{target} {op} {value}")
self.statement(f"if _bolt_rebind is not None")
self.statement("if", "_bolt_rebind is not None")
with self.block():
self.statement(f"{target} = _bolt_rebind({target})")

def rebind_dup(self, target: str, dup: str, value: str, *, lineno: Any = None):
"""Emit __rebind__() if target was __dup__()."""
self.statement(f"if {dup} is not None")
self.statement("if", f"{dup} is not None")
with self.block():
self.rebind(target, "=", value, lineno=lineno)
self.statement("else")
Expand All @@ -345,7 +345,9 @@ class WithStatementFusion:
@classmethod
def finalize(cls, acc: Accumulator):
with_statement_fusion = cls()
acc.statements = [with_statement_fusion.fuse(statement, acc) for statement in acc.statements]
acc.statements = [
with_statement_fusion.fuse(statement, acc) for statement in acc.statements
]

def convert(self, statement: CodegenStatement, exit_stack: str) -> CodegenStatement:
code = (f"{exit_stack}.enter_context({statement.code[1]})",)
Expand All @@ -356,29 +358,31 @@ def convert(self, statement: CodegenStatement, exit_stack: str) -> CodegenStatem
def fuse(self, statement: CodegenStatement, acc: Accumulator) -> CodegenStatement:
children = [self.fuse(child, acc) for child in statement.children]

if statement.code[0] == "with" and children[-1].code[0] == "with":
nested_statement = children.pop()
if statement.code[0] != "with":
return replace(statement, children=children)

if nested_statement.code[1] == acc.helper("exit_stack"):
exit_stack = nested_statement.code[3]
code = nested_statement.code
else:
exit_stack = f"_bolt_fused_with_statement{self.counter}"
self.counter += 1
code = ("with", acc.helper("exit_stack"), "as", exit_stack)
children.append(self.convert(nested_statement, exit_stack))

return replace(
statement,
code=code,
children=[
self.convert(statement, exit_stack),
*children,
*nested_statement.children,
],
)
nested_children = children
while nested_children[-1].code[0] == "if":
nested_children = nested_children[-1].children

if nested_children[-1].code[0] != "with":
return replace(statement, children=children)

nested_statement = nested_children.pop()

return replace(statement, children=children)
if nested_statement.code[1] == acc.helper("exit_stack"):
exit_stack = nested_statement.code[3]
code = nested_statement.code
else:
exit_stack = f"_bolt_fused_with_statement{self.counter}"
self.counter += 1
code = ("with", acc.helper("exit_stack"), "as", exit_stack)
nested_children.append(self.convert(nested_statement, exit_stack))

children.insert(0, self.convert(statement, exit_stack))
nested_children.extend(nested_statement.children)

return replace(statement, code=code, children=children)


@dataclass
Expand Down Expand Up @@ -711,7 +715,7 @@ def memo(
acc.header[storage] = "None"
if not acc.root_scope:
acc.statement(f"global {storage}", lineno=node)
acc.statement(f"if {storage} is None")
acc.statement("if", f"{storage} is None")
with acc.block():
acc.statement(
f"{storage} = _bolt_runtime.memo.registry[__file__][{acc.make_ref(node)}, {file_index}]"
Expand All @@ -723,7 +727,7 @@ def memo(
invocation = f"_bolt_memo_invocation_{node.persistent_id.hex}"
acc.statement(f"{invocation} = {storage}[({path}, {' '.join(keys)})]")

acc.statement(f"if {invocation}.cached")
acc.statement("if", f"{invocation}.cached")
with acc.block():
acc.statement(f"_bolt_runtime.memo.restore(_bolt_runtime, {invocation})")
if cached_identifiers:
Expand Down Expand Up @@ -798,7 +802,7 @@ def function(

for arg in signature.arguments:
if isinstance(arg, AstFunctionSignatureArgument) and arg.default:
acc.statement(f"if {arg.name} is {acc.missing()}")
acc.statement("if", f"{arg.name} is {acc.missing()}")
with acc.block():
value = yield from visit_single(arg.default, required=True)
acc.statement(f"{arg.name} = {value}")
Expand Down Expand Up @@ -1023,7 +1027,9 @@ def while_statement(
with acc.block():
acc.statement("_bolt_runtime.commands.extend(_bolt_condition_commands)")

acc.statement("if not _bolt_loop_overridden", lineno=node.arguments[0])
acc.statement(
"if", "not _bolt_loop_overridden", lineno=node.arguments[0]
)
with acc.block():
acc.statement(f"{condition} = bool({condition})")

Expand Down
24 changes: 12 additions & 12 deletions tests/snapshots/bolt__parse_132__1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ with _bolt_helper_exit_stack() as _bolt_fused_with_statement0:
with _bolt_helper_branch(_bolt_var1) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[1].commands)
with _bolt_helper_branch(_bolt_var1_inverse) as _bolt_condition:
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var1_inverse))
if _bolt_condition:
_bolt_var2 = 3
_bolt_var2_inverse = _bolt_helper_operator_not(_bolt_var2)
with _bolt_helper_branch(_bolt_var2) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[2].commands)
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var2_inverse))
if _bolt_condition:
_bolt_var2 = 3
_bolt_var2_inverse = _bolt_helper_operator_not(_bolt_var2)
with _bolt_helper_branch(_bolt_var2) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[2].commands)
with _bolt_helper_branch(_bolt_var2_inverse) as _bolt_condition:
if _bolt_condition:
_bolt_var3 = 4
with _bolt_helper_branch(_bolt_var3) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[3].commands)
_bolt_var3 = 4
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var3))
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[3].commands)
_bolt_var5 = _bolt_helper_replace(_bolt_refs[4], commands=_bolt_helper_children(_bolt_var4))
---
output = _bolt_var5
Expand Down
30 changes: 16 additions & 14 deletions tests/snapshots/bolt__parse_133__1.txt
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@
_bolt_lineno = [1, 12, 19, 26], [1, 3, 5, 7]
_bolt_lineno = [1, 13, 21, 28], [1, 3, 5, 7]
_bolt_helper_operator_not = _bolt_runtime.helpers['operator_not']
_bolt_helper_branch = _bolt_runtime.helpers['branch']
_bolt_helper_children = _bolt_runtime.helpers['children']
_bolt_helper_replace = _bolt_runtime.helpers['replace']
_bolt_helper_exit_stack = _bolt_runtime.helpers['exit_stack']
with _bolt_runtime.scope() as _bolt_var4:
_bolt_var0 = 1
_bolt_var0_inverse = _bolt_helper_operator_not(_bolt_var0)
with _bolt_helper_branch(_bolt_var0) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[0].commands)
with _bolt_helper_branch(_bolt_var0_inverse) as _bolt_condition:
with _bolt_helper_exit_stack() as _bolt_fused_with_statement0:
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var0_inverse))
if _bolt_condition:
_bolt_var1 = 2
_bolt_var1_inverse = _bolt_helper_operator_not(_bolt_var1)
with _bolt_helper_branch(_bolt_var1) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[1].commands)
with _bolt_helper_branch(_bolt_var1_inverse) as _bolt_condition:
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var1_inverse))
if _bolt_condition:
_bolt_var2 = 3
_bolt_var2_inverse = _bolt_helper_operator_not(_bolt_var2)
with _bolt_helper_branch(_bolt_var2) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[2].commands)
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var2_inverse))
if _bolt_condition:
_bolt_var2 = 3
_bolt_var2_inverse = _bolt_helper_operator_not(_bolt_var2)
with _bolt_helper_branch(_bolt_var2) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[2].commands)
with _bolt_helper_branch(_bolt_var2_inverse) as _bolt_condition:
if _bolt_condition:
_bolt_var3 = 4
with _bolt_helper_branch(_bolt_var3) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[3].commands)
_bolt_var3 = 4
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var3))
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[3].commands)
_bolt_runtime.commands.append(_bolt_refs[4])
_bolt_var5 = _bolt_helper_replace(_bolt_refs[5], commands=_bolt_helper_children(_bolt_var4))
---
Expand Down
30 changes: 15 additions & 15 deletions tests/snapshots/bolt__parse_134__1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@ with _bolt_helper_exit_stack() as _bolt_fused_with_statement0:
with _bolt_helper_branch(_bolt_var1) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[1].commands)
with _bolt_helper_branch(_bolt_var1_inverse) as _bolt_condition:
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var1_inverse))
if _bolt_condition:
_bolt_var2 = 3
_bolt_var2_inverse = _bolt_helper_operator_not(_bolt_var2)
with _bolt_helper_branch(_bolt_var2) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[2].commands)
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var2_inverse))
if _bolt_condition:
_bolt_var2 = 3
_bolt_var2_inverse = _bolt_helper_operator_not(_bolt_var2)
with _bolt_helper_branch(_bolt_var2) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[2].commands)
with _bolt_helper_branch(_bolt_var2_inverse) as _bolt_condition:
_bolt_var3 = 4
_bolt_var3_inverse = _bolt_helper_operator_not(_bolt_var3)
with _bolt_helper_branch(_bolt_var3) as _bolt_condition:
if _bolt_condition:
_bolt_var3 = 4
_bolt_var3_inverse = _bolt_helper_operator_not(_bolt_var3)
with _bolt_helper_branch(_bolt_var3) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[3].commands)
with _bolt_helper_branch(_bolt_var3_inverse) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[4].commands)
_bolt_runtime.commands.extend(_bolt_refs[3].commands)
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var3_inverse))
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[4].commands)
_bolt_var5 = _bolt_helper_replace(_bolt_refs[5], commands=_bolt_helper_children(_bolt_var4))
---
output = _bolt_var5
Expand Down
36 changes: 19 additions & 17 deletions tests/snapshots/bolt__parse_135__1.txt
Original file line number Diff line number Diff line change
@@ -1,38 +1,40 @@
_bolt_lineno = [1, 12, 19, 26, 33], [1, 3, 5, 7, 9]
_bolt_lineno = [1, 13, 21, 28, 35], [1, 3, 5, 7, 9]
_bolt_helper_operator_not = _bolt_runtime.helpers['operator_not']
_bolt_helper_branch = _bolt_runtime.helpers['branch']
_bolt_helper_children = _bolt_runtime.helpers['children']
_bolt_helper_replace = _bolt_runtime.helpers['replace']
_bolt_helper_exit_stack = _bolt_runtime.helpers['exit_stack']
with _bolt_runtime.scope() as _bolt_var4:
_bolt_var0 = 1
_bolt_var0_inverse = _bolt_helper_operator_not(_bolt_var0)
with _bolt_helper_branch(_bolt_var0) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[0].commands)
with _bolt_helper_branch(_bolt_var0_inverse) as _bolt_condition:
with _bolt_helper_exit_stack() as _bolt_fused_with_statement0:
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var0_inverse))
if _bolt_condition:
_bolt_var1 = 2
_bolt_var1_inverse = _bolt_helper_operator_not(_bolt_var1)
with _bolt_helper_branch(_bolt_var1) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[1].commands)
with _bolt_helper_branch(_bolt_var1_inverse) as _bolt_condition:
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var1_inverse))
if _bolt_condition:
_bolt_var2 = 3
_bolt_var2_inverse = _bolt_helper_operator_not(_bolt_var2)
with _bolt_helper_branch(_bolt_var2) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[2].commands)
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var2_inverse))
if _bolt_condition:
_bolt_var2 = 3
_bolt_var2_inverse = _bolt_helper_operator_not(_bolt_var2)
with _bolt_helper_branch(_bolt_var2) as _bolt_condition:
_bolt_var3 = 4
_bolt_var3_inverse = _bolt_helper_operator_not(_bolt_var3)
with _bolt_helper_branch(_bolt_var3) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[2].commands)
with _bolt_helper_branch(_bolt_var2_inverse) as _bolt_condition:
if _bolt_condition:
_bolt_var3 = 4
_bolt_var3_inverse = _bolt_helper_operator_not(_bolt_var3)
with _bolt_helper_branch(_bolt_var3) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[3].commands)
with _bolt_helper_branch(_bolt_var3_inverse) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[4].commands)
_bolt_runtime.commands.extend(_bolt_refs[3].commands)
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var3_inverse))
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[4].commands)
_bolt_runtime.commands.append(_bolt_refs[5])
_bolt_var5 = _bolt_helper_replace(_bolt_refs[6], commands=_bolt_helper_children(_bolt_var4))
---
Expand Down
6 changes: 3 additions & 3 deletions tests/snapshots/bolt__parse_35__1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ with _bolt_helper_exit_stack() as _bolt_fused_with_statement0:
with _bolt_helper_branch(_bolt_var1) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[1].commands)
with _bolt_helper_branch(_bolt_var1_inverse) as _bolt_condition:
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[2].commands)
_bolt_condition = _bolt_fused_with_statement0.enter_context(_bolt_helper_branch(_bolt_var1_inverse))
if _bolt_condition:
_bolt_runtime.commands.extend(_bolt_refs[2].commands)
_bolt_var6 = _bolt_helper_replace(_bolt_refs[3], commands=_bolt_helper_children(_bolt_var5))
---
output = _bolt_var6
Expand Down
Loading

0 comments on commit d93140a

Please sign in to comment.