Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: refactoring, attributes to parse linear trees as well #50

Merged
merged 5 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 112 additions & 59 deletions slim_trees/lgbm_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,62 @@ def _decompress_booster_state(compressed_state: dict):
return state


FRONT_STRING_REGEX = r"(?:\w+(?:=.*)?\n)*\n(?=Tree)"
BACK_STRING_REGEX = r"end of trees(?:\n)+(?:.|\n)*"
TREE_GROUP_REGEX = r"(Tree=\d+\n+)((?:.+\n)*)\n\n"

SPLIT_FEATURE_DTYPE = np.int16
THRESHOLD_DTYPE = np.float64
DECISION_TYPE_DTYPE = np.int8
LEFT_CHILD_DTYPE = np.int16
RIGHT_CHILD_DTYPE = LEFT_CHILD_DTYPE
LEAF_VALUE_DTYPE = np.float64


def _extract_feature(feature_line: str) -> Tuple[str, List[str]]:
feat_name, values_str = feature_line.split("=")
return feat_name, values_str.split(" ")


def _validate_feature_lengths(feats_map: dict):
# features on tree-level
assert len(feats_map["num_leaves"]) == 1
assert len(feats_map["num_cat"]) == 1
assert len(feats_map["is_linear"]) == 1
assert len(feats_map["shrinkage"]) == 1

# features on node-level
num_leaves = int(feats_map["num_leaves"][0])
assert len(feats_map["split_feature"]) == num_leaves - 1
assert len(feats_map["threshold"]) == num_leaves - 1
assert len(feats_map["decision_type"]) == num_leaves - 1
assert len(feats_map["left_child"]) == num_leaves - 1
assert len(feats_map["right_child"]) == num_leaves - 1

# features on leaf-level
num_leaves = int(feats_map["num_leaves"][0])
assert len(feats_map["leaf_value"]) == num_leaves


def parse(str_list, dtype):
if np.can_cast(dtype, np.int64):
int64_array = np.array(str_list, dtype=np.int64)
return safe_cast(int64_array, dtype)
assert np.can_cast(dtype, np.float64)
return np.array(str_list, dtype=dtype)


def _compress_booster_handle(model_string: str) -> Tuple[str, List[dict], str]:
if not model_string.startswith("tree\nversion=v3"):
raise ValueError("Only v3 is supported for the booster string format.")
FRONT_STRING_REGEX = r"(?:\w+(?:=.*)?\n)*\n(?=Tree)"
BACK_STRING_REGEX = r"end of trees(?:\n)+(?:.|\n)*"
TREE_GROUP_REGEX = r"(Tree=\d+\n+)((?:.+\n)*)\n\n"

def _extract_feature(feature_line):
feat_name, values_str = feature_line.split("=")
return feat_name, values_str.split(" ")

front_str_match = re.search(FRONT_STRING_REGEX, model_string)
if front_str_match is None:
raise ValueError("Could not find front string.")
front_str = front_str_match.group()

# delete tree_sizes line since this messes up the tree parsing by LightGBM if not set correctly
# todo calculate correct tree_sizes
front_str = re.sub(r"tree_sizes=(?:\d+ )*\d+\n", "", front_str)
front_str = re.sub(r"tree_sizes=(?:\d+ )*\d+\n", "", front_str_match.group())

back_str_match = re.search(BACK_STRING_REGEX, model_string)
if back_str_match is None:
Expand All @@ -101,66 +139,68 @@ def _extract_feature(feature_line):
# extract features -- filter out empty ones
features = [f for f in features_list.split("\n") if "=" in f]
feats_map = dict(_extract_feature(fl) for fl in features)
_validate_feature_lengths(feats_map)

tree_values = {
"num_leaves": int(feats_map["num_leaves"][0]),
"num_cat": int(feats_map["num_cat"][0]),
"split_feature": parse(feats_map["split_feature"], SPLIT_FEATURE_DTYPE),
"threshold": compress_half_int_float_array(
parse(feats_map["threshold"], THRESHOLD_DTYPE)
),
"decision_type": parse(feats_map["decision_type"], DECISION_TYPE_DTYPE),
"left_child": parse(feats_map["left_child"], LEFT_CHILD_DTYPE),
"right_child": parse(feats_map["right_child"], RIGHT_CHILD_DTYPE),
"leaf_value": parse(feats_map["leaf_value"], LEAF_VALUE_DTYPE),
"is_linear": int(feats_map["is_linear"][0]),
"shrinkage": float(feats_map["shrinkage"][0]),
}

