Skip to content

Commit

Permalink
[BugFix] Fix BitBLAS Linear with BFloat16 input (#164)
Browse files Browse the repository at this point in the history
* Merge branch 'main' of https://github.com/microsoft/BitBLAS into main

* remove debug print

* Refactor Matmul class for improved readability and maintainability

* Refactor Matmul class for improved readability and maintainability

* revert set device

* lint fix

* register fp8 for dynamic

* Linear Fix
  • Loading branch information
LeiWang1999 authored Aug 30, 2024
1 parent 872d6d7 commit f284c32
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,6 @@ def warmup(self, topk=20):
self.bitblas_matmul.hardware_aware_finetune(topk=topk)

def forward(self, A, output=None):
if A.dtype != torch.float16:
A = A.half()
A = self.bitblas_matmul.transform_input(A)
stream = torch.cuda.current_stream()

Expand All @@ -277,7 +275,9 @@ def forward(self, A, output=None):
args = [A_void, *self.q_params]
if output is None:
output = torch.empty(
A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device)
A.shape[:-1] + (self.out_features,),
dtype=getattr(torch, self.bitblas_matmul.out_dtype),
device=A.device)
args.append(ctypes.c_void_p(output.data_ptr()))
if self.bitblas_matmul.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
Expand Down

0 comments on commit f284c32

Please sign in to comment.