Skip to content

Commit

Permalink
Adapt quant lm head (#1671)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Chang <[email protected]>
  • Loading branch information
changwangss authored Aug 1, 2024
1 parent 3e78ae8 commit b400cb9
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,6 @@
help="Use determined group to do quantization",
)
# ============AutoRound==================
parser.add_argument(
"--autoround_iters",
default=2048,
type=int,
help="Calibration dataset max or padding max length for AutoRound.",
)
parser.add_argument(
"--lr",
type=float,
Expand Down Expand Up @@ -172,7 +166,6 @@
bits=args.bits,
sym=True if args.scheme == "sym" else False,
group_size=args.group_size,
seq_len=args.seq_len,
compute_dtype=args.compute_dtype,
scale_dtype=args.compute_dtype,
weight_dtype=args.weight_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,12 @@ def replace_linear(
if modules_to_not_convert is None:
# output_layer is chatglm last layer name
# embed_out is dolly_v2 last layer name
modules_to_not_convert = ["lm_head", "output_layer", "embed_out"]
modules_to_not_convert = []
if quantization_config.llm_int8_skip_modules:
modules_to_not_convert = modules_to_not_convert.extend(
modules_to_not_convert.extend(
quantization_config.llm_int8_skip_modules
)
modules_to_not_convert = list(set(modules_to_not_convert))
model, is_replaced = _replace_linear(
model,
modules_to_not_convert,
Expand Down Expand Up @@ -559,9 +560,11 @@ def convert_to_quantized_model(model, config, device="cpu"):
group_size=config.group_size,
use_layer_wise=config.layer_wise,
)
quant_config.set_local(".*lm_head", RTNConfig(dtype="fp32"))
quant_config.set_local(".*output_layer", RTNConfig(dtype="fp32"))
quant_config.set_local(".*embed_out", RTNConfig(dtype="fp32"))
if config.llm_int8_skip_modules != []:
for module in config.llm_int8_skip_modules:
module_name = ".*" + module
quant_config.set_local(module_name, RTNConfig(dtype="fp32"))
logger.info(f"Do RTN algorithm with config {quant_config}")
model = prepare(model, quant_config)
model = convert(model)
elif config.quant_method.value == "awq":
Expand All @@ -575,9 +578,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
use_auto_clip=config.auto_clip,
folding=True,
)
quant_config.set_local(".*lm_head", AWQConfig(dtype="fp32"))
quant_config.set_local(".*output_layer", AWQConfig(dtype="fp32"))
quant_config.set_local(".*embed_out", AWQConfig(dtype="fp32"))
if config.llm_int8_skip_modules != []:
for module in config.llm_int8_skip_modules:
module_name = ".*" + module
quant_config.set_local(module_name, AWQConfig(dtype="fp32"))
logger.info(f"Do AWQ algorithm with config {quant_config}")
run_fn = default_run_fn
run_args = (
Expand All @@ -601,9 +605,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
use_layer_wise=config.layer_wise,
absorb_to_layer=config.absorb_to_layer
)
quant_config.set_local(".*lm_head", TEQConfig(dtype="fp32"))
quant_config.set_local(".*output_layer", TEQConfig(dtype="fp32"))
quant_config.set_local(".*embed_out", TEQConfig(dtype="fp32"))
if config.llm_int8_skip_modules != []:
for module in config.llm_int8_skip_modules:
module_name = ".*" + module
quant_config.set_local(module_name, TEQConfig(dtype="fp32"))
logger.info(f"Do TEQ algorithm with config {quant_config}")
run_fn = default_run_fn
run_args = (
Expand Down Expand Up @@ -632,9 +637,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
block_size=config.blocksize,
static_groups=config.static_groups,
)
quant_config.set_local(".*lm_head", GPTQConfig(dtype="fp32"))
quant_config.set_local(".*output_layer", GPTQConfig(dtype="fp32"))
quant_config.set_local(".*embed_out", GPTQConfig(dtype="fp32"))
if config.llm_int8_skip_modules != []:
for module in config.llm_int8_skip_modules:
module_name = ".*" + module
quant_config.set_local(module_name, GPTQConfig(dtype="fp32"))
logger.info(f"Do GPTQ algorithm with config {quant_config}")
run_fn = default_run_fn
run_args = (
Expand Down Expand Up @@ -662,10 +668,10 @@ def convert_to_quantized_model(model, config, device="cpu"):
iters=config.iters,
scale_dtype=config.scale_dtype,
)
if config.quant_lm_head is False:
quant_config.set_local(".*lm_head", AutoRoundConfig(dtype="fp32"))
quant_config.set_local(".*output_layer", AutoRoundConfig(dtype="fp32"))
quant_config.set_local(".*embed_out", AutoRoundConfig(dtype="fp32"))
if config.llm_int8_skip_modules != []:
for module in config.llm_int8_skip_modules:
module_name = ".*" + module
quant_config.set_local(module_name, AutoRoundConfig(dtype="fp32"))
logger.info(f"Do AutoRound algorithm with config {quant_config}")
dataloader = get_autoround_dataloader(tokenizer=config.tokenizer,
seqlen=config.seq_len,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def build_woq_model(model, quantization_config):
from neural_compressor.adaptor.torch_utils.util import set_module
weight_dtype = quantization_config.weight_dtype
for n, m in model.named_modules():
if "lm_head" in n or "output_layer" in n or "embed_out" in n:
if n in quantization_config.llm_int8_skip_modules:
continue
if isinstance(m, torch.nn.Linear):
zp = getattr(
Expand Down Expand Up @@ -883,6 +883,7 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]:
hasattr(torch, "xpu") and torch.xpu.is_available()
), "There is no xpu device in this system!"
quantization_config.update(**{"device": "xpu"})
quantization_config.post_init_xpu()
if (
not torch.cuda.is_available()
or device_map == "cpu"
Expand Down
12 changes: 7 additions & 5 deletions intel_extension_for_transformers/transformers/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ def __init__(
self.double_quant_bits = double_quant_bits
self.double_quant_use_sym = double_quant_use_sym
self.double_quant_group_size = double_quant_group_size
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
self.use_ggml = use_ggml
self.use_quant = use_quant
self.use_neural_speed = use_neural_speed
Expand Down Expand Up @@ -911,7 +911,7 @@ def __init__(
self.true_sequential = true_sequential
self.layer_wise = layer_wise
self.seq_len = seq_len
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
self.use_ggml = use_ggml
self.use_quant = use_quant
self.use_neural_speed = use_neural_speed
Expand Down Expand Up @@ -1009,7 +1009,7 @@ def __init__(
self.seq_len = seq_len
self.use_double_quant = use_double_quant
self.double_quant_scale_dtype = double_quant_scale_dtype
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
self.use_ggml = use_ggml
self.use_quant = use_quant
self.use_neural_speed = use_neural_speed
Expand Down Expand Up @@ -1078,7 +1078,7 @@ def __init__(
self.seq_len = seq_len
self.use_double_quant = use_double_quant
self.double_quant_scale_dtype = double_quant_scale_dtype
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
self.use_ggml = use_ggml
self.use_neural_speed = use_neural_speed
self.device = kwargs.get("device", "auto")
Expand Down Expand Up @@ -1154,7 +1154,9 @@ def __init__(
self.iters = iters
self.seq_len = seq_len
self.quant_lm_head = quant_lm_head
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", [])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
if self.quant_lm_head:
self.llm_int8_skip_modules = []
self.use_ggml = use_ggml
self.use_neural_speed = use_neural_speed
self.batch_size = kwargs.pop("batch_size", 8)
Expand Down

0 comments on commit b400cb9

Please sign in to comment.