Skip to content

Commit

Permalink
Simplify cython partition api
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Jun 20, 2024
1 parent ba18c4d commit 08658c6
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions sklearn/tree/_splitter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -418,14 +418,17 @@ cdef inline intp_t node_split_best(
Criterion criterion,
SplitRecord* split,
ParentInfo* parent_record,
bint with_monotonic_cst,
const int8_t[:] monotonic_cst,
# bint with_monotonic_cst,
# const int8_t[:] monotonic_cst,
) except -1 nogil:
"""Find the best split on node samples[start:end]
Returns -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
cdef const int8_t[:] monotonic_cst = splitter.monotonic_cst
cdef bint with_monotonic_cst = splitter.with_monotonic_cst

# Find the best split
cdef intp_t start = splitter.start
cdef intp_t end = splitter.end
Expand Down Expand Up @@ -809,14 +812,15 @@ cdef inline int node_split_random(
Criterion criterion,
SplitRecord* split,
ParentInfo* parent_record,
bint with_monotonic_cst,
const int8_t[:] monotonic_cst,
) except -1 nogil:
"""Find the best random split on node samples[start:end]
Returns -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
cdef const int8_t[:] monotonic_cst = splitter.monotonic_cst
cdef bint with_monotonic_cst = splitter.with_monotonic_cst

# Draw random splits and pick the best
cdef intp_t start = splitter.start
cdef intp_t end = splitter.end
Expand Down Expand Up @@ -1662,8 +1666,6 @@ cdef class BestSplitter(Splitter):
self.criterion,
split,
parent_record,
self.with_monotonic_cst,
self.monotonic_cst,
)

cdef class BestSparseSplitter(Splitter):
Expand Down Expand Up @@ -1692,8 +1694,6 @@ cdef class BestSparseSplitter(Splitter):
self.criterion,
split,
parent_record,
self.with_monotonic_cst,
self.monotonic_cst,
)

cdef class RandomSplitter(Splitter):
Expand Down Expand Up @@ -1722,8 +1722,6 @@ cdef class RandomSplitter(Splitter):
self.criterion,
split,
parent_record,
self.with_monotonic_cst,
self.monotonic_cst,
)

cdef class RandomSparseSplitter(Splitter):
Expand Down Expand Up @@ -1751,6 +1749,4 @@ cdef class RandomSparseSplitter(Splitter):
self.criterion,
split,
parent_record,
self.with_monotonic_cst,
self.monotonic_cst,
)

0 comments on commit 08658c6

Please sign in to comment.