Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix BitBLAS Linear with BFloat16 input #164

Merged
merged 8 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
22 changes: 10 additions & 12 deletions bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,18 +992,16 @@ def get_idx():
sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True)

dequantize_block_local = block_shared_local
if ("zeros_mode" in weight_decode_info and
weight_decode_info["zeros_mode"] == "quantized"):
if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]):
block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local")
sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True)
# pop the scale block
auto_inline_producers(sch, block_local_scales)

if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]):
block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local")
sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]):
block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local")
sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True)
# pop the scale block
auto_inline_producers(sch, block_local_scales)

if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]):
block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local")
sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)

for producer in weight_producers:
with suppress(Exception):
Expand Down
10 changes: 5 additions & 5 deletions bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,6 @@ def warmup(self, topk=20):
self.bitblas_matmul.hardware_aware_finetune(topk=topk)

def forward(self, A, output=None):
if A.dtype != torch.float16:
A = A.half()
A = self.bitblas_matmul.transform_input(A)
stream = torch.cuda.current_stream()

Expand All @@ -277,7 +275,9 @@ def forward(self, A, output=None):
args = [A_void, *self.q_params]
if output is None:
output = torch.empty(
A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device)
A.shape[:-1] + (self.out_features,),
dtype=getattr(torch, self.bitblas_matmul.out_dtype),
device=A.device)
args.append(ctypes.c_void_p(output.data_ptr()))
if self.bitblas_matmul.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
Expand Down Expand Up @@ -312,12 +312,12 @@ def load_and_transform_weight(
if bias is not None:
self.bias = bias

def repack_from_gptq(self, gptq_module):
def repack_from_gptq(self, gptq_module, device="cuda"):
# qweight in gptq old quant linear stored with (out_features, in_features), should be transposed.
qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE)
intweight = unpack_qweight(qweight, self.bits).contiguous()
if self.bitblas_matmul.weight_transform is not None:
qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).cuda()
qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).to(device)
self.qweight = qweight
# scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed.
scales = gptq_module.scales.T.contiguous().view(self.torch_dtype)
Expand Down
15 changes: 7 additions & 8 deletions bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
weight = weight.contiguous()
if self.W_dtype == self.A_dtype:
if self.weight_transform is not None:
return self.weight_transform(weight.cpu()).cuda().contiguous()
return self.weight_transform(weight.cpu()).to(weight.device).contiguous()
return weight

source_format, bit = self.source_format, self.bit
Expand All @@ -624,7 +624,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):

# Apply an optional weight transformation if specified
if self.weight_transform is not None:
weight = self.weight_transform(weight.cpu()).cuda().contiguous()
weight = self.weight_transform(weight.cpu()).to(weight.device).contiguous()

# Prepare the return list with the transformed weight and optionally include scale, zeros, and bias
result = [weight]
Expand Down Expand Up @@ -667,15 +667,14 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
args.append(bias)
args.append(output)

if self.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
args.append(m)

stream = torch.cuda.current_stream()

if self.lib is None:
self._forward_from_torch_func(*args)
else:
if self.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
args.append(m)

stream = torch.cuda.current_stream(device=A.device)
self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream)

return output
Expand Down
Loading