Skip to content

Commit

Permalink
Refactor 'func' dialect ops out of the standard dialect (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
blaine-fs authored Mar 23, 2023
1 parent 22bb0c5 commit dde3e93
Show file tree
Hide file tree
Showing 16 changed files with 160 additions and 97 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ ast1 = mlir.parse_path('/path/to/file.mlir')
ast2 = mlir.parse_file(open('/path/to/file.mlir', 'r'))
ast3 = mlir.parse_string('''
module {
func @toy_func(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> {
func.func @toy_func(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> {
%t_tensor = "toy.transpose"(%tensor) { inplace = true } : (tensor<2x3xf64>) -> tensor<3x2xf64>
return %t_tensor : tensor<3x2xf64>
}
Expand Down Expand Up @@ -150,7 +150,7 @@ print(mlirfile.dump())
prints:
```mlir
module {
func @hello_world(%a: f64, %b: f64) {
func.func @hello_world(%a: f64, %b: f64) {
%_pymlir_ssa = addf %a , %b : f64
return %_pymlir_ssa : f64
}
Expand Down
8 changes: 4 additions & 4 deletions mlir/astnodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,11 +641,11 @@ class Function(Node):
location: Optional[Location] = None

def dump(self, indent=0) -> str:
result = 'func'
result = 'func.func'
result += ' %s' % self.name.dump(indent)
if self.args:
result += '(%s)' % ', '.join(
dump_or_value(arg, indent) for arg in self.args)
arg_list = self.args if self.args else []
result += '(%s)' % ', '.join(
dump_or_value(arg, indent) for arg in arg_list)
if self.result_types:
if not isinstance(self.result_types, list):
result += ' -> ' + dump_or_value(self.result_types, indent)
Expand Down
24 changes: 16 additions & 8 deletions mlir/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import mlir.astnodes as mast
import mlir.dialects.standard as std
import mlir.dialects.affine as affine
import mlir.dialects.func as func
from typing import Optional, Tuple, Union, List, Any
from contextlib import contextmanager
from mlir.builder.match import Reads, Writes, Isa, All, And, Or, Not # noqa: F401
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(self):

self._dialects = {
"affine": AffineBuilder(self),
"func": FuncBuilder(self),
"std": self, # std dialect ops can also be globally referenced
}

Expand Down Expand Up @@ -460,14 +462,6 @@ def float_constant(self, value: float, type: mast.FloatType,

# }}}

def ret(self, values: Optional[List[mast.SsaId]] = None,
types: Optional[List[mast.Type]] = None):

op = std.ReturnOperation(match=0, values=values, types=types)
self._insert_op_in_block([], op)
self.block = None
self.position = 0


@dataclass
class DialectBuilder:
Expand Down Expand Up @@ -528,5 +522,19 @@ def store(self, address: mast.SsaId, memref: mast.SsaId,
type=memref_type)
self.core_builder._insert_op_in_block([], op)

class FuncBuilder(DialectBuilder):
"""
Func dialect ops builder.
.. automethod:: ret
"""
def ret(self, values: Optional[List[mast.SsaId]] = None,
types: Optional[List[mast.Type]] = None):

op = func.ReturnOperation(match=0, values=values, types=types)
self.core_builder._insert_op_in_block([], op)
self.block = None
self.position = 0


# vim: fdm=marker
3 changes: 2 additions & 1 deletion mlir/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .standard import standard as std_dialect
from .scf import scf as scf_dialect
from .linalg import linalg
from .func import func as func_dialect


STANDARD_DIALECTS = [affine_dialect, std_dialect, scf_dialect, linalg]
STANDARD_DIALECTS = [affine_dialect, std_dialect, scf_dialect, linalg, func_dialect]
60 changes: 60 additions & 0 deletions mlir/dialects/func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

import inspect
import sys
from typing import List, Tuple, Optional, Union
from dataclasses import dataclass

import mlir.astnodes as mast
from mlir.dialect import Dialect, DialectOp, is_op

Literal = Union[mast.StringLiteral, float, int, bool]
SsaUse = Union[mast.SsaId, Literal]

@dataclass
class CallIndirectOperation(DialectOp):
func: mast.SymbolRefId
type: mast.FunctionType
args: Optional[List[SsaUse]] = None
argtypes: Optional[List[mast.Type]] = None
_syntax_ = ['func.call_indirect {func.symbol_ref_id} () : {type.function_type}',
'func.call_indirect {func.symbol_ref_id} ( {args.ssa_use_list} ) : {type.function_type}']


@dataclass
class CallOperation(DialectOp):
func: mast.SymbolRefId
type: mast.FunctionType
args: Optional[List[SsaUse]] = None
argtypes: Optional[List[mast.Type]] = None
_syntax_ = ['func.call {func.symbol_ref_id} () : {type.function_type}',
'func.call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {argtypes.function_type}']

@dataclass
class ConstantOperation(DialectOp):
value: mast.SymbolRefId
type: mast.Type
_syntax_ = ['func.constant {value.symbol_ref_id} : {type.type}']

# Note: The 'func.func' operation is defined as 'function' in mlir.lark.

@dataclass
class ReturnOperation(DialectOp):
values: Optional[List[SsaUse]] = None
types: Optional[List[mast.Type]] = None
_syntax_ = ['return',
'return {values.ssa_use_list} : {types.type_list_no_parens}']

def dump(self, indent: int = 0) -> str:
output = 'return'
if self.values:
output += ' ' + ', '.join([v.dump(indent) for v in self.values])
if self.types:
output += ' : ' + ', '.join([t.dump(indent) for t in self.types])

return output



# Inspect current module to get all classes defined above
func = Dialect('func', ops=[m[1] for m in inspect.getmembers(
sys.modules[__name__], lambda obj: is_op(obj, __name__))])
36 changes: 0 additions & 36 deletions mlir/dialects/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,43 +30,7 @@ class CondBrOperation(DialectOp):
_syntax_ = ['cond_br {cond.ssa_use} , {block_true.block_id} , {block_false.block_id}']


@dataclass
class ReturnOperation(DialectOp):
values: Optional[List[SsaUse]] = None
types: Optional[List[mast.Type]] = None
_syntax_ = ['return',
'return {values.ssa_use_list} : {types.type_list_no_parens}']
def dump(self, indent: int = 0) -> str:
output = 'return'
if self.values:
output += ' ' + ', '.join([v.dump(indent) for v in self.values])
if self.types:
output += ' : ' + ', '.join([t.dump(indent) for t in self.types])

return output


# Core Operations
@dataclass
class CallOperation(DialectOp):
func: mast.SymbolRefId
type: mast.FunctionType
args: Optional[List[SsaUse]] = None
argtypes: Optional[List[mast.Type]] = None
_syntax_ = ['call {func.symbol_ref_id} () : {type.function_type}',
'call {func.symbol_ref_id} ( {args.ssa_use_list} ) : {argtypes.function_type}']


@dataclass
class CallIndirectOperation(DialectOp):
func: mast.SymbolRefId
type: mast.FunctionType
args: Optional[List[SsaUse]] = None
argtypes: Optional[List[mast.Type]] = None
_syntax_ = ['call_indirect {func.symbol_ref_id} () : {type.function_type}',
'call_indirect {func.symbol_ref_id} ( {args.ssa_use_list} ) : {type.function_type}']


@dataclass
class DimOperation(DialectOp):
operand: mast.SsaId
Expand Down
2 changes: 1 addition & 1 deletion mlir/lark/mlir.lark
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ function_result_list_parens : ("(" ")") | ("(" function_result_list_no_parens ")

// Definition
module : "module" optional_symbol_ref_id optional_func_mod_attrs region optional_trailing_loc
function : "func" symbol_ref_id "(" optional_arg_list ")" optional_fn_result_list optional_func_mod_attrs optional_fn_body optional_trailing_loc
function : "func.func" symbol_ref_id "(" optional_arg_list ")" optional_fn_result_list optional_func_mod_attrs optional_fn_body optional_trailing_loc
generic_module : string_literal "(" optional_arg_list ")" "(" region ")" optional_attr_dict trailing_type optional_trailing_loc

// ----------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
lark-parser==0.7.8
parse==1.14.0
pytest
6 changes: 3 additions & 3 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def test_saxpy_builder():
axpyi = builder.addf(builder.affine.load(y, i, Mref1D), axi, F64)
builder.affine.store(axpyi, y, i, Mref1D)

builder.ret()
builder.func.ret()

print(mlirfile.dump())


def test_query():
block = parse_string("""
func @saxpy(%a : f64, %x : memref<?xf64>, %y : memref<?xf64>) {
func.func @saxpy(%a : f64, %x : memref<?xf64>, %y : memref<?xf64>) {
%c0 = constant 0 : index
%n = dim %x, %c0 : memref<?xf64>
Expand Down Expand Up @@ -96,7 +96,7 @@ def index(expr):
with builder.goto_after(Reads(b0) & Isa(AddfOperation)):
builder.addf(c0, c1, F64)

builder.ret()
builder.func.ret()

assert index(Reads(b0)) == 0
assert index(Reads(c0)) == 1
Expand Down
3 changes: 2 additions & 1 deletion tests/test_custom_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mlir import parse_string
from mlir.astnodes import Node, dump_or_value, SsaId, StringLiteral, TensorType, MemRefType, Dimension
from mlir.dialect import Dialect, DialectOp, DialectType
from mlir.dialects.func import func
from dataclasses import dataclass


Expand Down Expand Up @@ -77,7 +78,7 @@ class DensifyOp(DialectOp):

def test_custom_dialect():
code = '''module {
func @toy_test(%ragged: !toy.ragged<coo+csr, 32x14xf64>) -> tensor<32x14xf64> {
func.func @toy_test(%ragged: !toy.ragged<coo+csr, 32x14xf64>) -> tensor<32x14xf64> {
%t_tensor = toy.densify %ragged : tensor<32x14xf64>
return %t_tensor : tensor<32x14xf64>
}
Expand Down
27 changes: 14 additions & 13 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import mlir
from mlir.dialects.func import func

# All source strings have been taken from MLIR's codebase.
# See llvm-project/mlir/test/Dialect/Linalg
Expand All @@ -11,7 +12,7 @@ def assert_roundtrip_equivalence(source):

def test_batch_matmul():
assert_roundtrip_equivalence("""module {
func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>, %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
func.func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>, %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
linalg.batch_matmul ins( %a3 , %b3 : memref<?x?x?xf32> , memref<?x?x?xf32> ) outs( %c3 : memref<?x?x?xf32> )
linalg.batch_matmul ins( %ta3 , %tb3 : tensor<?x?x?xf32> , tensor<?x?x?xf32> ) outs( %c3 : memref<?x?x?xf32> )
%res1 = linalg.batch_matmul ins( %ta3 , %tb3 : tensor<?x?x?xf32> , tensor<?x?x?xf32> ) init( %tc3 : tensor<?x?x?xf32> ) -> tensor<?x?x?xf32>
Expand All @@ -23,15 +24,15 @@ def test_batch_matmul():

def test_conv():
assert_roundtrip_equivalence("""module {
func @conv1d_no_symbols(%in: memref<?xf32>, %filter: memref<?xf32>, %out: memref<?xf32>) {
func.func @conv1d_no_symbols(%in: memref<?xf32>, %filter: memref<?xf32>, %out: memref<?xf32>) {
linalg.conv_1d ins( %in , %filter : memref<?xf32> , memref<?xf32> ) outs( %out : memref<?xf32> )
return
}
func @conv2d_no_symbols(%in: memref<?x?xf32>, %filter: memref<?x?xf32>, %out: memref<?x?xf32>) {
func.func @conv2d_no_symbols(%in: memref<?x?xf32>, %filter: memref<?x?xf32>, %out: memref<?x?xf32>) {
linalg.conv_2d ins( %in , %filter : memref<?x?xf32> , memref<?x?xf32> ) outs( %out : memref<?x?xf32> )
return
}
func @conv3d_no_symbols(%in: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %out: memref<?x?x?xf32>) {
func.func @conv3d_no_symbols(%in: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %out: memref<?x?x?xf32>) {
linalg.conv_3d ins( %in , %filter : memref<?x?x?xf32> , memref<?x?x?xf32> ) outs( %out : memref<?x?x?xf32> )
return
}
Expand All @@ -40,11 +41,11 @@ def test_conv():

def test_copy():
assert_roundtrip_equivalence("""module {
func @copy_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>) {
func.func @copy_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>) {
linalg.copy( %arg0 , %arg1 ) : memref<?xf32, offset: ?, strides: [1]> , memref<?xf32, offset: ?, strides: [1]>
return
}
func @copy_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
func.func @copy_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.copy( %arg0 , %arg1 ) {inputPermutation = affine_map<(i, j, k) -> (i, k, j)>, outputPermutation = affine_map<(i, j, k) -> (k, j, i)>} : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> , memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
return
}
Expand All @@ -53,7 +54,7 @@ def test_copy():

def test_dot():
assert_roundtrip_equivalence("""module {
func @dot(%arg0: memref<?xi8>, %M: index) {
func.func @dot(%arg0: memref<?xi8>, %M: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%1 = view %arg0 [ %c0 ] [ %M ] : memref<?xi8> to memref<?xf32>
Expand All @@ -67,7 +68,7 @@ def test_dot():

def test_fill():
assert_roundtrip_equivalence("""module {
func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
func.func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
linalg.fill( %arg0 , %arg1 ) : memref<?xf32, offset: ?, strides: [1]> , f32
return
}
Expand All @@ -76,7 +77,7 @@ def test_fill():

def test_generic():
assert_roundtrip_equivalence("""module {
func @example(%A: memref<?x?xf64>, %B: memref<?x?xf64>, %C: memref<?x?xf64>) {
func.func @example(%A: memref<?x?xf64>, %B: memref<?x?xf64>, %C: memref<?x?xf64>) {
linalg.generic {indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"]} ins( %A, %B : memref<?x?xf64>, memref<?x?xf64> ) outs( %C : memref<?x?xf64> ) {
^bb0 (%a: f64, %b: f64, %c: f64):
%c0 = constant 3.14 : f64
Expand All @@ -90,7 +91,7 @@ def test_generic():

def test_indexed_generic():
assert_roundtrip_equivalence("""module {
func @indexed_generic_region(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
func.func @indexed_generic_region(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.indexed_generic {args_in = 1, args_out = 2, iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (i, j, k)>, affine_map<(i, j, k) -> (i, k, j)>], library_call = "some_external_function_name_2", doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"} ins( %arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]> ) outs( %arg1, %arg2 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> ) {
^bb0 (%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
%result_1 = mulf %a , %b : f32
Expand All @@ -108,7 +109,7 @@ def test_indexed_generic():

def test_view():
assert_roundtrip_equivalence("""module {
func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
func.func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
%c0 = constant 0 : index
%0 = muli %arg0 , %arg0 : index
%1 = alloc (%0) : memref<?xi8>
Expand All @@ -127,7 +128,7 @@ def test_view():

def test_matmul():
assert_roundtrip_equivalence("""module {
func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
func.func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%A = view %arg0 [ %c0 ] [ %M, %K ] : memref<?xi8> to memref<?x?xf32>
Expand All @@ -141,7 +142,7 @@ def test_matmul():

def test_matvec():
assert_roundtrip_equivalence("""module {
func @matvec(%arg0: memref<?xi8>, %M: index, %N: index) {
func.func @matvec(%arg0: memref<?xi8>, %M: index, %N: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%2 = view %arg0 [ %c0 ] [ %M, %N ] : memref<?xi8> to memref<?x?xf32>
Expand Down
Loading

0 comments on commit dde3e93

Please sign in to comment.