Skip to content

Commit

Permalink
cimport _build pruned tree
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Jun 21, 2024
1 parent f0f69be commit d455aa1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 34 deletions.
45 changes: 11 additions & 34 deletions sklearn/tree/_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ cdef class Splitter(BaseSplitter):
This is typically a metric that is cheaply computed given the
current proposed split, which is stored as a the `current_split`
argument.
Returns 1 if not a valid split, and 0 if it is.
"""
cdef intp_t min_samples_leaf = self.min_samples_leaf
cdef intp_t end_non_missing = self.end - n_missing
Expand Down Expand Up @@ -418,8 +420,6 @@ cdef inline intp_t node_split_best(
Criterion criterion,
SplitRecord* split,
ParentInfo* parent_record,
# bint with_monotonic_cst,
# const int8_t[:] monotonic_cst,
) except -1 nogil:
"""Find the best split on node samples[start:end]
Expand Down Expand Up @@ -566,25 +566,7 @@ cdef inline intp_t node_split_best(

current_split.pos = p

# Reject if monotonicity constraints are not satisfied
if (
with_monotonic_cst and
monotonic_cst[current_split.feature] != 0 and
not criterion.check_monotonicity(
monotonic_cst[current_split.feature],
lower_bound,
upper_bound,
)
):
continue

# Reject if min_samples_leaf is not guaranteed
if missing_go_to_left:
n_left = current_split.pos - splitter.start + n_missing
n_right = end_non_missing - current_split.pos
else:
n_left = current_split.pos - splitter.start
n_right = end_non_missing - current_split.pos + n_missing
if splitter.check_presplit_conditions(&current_split, n_missing, missing_go_to_left) == 1:
continue

Expand Down Expand Up @@ -624,6 +606,13 @@ cdef inline intp_t node_split_best(

current_split.n_missing = n_missing
if n_missing == 0:
if missing_go_to_left:
n_left = current_split.pos - splitter.start + n_missing
n_right = end_non_missing - current_split.pos
else:
n_left = current_split.pos - splitter.start
n_right = end_non_missing - current_split.pos + n_missing

current_split.missing_go_to_left = n_left > n_right
else:
current_split.missing_go_to_left = missing_go_to_left
Expand Down Expand Up @@ -938,10 +927,6 @@ cdef inline int node_split_random(
criterion.reset()
criterion.update(current_split.pos)

# Reject if min_weight_leaf is not satisfied
if splitter.check_postsplit_conditions() == 1:
continue

# Reject if monotonicity constraints are not satisfied
if (
with_monotonic_cst and
Expand All @@ -954,16 +939,8 @@ cdef inline int node_split_random(
):
continue

# Reject if monotonicity constraints are not satisfied
if (
with_monotonic_cst and
monotonic_cst[current_split.feature] != 0 and
not criterion.check_monotonicity(
monotonic_cst[current_split.feature],
lower_bound,
upper_bound,
)
):
# Reject if min_weight_leaf is not satisfied
if splitter.check_postsplit_conditions() == 1:
continue

current_proxy_improvement = criterion.proxy_impurity_improvement()
Expand Down
8 changes: 8 additions & 0 deletions sklearn/tree/_tree.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,11 @@ cdef class TreeBuilder:
const float64_t[:, ::1] y,
const float64_t[:] sample_weight,
)


cdef _build_pruned_tree(
Tree tree, # OUT
Tree orig_tree,
const unsigned char[:] leaves_in_subtree,
intp_t capacity
)

0 comments on commit d455aa1

Please sign in to comment.