From ccc12b07e1f970e6e60682f28929e8cad014c24e Mon Sep 17 00:00:00 2001 From: Mengni Wang Date: Wed, 10 Jul 2024 16:11:29 +0800 Subject: [PATCH 01/18] enhance layer-wise quant and fix bug Signed-off-by: Mengni Wang --- .../llama/quantization/weight_only/main.py | 6 +- .../algorithms/layer_wise/core.py | 65 ++++++---- .../algorithms/weight_only/awq.py | 13 +- .../algorithms/weight_only/gptq.py | 4 +- .../algorithms/weight_only/rtn.py | 2 +- onnx_neural_compressor/onnx_model.py | 121 ++++++++++++------ .../quantization/matmul_nbits_quantizer.py | 15 ++- .../quantization/quantize.py | 2 +- onnx_neural_compressor/quantization/tuning.py | 2 +- .../layer_wise/test_layer_wise.py | 83 +++++++++--- test/quantization/weight_only/test_awq.py | 53 +++++++- test/quantization/weight_only/test_gptq.py | 53 ++++++++ test/quantization/weight_only/test_rtn.py | 38 ++++++ 13 files changed, 352 insertions(+), 105 deletions(-) diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py index 572e1f010..174d2ccf1 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py +++ b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py @@ -326,16 +326,16 @@ def rewind(self): if args.tune: model_name = "model.onnx" # require optimum >= 1.14.0 model_path = os.path.join(args.model_path, model_name) - best_model = None if args.algorithm.upper() == "RTN": - algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig() + algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig(layer_wise_quant=True) quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( model_path, n_bits=4, block_size=32, is_symmetric=True, algo_config=algo_config, + optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, ) quant.process() best_model = quant.model @@ -358,7 +358,7 @@ def rewind(self): elif args.algorithm.upper() == "GPTQ": calibration_data_reader = GPTQDataloader(model_path, seqlen=args.seqlen, batch_size=1) algo_config = matmul_nbits_quantizer.GPTQWeightOnlyQuantConfig( - calibration_data_reader=calibration_data_reader, + calibration_data_reader=calibration_data_reader, layer_wise_quant=True ) quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( model_path, diff --git a/onnx_neural_compressor/algorithms/layer_wise/core.py b/onnx_neural_compressor/algorithms/layer_wise/core.py index b4a665f8c..6ff16dce2 100644 --- a/onnx_neural_compressor/algorithms/layer_wise/core.py +++ b/onnx_neural_compressor/algorithms/layer_wise/core.py @@ -46,17 +46,15 @@ def layer_wise_quant( Returns: _type_: _description_ """ - # check whether model shape is inferred - if not _check_model_with_infer_shapes(model): - logger.error( - "Before applying layer-wise quantization, please make sure to " - "run symbolic shape inference on your model like follows:\n" - "import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer\n" - "model = onnx.load(your_model_path)\n" - "out = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(model, auto_merge=True)\n" - "onnx.save(out, infer_shape_model_path)\n" - ) - raise ValueError("Fail to run layer-wise quantization.") + logger.warning( + "Layer-wise quantization requires data_type info for some tensors. " + "We will try to infer the data_type automatically if it doesn't exist." + "You can use model with symbolic shape inference before layer-wise quantization as well like follows:\n" + "import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer\n" + "model = onnx.load(your_model_path)\n" + "out = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(model, auto_merge=True)\n" + "onnx.save(out, infer_shape_model_path)\n" + ) if not isinstance(model, onnx_model.ONNXModel): model = onnx_model.ONNXModel(model, ignore_warning=True, load_external_data=False) @@ -110,6 +108,7 @@ def layer_wise_quant( split_model_part_1, split_model_part_2 = split_model.split_model_with_node( split_node.name, model.model_path, save_both_split_models ) + if not save_both_split_models: # append split_model_part_2 to do next split model_to_split.append(split_model_part_2) @@ -117,9 +116,11 @@ def layer_wise_quant( logger.info("Quantize split model {}".format(split_idx)) if require_data_reader: # process data_reader for current split and next split + current_data_reader = _filter_data_reader_for_current_split_model( - split_model_part_1.model, current_data_reader + split_model_part_1.model, current_data_reader, data_reader ) + # next_data_reader contains split_model_part_1 output data next_data_reader = _prepare_data_reader_for_next_split_model( split_model_part_1.model_path, current_data_reader, providers ) @@ -166,7 +167,7 @@ def layer_wise_quant( # process data_reader for current split current_data_reader = lwq_data_reader.pop(0) current_data_reader = _filter_data_reader_for_current_split_model( - split_model_part_2.model, current_data_reader + split_model_part_2.model, current_data_reader, data_reader ) # perform quantization @@ -204,7 +205,6 @@ def layer_wise_quant( onnx.external_data_helper.load_external_data_for_model( quantized_model_merged.model, os.path.dirname(quantized_model_merged.model_path) ) - return quantized_model_merged @@ -222,26 +222,46 @@ def rewind(self): self.iter_next = iter(self.data_list) -def _filter_data_reader_for_current_split_model(model: onnx.ModelProto, data_reader: data_reader.CalibrationDataReader): +def _filter_data_reader_for_current_split_model( + model: onnx.ModelProto, + current_data_reader: data_reader.CalibrationDataReader, + data_reader: data_reader.CalibrationDataReader, +): """Filter data reader to remove data that is not in model input. Args: model (onnx.ModelProto): onnx model. - data_reader (data_reader.CalibrationDataReader): data reader. + current_data_reader (data_reader.CalibrationDataReader): data reader of current split model. + data_reader (data_reader.CalibrationDataReader): data reader of the original model. Returns: data_reader.CalibrationDataReader: filtered data reader. """ filter_inputs = [] input_names = [input.name for input in model.graph.input] + current_data_reader.rewind() + data_reader.rewind() + while True: - inputs = data_reader.get_next() + inputs = current_data_reader.get_next() if not inputs: break filter_input = { input_name: input_tensor for input_name, input_tensor in inputs.items() if input_name in input_names } filter_inputs.append(filter_input) + + idx = 0 + while True: + inputs = data_reader.get_next() + if not inputs: + break + filter_input = { + input_name: input_tensor for input_name, input_tensor in inputs.items() if input_name in input_names + } + if len(filter_input) > 0: + filter_inputs[idx].update(filter_input) + idx += 1 return DataReader(filter_inputs) @@ -275,14 +295,3 @@ def _prepare_data_reader_for_next_split_model( inputs.update({name: value for name, value in zip(output_names, out)}) data_reader_for_next_split_model.append(inputs) return DataReader(data_reader_for_next_split_model) - - -def _check_model_with_infer_shapes(model): - """Check if the model has been shape inferred.""" - if isinstance(model, (pathlib.Path, str)): - model = onnx.load(model, load_external_data=False) - elif isinstance(model, onnx_model.ONNXModel): - model = model.model - if len(model.graph.value_info) > 0: - return True - return False diff --git a/onnx_neural_compressor/algorithms/weight_only/awq.py b/onnx_neural_compressor/algorithms/weight_only/awq.py index 889909d03..81d896288 100644 --- a/onnx_neural_compressor/algorithms/weight_only/awq.py +++ b/onnx_neural_compressor/algorithms/weight_only/awq.py @@ -148,8 +148,10 @@ def _apply_awq_scale(model, weight_config, absorb_pairs, output_dicts): if init_share_num == 1: model.remove_initializer(weight_tensor) + if parent is None: + continue parent = model.get_node(parent) - if parent.name in updated_nodes: + if parent is None or parent.name in updated_nodes: continue if parent.op_type in ["LayerNormalization", "BatchNormalization", "InstanceNormalization"] and len( @@ -363,8 +365,9 @@ def awq_quantize( output_name_to_node = model.output_name_to_node() input_name_to_nodes = model.input_name_to_nodes() for input_name in output_names: - parent = output_name_to_node[input_name] - dump_pairs = {parent.name: []} + # input_name maybe the input of graph and there is no parent node + parent = output_name_to_node[input_name].name if input_name in output_name_to_node else None + dump_pairs = {parent: []} for node in input_name_to_nodes[input_name]: # check op_type of node is MatMul @@ -375,9 +378,9 @@ def awq_quantize( and model.get_initializer(node.input[1]) is not None and weight_config.get(node.name, {}).get("weight_dtype", "fp32") != "fp32" ): - dump_pairs[parent.name].append(model.get_node(node.name)) + dump_pairs[parent].append(model.get_node(node.name)) - if len(dump_pairs[parent.name]) == 0: # pragma: no cover + if len(dump_pairs[parent]) == 0: # pragma: no cover continue output_dicts = {} diff --git a/onnx_neural_compressor/algorithms/weight_only/gptq.py b/onnx_neural_compressor/algorithms/weight_only/gptq.py index ff650b9db..b780cd81d 100644 --- a/onnx_neural_compressor/algorithms/weight_only/gptq.py +++ b/onnx_neural_compressor/algorithms/weight_only/gptq.py @@ -334,7 +334,7 @@ def gptq_quantize( k_blocks = (org_shape[0] + group_size - 1) // group_size q_weight = quant_utils.pad_tensor(q_weight, group_size, k_blocks) _, _, zp, scale, q_weight = quant_utils.quantize_data( - q_weight.T, + q_weight.T.reshape((-1, group_size)), "uint" + str(num_bits), sym, axis=1, @@ -345,7 +345,7 @@ def gptq_quantize( num_bits=num_bits, group_size=group_size, k_blocks=k_blocks, - q_weight=q_weight.astype("uint8"), + q_weight=q_weight, scale=scale.astype(dtype), zero_point=zp if not sym else None, accuracy_level=accuracy_level, diff --git a/onnx_neural_compressor/algorithms/weight_only/rtn.py b/onnx_neural_compressor/algorithms/weight_only/rtn.py index 18fdc1e47..58ba80a40 100644 --- a/onnx_neural_compressor/algorithms/weight_only/rtn.py +++ b/onnx_neural_compressor/algorithms/weight_only/rtn.py @@ -124,7 +124,7 @@ def rtn_quantize( num_bits=num_bits, group_size=group_size, k_blocks=k_blocks, - q_weight=q_weight.astype("uint8"), + q_weight=q_weight, scale=scale.astype(dtype), zero_point=zp if not sym else None, accuracy_level=accuracy_level, diff --git a/onnx_neural_compressor/onnx_model.py b/onnx_neural_compressor/onnx_model.py index 64040f2e9..8d141a464 100644 --- a/onnx_neural_compressor/onnx_model.py +++ b/onnx_neural_compressor/onnx_model.py @@ -480,21 +480,28 @@ def remove_duplicate_nodes(self): def remove_unused_nodes(self): """Remove unused nodes.""" unused_nodes = [] - nodes = self.nodes() + for node in self.model.graph.node: + # remove constant + if node.op_type == "Constant": + tensor = node.attribute[0].t + tensor.name = node.output[0] + self.add_initializer(tensor) + unused_nodes.append(node) + + # remove identity + if node.op_type == "Identity": + tensor = self.get_initializer(node.input[0]) + if tensor is not None: + new_tensor = copy.deepcopy(tensor) + new_tensor.name = node.output[0] + unused_nodes.append(node) + self.add_initializer(new_tensor) + self.remove_nodes(unused_nodes) if len(self._input_name_to_nodes) == 0: self._input_name_to_nodes = self.input_name_to_nodes() if len(self._output_name_to_node) == 0: self._output_name_to_node = self.output_name_to_node() - for node in nodes: - if ( - node.op_type == "Constant" - and node.output[0] not in self.model.graph.output - and node.output[0] not in self._input_name_to_nodes - ): - unused_nodes.append(node) - - self.remove_nodes(unused_nodes) unvalid_nodes = [ i @@ -795,18 +802,72 @@ def find_split_node_for_layer_wise_quantization(self): [None, 0, None, 0, 0], ), ] - if not start_node: - continue - if not any(qkv_nodes_list): - continue - start_nodes.append(start_node) + if qkv_nodes_list is not None and any(qkv_nodes_list): + start_nodes.append(start_node) + + # can't find qkv nodes with above patterns, use Softmax nodes to split model + if len(start_nodes) == 0: + for node in self.model.graph.node: + if node.op_type == "Softmax": + start_nodes.append(node) return start_nodes def find_split_nodes(self): """Find split nodes for layer-wise quantization.""" + self.remove_unused_nodes() split_nodes = self.find_split_node_for_layer_wise_quantization() return split_nodes + def _infer_tensor_dtype(self): + """Infer the elem_type of tensors.""" + initializers = dict([(i.name, i.data_type) for i in self.model.graph.initializer]) + inputs = dict([(i.name, i.type.tensor_type.elem_type) for i in self.model.graph.input]) + value_info = dict([(i.name, i.type.tensor_type.elem_type) for i in self.model.graph.value_info]) + outputs = dict([(i.name, i.type.tensor_type.elem_type) for i in self.model.graph.output]) + for node in self.model.graph.node: + if node.output[0] in value_info: + continue + elem_type = None + if node.op_type in ["And", "Equal", "Greater", "GreaterOrEqual", "Less", "LessOrEqual", "Or", "Xor"]: + elem_type = onnx.TensorProto.BOOL + elif node.op_type in ["ArgMax", "ArgMin", "NonZero", "Shape"]: + elem_type = onnx.TensorProto.INT64 + elif node.op_type == "Cast" and len(node.attribute) > 0: + elem_type = node.attribute[0].i + elif node.op_type in ["Constant", "ConstantOfShape"] and len(node.attribute) > 0: + elem_type = node.attribute[0].t.data_type + elif len(node.input) >= 2: + for inp in node.input[:2]: + if inp in initializers and initializers[inp] != onnx.TensorProto.INT64: + elem_type = initializers[inp] + break + + # output elem_type aligns with input + if elem_type is None and len(node.input) > 0: + inp = node.input[0] + if inp in value_info: + elem_type = value_info[inp] + elif inp in inputs: + elem_type = inputs[inp] + elif inp in outputs: + elem_type = outputs[inp] + if elem_type is not None: + if node.op_type in ["Split", "Slice"]: + for out in node.output: + value_info.update({out: elem_type}) + else: + value_info.update({node.output[0]: elem_type}) + + return value_info + + def _build_input_output_tensor(self, tensor_name, value_info): + if tensor_name in self.input(): + return self.model.graph.input[self.input().index(tensor_name)] + if tensor_name in self.output(): + return self.model.graph.output[self.output().index(tensor_name)] + tensor_type = value_info.get(tensor_name, onnx.TensorProto.FLOAT) + return onnx.helper.make_tensor_value_info(tensor_name, tensor_type, None) + def split_model_with_node(self, split_node_name, path_of_model_to_split, save_both_split_models=True): """Split model into two parts at a given node. @@ -824,7 +885,9 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo # origin model : ... -> node_1 -> split_node -> node_2 -> ... # split model 1: ... -> node_1 -> split_node # split model 2: node_2 -> ... - self.remove_unused_nodes() + + # infer elem_type of tensors to make sure layer-wise quant run successfully + value_info = self._infer_tensor_dtype() split_model_part_1 = onnx.ModelProto() split_model_part_1.CopyFrom(self.model) @@ -852,8 +915,7 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo ) split_tensor_name = split_node_output[0] - split_tensor_type, split_tensor_shape = self._get_output_type_shape_by_tensor_name(split_tensor_name) - split_tensor = onnx.helper.make_tensor_value_info(split_tensor_name, split_tensor_type, split_tensor_shape) + split_tensor = self._build_input_output_tensor(split_tensor_name, value_info) split_model_part_1.graph.output.append(split_tensor) split_model_part_2.graph.input.append(split_tensor) @@ -869,8 +931,7 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo insert_input_for_model_2 = [] for output in split_model_part_1._output_name_to_node.keys(): if output in split_model_part_2._input_name_to_nodes.keys(): - output_type, output_shape = self._get_output_type_shape_by_tensor_name(output) - output_tensor = onnx.helper.make_tensor_value_info(output, output_type, output_shape) + output_tensor = self._build_input_output_tensor(output, value_info) if output_tensor not in split_model_part_1.model.graph.output: insert_output_for_model_1.append(output_tensor) if output_tensor not in split_model_part_2.model.graph.input: @@ -929,26 +990,6 @@ def _save_split_model(self, save_path): convert_attribute=False, ) - def _get_output_type_shape_by_tensor_name(self, tensor_name): - """Get output type and shape with a tensor name. - - Args: - tensor_name (str): name of a tensor - - Returns: - tuple: output type and shape - """ - elem_type = onnx.TensorProto.FLOAT - shape = None - for output in self.model.graph.value_info: - if output.name == tensor_name: - elem_type = output.type.tensor_type.elem_type - shape = [ - dim.dim_value if dim.HasField("dim_value") else -1 for dim in output.type.tensor_type.shape.dim - ] - break - return elem_type, shape - def _remove_unused_input_output(self): """Remove unused input & output for split model.""" remove_outputs = [] diff --git a/onnx_neural_compressor/quantization/matmul_nbits_quantizer.py b/onnx_neural_compressor/quantization/matmul_nbits_quantizer.py index ea77b18de..99bf760e9 100644 --- a/onnx_neural_compressor/quantization/matmul_nbits_quantizer.py +++ b/onnx_neural_compressor/quantization/matmul_nbits_quantizer.py @@ -171,13 +171,24 @@ def int4_quant_algo(self): model = self.model opt_tmp_file = tempfile.TemporaryDirectory() + if getattr(self.algo_config, "layer_wise_quant", False) and not isinstance(model, str): + logger.warning("Please use model path for layer-wise quantization.") + # do graph optimization if not layer_wise_quant if ( not getattr(self.algo_config, "layer_wise_quant", False) and self.optimization_level != ort.GraphOptimizationLevel.ORT_DISABLE_ALL ): if not isinstance(model, str): - onnx.save(model, pathlib.Path(opt_tmp_file.name).joinpath("tmp.onnx").as_posix()) + onnx.save_model( + model, + pathlib.Path(opt_tmp_file.name).joinpath("tmp.onnx").as_posix(), + save_as_external_data=True, + all_tensors_to_one_file=True, + location="tmp.onnx_data", + size_threshold=1024, + convert_attribute=False, + ) model = pathlib.Path(opt_tmp_file.name).joinpath("tmp.onnx").as_posix() logger.info("Start graph optimization...") sess_options = ort.SessionOptions() @@ -189,7 +200,7 @@ def int4_quant_algo(self): sess_options.add_session_config_entry( "session.optimized_model_external_initializers_min_size_in_bytes", "1024" ) - session = ort.InferenceSession(model, sess_options) + session = ort.InferenceSession(model, sess_options, providers=["CPUExecutionProvider"]) model = sess_options.optimized_model_filepath del session logger.info("Graph optimization done.") diff --git a/onnx_neural_compressor/quantization/quantize.py b/onnx_neural_compressor/quantization/quantize.py index 9fb3dfd41..9f6ffc5d1 100644 --- a/onnx_neural_compressor/quantization/quantize.py +++ b/onnx_neural_compressor/quantization/quantize.py @@ -42,7 +42,7 @@ def quantize( sess_options.add_session_config_entry( "session.optimized_model_external_initializers_min_size_in_bytes", "1024" ) - session = ort.InferenceSession(model_input, sess_options) + session = ort.InferenceSession(model_input, sess_options, provides=["CPUExecutionProvider"]) del session model_input = sess_options.optimized_model_filepath diff --git a/onnx_neural_compressor/quantization/tuning.py b/onnx_neural_compressor/quantization/tuning.py index 5bf2d95d4..385ac63c0 100644 --- a/onnx_neural_compressor/quantization/tuning.py +++ b/onnx_neural_compressor/quantization/tuning.py @@ -501,7 +501,7 @@ def autotune( "session.optimized_model_external_initializers_file_name", "model.onnx_data" ) sess_options.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "1024") - session = ort.InferenceSession(model_input, sess_options) + session = ort.InferenceSession(model_input, sess_options, providers=["CPUExecutionProvider"]) # copy config.json to tmp dir for evaluation, LLMs evaluation may need it if isinstance(model_input, str) and os.path.exists( diff --git a/test/quantization/layer_wise/test_layer_wise.py b/test/quantization/layer_wise/test_layer_wise.py index f1a153c4e..729549c35 100644 --- a/test/quantization/layer_wise/test_layer_wise.py +++ b/test/quantization/layer_wise/test_layer_wise.py @@ -3,6 +3,7 @@ import shutil import unittest +import numpy as np import onnx import onnxruntime as ort import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer @@ -26,10 +27,13 @@ def find_onnx_file(folder_path): class DummyNLPDataloader(data_reader.CalibrationDataReader): - def __init__(self, model_name): + def __init__(self, model_name, model_path): self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) self.sequence_a = "intel-extension-for-transformers is based in SH" self.sequence_b = "Where is intel-extension-for-transformers based? NYC or SH" + model = onnx.load(model_path, load_external_data=False) + config = transformers.AutoConfig.from_pretrained(model_name) + inputs_names = [input.name for input in model.graph.input] self.encoded_list = [] encoded_input = dict(self.tokenizer(self.sequence_a, self.sequence_b, return_tensors="pt")) @@ -38,6 +42,14 @@ def __init__(self, model_name): torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1]) ) + num_attention_heads = config.num_key_value_heads + embed_size_per_head = config.hidden_size // config.num_attention_heads + shape = (1, num_attention_heads, 0, embed_size_per_head) + key_or_value = np.zeros(shape, dtype=np.float32) + for input_name in inputs_names: + if input_name not in encoded_input: + encoded_input[input_name] = key_or_value + # convert torch tensor to numpy for input_name, input_value in encoded_input.items(): if isinstance(input_value, torch.Tensor): @@ -62,7 +74,7 @@ def setUpClass(self): # limit transformers to 4.37.2 # TODO: remove transformers version limitation llama_id = "yujiepan/llama-2-tiny-3layers-random" - main_export(llama_id, output="llama-2-tiny-3layers-random", task="text-generation") + main_export(llama_id, output="llama-2-tiny-3layers-random", task="text-generation-with-past") model_path = find_onnx_file("llama-2-tiny-3layers-random") self.llama = model_path @@ -74,10 +86,10 @@ def setUpClass(self): sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED sess_options.optimized_model_filepath = "llama-2-tiny-3layers-random/optimized_model.onnx" - ort.InferenceSession(infer_shape_model_path, sess_options) + ort.InferenceSession(infer_shape_model_path, sess_options, providers=["CPUExecutionProvider"]) self.llama_optimized = "llama-2-tiny-3layers-random/optimized_model.onnx" - self.calibration_data_reader = DummyNLPDataloader(llama_id) + self.calibration_data_reader = DummyNLPDataloader(llama_id, self.llama_optimized) @classmethod def tearDownClass(self): @@ -105,8 +117,8 @@ def _get_quantized_matmul_weight(self, model, matmul_name): weight_init = onnx.numpy_helper.to_array(init) return weight_init - def _apply_quantize(self, quant_config, quant_func, data_reader=None): - fp32_model = copy.deepcopy(self.llama_optimized) + def _apply_quantize(self, model, quant_config, quant_func, data_reader=None): + fp32_model = copy.deepcopy(model) if data_reader is None: qmodel = quant_func(fp32_model, quant_config) else: @@ -115,12 +127,28 @@ def _apply_quantize(self, quant_config, quant_func, data_reader=None): return qmodel def test_rtn_layer_wise(self): + # optimized model rtn_config = config.RTNConfig(layer_wise_quant=True) - qmodel_lwq = self._apply_quantize(rtn_config, algos.rtn_quantize_entry) + qmodel_lwq = self._apply_quantize(self.llama_optimized, rtn_config, algos.rtn_quantize_entry) self.assertTrue(self._check_model_is_quantized(qmodel_lwq)) rtn_config = config.RTNConfig(layer_wise_quant=False) - qmodel = self._apply_quantize(rtn_config, algos.rtn_quantize_entry) + qmodel = self._apply_quantize(self.llama_optimized, rtn_config, algos.rtn_quantize_entry) + self.assertTrue(self._check_model_is_quantized(qmodel)) + + lwq_quantized_weight = self._get_quantized_matmul_weight(qmodel_lwq, "/lm_head/MatMul_Q4") + self.assertIsNotNone(lwq_quantized_weight) + quantized_weight = self._get_quantized_matmul_weight(qmodel, "/lm_head/MatMul_Q4") + self.assertIsNotNone(quantized_weight) + self.assertTrue((lwq_quantized_weight == quantized_weight).all()) + + # original model + rtn_config = config.RTNConfig(layer_wise_quant=True) + qmodel_lwq = self._apply_quantize(self.llama, rtn_config, algos.rtn_quantize_entry) + self.assertTrue(self._check_model_is_quantized(qmodel_lwq)) + + rtn_config = config.RTNConfig(layer_wise_quant=False) + qmodel = self._apply_quantize(self.llama, rtn_config, algos.rtn_quantize_entry) self.assertTrue(self._check_model_is_quantized(qmodel)) lwq_quantized_weight = self._get_quantized_matmul_weight(qmodel_lwq, "/lm_head/MatMul_Q4") @@ -162,14 +190,38 @@ def test_rtn_layer_wise_with_ort_like_api(self): self.assertTrue((lwq_quantized_weight == quantized_weight).all()) def test_gptq_layer_wise(self): + # optimized model self.calibration_data_reader.rewind() gptq_config = config.GPTQConfig(layer_wise_quant=True) - qmodel_lwq = self._apply_quantize(gptq_config, algos.gptq_quantize_entry, self.calibration_data_reader) + qmodel_lwq = self._apply_quantize( + self.llama_optimized, gptq_config, algos.gptq_quantize_entry, self.calibration_data_reader + ) self.assertTrue(self._check_model_is_quantized(qmodel_lwq)) self.calibration_data_reader.rewind() gptq_config = config.GPTQConfig(layer_wise_quant=False) - qmodel = self._apply_quantize(gptq_config, algos.gptq_quantize_entry, self.calibration_data_reader) + qmodel = self._apply_quantize( + self.llama_optimized, gptq_config, algos.gptq_quantize_entry, self.calibration_data_reader + ) + self.assertTrue(self._check_model_is_quantized(qmodel)) + + lwq_quantized_weight = self._get_quantized_matmul_weight(qmodel_lwq, "/lm_head/MatMul_Q4") + self.assertIsNotNone(lwq_quantized_weight) + quantized_weight = self._get_quantized_matmul_weight(qmodel, "/lm_head/MatMul_Q4") + self.assertIsNotNone(quantized_weight) + self.assertTrue((lwq_quantized_weight == quantized_weight).all()) + + # original model + self.calibration_data_reader.rewind() + gptq_config = config.GPTQConfig(layer_wise_quant=True) + qmodel_lwq = self._apply_quantize( + self.llama, gptq_config, algos.gptq_quantize_entry, self.calibration_data_reader + ) + self.assertTrue(self._check_model_is_quantized(qmodel_lwq)) + + self.calibration_data_reader.rewind() + gptq_config = config.GPTQConfig(layer_wise_quant=False) + qmodel = self._apply_quantize(self.llama, gptq_config, algos.gptq_quantize_entry, self.calibration_data_reader) self.assertTrue(self._check_model_is_quantized(qmodel)) lwq_quantized_weight = self._get_quantized_matmul_weight(qmodel_lwq, "/lm_head/MatMul_Q4") @@ -214,17 +266,6 @@ def test_gptq_layer_wise_with_ort_like_api(self): self.assertIsNotNone(quantized_weight) self.assertTrue((lwq_quantized_weight == quantized_weight).all()) - def test__check_model_with_infer_shapes(self): - from onnx_neural_compressor.algorithms.layer_wise import core as lwq_core - - self.assertFalse(lwq_core._check_model_with_infer_shapes(self.llama)) - self.assertTrue(lwq_core._check_model_with_infer_shapes(self.llama_optimized)) - self.assertTrue( - lwq_core._check_model_with_infer_shapes( - onnx_model.ONNXModel(onnx.load(self.llama_optimized, load_external_data=False)) - ) - ) - if __name__ == "__main__": unittest.main() diff --git a/test/quantization/weight_only/test_awq.py b/test/quantization/weight_only/test_awq.py index 0a574c6db..b7def741b 100644 --- a/test/quantization/weight_only/test_awq.py +++ b/test/quantization/weight_only/test_awq.py @@ -5,8 +5,9 @@ import unittest import numpy as np +import onnx +import onnxruntime as ort import torch -import torch.nn as nn import transformers from optimum.exporters.onnx import main_export @@ -53,6 +54,37 @@ def rewind(self): self.iter_next = iter(self.encoded_list) +class MatMulDataloader(data_reader.CalibrationDataReader): + + def __init__(self): + self.encoded_list = [{"A": torch.randn((1, 11008)).numpy()}] + self.iter_next = iter(self.encoded_list) + + def get_next(self): + return next(self.iter_next, None) + + def rewind(self): + self.iter_next = iter(self.encoded_list) + + +def build_matmul_model(): + # MatMul - Add - Add + A = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [1, 11008]) + C = onnx.helper.make_tensor_value_info("C", onnx.TensorProto.FLOAT, [1, 1024]) + D = onnx.helper.make_tensor_value_info("D", onnx.TensorProto.FLOAT, [1, 1024]) + + B_init = onnx.helper.make_tensor("B", onnx.TensorProto.FLOAT, [11008, 1024], np.random.random((11008, 1024))) + E_init = onnx.helper.make_tensor("E", onnx.TensorProto.FLOAT, [1, 1024], np.random.random((1, 1024))) + + matmul_node = onnx.helper.make_node("MatMul", ["A", "B"], ["C"], name="Matmul") + add = onnx.helper.make_node("Add", ["C", "E"], ["D"], name="add") + + graph = onnx.helper.make_graph([matmul_node, add], "test_graph_1", [A], [D], [B_init, E_init]) + model = onnx.helper.make_model(graph) + model = onnx.helper.make_model(graph, **{"opset_imports": [onnx.helper.make_opsetid("", 13)]}) + return model + + class TestAWQQuant(unittest.TestCase): @classmethod @@ -64,6 +96,9 @@ def setUpClass(self): self.gptj = find_onnx_file("./gptj") self.calibration_data_reader = DummyNLPDataloader("hf-internal-testing/tiny-random-gptj") + self.matmul_model = build_matmul_model() + self.matmul_data_reader = MatMulDataloader() + @classmethod def tearDownClass(self): shutil.rmtree("gptj", ignore_errors=True) @@ -275,6 +310,22 @@ def test_awq_config_nbits_with_exclude_node(self): self.assertIsNotNone(quant.model) self.assertEqual(self._count_woq_matmul(quant.model, bits=n_bits, group_size=32), 29) + def test_awq_with_specified_matmul(self): + + algo_config = matmul_nbits_quantizer.AWQWeightOnlyQuantConfig(calibration_data_reader=self.matmul_data_reader) + + quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( + copy.deepcopy(self.matmul_model), + n_bits=4, + block_size=32, + is_symmetric=False, + algo_config=algo_config, + optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + ) + quant.process() + self.assertIsNotNone(quant.model) + self.assertEqual(self._count_woq_matmul(quant.model, bits=4, group_size=32), 1) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/weight_only/test_gptq.py b/test/quantization/weight_only/test_gptq.py index 3dc3114ce..7902371c7 100644 --- a/test/quantization/weight_only/test_gptq.py +++ b/test/quantization/weight_only/test_gptq.py @@ -4,6 +4,9 @@ import shutil import unittest +import numpy as np +import onnx +import onnxruntime as ort import torch import transformers from optimum.exporters.onnx import main_export @@ -51,6 +54,37 @@ def rewind(self): self.iter_next = iter(self.encoded_list) +class MatMulDataloader(data_reader.CalibrationDataReader): + + def __init__(self): + self.encoded_list = [{"A": torch.randn((1, 11008)).numpy()}] + self.iter_next = iter(self.encoded_list) + + def get_next(self): + return next(self.iter_next, None) + + def rewind(self): + self.iter_next = iter(self.encoded_list) + + +def build_matmul_model(): + # MatMul - Add - Add + A = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [1, 11008]) + C = onnx.helper.make_tensor_value_info("C", onnx.TensorProto.FLOAT, [1, 1024]) + D = onnx.helper.make_tensor_value_info("D", onnx.TensorProto.FLOAT, [1, 1024]) + + B_init = onnx.helper.make_tensor("B", onnx.TensorProto.FLOAT, [11008, 1024], np.random.random((11008, 1024))) + E_init = onnx.helper.make_tensor("E", onnx.TensorProto.FLOAT, [1, 1024], np.random.random((1, 1024))) + + matmul_node = onnx.helper.make_node("MatMul", ["A", "B"], ["C"], name="Matmul") + add = onnx.helper.make_node("Add", ["C", "E"], ["D"], name="add") + + graph = onnx.helper.make_graph([matmul_node, add], "test_graph_1", [A], [D], [B_init, E_init]) + model = onnx.helper.make_model(graph) + model = onnx.helper.make_model(graph, **{"opset_imports": [onnx.helper.make_opsetid("", 13)]}) + return model + + class TestGPTQQuant(unittest.TestCase): @classmethod @@ -62,6 +96,9 @@ def setUpClass(self): self.gptj = find_onnx_file("./gptj") self.calibration_data_reader = DummyNLPDataloader("hf-internal-testing/tiny-random-gptj") + self.matmul_model = build_matmul_model() + self.matmul_data_reader = MatMulDataloader() + @classmethod def tearDownClass(self): shutil.rmtree("gptj", ignore_errors=True) @@ -270,6 +307,22 @@ def test_gptq_config_nbits_with_exclude_node(self): self.assertIsNotNone(quant.model) self.assertEqual(self._count_woq_matmul(quant.model, bits=n_bits, group_size=32), 29) + def test_gptq_with_specified_matmul(self): + + algo_config = matmul_nbits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=self.matmul_data_reader) + + quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( + copy.deepcopy(self.matmul_model), + n_bits=4, + block_size=32, + is_symmetric=False, + algo_config=algo_config, + optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + ) + quant.process() + self.assertIsNotNone(quant.model) + self.assertEqual(self._count_woq_matmul(quant.model, bits=4, group_size=32), 1) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/weight_only/test_rtn.py b/test/quantization/weight_only/test_rtn.py index 62467660e..6f7cea1b8 100644 --- a/test/quantization/weight_only/test_rtn.py +++ b/test/quantization/weight_only/test_rtn.py @@ -4,6 +4,9 @@ import shutil import unittest +import numpy as np +import onnx +import onnxruntime as ort from optimum.exporters.onnx import main_export from onnx_neural_compressor import logger @@ -20,6 +23,24 @@ def find_onnx_file(folder_path): return None +def build_matmul_model(): + # MatMul - Add - Add + A = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [1, 11008]) + C = onnx.helper.make_tensor_value_info("C", onnx.TensorProto.FLOAT, [1, 1024]) + D = onnx.helper.make_tensor_value_info("D", onnx.TensorProto.FLOAT, [1, 1024]) + + B_init = onnx.helper.make_tensor("B", onnx.TensorProto.FLOAT, [11008, 1024], np.random.random((11008, 1024))) + E_init = onnx.helper.make_tensor("E", onnx.TensorProto.FLOAT, [1, 1024], np.random.random((1, 1024))) + + matmul_node = onnx.helper.make_node("MatMul", ["A", "B"], ["C"], name="Matmul") + add = onnx.helper.make_node("Add", ["C", "E"], ["D"], name="add") + + graph = onnx.helper.make_graph([matmul_node, add], "test_graph_1", [A], [D], [B_init, E_init]) + model = onnx.helper.make_model(graph) + model = onnx.helper.make_model(graph, **{"opset_imports": [onnx.helper.make_opsetid("", 13)]}) + return model + + class TestRTNQuant(unittest.TestCase): @classmethod @@ -29,6 +50,7 @@ def setUpClass(self): output="gptj", ) self.gptj = find_onnx_file("./gptj") + self.matmul_model = build_matmul_model() @classmethod def tearDownClass(self): @@ -229,6 +251,22 @@ def test_rtn_config_nbits_with_exclude_node(self): self.assertIsNotNone(quant.model) self.assertEqual(self._count_woq_matmul(quant.model, bits=n_bits, group_size=32), 29) + def test_rtn_with_specified_matmul(self): + + algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig() + + quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( + copy.deepcopy(self.matmul_model), + n_bits=4, + block_size=32, + is_symmetric=False, + algo_config=algo_config, + optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + ) + quant.process() + self.assertIsNotNone(quant.model) + self.assertEqual(self._count_woq_matmul(quant.model, bits=4, group_size=32), 1) + if __name__ == "__main__": unittest.main() From 7700ee49ae20059d3e807af56e14d220d6857a7e Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Thu, 18 Jul 2024 10:50:35 +0800 Subject: [PATCH 02/18] enable more llm example Signed-off-by: yuwenzho --- .../quantization/weight_only/README.md | 3 +- .../weight_only/evaluation/__init__.py | 0 .../weight_only/evaluation/accuracy.py | 0 .../weight_only/evaluation/evaluator.py | 0 .../weight_only/evaluation/models/__init__.py | 0 .../evaluation/models/huggingface.py | 0 .../weight_only/evaluation/utils.py | 0 .../quantization/weight_only/main.py | 31 ++++++++++++++----- .../quantization/weight_only/prepare_model.py | 9 ++++-- .../quantization/weight_only/requirements.txt | 0 .../quantization/weight_only/run_benchmark.sh | 0 .../quantization/weight_only/run_quant.sh | 14 +++++++-- 12 files changed, 43 insertions(+), 14 deletions(-) rename examples/nlp/huggingface_model/text_generation/{llama => }/quantization/weight_only/README.md (96%) rename examples/nlp/huggingface_model/text_generation/{llama => }/quantization/weight_only/evaluation/__init__.py (100%) rename examples/nlp/huggingface_model/text_generation/{llama => }/quantization/weight_only/evaluation/accuracy.py (100%) rename examples/nlp/huggingface_model/text_generation/{llama => }/quantization/weight_only/evaluation/evaluator.py (100%) rename examples/nlp/huggingface_model/text_generation/{llama => }/quantization/weight_only/evaluation/models/__init__.py (100%) rename examples/nlp/huggingface_model/text_generation/{llama => }/quantization/weight_only/evaluation/models/huggingface.py (100%) rename examples/nlp/huggingface_model/text_generation/{llama => }/quantization/weight_only/evaluation/utils.py (100%) rename examples/nlp/huggingface_model/text_generation/{llama => }/quantization/weight_only/main.py (92%) rename examples/nlp/huggingface_model/text_generation/{llama => }/quantization/weight_only/prepare_model.py (78%) rename examples/nlp/huggingface_model/text_generation/{llama => }/quantization/weight_only/requirements.txt (100%) rename examples/nlp/huggingface_model/text_generation/{llama => }/quantization/weight_only/run_benchmark.sh (100%) rename examples/nlp/huggingface_model/text_generation/{llama => }/quantization/weight_only/run_quant.sh (77%) diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/README.md b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md similarity index 96% rename from examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/README.md rename to examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md index 9ddbc7f2c..ef507875f 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/README.md +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md @@ -28,7 +28,6 @@ Note that this README.md uses meta-llama/Llama-2-7b-hf as an example. There are Export to ONNX model: ```bash python prepare_model.py --input_model="meta-llama/Llama-2-7b-hf" \ - --output_model="./llama-2-7b-hf" \ --task=text-generation-with-past \ # or text-generation ``` @@ -53,7 +52,7 @@ Accuracy: ```bash bash run_benchmark.sh --input_model=path/to/model \ # folder path of onnx model - --batch_size=batch_size \ # optional + --batch_size=batch_size \ # optional --mode=accuracy \ --tokenizer=meta-llama/Llama-2-7b-hf \ # model name or folder path containing all relevant files for model's tokenizer --tasks=lambada_openai diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/__init__.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/evaluation/__init__.py similarity index 100% rename from examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/__init__.py rename to examples/nlp/huggingface_model/text_generation/quantization/weight_only/evaluation/__init__.py diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/accuracy.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/evaluation/accuracy.py similarity index 100% rename from examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/accuracy.py rename to examples/nlp/huggingface_model/text_generation/quantization/weight_only/evaluation/accuracy.py diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/evaluator.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/evaluation/evaluator.py similarity index 100% rename from examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/evaluator.py rename to examples/nlp/huggingface_model/text_generation/quantization/weight_only/evaluation/evaluator.py diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/models/__init__.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/evaluation/models/__init__.py similarity index 100% rename from examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/models/__init__.py rename to examples/nlp/huggingface_model/text_generation/quantization/weight_only/evaluation/models/__init__.py diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/models/huggingface.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/evaluation/models/huggingface.py similarity index 100% rename from examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/models/huggingface.py rename to examples/nlp/huggingface_model/text_generation/quantization/weight_only/evaluation/models/huggingface.py diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/utils.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/evaluation/utils.py similarity index 100% rename from examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/evaluation/utils.py rename to examples/nlp/huggingface_model/text_generation/quantization/weight_only/evaluation/utils.py diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py similarity index 92% rename from examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py rename to examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py index 174d2ccf1..247757576 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py @@ -44,7 +44,7 @@ parser.add_argument("--model_path", type=str, help="Folder path of pre-trained onnx model") parser.add_argument("--benchmark", action="store_true", default=False) parser.add_argument("--tune", action="store_true", default=False, help="whether quantize the model") -parser.add_argument("--output_model", type=str, default=None, help="output model path") +parser.add_argument("--output_model", type=str, default=None, help="path of output dircectory") parser.add_argument( "--batch_size", default=1, @@ -92,11 +92,19 @@ parser.add_argument("--mode", type=str, help="benchmark mode of performance or accuracy") parser.add_argument("--intra_op_num_threads", type=int, default=24) parser.add_argument("--trust_remote_code", type=bool, default=False) +parser.add_argument("--layer_wise", action="store_true", default=False) +parser.add_argument("--quantize_lm_head", action="store_true", default=False, + help="language modelling head will not be quantized by default. Doesn't take effect when 'algorithm' is 'WOQ_TUNE'") +parser.add_argument("--nodes_to_exclude", nargs="+", default=[], + help="nodes that will not be quantized. Doesn't take effect when 'algorithm' is 'WOQ_TUNE'") args = parser.parse_args() +if not os.path.exists(args.output_model): + os.makedirs(args.output_model) + # load model -tokenizer = transformers.LlamaTokenizer.from_pretrained(args.tokenizer) -model_config = transformers.LlamaConfig.from_pretrained(args.model_path) +tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer) +model_config = transformers.AutoConfig.from_pretrained(args.model_path) def tokenize_function(examples): @@ -110,7 +118,8 @@ def replace_architectures(json_path): # refer to https://github.com/huggingface/transformers/issues/22222#issuecomment-1477171703 with open(json_path, "r") as file: data = json.load(file) - data["architectures"] = ["LlamaForCausalLM"] + if data["architectures"] == ["LLaMATokenizer"]: + data["architectures"] = ["LlamaForCausalLM"] with open(json_path, "w") as file: json.dump(data, file, indent=4) @@ -327,15 +336,20 @@ def rewind(self): model_name = "model.onnx" # require optimum >= 1.14.0 model_path = os.path.join(args.model_path, model_name) best_model = None + + nodes_to_exclude=["/lm_head/MatMul"] if not args.quantize_lm_head else [] + print(nodes_to_exclude, args.nodes_to_exclude) + nodes_to_exclude = list(set(args.nodes_to_exclude + nodes_to_exclude)) if args.algorithm.upper() == "RTN": - algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig(layer_wise_quant=True) + algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig( + layer_wise_quant=args.layer_wise) quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( model_path, n_bits=4, block_size=32, is_symmetric=True, algo_config=algo_config, - optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + nodes_to_exclude=nodes_to_exclude, ) quant.process() best_model = quant.model @@ -351,6 +365,7 @@ def rewind(self): block_size=32, is_symmetric=True, algo_config=algo_config, + nodes_to_exclude=nodes_to_exclude, ) quant.process() best_model = quant.model @@ -358,7 +373,8 @@ def rewind(self): elif args.algorithm.upper() == "GPTQ": calibration_data_reader = GPTQDataloader(model_path, seqlen=args.seqlen, batch_size=1) algo_config = matmul_nbits_quantizer.GPTQWeightOnlyQuantConfig( - calibration_data_reader=calibration_data_reader, layer_wise_quant=True + calibration_data_reader=calibration_data_reader, + layer_wise_quant=args.layer_wise ) quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( model_path, @@ -366,6 +382,7 @@ def rewind(self): block_size=32, is_symmetric=False, algo_config=algo_config, + nodes_to_exclude=nodes_to_exclude, ) quant.process() best_model = quant.model diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/prepare_model.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/prepare_model.py similarity index 78% rename from examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/prepare_model.py rename to examples/nlp/huggingface_model/text_generation/quantization/weight_only/prepare_model.py index 3af820943..708c631a0 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/prepare_model.py +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/prepare_model.py @@ -10,8 +10,8 @@ def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument("--input_model", type=str, required=False, default="") - parser.add_argument("--output_model", type=str, required=True) + parser.add_argument("--input_model", type=str, required=True, default="") + parser.add_argument("--output_model", type=str, required=False, default=None) parser.add_argument( "--task", type=str, @@ -19,7 +19,10 @@ def parse_arguments(): default="text-generation-with-past", choices=["text-generation-with-past", "text-generation"], ) - return parser.parse_args() + args = parser.parse_args() + if args.output_model is None: + args.output_model = os.path.basename(args.input_model) + "-onnx" + return args def prepare_model(input_model, output_model, task): diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/requirements.txt b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/requirements.txt similarity index 100% rename from examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/requirements.txt rename to examples/nlp/huggingface_model/text_generation/quantization/weight_only/requirements.txt diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/run_benchmark.sh b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_benchmark.sh similarity index 100% rename from examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/run_benchmark.sh rename to examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_benchmark.sh diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/run_quant.sh b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh similarity index 77% rename from examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/run_quant.sh rename to examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh index 295b47249..3a50967c8 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/run_quant.sh +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh @@ -56,6 +56,15 @@ function run_tuning { echo "Created directory $output_model" fi + if [[ "${input_model}" =~ "Phi-3-mini-128k-instruct" ]]; then + nodes_to_exclude="/model/layers./self_attn/qkv_proj/MatMul /model/layers./mlp/down_proj/MatMul" + extra_cmd="--nodes_to_exclude ${nodes_to_exclude}" + fi + if [[ "${input_model}" =~ "Meta-Llama-3-8B" ]]; then + nodes_to_exclude="/model/layers.*/mlp/down_proj/MatMul" + extra_cmd="--nodes_to_exclude ${nodes_to_exclude}" + fi + python main.py \ --model_path ${input_model} \ --tokenizer ${tokenizer-meta-llama/Llama-2-7b-hf} \ @@ -64,8 +73,9 @@ function run_tuning { --dataset ${dataset-NeelNanda/pile-10k} \ --algorithm ${algorithm-WOQ_TUNE} \ --tasks ${tasks-lambada_openai} \ - --tune + --layer_wise \ + --tune \ + ${extra_cmd} } main "$@" - From bf1891a726011e3bbee506aff1bd554280859fc0 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Thu, 18 Jul 2024 17:06:04 +0800 Subject: [PATCH 03/18] update code Signed-off-by: yuwenzho --- .../quantization/weight_only/main.py | 1 - .../algorithms/layer_wise/core.py | 21 +------------------ onnx_neural_compressor/onnx_model.py | 4 ++-- 3 files changed, 3 insertions(+), 23 deletions(-) diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py index 247757576..800827e5a 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py @@ -338,7 +338,6 @@ def rewind(self): best_model = None nodes_to_exclude=["/lm_head/MatMul"] if not args.quantize_lm_head else [] - print(nodes_to_exclude, args.nodes_to_exclude) nodes_to_exclude = list(set(args.nodes_to_exclude + nodes_to_exclude)) if args.algorithm.upper() == "RTN": algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig( diff --git a/onnx_neural_compressor/algorithms/layer_wise/core.py b/onnx_neural_compressor/algorithms/layer_wise/core.py index 6ff16dce2..054e228d4 100644 --- a/onnx_neural_compressor/algorithms/layer_wise/core.py +++ b/onnx_neural_compressor/algorithms/layer_wise/core.py @@ -46,16 +46,6 @@ def layer_wise_quant( Returns: _type_: _description_ """ - logger.warning( - "Layer-wise quantization requires data_type info for some tensors. " - "We will try to infer the data_type automatically if it doesn't exist." - "You can use model with symbolic shape inference before layer-wise quantization as well like follows:\n" - "import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer\n" - "model = onnx.load(your_model_path)\n" - "out = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(model, auto_merge=True)\n" - "onnx.save(out, infer_shape_model_path)\n" - ) - if not isinstance(model, onnx_model.ONNXModel): model = onnx_model.ONNXModel(model, ignore_warning=True, load_external_data=False) @@ -66,16 +56,7 @@ def layer_wise_quant( # get and check split nodes split_nodes = origin_model.find_split_nodes() if len(split_nodes) == 0: - logger.error( - "Can't find split nodes for layer-wise quantization. " - "We recommend applying graph optimization for your model like follows: \n" - "import onnxruntime as ort \n" - "sess_options = ort.SessionOptions() \n" - "sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED " - "# or ORT_ENABLE_BASIC \n" - "sess_options.optimized_model_filepath = 'optimized_model_path' \n" - "ort.InferenceSession(infer_shape_model_path, sess_options)" - ) + logger.error("Can't find split nodes for layer-wise quantization.") raise ValueError("Fail to run layer-wise quantization.") logger.info( "Will split model into {} parts to do layer-wise quantization".format( diff --git a/onnx_neural_compressor/onnx_model.py b/onnx_neural_compressor/onnx_model.py index 8d141a464..50d7b048e 100644 --- a/onnx_neural_compressor/onnx_model.py +++ b/onnx_neural_compressor/onnx_model.py @@ -241,7 +241,7 @@ def save(self, root): root, save_as_external_data=True, all_tensors_to_one_file=True, - location=root.split("/")[-1] + "_data", + location=os.path.basename(root) + "_data", size_threshold=1024, convert_attribute=False, ) @@ -985,7 +985,7 @@ def _save_split_model(self, save_path): save_path, save_as_external_data=True, all_tensors_to_one_file=True, - location=save_path.split("/")[-1] + "_data", + location=os.path.basename(save_path) + "_data", size_threshold=1024, convert_attribute=False, ) From 902b4fdf60498339d8e2aa4fea98f294b6db3b40 Mon Sep 17 00:00:00 2001 From: Mengni Wang Date: Thu, 18 Jul 2024 20:01:35 -0700 Subject: [PATCH 04/18] fix onnx_model Signed-off-by: Mengni Wang --- .../algorithms/layer_wise/core.py | 50 ++++++++----------- onnx_neural_compressor/onnx_model.py | 50 ++++++++++++------- 2 files changed, 54 insertions(+), 46 deletions(-) diff --git a/onnx_neural_compressor/algorithms/layer_wise/core.py b/onnx_neural_compressor/algorithms/layer_wise/core.py index 6ff16dce2..8d421ac4a 100644 --- a/onnx_neural_compressor/algorithms/layer_wise/core.py +++ b/onnx_neural_compressor/algorithms/layer_wise/core.py @@ -99,7 +99,7 @@ def layer_wise_quant( split_model = model_to_split.pop(0) split_node = split_nodes.pop(0) if require_data_reader: - current_data_reader = lwq_data_reader.pop(0) + complete_data_reader = lwq_data_reader.pop(0) # if no remaining split nodes, it means this is the last split, and the two split models will be saved. save_both_split_models = True if len(split_nodes) == 0 else False @@ -114,17 +114,19 @@ def layer_wise_quant( model_to_split.append(split_model_part_2) logger.info("Quantize split model {}".format(split_idx)) + if require_data_reader: # process data_reader for current split and next split - current_data_reader = _filter_data_reader_for_current_split_model( - split_model_part_1.model, current_data_reader, data_reader + split_model_part_1.model, complete_data_reader ) + # next_data_reader contains split_model_part_1 output data - next_data_reader = _prepare_data_reader_for_next_split_model( - split_model_part_1.model_path, current_data_reader, providers + complete_data_reader = _prepare_data_reader_for_next_split_model( + split_model_part_1.model_path, [i.name for i in split_model_part_2.model.graph.input], complete_data_reader, providers ) - lwq_data_reader.append(next_data_reader) + + lwq_data_reader.append(complete_data_reader) # perform quantization split_model_part_1_quantized = quant_func( @@ -142,7 +144,7 @@ def layer_wise_quant( # check split model is valid try: - ort.InferenceSession(split_model_part_1_quantized.model.SerializeToString(), providers=providers) + ort.InferenceSession(split_model_part_1_quantized.model_path, providers=providers) except Exception as e: logger.error( "Layer-wise quantized model {} can't be inferred correctly. " @@ -167,7 +169,7 @@ def layer_wise_quant( # process data_reader for current split current_data_reader = lwq_data_reader.pop(0) current_data_reader = _filter_data_reader_for_current_split_model( - split_model_part_2.model, current_data_reader, data_reader + split_model_part_2.model, complete_data_reader ) # perform quantization @@ -186,7 +188,7 @@ def layer_wise_quant( # check split model is valid try: - ort.InferenceSession(split_model_part_2_quantized.model.SerializeToString(), providers=providers) + ort.InferenceSession(split_model_part_2_quantized.model_path, providers=providers) except Exception as e: logger.error( "Layer-wise quantized model {} can't be inferred correctly. " @@ -225,14 +227,12 @@ def rewind(self): def _filter_data_reader_for_current_split_model( model: onnx.ModelProto, current_data_reader: data_reader.CalibrationDataReader, - data_reader: data_reader.CalibrationDataReader, ): """Filter data reader to remove data that is not in model input. Args: model (onnx.ModelProto): onnx model. current_data_reader (data_reader.CalibrationDataReader): data reader of current split model. - data_reader (data_reader.CalibrationDataReader): data reader of the original model. Returns: data_reader.CalibrationDataReader: filtered data reader. @@ -240,7 +240,6 @@ def _filter_data_reader_for_current_split_model( filter_inputs = [] input_names = [input.name for input in model.graph.input] current_data_reader.rewind() - data_reader.rewind() while True: inputs = current_data_reader.get_next() @@ -251,22 +250,12 @@ def _filter_data_reader_for_current_split_model( } filter_inputs.append(filter_input) - idx = 0 - while True: - inputs = data_reader.get_next() - if not inputs: - break - filter_input = { - input_name: input_tensor for input_name, input_tensor in inputs.items() if input_name in input_names - } - if len(filter_input) > 0: - filter_inputs[idx].update(filter_input) - idx += 1 return DataReader(filter_inputs) def _prepare_data_reader_for_next_split_model( model_path: str, + next_model_input_names: list, data_reader: data_reader.CalibrationDataReader, providers: List[str] = ["CPUExecutionProvider"], ): @@ -282,16 +271,21 @@ def _prepare_data_reader_for_next_split_model( Returns: data_reader.CalibrationDataReader: data reader for next split model. """ - data_reader = copy.deepcopy(data_reader) - + data_reader.rewind() data_reader_for_next_split_model = [] session = ort.InferenceSession(model_path, providers=providers) output_names = [output.name for output in session.get_outputs()] + input_names = [input.name for input in session.get_inputs()] while True: inputs = data_reader.get_next() if not inputs: break - out = session.run(None, inputs) - inputs.update({name: value for name, value in zip(output_names, out)}) - data_reader_for_next_split_model.append(inputs) + out = session.run(None, {name: inputs[name] for name in input_names}) + filter_input = { + name: value for name, value in zip(output_names, out) + } + for name, value in inputs.items(): + if name in next_model_input_names and name not in filter_input: + filter_input[name] = value + data_reader_for_next_split_model.append(filter_input) return DataReader(data_reader_for_next_split_model) diff --git a/onnx_neural_compressor/onnx_model.py b/onnx_neural_compressor/onnx_model.py index 8d141a464..0d11ea15b 100644 --- a/onnx_neural_compressor/onnx_model.py +++ b/onnx_neural_compressor/onnx_model.py @@ -233,7 +233,7 @@ def is_graph_output(self, name): def save(self, root): """Save ONNX model.""" if os.path.split(root)[0] != "" and not os.path.exists(os.path.split(root)[0]): - raise ValueError('"root" directory does not exists.') + os.mkdir(os.path.split(root)[0]) if self.is_large_model: # pragma: no cover onnx.external_data_helper.load_external_data_for_model(self.model, os.path.split(self._model_path)[0]) onnx.save_model( @@ -897,30 +897,44 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo split_model_part_2.CopyFrom(self.model) split_model_part_2.graph.ClearField("node") - split_node_output = None - part_idx = 1 + split_node = None + nodes = [] for node in self.model.graph.node: - if part_idx == 1: - split_model_part_1.graph.node.append(node) - elif part_idx == 2: - split_model_part_2.graph.node.append(node) + nodes.append(node) if node.name == split_node_name: - split_node_output = node.output - part_idx = 2 + split_node = node + break - assert len(split_node_output) == 1, ( + assert len(split_node.output) == 1, ( "Only support split at node with 1 output tensor, while " - "current split node {} has {} output tensors".format(split_node_name, len(split_node_output)) + "current split node {} has {} output tensors".format(split_node_name, len(split_node.output)) ) - split_tensor_name = split_node_output[0] + split_tensor_name = split_node.output[0] split_tensor = self._build_input_output_tensor(split_tensor_name, value_info) + split_model_part_1.graph.node.extend(nodes) split_model_part_1.graph.output.append(split_tensor) - split_model_part_2.graph.input.append(split_tensor) - split_model_part_1 = ONNXModel(split_model_part_1, ignore_warning=True) + + # remove isolated graphs which are not related to the split_node + output_name_to_node = split_model_part_1.output_name_to_node() + valid_nodes = [split_node] + while len(valid_nodes) > 0: + node = valid_nodes.pop(0) + for inp in node.input: + if inp in output_name_to_node: + valid_nodes.append(output_name_to_node[inp]) + if node in nodes: + nodes.remove(node) + split_model_part_1.remove_nodes(nodes) + + for node in self.model.graph.node: + if node not in split_model_part_1.nodes(): + split_model_part_2.graph.node.append(node) + + split_model_part_2.graph.input.append(split_tensor) split_model_part_2 = ONNXModel(split_model_part_2, ignore_warning=True) # remove unused input & output @@ -994,14 +1008,14 @@ def _remove_unused_input_output(self): """Remove unused input & output for split model.""" remove_outputs = [] remove_inputs = [] - if len(self._input_name_to_nodes) == 0: - self._input_name_to_nodes = self.input_name_to_nodes() + input_name_to_nodes = self.input_name_to_nodes() + output_name_to_node = self.output_name_to_node() for output in self.model.graph.output: - if output.name not in self._output_name_to_node.keys(): + if output.name not in output_name_to_node.keys(): remove_outputs.append(output) for input in self.model.graph.input: - if input.name not in self._input_name_to_nodes.keys(): + if input.name not in input_name_to_nodes.keys(): remove_inputs.append(input) for output in remove_outputs: From 099001119f53672592af5538c4e24b1063602530 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Fri, 19 Jul 2024 11:06:08 +0800 Subject: [PATCH 05/18] Update main.py Signed-off-by: Wang, Mengni --- .../text_generation/llama/quantization/weight_only/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py index 174d2ccf1..37c971534 100644 --- a/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py +++ b/examples/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py @@ -335,7 +335,6 @@ def rewind(self): block_size=32, is_symmetric=True, algo_config=algo_config, - optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, ) quant.process() best_model = quant.model From 07d188d1a0e594ffeae9e1bd3b6276e54e7657fe Mon Sep 17 00:00:00 2001 From: Mengni Wang Date: Thu, 18 Jul 2024 20:08:43 -0700 Subject: [PATCH 06/18] fix CI Signed-off-by: Mengni Wang --- onnx_neural_compressor/algorithms/layer_wise/core.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/onnx_neural_compressor/algorithms/layer_wise/core.py b/onnx_neural_compressor/algorithms/layer_wise/core.py index 8d421ac4a..1667e6d91 100644 --- a/onnx_neural_compressor/algorithms/layer_wise/core.py +++ b/onnx_neural_compressor/algorithms/layer_wise/core.py @@ -53,7 +53,7 @@ def layer_wise_quant( "import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer\n" "model = onnx.load(your_model_path)\n" "out = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(model, auto_merge=True)\n" - "onnx.save(out, infer_shape_model_path)\n" + "onnx.save_model(out, infer_shape_model_path, save_as_external_data=True)\n" ) if not isinstance(model, onnx_model.ONNXModel): @@ -123,7 +123,10 @@ def layer_wise_quant( # next_data_reader contains split_model_part_1 output data complete_data_reader = _prepare_data_reader_for_next_split_model( - split_model_part_1.model_path, [i.name for i in split_model_part_2.model.graph.input], complete_data_reader, providers + split_model_part_1.model_path, + [i.name for i in split_model_part_2.model.graph.input], + complete_data_reader, + providers, ) lwq_data_reader.append(complete_data_reader) @@ -281,9 +284,7 @@ def _prepare_data_reader_for_next_split_model( if not inputs: break out = session.run(None, {name: inputs[name] for name in input_names}) - filter_input = { - name: value for name, value in zip(output_names, out) - } + filter_input = {name: value for name, value in zip(output_names, out)} for name, value in inputs.items(): if name in next_model_input_names and name not in filter_input: filter_input[name] = value From f528f03be31435f263035265b8e4d6ce9306789d Mon Sep 17 00:00:00 2001 From: Mengni Wang Date: Mon, 22 Jul 2024 19:12:40 -0700 Subject: [PATCH 07/18] fix bug Signed-off-by: Mengni Wang --- onnx_neural_compressor/algorithms/weight_only/gptq.py | 1 + onnx_neural_compressor/algorithms/weight_only/rtn.py | 1 + onnx_neural_compressor/onnx_model.py | 4 +++- test/utils/test_onnx_model.py | 6 +----- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnx_neural_compressor/algorithms/weight_only/gptq.py b/onnx_neural_compressor/algorithms/weight_only/gptq.py index b780cd81d..4a7b35b31 100644 --- a/onnx_neural_compressor/algorithms/weight_only/gptq.py +++ b/onnx_neural_compressor/algorithms/weight_only/gptq.py @@ -380,6 +380,7 @@ def gptq_quantize( if return_modelproto: return model.model else: + model.save(model.model_path + "_quant.onnx") return model diff --git a/onnx_neural_compressor/algorithms/weight_only/rtn.py b/onnx_neural_compressor/algorithms/weight_only/rtn.py index 58ba80a40..d4ca7e55e 100644 --- a/onnx_neural_compressor/algorithms/weight_only/rtn.py +++ b/onnx_neural_compressor/algorithms/weight_only/rtn.py @@ -167,6 +167,7 @@ def rtn_quantize( if return_modelproto: return model.model else: + model.save(model.model_path + "_quant.onnx") return model diff --git a/onnx_neural_compressor/onnx_model.py b/onnx_neural_compressor/onnx_model.py index 0d11ea15b..fe51eac92 100644 --- a/onnx_neural_compressor/onnx_model.py +++ b/onnx_neural_compressor/onnx_model.py @@ -248,7 +248,9 @@ def save(self, root): else: onnx.save(self.model, root) - if self._config is not None: + self._model_path = root + + if self._config is not None and not os.path.exists(os.path.join(os.path.split(root)[0], "config.json")): model_type = "" if not hasattr(self._config, "model_type") else getattr(self._config, "model_type") setattr(self._config.__class__, "model_type", model_type) output_config_file = pathlib.Path(root).parent.joinpath("config.json").as_posix() diff --git a/test/utils/test_onnx_model.py b/test/utils/test_onnx_model.py index ffa1c02ec..f27f64e1f 100644 --- a/test/utils/test_onnx_model.py +++ b/test/utils/test_onnx_model.py @@ -146,11 +146,7 @@ def test_save(self): save_path = ".large_model_save.onnx" model.save(save_path) - # test save path does not exist - with self.assertRaises(ValueError) as cm: - save_path = "./gptj_output/test.onnx" - model.save(save_path) - self.assertEqual(str(cm.exception), '"root" directory does not exists.') + self.assertEqual(model.model_path, ".large_model_save.onnx") def test_get_initializer_share_num(self): model = onnx_model.ONNXModel(self.matmul_add_model) From 237073aaf9e689509f9b8395bb55d98d64c3ab23 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Tue, 23 Jul 2024 13:00:54 +0800 Subject: [PATCH 08/18] update code Signed-off-by: yuwenzho --- .../text_generation/quantization/weight_only/README.md | 6 +++++- .../text_generation/quantization/weight_only/main.py | 4 ++-- .../text_generation/quantization/weight_only/run_quant.sh | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md index ef507875f..3c9160a63 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md @@ -14,7 +14,7 @@ pip install -r requirements.txt ## 2. Prepare Model -Note that this README.md uses meta-llama/Llama-2-7b-hf as an example. There are other models available that can be used for weight-only quantization. The following table shows a few models' configurations: +Note that this README.md uses meta-llama/Llama-2-7b-hf as an example. We verified weight-only quantization on other models as follows. | Model | Num Hidden Layers| Num Attention Heads | Hidden Size | | --- | --- | --- | --- | @@ -24,6 +24,9 @@ Note that this README.md uses meta-llama/Llama-2-7b-hf as an example. There are | [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) | 40 | 40 | 5120 | | [meta-llama/Llama-2-70b-hf](https://huggingface.co/meta-llama/Llama-2-70b-hf) | 80 | 64 | 8192 | | [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) | 80 | 64 | 8192 | +| [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) | 32 | 32 | 4096 | +| [Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) | 32 | 32 | 3072 | +| [Qwen2-72B-Instruct](https://huggingface.co/Qwen/Qwen2-72B-Instruct) | 80 | 64 | 8192 | Export to ONNX model: ```bash @@ -31,6 +34,7 @@ python prepare_model.py --input_model="meta-llama/Llama-2-7b-hf" \ --task=text-generation-with-past \ # or text-generation ``` + # Run ## 1. Quantization diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py index 800827e5a..7739f3eee 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py @@ -99,12 +99,12 @@ help="nodes that will not be quantized. Doesn't take effect when 'algorithm' is 'WOQ_TUNE'") args = parser.parse_args() -if not os.path.exists(args.output_model): +if args.tune and not os.path.exists(args.output_model): os.makedirs(args.output_model) # load model tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer) -model_config = transformers.AutoConfig.from_pretrained(args.model_path) +model_config = transformers.AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) def tokenize_function(examples): diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh index 3a50967c8..07b2177eb 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh @@ -57,7 +57,7 @@ function run_tuning { fi if [[ "${input_model}" =~ "Phi-3-mini-128k-instruct" ]]; then - nodes_to_exclude="/model/layers./self_attn/qkv_proj/MatMul /model/layers./mlp/down_proj/MatMul" + nodes_to_exclude="/model/layers.*/self_attn/qkv_proj/MatMul /model/layers.*/mlp/down_proj/MatMul" extra_cmd="--nodes_to_exclude ${nodes_to_exclude}" fi if [[ "${input_model}" =~ "Meta-Llama-3-8B" ]]; then From be276c692b0874e6e53b953c513cc2f30c154709 Mon Sep 17 00:00:00 2001 From: "Wang, Mengni" Date: Tue, 23 Jul 2024 13:29:30 +0800 Subject: [PATCH 09/18] Update core.py Signed-off-by: Wang, Mengni --- onnx_neural_compressor/algorithms/layer_wise/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_neural_compressor/algorithms/layer_wise/core.py b/onnx_neural_compressor/algorithms/layer_wise/core.py index 1667e6d91..fe07a3a77 100644 --- a/onnx_neural_compressor/algorithms/layer_wise/core.py +++ b/onnx_neural_compressor/algorithms/layer_wise/core.py @@ -121,7 +121,7 @@ def layer_wise_quant( split_model_part_1.model, complete_data_reader ) - # next_data_reader contains split_model_part_1 output data + # complete_data_reader contains split_model_part_1 output data complete_data_reader = _prepare_data_reader_for_next_split_model( split_model_part_1.model_path, [i.name for i in split_model_part_2.model.graph.input], From f9c3c108e80d24868788d6f4edef7fe475026f8b Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Tue, 23 Jul 2024 13:17:33 +0800 Subject: [PATCH 10/18] fix typo Signed-off-by: yuwenzho --- .../text_generation/quantization/weight_only/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py index 7739f3eee..f65e7111f 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py @@ -94,7 +94,7 @@ parser.add_argument("--trust_remote_code", type=bool, default=False) parser.add_argument("--layer_wise", action="store_true", default=False) parser.add_argument("--quantize_lm_head", action="store_true", default=False, - help="language modelling head will not be quantized by default. Doesn't take effect when 'algorithm' is 'WOQ_TUNE'") + help="language modeling head will not be quantized by default. Doesn't take effect when 'algorithm' is 'WOQ_TUNE'") parser.add_argument("--nodes_to_exclude", nargs="+", default=[], help="nodes that will not be quantized. Doesn't take effect when 'algorithm' is 'WOQ_TUNE'") args = parser.parse_args() From 9ff12ec413104bd14f43983fa310aee4f301ff24 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Tue, 23 Jul 2024 21:13:37 +0800 Subject: [PATCH 11/18] update code Signed-off-by: yuwenzho --- examples/.config/model_params_onnxrt.json | 37 +++++++++++++++---- .../quantization/weight_only/main.py | 24 +++++++----- 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/examples/.config/model_params_onnxrt.json b/examples/.config/model_params_onnxrt.json index 085c7ef6c..06e1c0566 100644 --- a/examples/.config/model_params_onnxrt.json +++ b/examples/.config/model_params_onnxrt.json @@ -1,61 +1,82 @@ { "onnxrt": { "llama-2-7b-rtn": { - "model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf", "main_script": "main.py", "batch_size": 1 }, "llama-2-7b-rtn-with-past": { - "model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past", "main_script": "main.py", "batch_size": 1 }, "llama-2-7b-awq": { - "model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf", "main_script": "main.py", "batch_size": 1 }, "llama-2-7b-awq-with-past": { - "model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past", "main_script": "main.py", "batch_size": 1 }, "llama-2-7b-gptq": { - "model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf", "main_script": "main.py", "batch_size": 1 }, "llama-2-7b-gptq-with-past": { - "model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past", "main_script": "main.py", "batch_size": 1 }, "llama-2-7b-woq_tune": { - "model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf", "main_script": "main.py", "batch_size": 1 }, "llama-2-7b-woq_tune-with-past": { - "model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only", + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past", "main_script": "main.py", "batch_size": 1 }, + "llama-3-8b-gptq": { + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", + "dataset_location": "", + "input_model": "/tf_dataset2/models/onnx/Meta-Llama-3-8B-onnx", + "main_script": "main.py", + "batch_size": 1 + }, + "phi-3-mini-128k-instruct-rtn-with-past": { + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", + "dataset_location": "", + "input_model": "/tf_dataset2/models/onnx/Phi-3-mini-128k-instruct-onnx", + "main_script": "main.py", + "batch_size": 1 + }, + "qwen2-7b-instruct-rtn-with-past": { + "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", + "dataset_location": "", + "input_model": "/tf_dataset2/models/onnx/Qwen2-7B-Instruct-onnx", + "main_script": "main.py", + "batch_size": 1 + }, "bert_base_MRPC": { "model_src_dir": "nlp/bert/quantization/ptq_static", "dataset_location": "/tf_dataset/pytorch/glue_data/MRPC", diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py index f65e7111f..b856ebc3d 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py @@ -93,10 +93,18 @@ parser.add_argument("--intra_op_num_threads", type=int, default=24) parser.add_argument("--trust_remote_code", type=bool, default=False) parser.add_argument("--layer_wise", action="store_true", default=False) -parser.add_argument("--quantize_lm_head", action="store_true", default=False, - help="language modeling head will not be quantized by default. Doesn't take effect when 'algorithm' is 'WOQ_TUNE'") -parser.add_argument("--nodes_to_exclude", nargs="+", default=[], - help="nodes that will not be quantized. Doesn't take effect when 'algorithm' is 'WOQ_TUNE'") +parser.add_argument( + "--quantize_lm_head", + action="store_true", + default=False, + help="language modeling head will not be quantized by default. Doesn't take effect when 'algorithm' is 'WOQ_TUNE'", +) +parser.add_argument( + "--nodes_to_exclude", + nargs="+", + default=[], + help="nodes that will not be quantized. Doesn't take effect when 'algorithm' is 'WOQ_TUNE'", +) args = parser.parse_args() if args.tune and not os.path.exists(args.output_model): @@ -337,11 +345,10 @@ def rewind(self): model_path = os.path.join(args.model_path, model_name) best_model = None - nodes_to_exclude=["/lm_head/MatMul"] if not args.quantize_lm_head else [] + nodes_to_exclude = ["/lm_head/MatMul"] if not args.quantize_lm_head else [] nodes_to_exclude = list(set(args.nodes_to_exclude + nodes_to_exclude)) if args.algorithm.upper() == "RTN": - algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig( - layer_wise_quant=args.layer_wise) + algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig(layer_wise_quant=args.layer_wise) quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( model_path, n_bits=4, @@ -372,8 +379,7 @@ def rewind(self): elif args.algorithm.upper() == "GPTQ": calibration_data_reader = GPTQDataloader(model_path, seqlen=args.seqlen, batch_size=1) algo_config = matmul_nbits_quantizer.GPTQWeightOnlyQuantConfig( - calibration_data_reader=calibration_data_reader, - layer_wise_quant=args.layer_wise + calibration_data_reader=calibration_data_reader, layer_wise_quant=args.layer_wise ) quant = matmul_nbits_quantizer.MatMulNBitsQuantizer( model_path, From d28255a02b4118f6a1948ffa28e16e7fc69dcf34 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 24 Jul 2024 08:23:09 +0800 Subject: [PATCH 12/18] update code Signed-off-by: yuwenzho --- examples/.config/model_params_onnxrt.json | 10 +++++----- .../text_generation/quantization/weight_only/README.md | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/.config/model_params_onnxrt.json b/examples/.config/model_params_onnxrt.json index 06e1c0566..f9fbd6fd8 100644 --- a/examples/.config/model_params_onnxrt.json +++ b/examples/.config/model_params_onnxrt.json @@ -56,24 +56,24 @@ "main_script": "main.py", "batch_size": 1 }, - "llama-3-8b-gptq": { + "llama-3-8b-gptq-with-past": { "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", - "input_model": "/tf_dataset2/models/onnx/Meta-Llama-3-8B-onnx", + "input_model": "/tf_dataset2/models/onnx/Meta-Llama-3-8B-with-past", "main_script": "main.py", "batch_size": 1 }, "phi-3-mini-128k-instruct-rtn-with-past": { "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", - "input_model": "/tf_dataset2/models/onnx/Phi-3-mini-128k-instruct-onnx", + "input_model": "/tf_dataset2/models/onnx/Phi-3-mini-128k-instruct-with-past", "main_script": "main.py", "batch_size": 1 }, "qwen2-7b-instruct-rtn-with-past": { "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", - "input_model": "/tf_dataset2/models/onnx/Qwen2-7B-Instruct-onnx", + "input_model": "/tf_dataset2/models/onnx/Qwen2-7B-Instruct-with-past", "main_script": "main.py", "batch_size": 1 }, @@ -104,6 +104,6 @@ "input_model": "/tf_dataset2/models/onnx/resnet50-v1-12/resnet50-v1-12.onnx", "main_script": "main.py", "batch_size": 1 - }, + } } } diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md index 3c9160a63..79b17c73f 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/README.md @@ -25,8 +25,8 @@ Note that this README.md uses meta-llama/Llama-2-7b-hf as an example. We verifie | [meta-llama/Llama-2-70b-hf](https://huggingface.co/meta-llama/Llama-2-70b-hf) | 80 | 64 | 8192 | | [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) | 80 | 64 | 8192 | | [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) | 32 | 32 | 4096 | -| [Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) | 32 | 32 | 3072 | -| [Qwen2-72B-Instruct](https://huggingface.co/Qwen/Qwen2-72B-Instruct) | 80 | 64 | 8192 | +| [microsoft/Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) | 32 | 32 | 3072 | +| [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) | 28 | 28 | 3584 | Export to ONNX model: ```bash From efed353a3415884c1eb7fbe1bbdbad7da85c5f13 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 24 Jul 2024 09:09:01 +0800 Subject: [PATCH 13/18] update example config Signed-off-by: yuwenzho --- examples/.config/model_params_onnxrt.json | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/examples/.config/model_params_onnxrt.json b/examples/.config/model_params_onnxrt.json index f9fbd6fd8..ee802fb61 100644 --- a/examples/.config/model_params_onnxrt.json +++ b/examples/.config/model_params_onnxrt.json @@ -1,6 +1,7 @@ { "onnxrt": { "llama-2-7b-rtn": { + "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf", @@ -8,6 +9,7 @@ "batch_size": 1 }, "llama-2-7b-rtn-with-past": { + "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past", @@ -15,6 +17,7 @@ "batch_size": 1 }, "llama-2-7b-awq": { + "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf", @@ -22,6 +25,7 @@ "batch_size": 1 }, "llama-2-7b-awq-with-past": { + "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past", @@ -29,6 +33,7 @@ "batch_size": 1 }, "llama-2-7b-gptq": { + "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf", @@ -36,6 +41,7 @@ "batch_size": 1 }, "llama-2-7b-gptq-with-past": { + "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past", @@ -43,6 +49,7 @@ "batch_size": 1 }, "llama-2-7b-woq_tune": { + "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf", @@ -50,6 +57,7 @@ "batch_size": 1 }, "llama-2-7b-woq_tune-with-past": { + "model_name": "meta-llama/Llama-2-7b-hf", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past", @@ -57,6 +65,7 @@ "batch_size": 1 }, "llama-3-8b-gptq-with-past": { + "model_name": "meta-llama/Meta-Llama-3-8B", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Meta-Llama-3-8B-with-past", @@ -64,6 +73,7 @@ "batch_size": 1 }, "phi-3-mini-128k-instruct-rtn-with-past": { + "model_name": "microsoft/Phi-3-mini-128k-instruct", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Phi-3-mini-128k-instruct-with-past", @@ -71,6 +81,7 @@ "batch_size": 1 }, "qwen2-7b-instruct-rtn-with-past": { + "model_name": "Qwen/Qwen2-7B-Instruct", "model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only", "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Qwen2-7B-Instruct-with-past", From f1440e9bb5597466602cf86fd48242f930b47957 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Wed, 24 Jul 2024 09:33:48 +0800 Subject: [PATCH 14/18] enhance code Signed-off-by: yuwenzho --- .../text_generation/quantization/weight_only/prepare_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/prepare_model.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/prepare_model.py index 708c631a0..4d5d357da 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/prepare_model.py +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/prepare_model.py @@ -40,6 +40,7 @@ def prepare_model(input_model, output_model, task): "--task", task, f"{output_model}", + "--trust-remote-code", ], stdout=subprocess.PIPE, text=True, From 00812a9e94b6fe22bbd6d4e1e608093e211164c0 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Thu, 25 Jul 2024 08:52:59 +0800 Subject: [PATCH 15/18] update example config Signed-off-by: yuwenzho --- .../text_generation/quantization/weight_only/main.py | 2 +- .../quantization/weight_only/run_quant.sh | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py index b856ebc3d..e327aa827 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/main.py @@ -112,7 +112,7 @@ # load model tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer) -model_config = transformers.AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) +model_config = transformers.AutoConfig.from_pretrained(args.model_path, trust_remote_code=args.trust_remote_code) def tokenize_function(examples): diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh index 07b2177eb..9c9f63606 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh @@ -56,14 +56,18 @@ function run_tuning { echo "Created directory $output_model" fi - if [[ "${input_model}" =~ "Phi-3-mini-128k-instruct" ]]; then + if [[ "${tokenizer}" =~ "Phi-3-mini" ]]; then nodes_to_exclude="/model/layers.*/self_attn/qkv_proj/MatMul /model/layers.*/mlp/down_proj/MatMul" - extra_cmd="--nodes_to_exclude ${nodes_to_exclude}" + extra_cmd="--nodes_to_exclude ${nodes_to_exclude} --trust_remote_code True" fi - if [[ "${input_model}" =~ "Meta-Llama-3-8B" ]]; then + if [[ "${tokenizer}" =~ "Llama-3-8B" ]]; then nodes_to_exclude="/model/layers.*/mlp/down_proj/MatMul" extra_cmd="--nodes_to_exclude ${nodes_to_exclude}" fi + if [[ "${tokenizer}" =~ "Qwen2-7B" ]]; then + nodes_to_exclude="/model/layers.*/mlp/down_proj/MatMul /model/layers.*/mlp/up_proj/MatMul" + extra_cmd="--nodes_to_exclude ${nodes_to_exclude}" + fi python main.py \ --model_path ${input_model} \ From 22c1f41ce002d77fac2d8518be8f5bfce326db15 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Thu, 25 Jul 2024 09:05:31 +0800 Subject: [PATCH 16/18] update example config Signed-off-by: yuwenzho --- examples/.config/model_params_onnxrt.json | 33 +++++++++++++++-------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/examples/.config/model_params_onnxrt.json b/examples/.config/model_params_onnxrt.json index ee802fb61..4ade34f75 100644 --- a/examples/.config/model_params_onnxrt.json +++ b/examples/.config/model_params_onnxrt.json @@ -6,7 +6,8 @@ "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf", "main_script": "main.py", - "batch_size": 1 + "batch_size": 1, + "algorithm": "RTN" }, "llama-2-7b-rtn-with-past": { "model_name": "meta-llama/Llama-2-7b-hf", @@ -14,7 +15,8 @@ "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past", "main_script": "main.py", - "batch_size": 1 + "batch_size": 1, + "algorithm": "RTN" }, "llama-2-7b-awq": { "model_name": "meta-llama/Llama-2-7b-hf", @@ -22,7 +24,8 @@ "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf", "main_script": "main.py", - "batch_size": 1 + "batch_size": 1, + "algorithm": "AWQ" }, "llama-2-7b-awq-with-past": { "model_name": "meta-llama/Llama-2-7b-hf", @@ -30,7 +33,8 @@ "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past", "main_script": "main.py", - "batch_size": 1 + "batch_size": 1, + "algorithm": "AWQ" }, "llama-2-7b-gptq": { "model_name": "meta-llama/Llama-2-7b-hf", @@ -38,7 +42,8 @@ "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf", "main_script": "main.py", - "batch_size": 1 + "batch_size": 1, + "algorithm": "GPTQ" }, "llama-2-7b-gptq-with-past": { "model_name": "meta-llama/Llama-2-7b-hf", @@ -46,7 +51,8 @@ "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past", "main_script": "main.py", - "batch_size": 1 + "batch_size": 1, + "algorithm": "GPTQ" }, "llama-2-7b-woq_tune": { "model_name": "meta-llama/Llama-2-7b-hf", @@ -54,7 +60,8 @@ "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf", "main_script": "main.py", - "batch_size": 1 + "batch_size": 1, + "algorithm": "WOQ_TUNE" }, "llama-2-7b-woq_tune-with-past": { "model_name": "meta-llama/Llama-2-7b-hf", @@ -62,7 +69,8 @@ "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past", "main_script": "main.py", - "batch_size": 1 + "batch_size": 1, + "algorithm": "WOQ_TUNE" }, "llama-3-8b-gptq-with-past": { "model_name": "meta-llama/Meta-Llama-3-8B", @@ -70,7 +78,8 @@ "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Meta-Llama-3-8B-with-past", "main_script": "main.py", - "batch_size": 1 + "batch_size": 1, + "algorithm": "GPTQ" }, "phi-3-mini-128k-instruct-rtn-with-past": { "model_name": "microsoft/Phi-3-mini-128k-instruct", @@ -78,7 +87,8 @@ "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Phi-3-mini-128k-instruct-with-past", "main_script": "main.py", - "batch_size": 1 + "batch_size": 1, + "algorithm": "RTN" }, "qwen2-7b-instruct-rtn-with-past": { "model_name": "Qwen/Qwen2-7B-Instruct", @@ -86,7 +96,8 @@ "dataset_location": "", "input_model": "/tf_dataset2/models/onnx/Qwen2-7B-Instruct-with-past", "main_script": "main.py", - "batch_size": 1 + "batch_size": 1, + "algorithm": "RTN" }, "bert_base_MRPC": { "model_src_dir": "nlp/bert/quantization/ptq_static", From f82aed3fd21b1c235d0727e8ced0d0854dea78c3 Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Thu, 25 Jul 2024 15:49:00 +0800 Subject: [PATCH 17/18] fix shell Signed-off-by: yuwenzho --- .../text_generation/quantization/weight_only/run_quant.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh index 9c9f63606..4198da9a8 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_quant.sh @@ -69,7 +69,7 @@ function run_tuning { extra_cmd="--nodes_to_exclude ${nodes_to_exclude}" fi - python main.py \ + eval "python main.py \ --model_path ${input_model} \ --tokenizer ${tokenizer-meta-llama/Llama-2-7b-hf} \ --output_model ${output_model} \ @@ -79,7 +79,7 @@ function run_tuning { --tasks ${tasks-lambada_openai} \ --layer_wise \ --tune \ - ${extra_cmd} + ${extra_cmd}" } main "$@" From 539d01adec8e7b51c072f6128f8e5937a0d8b56e Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Fri, 26 Jul 2024 08:54:40 +0800 Subject: [PATCH 18/18] fix bug Signed-off-by: yuwenzho --- .../quantization/weight_only/run_benchmark.sh | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_benchmark.sh b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_benchmark.sh index 1f728c0f1..72348427c 100644 --- a/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_benchmark.sh +++ b/examples/nlp/huggingface_model/text_generation/quantization/weight_only/run_benchmark.sh @@ -35,23 +35,27 @@ function init_params { # run_benchmark function run_benchmark { - + # Check if the input_model ends with the filename extension ".onnx" if [[ $input_model =~ \.onnx$ ]]; then # If the string ends with the filename extension, get the path of the file input_model=$(dirname "$input_model") fi - python main.py \ + if [[ "${tokenizer}" =~ "Phi-3-mini" ]]; then + extra_cmd="--trust_remote_code True" + fi + + eval "python main.py \ --model_path ${input_model} \ --batch_size=${batch_size-1} \ --tokenizer=${tokenizer-meta-llama/Llama-2-7b-hf} \ --tasks=${tasks-lambada_openai} \ --mode=${mode} \ --intra_op_num_threads=${intra_op_num_threads-24} \ - --benchmark - + --benchmark \ + ${extra_cmd}" + } main "$@" -