Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
YYYasin19 committed Mar 6, 2023
1 parent ef9933c commit 9db0e33
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 45 deletions.
4 changes: 2 additions & 2 deletions model_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pickle
from curses.ascii import isdigit
import re
from typing import Union

import pandas as pd
import pyarrow as pa
from pyarrow import parquet as pq
Expand Down Expand Up @@ -76,7 +76,7 @@ def pyarrow_table_to_bytes(table: pa.Table) -> bytes:

if __name__ == "__main__":
# df = df_from_model_string(open("test_model_string.model", "r").read())
df = df_from_model_string(open("private_/lgb1.txt", "r").read())
df = df_from_model_string(open("private_/lgb1.txt").read())
# TODO: cannot be interpreted directly by parquet (pyarrow & fastparquet) as of now.
# dfo = df_from_model_string(open("lgb1.txt", "r").read(), transform_values=True)
print(df.head(50))
Expand Down
100 changes: 59 additions & 41 deletions slim_trees/lgbm_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@
import re
import sys
from typing import Any, BinaryIO, List, Tuple

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np

from slim_trees.compression import (
compress_half_int_float_array,
decompress_half_int_float_array,
)
from slim_trees.utils import pyarrow_table_to_bytes, pq_bytes_to_df
from slim_trees.utils import pq_bytes_to_df, pyarrow_table_to_bytes

try:
from lightgbm.basic import Booster
Expand Down Expand Up @@ -117,31 +113,39 @@ def parse(str_list, dtype):
but one of them can be "exploded" later (node level) while the tree level is for meta information
"""

trees.append({
"tree_idx": int(tree_idx),
"num_leaves": int(feats_map["num_leaves"][0]),
"num_cat": int(feats_map["num_cat"][0]),
# "last_leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype)[-1],
"is_linear": int(feats_map["is_linear"][0]),
"shrinkage": float(feats_map["shrinkage"][0]),
})

node_features.append({
"tree_idx": int(tree_idx), # TODO: this is new, have to recover this as well
"node_idx": list(range(int(feats_map["num_leaves"][0]) - 1)),
# all the upcoming attributes have length num_leaves - 1
"split_feature": parse(feats_map["split_feature"], split_feature_dtype),
"threshold": parse(feats_map["threshold"], threshold_dtype),
"decision_type": parse(feats_map["decision_type"], decision_type_dtype),
"left_child": parse(feats_map["left_child"], left_child_dtype),
"right_child": parse(feats_map["right_child"], right_child_dtype),
# "leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype)[:-1],
})

leaf_values.append({
"tree_idx": int(tree_idx),
"leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype),
})
trees.append(
{
"tree_idx": int(tree_idx),
"num_leaves": int(feats_map["num_leaves"][0]),
"num_cat": int(feats_map["num_cat"][0]),
# "last_leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype)[-1],
"is_linear": int(feats_map["is_linear"][0]),
"shrinkage": float(feats_map["shrinkage"][0]),
}
)

node_features.append(
{
"tree_idx": int(
tree_idx
), # TODO: this is new, have to recover this as well
"node_idx": list(range(int(feats_map["num_leaves"][0]) - 1)),
# all the upcoming attributes have length num_leaves - 1
"split_feature": parse(feats_map["split_feature"], split_feature_dtype),
"threshold": parse(feats_map["threshold"], threshold_dtype),
"decision_type": parse(feats_map["decision_type"], decision_type_dtype),
"left_child": parse(feats_map["left_child"], left_child_dtype),
"right_child": parse(feats_map["right_child"], right_child_dtype),
# "leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype)[:-1],
}
)

leaf_values.append(
{
"tree_idx": int(tree_idx),
"leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype),
}
)

trees_df = pd.DataFrame(trees)
trees_table = pa.Table.from_pandas(trees_df)
Expand Down Expand Up @@ -170,17 +174,27 @@ def parse(str_list, dtype):
return front_str, trees_df_bytes, nodes_df_bytes, leaf_values_bytes, back_str


def _decompress_booster_handle(compressed_state: Tuple[str, bytes, bytes, bytes, str]) -> str:
front_str, trees_df_bytes, nodes_df_bytes, leaf_value_bytes, back_str = compressed_state
def _decompress_booster_handle(
compressed_state: Tuple[str, bytes, bytes, bytes, str]
) -> str:
(
front_str,
trees_df_bytes,
nodes_df_bytes,
leaf_value_bytes,
back_str,
) = compressed_state
assert type(front_str) == str
assert type(back_str) == str

trees_df = pq_bytes_to_df(trees_df_bytes)
nodes_df = pq_bytes_to_df(nodes_df_bytes).groupby('tree_idx').agg(lambda x: list(x))
leaf_values_df = pq_bytes_to_df(leaf_value_bytes).groupby('tree_idx')['leaf_value'].apply(list)
nodes_df = pq_bytes_to_df(nodes_df_bytes).groupby("tree_idx").agg(lambda x: list(x))
leaf_values_df = (
pq_bytes_to_df(leaf_value_bytes).groupby("tree_idx")["leaf_value"].apply(list)
)
# merge trees_df, nodes_df, and leaf_values_df on tree_idx
trees_df = trees_df.merge(nodes_df, on='tree_idx')
trees_df = trees_df.merge(leaf_values_df, on='tree_idx')
trees_df = trees_df.merge(nodes_df, on="tree_idx")
trees_df = trees_df.merge(leaf_values_df, on="tree_idx")

handle = front_str

Expand All @@ -206,10 +220,14 @@ def _decompress_booster_handle(compressed_state: Tuple[str, bytes, bytes, bytes,
num_nodes = num_leaves - 1 # len(tree["split_feature"])

tree_str = f"Tree={i}\n"
tree_str += f"num_leaves={num_leaves}\nnum_cat={tree['num_cat']}\nsplit_feature="
tree_str += (
f"num_leaves={num_leaves}\nnum_cat={tree['num_cat']}\nsplit_feature="
)
tree_str += " ".join([str(x) for x in tree["split_feature"]])
tree_str += "\nsplit_gain=" + ("0 " * num_nodes)[:-1]
threshold = tree["threshold"] # decompress_half_int_float_array(tree["threshold"])
threshold = tree[
"threshold"
] # decompress_half_int_float_array(tree["threshold"])
tree_str += "\nthreshold=" + " ".join([str(x) for x in threshold])
tree_str += "\ndecision_type=" + " ".join(
[str(x) for x in tree["decision_type"]]
Expand Down Expand Up @@ -238,7 +256,7 @@ def compress_handle_parquet(trees: List[dict]) -> bytes:

# step 1: turn features into pyarrow arrays
# loop over all tree dicts in trees and create one dict per node with all features
for tree in trees:
for _tree in trees:
pass
# step 2: create pyarrow table
# step 3: write table to parquet
Expand Down
6 changes: 4 additions & 2 deletions slim_trees/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import io

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pandas as pd


def pyarrow_table_to_bytes(table: pa.Table) -> bytes:
"""
Given a pyarrow Table, return a .parquet file as bytes.
"""
stream = pa.BufferOutputStream()
pq.write_table(table, stream, compression="lz4") # TODO: investigate different effects of compression
pq.write_table(
table, stream, compression="lz4"
) # TODO: investigate different effects of compression
return stream.getvalue().to_pybytes()


Expand Down

0 comments on commit 9db0e33

Please sign in to comment.