Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing GPTQ #147

Open
wants to merge 1 commit into
base: gh/HDCharles/7/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 122 additions & 1 deletion GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,127 @@ def cuda(self):
self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values]


class GPTQMultiTensor(torch.Tensor):
"""
"""
# todo need default shape/dtype
@staticmethod
def __new__(cls, input, **kwargs):
kwargs["dtype"]=kwargs.get("dtype", input.dtype)
shape = kwargs.pop("shape", input.shape)
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)

def __init__(self, input, **kwargs):
self.values = []
self.append(inp)
self.debug = False


def append(self, input)
if isinstance(input, (tuple, list)):
for inp in input:
self.values.append(inp)
elif isinstance(input, torch.Tensor):
self.values(input)

# def __add__(self, other):
# for val in other.values:
# self.append(val)

def count(self):
return len(self.values)

def cuda(self):
self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values]

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None, skip_quant=False)
def tensors_to_cuda(args):
new_args = []
for x in args:
new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x)
return new_args

kwargs = {} if kwargs is None else kwargs
# combine args and kwargs
flat_args, spec = tree_flatten((args, kwargs))
# move single tensors to cuda
flat_args = tensors_to_cuda(flat_args)
# size of biggest MultiTensor
multi_tensor_size = max(
[x.count() if isinstance(x, GPTQMultiTensor) else 1 for x in flat_args]
)
# convert [a, MultiTensor(b,b,b), MultiTensor(c,c,c)] => [a,b,c], [a,b,c] [a,b,c]
grouped_args = list(
zip(
*[x.values if isinstance(x, GPTQMultiTensor) else [x] * multi_tensor_size for x in flat_args]
)
)

quantize_linear = (
func is nn.functional.linear
# and id(args[1]) in self.id_to_name
and not skip_quant
# and not (self.skip_layer_func)
)

# run function for each of the multitensors and return a multitensor
if not quantize_linear:
outputs = []
for inp in transposed_args:
inp = tensors_to_cuda(inp)
cur_args, cur_kwargs = tree_unflatten(inp, spec)
with torch._C.DisableTorchFunctionSubclass():
out = func(*cur_args, **cur_kwargs)
outputs.append(out.cpu() if isinstance(out, torch.Tensor) else out)
return cls(outputs)

total_batches = 0
H=0
for inp in transposed_args:
inp = tensors_to_cuda(inp)
cur_args, cur_kwargs = tree_unflatten(inp, spec)
x = cur_args[0].float()
shape = x.shape
n = 1 if len(shape) == 2 else shape[0]
H*= total_batches / (total_batches + n)
total_batches += n
x = (
(2 / total_batches) ** (1 / 2) *
x.reshape(-1, shape[-1]).t().float()

)
H += x.matmul(x.t())
W = args[1].to(H.device)
Q, DQ, qparams = args[0].faster_quant(H, W.detach())

new_out = func(args[0], DQ, *args[2:], kwargs, skip_quant = True)
if args[0].debug:
breakpoint()
return new_out



if func is torch.nn.functional.linear:

inputs, weight, bias = (
args[0],
args[1],
args[2] if len(args)>2 else None
)
if quantize_linear:
cls.do_gptq(input, weight)
return func(mat1, w_autoquant.weight, bias)
try:
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
except:
print(f"ERR: subclass doesn't implement {func}")





class GenericGPTQRunner(fx.Interpreter):
"""
This is a generic GPTQ runner that takes an existing model and applies GPTQ.
Expand All @@ -150,7 +271,7 @@ def __init__(
}

# trace model for one input
one_input = [multi.values[0].cpu() for multi in inputs]
one_input = tuple([multi.values[0].cpu() for multi in inputs])
exported_model = torch._dynamo.export(
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
)(*one_input)
Expand Down
20 changes: 20 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf

# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working
# echo "base"
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5
python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5

# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
# echo "quant good"

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5

# export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5