diff --git a/model_parser.py b/model_parser.py index 3b4d9b0..1b2d786 100644 --- a/model_parser.py +++ b/model_parser.py @@ -1,7 +1,7 @@ import pickle -from curses.ascii import isdigit import re from typing import Union + import pandas as pd import pyarrow as pa from pyarrow import parquet as pq @@ -76,7 +76,7 @@ def pyarrow_table_to_bytes(table: pa.Table) -> bytes: if __name__ == "__main__": # df = df_from_model_string(open("test_model_string.model", "r").read()) - df = df_from_model_string(open("private_/lgb1.txt", "r").read()) + df = df_from_model_string(open("private_/lgb1.txt").read()) # TODO: cannot be interpreted directly by parquet (pyarrow & fastparquet) as of now. # dfo = df_from_model_string(open("lgb1.txt", "r").read(), transform_values=True) print(df.head(50)) diff --git a/slim_trees/lgbm_booster.py b/slim_trees/lgbm_booster.py index e9f7b9a..045c718 100644 --- a/slim_trees/lgbm_booster.py +++ b/slim_trees/lgbm_booster.py @@ -4,16 +4,12 @@ import re import sys from typing import Any, BinaryIO, List, Tuple + +import numpy as np import pandas as pd import pyarrow as pa -import pyarrow.parquet as pq -import numpy as np -from slim_trees.compression import ( - compress_half_int_float_array, - decompress_half_int_float_array, -) -from slim_trees.utils import pyarrow_table_to_bytes, pq_bytes_to_df +from slim_trees.utils import pq_bytes_to_df, pyarrow_table_to_bytes try: from lightgbm.basic import Booster @@ -117,31 +113,39 @@ def parse(str_list, dtype): but one of them can be "exploded" later (node level) while the tree level is for meta information """ - trees.append({ - "tree_idx": int(tree_idx), - "num_leaves": int(feats_map["num_leaves"][0]), - "num_cat": int(feats_map["num_cat"][0]), - # "last_leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype)[-1], - "is_linear": int(feats_map["is_linear"][0]), - "shrinkage": float(feats_map["shrinkage"][0]), - }) - - node_features.append({ - "tree_idx": int(tree_idx), # TODO: this is new, have to recover this as well - "node_idx": list(range(int(feats_map["num_leaves"][0]) - 1)), - # all the upcoming attributes have length num_leaves - 1 - "split_feature": parse(feats_map["split_feature"], split_feature_dtype), - "threshold": parse(feats_map["threshold"], threshold_dtype), - "decision_type": parse(feats_map["decision_type"], decision_type_dtype), - "left_child": parse(feats_map["left_child"], left_child_dtype), - "right_child": parse(feats_map["right_child"], right_child_dtype), - # "leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype)[:-1], - }) - - leaf_values.append({ - "tree_idx": int(tree_idx), - "leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype), - }) + trees.append( + { + "tree_idx": int(tree_idx), + "num_leaves": int(feats_map["num_leaves"][0]), + "num_cat": int(feats_map["num_cat"][0]), + # "last_leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype)[-1], + "is_linear": int(feats_map["is_linear"][0]), + "shrinkage": float(feats_map["shrinkage"][0]), + } + ) + + node_features.append( + { + "tree_idx": int( + tree_idx + ), # TODO: this is new, have to recover this as well + "node_idx": list(range(int(feats_map["num_leaves"][0]) - 1)), + # all the upcoming attributes have length num_leaves - 1 + "split_feature": parse(feats_map["split_feature"], split_feature_dtype), + "threshold": parse(feats_map["threshold"], threshold_dtype), + "decision_type": parse(feats_map["decision_type"], decision_type_dtype), + "left_child": parse(feats_map["left_child"], left_child_dtype), + "right_child": parse(feats_map["right_child"], right_child_dtype), + # "leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype)[:-1], + } + ) + + leaf_values.append( + { + "tree_idx": int(tree_idx), + "leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype), + } + ) trees_df = pd.DataFrame(trees) trees_table = pa.Table.from_pandas(trees_df) @@ -170,17 +174,27 @@ def parse(str_list, dtype): return front_str, trees_df_bytes, nodes_df_bytes, leaf_values_bytes, back_str -def _decompress_booster_handle(compressed_state: Tuple[str, bytes, bytes, bytes, str]) -> str: - front_str, trees_df_bytes, nodes_df_bytes, leaf_value_bytes, back_str = compressed_state +def _decompress_booster_handle( + compressed_state: Tuple[str, bytes, bytes, bytes, str] +) -> str: + ( + front_str, + trees_df_bytes, + nodes_df_bytes, + leaf_value_bytes, + back_str, + ) = compressed_state assert type(front_str) == str assert type(back_str) == str trees_df = pq_bytes_to_df(trees_df_bytes) - nodes_df = pq_bytes_to_df(nodes_df_bytes).groupby('tree_idx').agg(lambda x: list(x)) - leaf_values_df = pq_bytes_to_df(leaf_value_bytes).groupby('tree_idx')['leaf_value'].apply(list) + nodes_df = pq_bytes_to_df(nodes_df_bytes).groupby("tree_idx").agg(lambda x: list(x)) + leaf_values_df = ( + pq_bytes_to_df(leaf_value_bytes).groupby("tree_idx")["leaf_value"].apply(list) + ) # merge trees_df, nodes_df, and leaf_values_df on tree_idx - trees_df = trees_df.merge(nodes_df, on='tree_idx') - trees_df = trees_df.merge(leaf_values_df, on='tree_idx') + trees_df = trees_df.merge(nodes_df, on="tree_idx") + trees_df = trees_df.merge(leaf_values_df, on="tree_idx") handle = front_str @@ -206,10 +220,14 @@ def _decompress_booster_handle(compressed_state: Tuple[str, bytes, bytes, bytes, num_nodes = num_leaves - 1 # len(tree["split_feature"]) tree_str = f"Tree={i}\n" - tree_str += f"num_leaves={num_leaves}\nnum_cat={tree['num_cat']}\nsplit_feature=" + tree_str += ( + f"num_leaves={num_leaves}\nnum_cat={tree['num_cat']}\nsplit_feature=" + ) tree_str += " ".join([str(x) for x in tree["split_feature"]]) tree_str += "\nsplit_gain=" + ("0 " * num_nodes)[:-1] - threshold = tree["threshold"] # decompress_half_int_float_array(tree["threshold"]) + threshold = tree[ + "threshold" + ] # decompress_half_int_float_array(tree["threshold"]) tree_str += "\nthreshold=" + " ".join([str(x) for x in threshold]) tree_str += "\ndecision_type=" + " ".join( [str(x) for x in tree["decision_type"]] @@ -238,7 +256,7 @@ def compress_handle_parquet(trees: List[dict]) -> bytes: # step 1: turn features into pyarrow arrays # loop over all tree dicts in trees and create one dict per node with all features - for tree in trees: + for _tree in trees: pass # step 2: create pyarrow table # step 3: write table to parquet diff --git a/slim_trees/utils.py b/slim_trees/utils.py index 69a81fd..374005f 100644 --- a/slim_trees/utils.py +++ b/slim_trees/utils.py @@ -1,8 +1,8 @@ import io +import pandas as pd import pyarrow as pa import pyarrow.parquet as pq -import pandas as pd def pyarrow_table_to_bytes(table: pa.Table) -> bytes: @@ -10,7 +10,9 @@ def pyarrow_table_to_bytes(table: pa.Table) -> bytes: Given a pyarrow Table, return a .parquet file as bytes. """ stream = pa.BufferOutputStream() - pq.write_table(table, stream, compression="lz4") # TODO: investigate different effects of compression + pq.write_table( + table, stream, compression="lz4" + ) # TODO: investigate different effects of compression return stream.getvalue().to_pybytes()