Skip to content

Commit

Permalink
Update memref strided layout syntax (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
amanda849 authored Oct 17, 2023
1 parent dde3e93 commit e16a6a2
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 22 deletions.
9 changes: 6 additions & 3 deletions mlir/astnodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,15 @@ class MemRefType(Type):

@dataclass
class StridedLayout(Node):
strided: Optional[List[int]] = None
offset: int = 0
strides: Optional[List[int]] = None

def dump(self, indent: int = 0) -> str:
return 'offset: %s, strides: [%s]' % (dump_or_value(
self.offset, indent), dump_or_value(self.strides, indent))
result = 'strided<[%s]' % dump_or_value(self.strided, indent)
if self.offset is not None:
result += ', offset: %s' % dump_or_value(self.offset, indent)
result += '>'
return result


@dataclass
Expand Down
14 changes: 7 additions & 7 deletions mlir/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,27 +200,27 @@ def MemRefType(self,
dtype: mast.Type,
shape: Optional[Tuple[Optional[int], ...]],
offset: Optional[int] = None,
strides: Optional[Tuple[Optional[int], ...]] = None
strided: Optional[Tuple[Optional[int], ...]] = None
) -> mast.MemRefType:
"""
Returns an instance of :class:`mlir.astnodes.UnrankedMemRefType` if shape is
*None*, else returns a :class:`mlir.astnodes.RankedMemRefType`.
"""
if shape is None:
assert strides is None
assert strided is None
return mast.UnrankedMemRefType(dtype)
else:
shape = tuple(mast.Dimension(dim) for dim in shape)
if strides is None and offset is None:
if strided is None and offset is None:
layout = None
else:
if offset is None:
offset = 0
if strides is not None:
if len(shape) != len(strides):
raise ValueError("shapes and strides must be of tuples"
if strided is not None:
if len(shape) != len(strided):
raise ValueError("shapes and strided must be of tuples"
" of same dimensionality.")
layout = mast.StridedLayout(strides, offset)
layout = mast.StridedLayout(strided, offset)

return mast.RankedMemRefType(shape, dtype, layout)

Expand Down
2 changes: 1 addition & 1 deletion mlir/lark/mlir.lark
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ tensor_type : ranked_tensor_type | unranked_tensor_type

// Memref type
stride_list : "[" (dimension ("," dimension)*)? "]"
strided_layout : "offset:" dimension "," "strides: " stride_list
strided_layout : "strided" "<" stride_list ("," "offset:" dimension)? ">"
?layout_specification : semi_affine_map | strided_layout
?memory_space : integer_literal // | TODO(mlir): address_space_id
ranked_memref_type : "memref" "<" dimension_list_ranked tensor_memref_element_type optional_layout_specification optional_memory_space ">"
Expand Down
20 changes: 10 additions & 10 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def test_conv():

def test_copy():
assert_roundtrip_equivalence("""module {
func.func @copy_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>) {
linalg.copy( %arg0 , %arg1 ) : memref<?xf32, offset: ?, strides: [1]> , memref<?xf32, offset: ?, strides: [1]>
func.func @copy_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: memref<?xf32, strided<[1], offset: ?>>) {
linalg.copy( %arg0 , %arg1 ) : memref<?xf32, strided<[1], offset: ?>> , memref<?xf32, strided<[1], offset: ?>>
return
}
func.func @copy_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.copy( %arg0 , %arg1 ) {inputPermutation = affine_map<(i, j, k) -> (i, k, j)>, outputPermutation = affine_map<(i, j, k) -> (k, j, i)>} : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> , memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
func.func @copy_view3(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>, %arg1: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
linalg.copy( %arg0 , %arg1 ) {inputPermutation = affine_map<(i, j, k) -> (i, k, j)>, outputPermutation = affine_map<(i, j, k) -> (k, j, i)>} : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> , memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>
return
}
}""")
Expand All @@ -68,8 +68,8 @@ def test_dot():

def test_fill():
assert_roundtrip_equivalence("""module {
func.func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
linalg.fill( %arg0 , %arg1 ) : memref<?xf32, offset: ?, strides: [1]> , f32
func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32) {
linalg.fill( %arg0 , %arg1 ) : memref<?xf32, strided<[1], offset: ?>> , f32
return
}
}""")
Expand All @@ -91,8 +91,8 @@ def test_generic():

def test_indexed_generic():
assert_roundtrip_equivalence("""module {
func.func @indexed_generic_region(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.indexed_generic {args_in = 1, args_out = 2, iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (i, j, k)>, affine_map<(i, j, k) -> (i, k, j)>], library_call = "some_external_function_name_2", doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"} ins( %arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]> ) outs( %arg1, %arg2 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> ) {
func.func @indexed_generic_region(%arg0: memref<?x?xf32, strided<[?, 1], offset: ?>>, %arg1: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>, %arg2: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
linalg.indexed_generic {args_in = 1, args_out = 2, iterator_types = ["parallel", "parallel", "parallel"], indexing_maps = [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (i, j, k)>, affine_map<(i, j, k) -> (i, k, j)>], library_call = "some_external_function_name_2", doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"} ins( %arg0 : memref<?x?xf32, strided<[?, 1], offset: ?>> ) outs( %arg1, %arg2 : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>, memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> ) {
^bb0 (%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
%result_1 = mulf %a , %b : f32
%ij = addi %i , %j : index
Expand All @@ -116,8 +116,8 @@ def test_view():
%2 = linalg.range %arg0 : %arg1 : %arg 2 : !linalg.range
%3 = view %1 [ %c0 ] [ %arg0, %arg0 ] : memref<?xi8> to memref<?x?xf32>
%4 = linalg.slice %3 [ %2, %2 ] : memref<?x?xf32> , !linalg.range, !linalg.range , memref<?x?xf32>
%5 = linalg.slice %3 [ %2, %arg2 ] : memref<?x?xf32> , !linalg.range, index , memref<?xf32, offset: ?, strides: [1]>
%6 = linalg.slice %3 [ %arg2, %2 ] : memref<?x?xf32> , index, !linalg.range , memref<?xf32, offset: ?, strides: [1]>
%5 = linalg.slice %3 [ %2, %arg2 ] : memref<?x?xf32> , !linalg.range, index , memref<?xf32, strided<[1], offset: ?>>
%6 = linalg.slice %3 [ %arg2, %2 ] : memref<?x?xf32> , index, !linalg.range , memref<?xf32, strided<[1], offset: ?>>
%7 = linalg.slice %3 [ %arg2, %arg3 ] : memref<?x?xf32> , index, index , memref<f32>
%8 = view %1 [ %c0 ] [ %arg0, %arg0 ] : memref<?xi8> to memref<?x?xvector<4x4xf32>>
dealloc %1 : memref<?xi8>
Expand Down
2 changes: 1 addition & 1 deletion tests/test_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_memrefs(parser: Optional[Parser] = None):
module {
func.func @myfunc() {
%a, %b = "tensor_replicator"(%tensor, %tensor) : (memref<?xbf16, 2>,
memref<?xf32, offset: 5, strides: [6, 7]>,
memref<?xf32, strided<[6, 7], offset: 5>>,
memref<*xf32, 8>)
}
}
Expand Down

0 comments on commit e16a6a2

Please sign in to comment.