# if tree is linear, add additional features
if int(feats_map["is_linear"][0]):
# attributes: leaf_features, leaf_coeff, leaf_const, num_features
# TODO: not all of these attributes might be needed.
tree_values["num_features"] = parse(feats_map["num_features"], np.int32)
tree_values["leaf_const"] = parse(feats_map["leaf_const"], LEAF_VALUE_DTYPE)
tree_values["leaf_features"] = parse(
[s if s else -1 for s in feats_map["leaf_features"]],
np.int16,
)
tree_values["leaf_coeff"] = parse(
[s if s else None for s in feats_map["leaf_coeff"]], np.float64
)

# at last
trees.append(tree_values)

def parse(str_list, dtype):
if np.can_cast(dtype, np.int64):
int64_array = np.array(str_list, dtype=np.int64)
return safe_cast(int64_array, dtype)
assert np.can_cast(dtype, np.float64)
return np.array(str_list, dtype=dtype)

split_feature_dtype = np.int16
threshold_dtype = np.float64
decision_type_dtype = np.int8
left_child_dtype = np.int16
right_child_dtype = left_child_dtype
leaf_value_dtype = np.float64
assert len(feats_map["num_leaves"]) == 1
assert len(feats_map["num_cat"]) == 1
assert len(feats_map["is_linear"]) == 1
assert len(feats_map["shrinkage"]) == 1

trees.append(
{
"num_leaves": int(feats_map["num_leaves"][0]),
"num_cat": int(feats_map["num_cat"][0]),
"split_feature": parse(feats_map["split_feature"], split_feature_dtype),
"threshold": compress_half_int_float_array(
parse(feats_map["threshold"], threshold_dtype)
),
"decision_type": parse(feats_map["decision_type"], decision_type_dtype),
"left_child": parse(feats_map["left_child"], left_child_dtype),
"right_child": parse(feats_map["right_child"], right_child_dtype),
"leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype),
"is_linear": int(feats_map["is_linear"][0]),
"shrinkage": float(feats_map["shrinkage"][0]),
}
)
return front_str, trees, back_str


def _validate_tree_structure(tree: dict) -> bool:
return type(tree) == dict and tree.keys() == {
"num_leaves",
"num_cat",
"split_feature",
"threshold",
"decision_type",
"left_child",
"right_child",
"leaf_value",
"is_linear",
"shrinkage",
}


def _decompress_booster_handle(compressed_state: Tuple[str, List[dict], str]) -> str:
front_str, trees, back_str = compressed_state
assert type(front_str) == str
assert type(trees) == list
assert type(back_str) == str

handle = front_str

for i, tree in enumerate(trees):
assert type(tree) == dict
assert tree.keys() == {
"num_leaves",
"num_cat",
"split_feature",
"threshold",
"decision_type",
"left_child",
"right_child",
"leaf_value",
"is_linear",
"shrinkage",
}
_validate_tree_structure(tree)
is_linear = int(tree["is_linear"])

num_leaves = len(tree["leaf_value"])
num_nodes = len(tree["split_feature"])
Expand All @@ -183,9 +223,22 @@ def _decompress_booster_handle(compressed_state: Tuple[str, List[dict], str]) ->
tree_str += "\ninternal_weight=" + ("0 " * num_nodes)[:-1]
tree_str += "\ninternal_count=" + ("0 " * num_nodes)[:-1]
tree_str += f"\nis_linear={tree['is_linear']}"
tree_str += f"\nshrinkage={tree['shrinkage']}"
tree_str += "\n\n\n"
if is_linear:
tree_str += "\nleaf_const=" + " ".join(str(x) for x in tree["leaf_const"])
tree_str += "\nnum_features=" + " ".join(
str(x) for x in tree["num_features"]
)
tree_str += "\nleaf_features=" + " ".join(
"" if f == -1 else str(int(f)) for f in tree["leaf_features"]
)
tree_str += "\nleaf_coeff=" + " ".join(
"" if np.isnan(f) else str(f) for f in tree["leaf_coeff"]
)

tree_str += f"\nshrinkage={int(tree['shrinkage'])}"

tree_str += "\n\n\n"
handle += tree_str

handle += back_str
return handle
10 changes: 9 additions & 1 deletion tests/test_lgbm_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@ def lgbm_regressor(rng):
return LGBMRegressor(random_state=rng)


def test_compresed_predictions(diabetes_toy_df, lgbm_regressor, tmp_path):
@pytest.fixture
def lgbm_regressor_linear(rng):
return LGBMRegressor(random_state=rng, linear_trees=True)


@pytest.mark.parametrize(
"lgbm_regressor", [lgbm_regressor, lgbm_regressor_linear], indirect=True
)
def test_compressed_predictions(diabetes_toy_df, lgbm_regressor, tmp_path):
X, y = diabetes_toy_df
lgbm_regressor.fit(X, y)

Expand Down