Skip to content

Commit

Permalink
Refactor test_general_matmul_tilelang_impl.py and test_general_matmul…
Browse files Browse the repository at this point in the history
…_tilelang_kernel.py to use centered random values for input tensors
  • Loading branch information
LeiWang1999 committed Oct 2, 2024
1 parent 826255d commit 48dc94e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions testing/python/operators/test_general_matmul_tilelang_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def assert_matmul_blocked_correctness(M,
# src_code is the generated cuda source
assert src_code is not None

A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))

mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer)
Expand Down Expand Up @@ -111,8 +111,8 @@ def assert_matmul_macro_tensorcore_correctness(
# src_code represents generated cuda source
assert src_code is not None

A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))

mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer)
Expand Down Expand Up @@ -170,8 +170,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness(
# src_code is the generated cuda source
assert src_code is not None

A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))

ladder_permutate_config = bitblas.ops.LadderPermutateConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ def assert_matmul_fine_grained_apply_config_correctness(
# src_code is the generated cuda source
assert src_code is not None

A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))

mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer)
Expand Down

0 comments on commit 48dc94e

Please sign in to comment.