Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: refactoring, attributes to parse linear trees as well #50

Merged
merged 5 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 118 additions & 59 deletions slim_trees/lgbm_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,62 @@ def _decompress_booster_state(compressed_state: dict):
return state


FRONT_STRING_REGEX = r"(?:\w+(?:=.*)?\n)*\n(?=Tree)"
BACK_STRING_REGEX = r"end of trees(?:\n)+(?:.|\n)*"
TREE_GROUP_REGEX = r"(Tree=\d+\n+)((?:.+\n)*)\n\n"

SPLIT_FEATURE_DTYPE = np.int16
THRESHOLD_DTYPE = np.float64
DECISION_TYPE_DTYPE = np.int8
LEFT_CHILD_DTYPE = np.int16
RIGHT_CHILD_DTYPE = LEFT_CHILD_DTYPE
LEAF_VALUE_DTYPE = np.float64


def _extract_feature(feature_line: str) -> Tuple[str, List[str]]:
feat_name, values_str = feature_line.split("=")
return feat_name, values_str.split(" ")


def _validate_feature_lengths(feats_map: dict):
# features on tree-level
assert len(feats_map["num_leaves"]) == 1
assert len(feats_map["num_cat"]) == 1
assert len(feats_map["is_linear"]) == 1
assert len(feats_map["shrinkage"]) == 1

# features on node-level
num_leaves = int(feats_map["num_leaves"][0])
assert len(feats_map["split_feature"]) == num_leaves - 1
assert len(feats_map["threshold"]) == num_leaves - 1
assert len(feats_map["decision_type"]) == num_leaves - 1
assert len(feats_map["left_child"]) == num_leaves - 1
assert len(feats_map["right_child"]) == num_leaves - 1

# features on leaf-level
num_leaves = int(feats_map["num_leaves"][0])
assert len(feats_map["leaf_value"]) == num_leaves


def parse(str_list, dtype):
if np.can_cast(dtype, np.int64):
int64_array = np.array(str_list, dtype=np.int64)
return safe_cast(int64_array, dtype)
assert np.can_cast(dtype, np.float64)
return np.array(str_list, dtype=dtype)


def _compress_booster_handle(model_string: str) -> Tuple[str, List[dict], str]:
if not model_string.startswith("tree\nversion=v3"):
raise ValueError("Only v3 is supported for the booster string format.")
FRONT_STRING_REGEX = r"(?:\w+(?:=.*)?\n)*\n(?=Tree)"
BACK_STRING_REGEX = r"end of trees(?:\n)+(?:.|\n)*"
TREE_GROUP_REGEX = r"(Tree=\d+\n+)((?:.+\n)*)\n\n"

def _extract_feature(feature_line):
feat_name, values_str = feature_line.split("=")
return feat_name, values_str.split(" ")

front_str_match = re.search(FRONT_STRING_REGEX, model_string)
if front_str_match is None:
raise ValueError("Could not find front string.")
front_str = front_str_match.group()

# delete tree_sizes line since this messes up the tree parsing by LightGBM if not set correctly
# todo calculate correct tree_sizes
front_str = re.sub(r"tree_sizes=(?:\d+ )*\d+\n", "", front_str)
front_str = re.sub(r"tree_sizes=(?:\d+ )*\d+\n", "", front_str_match.group())

back_str_match = re.search(BACK_STRING_REGEX, model_string)
if back_str_match is None:
Expand All @@ -101,66 +139,70 @@ def _extract_feature(feature_line):
# extract features -- filter out empty ones
features = [f for f in features_list.split("\n") if "=" in f]
feats_map = dict(_extract_feature(fl) for fl in features)
_validate_feature_lengths(feats_map)

tree_values = {
"num_leaves": int(feats_map["num_leaves"][0]),
"num_cat": int(feats_map["num_cat"][0]),
"split_feature": parse(feats_map["split_feature"], SPLIT_FEATURE_DTYPE),
"threshold": compress_half_int_float_array(
parse(feats_map["threshold"], THRESHOLD_DTYPE)
),
"decision_type": parse(feats_map["decision_type"], DECISION_TYPE_DTYPE),
"left_child": parse(feats_map["left_child"], LEFT_CHILD_DTYPE),
"right_child": parse(feats_map["right_child"], RIGHT_CHILD_DTYPE),
"leaf_value": parse(feats_map["leaf_value"], LEAF_VALUE_DTYPE),
"is_linear": int(feats_map["is_linear"][0]),
"shrinkage": float(feats_map["shrinkage"][0]),
}

# if tree is linear, add additional features
if int(feats_map["is_linear"][0]):
# attributes: leaf_features, leaf_coeff, leaf_const, num_features
# TODO: not all of these attributes might be needed.
tree_values["num_features"] = parse(feats_map["num_features"], np.int32)
tree_values["leaf_const"] = parse(feats_map["leaf_const"], LEAF_VALUE_DTYPE)
tree_values["leaf_features"] = parse(
[s if s else -1 for s in feats_map["leaf_features"]],
np.int16,
)
tree_values["leaf_coeff"] = parse(
[s if s else None for s in feats_map["leaf_coeff"]], np.float64
)

# at last
trees.append(tree_values)

def parse(str_list, dtype):
if np.can_cast(dtype, np.int64):
int64_array = np.array(str_list, dtype=np.int64)
return safe_cast(int64_array, dtype)
assert np.can_cast(dtype, np.float64)
return np.array(str_list, dtype=dtype)

