diff --git a/slim_trees/lgbm_booster.py b/slim_trees/lgbm_booster.py index 39c728f..241ffc4 100644 --- a/slim_trees/lgbm_booster.py +++ b/slim_trees/lgbm_booster.py @@ -68,24 +68,62 @@ def _decompress_booster_state(compressed_state: dict): return state +FRONT_STRING_REGEX = r"(?:\w+(?:=.*)?\n)*\n(?=Tree)" +BACK_STRING_REGEX = r"end of trees(?:\n)+(?:.|\n)*" +TREE_GROUP_REGEX = r"(Tree=\d+\n+)((?:.+\n)*)\n\n" + +SPLIT_FEATURE_DTYPE = np.int16 +THRESHOLD_DTYPE = np.float64 +DECISION_TYPE_DTYPE = np.int8 +LEFT_CHILD_DTYPE = np.int16 +RIGHT_CHILD_DTYPE = LEFT_CHILD_DTYPE +LEAF_VALUE_DTYPE = np.float64 + + +def _extract_feature(feature_line: str) -> Tuple[str, List[str]]: + feat_name, values_str = feature_line.split("=") + return feat_name, values_str.split(" ") + + +def _validate_feature_lengths(feats_map: dict): + # features on tree-level + assert len(feats_map["num_leaves"]) == 1 + assert len(feats_map["num_cat"]) == 1 + assert len(feats_map["is_linear"]) == 1 + assert len(feats_map["shrinkage"]) == 1 + + # features on node-level + num_leaves = int(feats_map["num_leaves"][0]) + assert len(feats_map["split_feature"]) == num_leaves - 1 + assert len(feats_map["threshold"]) == num_leaves - 1 + assert len(feats_map["decision_type"]) == num_leaves - 1 + assert len(feats_map["left_child"]) == num_leaves - 1 + assert len(feats_map["right_child"]) == num_leaves - 1 + + # features on leaf-level + num_leaves = int(feats_map["num_leaves"][0]) + assert len(feats_map["leaf_value"]) == num_leaves + + +def parse(str_list, dtype): + if np.can_cast(dtype, np.int64): + int64_array = np.array(str_list, dtype=np.int64) + return safe_cast(int64_array, dtype) + assert np.can_cast(dtype, np.float64) + return np.array(str_list, dtype=dtype) + + def _compress_booster_handle(model_string: str) -> Tuple[str, List[dict], str]: if not model_string.startswith("tree\nversion=v3"): raise ValueError("Only v3 is supported for the booster string format.") - FRONT_STRING_REGEX = r"(?:\w+(?:=.*)?\n)*\n(?=Tree)" - BACK_STRING_REGEX = r"end of trees(?:\n)+(?:.|\n)*" - TREE_GROUP_REGEX = r"(Tree=\d+\n+)((?:.+\n)*)\n\n" - - def _extract_feature(feature_line): - feat_name, values_str = feature_line.split("=") - return feat_name, values_str.split(" ") front_str_match = re.search(FRONT_STRING_REGEX, model_string) if front_str_match is None: raise ValueError("Could not find front string.") - front_str = front_str_match.group() + # delete tree_sizes line since this messes up the tree parsing by LightGBM if not set correctly # todo calculate correct tree_sizes - front_str = re.sub(r"tree_sizes=(?:\d+ )*\d+\n", "", front_str) + front_str = re.sub(r"tree_sizes=(?:\d+ )*\d+\n", "", front_str_match.group()) back_str_match = re.search(BACK_STRING_REGEX, model_string) if back_str_match is None: @@ -101,44 +139,58 @@ def _extract_feature(feature_line): # extract features -- filter out empty ones features = [f for f in features_list.split("\n") if "=" in f] feats_map = dict(_extract_feature(fl) for fl in features) + _validate_feature_lengths(feats_map) + + tree_values = { + "num_leaves": int(feats_map["num_leaves"][0]), + "num_cat": int(feats_map["num_cat"][0]), + "split_feature": parse(feats_map["split_feature"], SPLIT_FEATURE_DTYPE), + "threshold": compress_half_int_float_array( + parse(feats_map["threshold"], THRESHOLD_DTYPE) + ), + "decision_type": parse(feats_map["decision_type"], DECISION_TYPE_DTYPE), + "left_child": parse(feats_map["left_child"], LEFT_CHILD_DTYPE), + "right_child": parse(feats_map["right_child"], RIGHT_CHILD_DTYPE), + "leaf_value": parse(feats_map["leaf_value"], LEAF_VALUE_DTYPE), + "is_linear": int(feats_map["is_linear"][0]), + "shrinkage": float(feats_map["shrinkage"][0]), + } + + # if tree is linear, add additional features + if int(feats_map["is_linear"][0]): + # attributes: leaf_features, leaf_coeff, leaf_const, num_features + # TODO: not all of these attributes might be needed. + tree_values["num_features"] = parse(feats_map["num_features"], np.int32) + tree_values["leaf_const"] = parse(feats_map["leaf_const"], LEAF_VALUE_DTYPE) + tree_values["leaf_features"] = parse( + [s if s else -1 for s in feats_map["leaf_features"]], + np.int16, + ) + tree_values["leaf_coeff"] = parse( + [s if s else None for s in feats_map["leaf_coeff"]], np.float64 + ) + + # at last + trees.append(tree_values) - def parse(str_list, dtype): - if np.can_cast(dtype, np.int64): - int64_array = np.array(str_list, dtype=np.int64) - return safe_cast(int64_array, dtype) - assert np.can_cast(dtype, np.float64) - return np.array(str_list, dtype=dtype) - - split_feature_dtype = np.int16 - threshold_dtype = np.float64 - decision_type_dtype = np.int8 - left_child_dtype = np.int16 - right_child_dtype = left_child_dtype - leaf_value_dtype = np.float64 - assert len(feats_map["num_leaves"]) == 1 - assert len(feats_map["num_cat"]) == 1 - assert len(feats_map["is_linear"]) == 1 - assert len(feats_map["shrinkage"]) == 1 - - trees.append( - { - "num_leaves": int(feats_map["num_leaves"][0]), - "num_cat": int(feats_map["num_cat"][0]), - "split_feature": parse(feats_map["split_feature"], split_feature_dtype), - "threshold": compress_half_int_float_array( - parse(feats_map["threshold"], threshold_dtype) - ), - "decision_type": parse(feats_map["decision_type"], decision_type_dtype), - "left_child": parse(feats_map["left_child"], left_child_dtype), - "right_child": parse(feats_map["right_child"], right_child_dtype), - "leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype), - "is_linear": int(feats_map["is_linear"][0]), - "shrinkage": float(feats_map["shrinkage"][0]), - } - ) return front_str, trees, back_str +def _validate_tree_structure(tree: dict) -> bool: + return type(tree) == dict and tree.keys() == { + "num_leaves", + "num_cat", + "split_feature", + "threshold", + "decision_type", + "left_child", + "right_child", + "leaf_value", + "is_linear", + "shrinkage", + } + + def _decompress_booster_handle(compressed_state: Tuple[str, List[dict], str]) -> str: front_str, trees, back_str = compressed_state assert type(front_str) == str @@ -146,21 +198,9 @@ def _decompress_booster_handle(compressed_state: Tuple[str, List[dict], str]) -> assert type(back_str) == str handle = front_str - for i, tree in enumerate(trees): - assert type(tree) == dict - assert tree.keys() == { - "num_leaves", - "num_cat", - "split_feature", - "threshold", - "decision_type", - "left_child", - "right_child", - "leaf_value", - "is_linear", - "shrinkage", - } + _validate_tree_structure(tree) + is_linear = int(tree["is_linear"]) num_leaves = len(tree["leaf_value"]) num_nodes = len(tree["split_feature"]) @@ -183,9 +223,22 @@ def _decompress_booster_handle(compressed_state: Tuple[str, List[dict], str]) -> tree_str += "\ninternal_weight=" + ("0 " * num_nodes)[:-1] tree_str += "\ninternal_count=" + ("0 " * num_nodes)[:-1] tree_str += f"\nis_linear={tree['is_linear']}" - tree_str += f"\nshrinkage={tree['shrinkage']}" - tree_str += "\n\n\n" + if is_linear: + tree_str += "\nleaf_const=" + " ".join(str(x) for x in tree["leaf_const"]) + tree_str += "\nnum_features=" + " ".join( + str(x) for x in tree["num_features"] + ) + tree_str += "\nleaf_features=" + " ".join( + "" if f == -1 else str(int(f)) for f in tree["leaf_features"] + ) + tree_str += "\nleaf_coeff=" + " ".join( + "" if np.isnan(f) else str(f) for f in tree["leaf_coeff"] + ) + + tree_str += f"\nshrinkage={int(tree['shrinkage'])}" + tree_str += "\n\n\n" handle += tree_str + handle += back_str return handle diff --git a/tests/test_lgbm_compression.py b/tests/test_lgbm_compression.py index bc5915e..e2d4f44 100644 --- a/tests/test_lgbm_compression.py +++ b/tests/test_lgbm_compression.py @@ -20,7 +20,15 @@ def lgbm_regressor(rng): return LGBMRegressor(random_state=rng) -def test_compresed_predictions(diabetes_toy_df, lgbm_regressor, tmp_path): +@pytest.fixture +def lgbm_regressor_linear(rng): + return LGBMRegressor(random_state=rng, linear_trees=True) + + +@pytest.mark.parametrize( + "lgbm_regressor", [lgbm_regressor, lgbm_regressor_linear], indirect=True +) +def test_compressed_predictions(diabetes_toy_df, lgbm_regressor, tmp_path): X, y = diabetes_toy_df lgbm_regressor.fit(X, y)