Skip to content

Commit

Permalink
store leaf_values seperately
Browse files Browse the repository at this point in the history
  • Loading branch information
YYYasin19 committed Mar 5, 2023
1 parent 9461eca commit 4fb17c4
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions slim_trees/lgbm_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _decompress_booster_state(compressed_state: dict):
return state


def _compress_booster_handle(model_string: str) -> Tuple[str, bytes, bytes, str]:
def _compress_booster_handle(model_string: str) -> Tuple[str, bytes, bytes, bytes, str]:
if not model_string.startswith("tree\nversion=v3"):
raise ValueError("Only v3 is supported for the booster string format.")
FRONT_STRING_REGEX = r"(?:\w+(?:=.*)?\n)*\n(?=Tree)"
Expand All @@ -85,7 +85,8 @@ def _extract_feature(feature_line):
raise ValueError("Could not find back string.")
back_str = back_str_match.group()
tree_matches = re.findall(TREE_GROUP_REGEX, model_string)
nodes: List[dict] = []
node_features: List[dict] = []
leaf_values: List[dict] = []
trees: List[dict] = []
for i, tree_match in enumerate(tree_matches):
tree_name, features_list = tree_match
Expand Down Expand Up @@ -116,17 +117,16 @@ def parse(str_list, dtype):
but one of them can be "exploded" later (node level) while the tree level is for meta information
"""

tree = {
trees.append({
"tree_idx": int(tree_idx),
"num_leaves": int(feats_map["num_leaves"][0]),
"num_cat": int(feats_map["num_cat"][0]),
"last_leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype)[-1],
# "last_leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype)[-1],
"is_linear": int(feats_map["is_linear"][0]),
"is_shrinkage": float(feats_map["shrinkage"][0]),
}
trees.append(tree)
})

node = {
node_features.append({
"tree_idx": int(tree_idx), # TODO: this is new, have to recover this as well
"node_idx": list(range(int(feats_map["num_leaves"][0]) - 1)),
# all the upcoming attributes have length num_leaves - 1
Expand All @@ -135,16 +135,20 @@ def parse(str_list, dtype):
"decision_type": parse(feats_map["decision_type"], decision_type_dtype),
"left_child": parse(feats_map["left_child"], left_child_dtype),
"right_child": parse(feats_map["right_child"], right_child_dtype),
"leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype)[:-1],
}
nodes.append(node)
# "leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype)[:-1],
})

leaf_values.append({
"tree_idx": int(tree_idx),
"leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype),
})

trees_df = pd.DataFrame(trees)
trees_table = pa.Table.from_pandas(trees_df)
trees_df_bytes = pyarrow_table_to_bytes(trees_table)

# transform nodes_df s.t. each feature is a column
nodes_df = pd.DataFrame(nodes)
nodes_df = pd.DataFrame(node_features)
nodes_df = nodes_df.explode(
[
"node_idx",
Expand All @@ -153,18 +157,21 @@ def parse(str_list, dtype):
"decision_type",
"left_child",
"right_child",
"leaf_value",
# "leaf_value",
]
)
nodes_table = pa.Table.from_pandas(nodes_df)

nodes_df_bytes = pyarrow_table_to_bytes(nodes_table)
leaf_values_df = pd.DataFrame(leaf_values).explode(["leaf_value"])
leaf_values_table = pa.Table.from_pandas(leaf_values_df)
leaf_values_bytes = pyarrow_table_to_bytes(leaf_values_table)

return front_str, trees_df_bytes, nodes_df_bytes, back_str
return front_str, trees_df_bytes, nodes_df_bytes, leaf_values_bytes, back_str


def _decompress_booster_handle(compressed_state: Tuple[str, bytes, bytes, str]) -> str:
front_str, trees_df_bytes, nodes_df_bytes, back_str = compressed_state
def _decompress_booster_handle(compressed_state: Tuple[str, bytes, bytes, bytes, str]) -> str:
front_str, trees_df_bytes, nodes_df_bytes, leaf_value_bytes, back_str = compressed_state
assert type(front_str) == str
# assert type(trees) == list
assert type(back_str) == str
Expand Down

0 comments on commit 4fb17c4

Please sign in to comment.