Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: regex-free evaluation, first try #11

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions pickle_compression/lgbm_booster.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copyreg
import os
import pickle
import re
import sys
from typing import Any, BinaryIO, List, Tuple

Expand Down Expand Up @@ -61,28 +60,80 @@ def _decompress_booster_state(compressed_state: dict):
def _compress_booster_handle(model_string: str) -> Tuple[str, List[dict], str]:
if not model_string.startswith("tree\nversion=v3"):
raise ValueError("Only v3 is supported for the booster string format.")
FRONT_STRING_REGEX = r"(?:\w+(?:=.*)?\n)*\n(?=Tree)"
BACK_STRING_REGEX = r"end of trees(?:\n)+(?:.|\n)*"
TREE_GROUP_REGEX = r"(Tree=\d+\n+)((?:.+\n)*)\n\n"

def _extract_feature(feature_line):
feat_name, values_str = feature_line.split("=")
return feat_name, values_str.split(" ")

def parse(str_list, dtype):
return np.array(str_list, dtype=dtype)

def _extract_tree(tree_lines) -> dict:
# features_list = "\n".join(tree_lines[1:])
features = [f for f in tree_lines[1:] if "=" in f]
feats_map = dict(_extract_feature(fl) for fl in features)

# TODO: why is this even here?
split_feature_dtype = np.int16
threshold_dtype = np.float64
decision_type_dtype = np.int8
left_child_dtype = np.int16
right_child_dtype = left_child_dtype
leaf_value_dtype = np.float64

assert len(feats_map["num_leaves"]) == 1
assert len(feats_map["num_cat"]) == 1
assert len(feats_map["is_linear"]) == 1
assert len(feats_map["shrinkage"]) == 1

return {
"num_leaves": int(feats_map["num_leaves"][0]),
"num_cat": int(feats_map["num_cat"][0]),
"split_feature": parse(feats_map["split_feature"], split_feature_dtype),
"threshold": compress_half_int_float_array(
parse(feats_map["threshold"], threshold_dtype)
),
"decision_type": parse(feats_map["decision_type"], decision_type_dtype),
"left_child": parse(feats_map["left_child"], left_child_dtype),
"right_child": parse(feats_map["right_child"], right_child_dtype),
"leaf_value": parse(feats_map["leaf_value"], leaf_value_dtype),
"is_linear": int(feats_map["is_linear"][0]),
"shrinkage": float(feats_map["shrinkage"][0]),
}

model_lines = np.array(model_string.split("\n"))
front_str = "\n".join(model_lines[:9]) + "\n\n"
# ; front_str_alt = re.sub(r"tree_sizes=(?:\d+ )*\d+\n", "", front_str_alt)
back_str = ""
trees = []
idx = 11
while idx < len(model_lines):
line = model_lines[idx]
if line.startswith("Tree="):
trees.append(_extract_tree(model_lines[idx : idx + 18]))
idx += 19
elif line.startswith("end of trees"):
back_str = "\n".join(model_lines[idx:])
break # finished, rest is back str
else:
idx += 1 # oops, something wrong

"""
front_str_match = re.search(FRONT_STRING_REGEX, model_string)
if front_str_match is None:
raise ValueError("Could not find front string.")
front_str = front_str_match.group()
# delete tree_sizes line since this messes up the tree parsing by LightGBM if not set correctly
# todo calculate correct tree_sizes
front_str = re.sub(r"tree_sizes=(?:\d+ )*\d+\n", "", front_str)
front_str = re.sub(r"tree_sizes=(?:\\d+ )*\\d+\n", "", front_str)

back_str_match = re.search(BACK_STRING_REGEX, model_string)
if back_str_match is None:
raise ValueError("Could not find back string.")
back_str = back_str_match.group()
tree_matches = re.findall(TREE_GROUP_REGEX, model_string)
trees: List[dict] = []

for i, tree_match in enumerate(tree_matches):
tree_name, features_list = tree_match
_, tree_idx = tree_name.replace("\n", "").split("=")
Expand Down Expand Up @@ -122,6 +173,7 @@ def parse(str_list, dtype):
"shrinkage": float(feats_map["shrinkage"][0]),
}
)
"""
return front_str, trees, back_str


Expand Down