split_feature_dtype = np.int16
threshold_dtype = np.float64
decision_type_dtype = np.int8
left_child_dtype = np.int16
right_child_dtype = left_child_dtype
leaf_value_dtype = np.float64
assert len(feats_map["num_leaves"]) == 1
assert len(feats_map["num_cat"]) == 1
assert len(feats_map["is_linear"]) == 1
assert len(feats_map["shrinkage"]) == 1

trees.append(
{
"num_leaves": int(feats_map["num_leaves"][0]),
"num_cat": int(feats_map["num_cat"][0]),
"split_feature": parse(feats_map["split_feature"], split_feature_dtype),
"threshold": compress_half_int_float_array(
parse(feats_map["threshold"], threshold_dtype)
),
"decision_type": parse(feats_map["decision_type"], decision_type_dtype),
"left_child": parse(feats_map["left_child"], left_child_dtype),
"right_child": parse(feats_map["right_child"], right_child_dtype),
"leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype),
"is_linear": int(feats_map["is_linear"][0]),
"shrinkage": float(feats_map["shrinkage"][0]),
}
)
return front_str, trees, back_str


def _validate_tree_structure(tree: dict) -> bool:
return type(tree) == dict and tree.keys() == {
"num_leaves",
"num_cat",
"split_feature",
"threshold",
"decision_type",
"left_child",
"right_child",
"leaf_value",
"is_linear",
"shrinkage",
}


def _decompress_booster_handle(compressed_state: Tuple[str, List[dict], str]) -> str:
front_str, trees, back_str = compressed_state
assert type(front_str) == str
assert type(trees) == list
assert type(back_str) == str

handle = front_str
# front_str += "tree_sizes=" + " ".join(["0" for t in trees]) + "\n"
YYYasin19 marked this conversation as resolved.
Show resolved Hide resolved

handle = front_str
for i, tree in enumerate(trees):
assert type(tree) == dict
assert tree.keys() == {
"num_leaves",
"num_cat",
"split_feature",
"threshold",
"decision_type",
"left_child",
"right_child",
"leaf_value",
"is_linear",
"shrinkage",
}
_validate_tree_structure(tree)
is_linear = int(tree["is_linear"])

num_leaves = len(tree["leaf_value"])
num_nodes = len(tree["split_feature"])
Expand All @@ -183,9 +225,26 @@ def _decompress_booster_handle(compressed_state: Tuple[str, List[dict], str]) ->
tree_str += "\ninternal_weight=" + ("0 " * num_nodes)[:-1]
tree_str += "\ninternal_count=" + ("0 " * num_nodes)[:-1]
tree_str += f"\nis_linear={tree['is_linear']}"
tree_str += f"\nshrinkage={tree['shrinkage']}"
tree_str += "\n\n\n"
if is_linear:
# TODO: attributes: leaf_features, leaf_coeff, leaf_const, num_features
tree_str += f"\nleaf_const={' '.join(str(x) for x in tree['leaf_const'])}"
YYYasin19 marked this conversation as resolved.
Show resolved Hide resolved
tree_str += (
f"\nnum_features={' '.join(str(x) for x in tree['num_features'])}"
)
tree_str += (
f"\nleaf_features"
f"={' '.join(['' if f == -1 else str(int(f)) for f in tree['leaf_features']])}"
)

tree_str += (
f"\nleaf_coeff"
f"={' '.join(['' if np.isnan(f) else str(f) for f in tree['leaf_coeff']])}"
)

tree_str += f"\nshrinkage={int(tree['shrinkage'])}"

tree_str += "\n\n\n"
handle += tree_str

handle += back_str
return handle
37 changes: 35 additions & 2 deletions tests/test_lgbm_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
)

from slim_trees import dump_lgbm_compressed
from slim_trees.lgbm_booster import _booster_pickle
from slim_trees.lgbm_booster import (
_booster_pickle,
_compress_booster_state,
_decompress_booster_state,
)
from slim_trees.pickling import dump_compressed, load_compressed


Expand All @@ -20,7 +24,36 @@ def lgbm_regressor(rng):
return LGBMRegressor(random_state=rng)


def test_compresed_predictions(diabetes_toy_df, lgbm_regressor, tmp_path):
@pytest.fixture
def lgbm_regressor_linear(rng):
return LGBMRegressor(random_state=rng, linear_trees=True)


@pytest.mark.xfail(reason="reconstructed model string is not expected to be equal")
def test_model_string_equality(diabetes_toy_df, lgbm_regressor):
"""
This test should fail because the model string
will not be equivalent since we're omitting values.
Nevertheless, for required features, some string rows should be equal.
This helps in debugging that.
pavelzw marked this conversation as resolved.
Show resolved Hide resolved
"""
# lgbm_regressor = lgbm_regressor_linear
YYYasin19 marked this conversation as resolved.
Show resolved Hide resolved
lgbm_regressor.fit(*diabetes_toy_df)

# get regular model string
model_str = lgbm_regressor.booster_.__reduce__()[2]["handle"]

# get our reconstructed model string
compressed_state = _compress_booster_state(lgbm_regressor.booster_.__reduce__()[2])
reconstructed_model_str = _decompress_booster_state(compressed_state)["handle"]

assert model_str == reconstructed_model_str


@pytest.mark.parametrize(
"lgbm_regressor", [lgbm_regressor, lgbm_regressor_linear], indirect=True
)
def test_compressed_predictions(diabetes_toy_df, lgbm_regressor, tmp_path):
X, y = diabetes_toy_df
lgbm_regressor.fit(X, y)

Expand Down