Skip to content

Commit

Permalink
Add linalg.fill_rng_2d (#41)
Browse files Browse the repository at this point in the history
Co-authored-by: Amanda Tang <[email protected]>
  • Loading branch information
amanda849 and Amanda Tang authored Oct 14, 2024
1 parent 8d78c03 commit 28b8ada
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
29 changes: 29 additions & 0 deletions mlir/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,35 @@ class LinalgFill(DialectOp):
" outs ( {out_id.ssa_id} : {out_type.type} )"
" -> {res_type.type}")]


@dataclass
class FillRng2DOp(DialectOp):
min_id: mast.SsaId
min_type: mast.Type
max_id: mast.SsaId
max_type: mast.Type
seed_id: mast.SsaId
seed_type: mast.Type
out_id: mast.SsaId
out_type: mast.Type
res_type: Optional[mast.Type] = None
attr: Optional[mast.Attribute] = None

_syntax_ = [("linalg.fill_rng_2d"
" ins ( {min_id.ssa_id} , {max_id.ssa_id} , {seed_id.ssa_id} : {min_type.type} , {max_type.type} , {seed_type.type} )"
" outs ( {out_id.ssa_id} : {out_type.type} )"),
("linalg.fill_rng_2d"
" ins ( {min_id.ssa_id} , {max_id.ssa_id} , {seed_id.ssa_id} : {min_type.type} , {max_type.type} , {seed_type.type} )"
" outs ( {out_id.ssa_id} : {out_type.type} )"
" {attr.attribute_value}"),
("linalg.fill_rng_2d"
" ins ( {min_id.ssa_id} , {max_id.ssa_id} , {seed_id.ssa_id} : {min_type.type} , {max_type.type} , {seed_type.type} )"
" outs ( {out_id.ssa_id} : {out_type.type} ) -> {res_type.type}"),
("linalg.fill_rng_2d"
" ins ( {min_id.ssa_id} , {max_id.ssa_id} , {seed_id.ssa_id} : {min_type.type} , {max_type.type} , {seed_type.type} )"
" outs ( {out_id.ssa_id} : {out_type.type} )"
" {attr.attribute_value} -> {res_type.type}")]


@dataclass
class LinalgGeneric(DialectOp):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ def test_fill():
}""")


def test_fill_rng_2d():
assert_roundtrip_equivalence("""module {
func.func @fill_rng_2d(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.fill_rng_2d ins ( %min , %max , %seed : f64 , f64 , i32 ) outs ( %O : tensor<16x32xf32> ) -> tensor<16x32xf32>
%1 = linalg.fill_rng_2d ins ( %min , %max , %seed : f64 , f64 , i32 ) outs ( %O : tensor<16x32xf32> )
return %1 : tensor<16x32xf32>
}
}""")


def test_generic():
assert_roundtrip_equivalence("""module {
func.func @example(%A: memref<?x?xf64>, %B: memref<?x?xf64>, %C: memref<?x?xf64>) {
Expand Down

0 comments on commit 28b8ada

Please sign in to comment.