Skip to content

Commit

Permalink
add: uncompress tree
Browse files Browse the repository at this point in the history
  • Loading branch information
YYYasin19 committed Mar 5, 2023
1 parent 4fb17c4 commit ef9933c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
22 changes: 14 additions & 8 deletions slim_trees/lgbm_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def parse(str_list, dtype):
"num_cat": int(feats_map["num_cat"][0]),
# "last_leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype)[-1],
"is_linear": int(feats_map["is_linear"][0]),
"is_shrinkage": float(feats_map["shrinkage"][0]),
"shrinkage": float(feats_map["shrinkage"][0]),
})

node_features.append({
Expand Down Expand Up @@ -173,15 +173,20 @@ def parse(str_list, dtype):
def _decompress_booster_handle(compressed_state: Tuple[str, bytes, bytes, bytes, str]) -> str:
front_str, trees_df_bytes, nodes_df_bytes, leaf_value_bytes, back_str = compressed_state
assert type(front_str) == str
# assert type(trees) == list
assert type(back_str) == str

trees_df = pq_bytes_to_df(trees_df_bytes)
nodes_df = pq_bytes_to_df(nodes_df_bytes)
nodes_df = pq_bytes_to_df(nodes_df_bytes).groupby('tree_idx').agg(lambda x: list(x))
leaf_values_df = pq_bytes_to_df(leaf_value_bytes).groupby('tree_idx')['leaf_value'].apply(list)
# merge trees_df, nodes_df, and leaf_values_df on tree_idx
trees_df = trees_df.merge(nodes_df, on='tree_idx')
trees_df = trees_df.merge(leaf_values_df, on='tree_idx')

handle = front_str

# TODO: directly go over trees and nodes
for i, tree in enumerate(trees_df):
for i, tree in trees_df.iterrows():
"""
assert type(tree) == dict
assert tree.keys() == {
"num_leaves",
Expand All @@ -195,15 +200,16 @@ def _decompress_booster_handle(compressed_state: Tuple[str, bytes, bytes, bytes,
"is_linear",
"shrinkage",
}
"""

num_leaves = len(tree["leaf_value"])
num_nodes = len(tree["split_feature"])
num_leaves = int(tree["num_leaves"])
num_nodes = num_leaves - 1 # len(tree["split_feature"])

tree_str = f"Tree={i}\n"
tree_str += f"num_leaves={tree['num_leaves']}\nnum_cat={tree['num_cat']}\nsplit_feature="
tree_str += f"num_leaves={num_leaves}\nnum_cat={tree['num_cat']}\nsplit_feature="
tree_str += " ".join([str(x) for x in tree["split_feature"]])
tree_str += "\nsplit_gain=" + ("0 " * num_nodes)[:-1]
threshold = decompress_half_int_float_array(tree["threshold"])
threshold = tree["threshold"] # decompress_half_int_float_array(tree["threshold"])
tree_str += "\nthreshold=" + " ".join([str(x) for x in threshold])
tree_str += "\ndecision_type=" + " ".join(
[str(x) for x in tree["decision_type"]]
Expand Down
7 changes: 3 additions & 4 deletions slim_trees/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import io

import pyarrow as pa
import pyarrow.parquet as pq
import pandas as pd
Expand All @@ -16,7 +18,4 @@ def pq_bytes_to_df(bytes_: bytes) -> pd.DataFrame:
"""
Given a .parquet file as bytes, return a pandas DataFrame.
"""
stream = pa.BufferReader(bytes_)
reader = pa.RecordBatchStreamReader(stream)
table = reader.read_all()
return table.to_pandas()
return pd.read_parquet(io.BytesIO(bytes_))

0 comments on commit ef9933c

Please sign in to comment.