From 88b81ab73d62d1bea81be79bd4a8a2fbae825d48 Mon Sep 17 00:00:00 2001 From: Berke Ates Date: Fri, 2 Jul 2021 00:08:02 +0200 Subject: [PATCH] Generic successor blocks (#16) --- AUTHORS | 1 + mlir/astnodes.py | 3 +++ mlir/lark/mlir.lark | 4 +++- tests/test_syntax.py | 19 +++++++++++++++++++ 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 5047613..6583143 100644 --- a/AUTHORS +++ b/AUTHORS @@ -5,3 +5,4 @@ ETH Zurich Tal Ben-Nun Kaushik Kulkarni Mehdi Amini +Berke Ates \ No newline at end of file diff --git a/mlir/astnodes.py b/mlir/astnodes.py index e15c71f..db8b57c 100644 --- a/mlir/astnodes.py +++ b/mlir/astnodes.py @@ -508,6 +508,7 @@ class Op(Node): class GenericOperation(Op): name: str args: Optional[List[SsaId]] + successors: Optional[List[BlockId]] attributes: Optional[AttributeDict] type: List[Type] @@ -519,6 +520,8 @@ def dump(self, indent: int = 0) -> str: result += ', '.join(dump_or_value(arg, indent) for arg in self.args) result += ')' + if self.successors: + result += '[' + dump_or_value(self.successors, indent) + ']' if self.attributes: result += ' ' + dump_or_value(self.attributes, indent) if isinstance(self.type, list): diff --git a/mlir/lark/mlir.lark b/mlir/lark/mlir.lark index 8636fc9..d54dffe 100644 --- a/mlir/lark/mlir.lark +++ b/mlir/lark/mlir.lark @@ -167,7 +167,7 @@ location : string_literal ":" decimal_literal ":" decimal_literal trailing_location : ("loc" "(" location ")") // Undefined operations in all dialects -generic_operation : string_literal "(" optional_ssa_use_list ")" optional_attr_dict trailing_type +generic_operation : string_literal "(" optional_ssa_use_list ")" optional_successor_list optional_attr_dict trailing_type custom_operation : bare_id "." bare_id optional_ssa_use_list trailing_type // Final operation definition @@ -184,6 +184,7 @@ ssa_id_and_type_list : ssa_id_and_type ("," ssa_id_and_type)* operation_list: operation+ block_label : block_id optional_block_arg_list ":" +successor_list : "[" block_id? ("," block_id)* "]" block : optional_block_label operation_list region : "{" block* "}" @@ -211,6 +212,7 @@ region : "{" block* "}" ?optional_memory_space : ("," memory_space)? -> optional ?optional_block_label : block_label? -> optional ?optional_symbol_use_list : symbol_use_list? -> optional +?optional_successor_list : successor_list? -> optional // ---------------------------------------------------------------------- // Modules and functions diff --git a/tests/test_syntax.py b/tests/test_syntax.py index a785b23..37208de 100644 --- a/tests/test_syntax.py +++ b/tests/test_syntax.py @@ -215,6 +215,24 @@ def test_generic_dialect_std(parser: Optional[Parser] = None): module = parser.parse(code) print(module.pretty()) +def test_generic_dialect_std_cond_br(parser: Optional[Parser] = None): + code = ''' +"module"() ( { +"func"() ( { +^bb0(%arg0: i32): // no predecessors + %c1_i32 = "std.constant"() {value = 1 : i32} : () -> i32 + %0 = "std.cmpi"(%arg0, %c1_i32) {predicate = 3 : i64} : (i32, i32) -> i1 + "std.cond_br"(%0)[^bb1, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> () +^bb1: // pred: ^bb0 + "std.return"(%c1_i32) : (i32) -> () +^bb2: // pred: ^bb0 + "std.return"(%c1_i32) : (i32) -> () +}) {sym_name = "mlir_entry", type = (i32) -> i32} : () -> () +}) : () -> () + ''' + parser = parser or Parser() + module = parser.parse(code) + print(module.pretty()) def test_generic_dialect_llvm(parser: Optional[Parser] = None): code = ''' @@ -255,5 +273,6 @@ def test_integer_sign(parser: Optional[Parser] = None): test_affine(p) test_definitions(p) test_generic_dialect_std(p) + test_generic_dialect_std_cond_br(p) test_generic_dialect_llvm(p) test_integer_sign(p)