diff --git a/asv_benchmarks/benchmarks/ensemble.py b/asv_benchmarks/benchmarks/ensemble.py index c336d1e5f8805..877fcdb09fe68 100644 --- a/asv_benchmarks/benchmarks/ensemble.py +++ b/asv_benchmarks/benchmarks/ensemble.py @@ -2,6 +2,7 @@ GradientBoostingClassifier, HistGradientBoostingClassifier, RandomForestClassifier, + RandomForestRegressor, ) from .common import Benchmark, Estimator, Predictor @@ -9,8 +10,50 @@ _20newsgroups_highdim_dataset, _20newsgroups_lowdim_dataset, _synth_classification_dataset, + _synth_regression_dataset, + _synth_regression_sparse_dataset, ) -from .utils import make_gen_classif_scorers +from .utils import make_gen_classif_scorers, make_gen_reg_scorers + + +class RandomForestRegressorBenchmark(Predictor, Estimator, Benchmark): + """ + Benchmarks for RandomForestRegressor. + """ + + param_names = ["representation", "n_jobs"] + params = (["dense", "sparse"], Benchmark.n_jobs_vals) + + def setup_cache(self): + super().setup_cache() + + def make_data(self, params): + representation, n_jobs = params + + if representation == "sparse": + data = _synth_regression_sparse_dataset() + else: + data = _synth_regression_dataset() + + return data + + def make_estimator(self, params): + representation, n_jobs = params + + n_estimators = 500 if Benchmark.data_size == "large" else 100 + + estimator = RandomForestRegressor( + n_estimators=n_estimators, + min_samples_split=10, + max_features="log2", + n_jobs=n_jobs, + random_state=0, + ) + + return estimator + + def make_scorers(self): + make_gen_reg_scorers(self) class RandomForestClassifierBenchmark(Predictor, Estimator, Benchmark): diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index fe1b239cdeb32..0aeb07c9606d4 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -6,6 +6,7 @@ # Jacob Schreiber # Adam Li # Jong Shin +# Samuel Carliles # # License: BSD 3 clause @@ -14,9 +15,49 @@ from libcpp.vector cimport vector from ._criterion cimport BaseCriterion, Criterion from ._tree cimport ParentInfo + from ..utils._typedefs cimport float32_t, float64_t, intp_t, int8_t, int32_t, uint32_t +# NICE IDEAS THAT DON'T APPEAR POSSIBLE +# - accessing elements of a memory view of cython extension types in a nogil block/function +# - storing cython extension types in cpp vectors +# +# despite the fact that we can access scalar extension type properties in such a context, +# as for instance node_split_best does with Criterion and Partition, +# and we can access the elements of a memory view of primitive types in such a context +# +# SO WHERE DOES THAT LEAVE US +# - we can transform these into cpp vectors of structs +# and with some minor casting irritations everything else works ok +ctypedef void* SplitConditionParameters +ctypedef bint (*SplitConditionFunction)( + Splitter splitter, + SplitRecord* current_split, + intp_t n_missing, + bint missing_go_to_left, + float64_t lower_bound, + float64_t upper_bound, + SplitConditionParameters split_condition_parameters +) noexcept nogil + +cdef struct SplitConditionTuple: + SplitConditionFunction f + SplitConditionParameters p + +cdef class SplitCondition: + cdef SplitConditionTuple t + +cdef class MinSamplesLeafCondition(SplitCondition): + pass + +cdef class MinWeightLeafCondition(SplitCondition): + pass + +cdef class MonotonicConstraintCondition(SplitCondition): + pass + + cdef struct SplitRecord: # Data to track sample split intp_t feature # Which feature to split on. @@ -105,6 +146,13 @@ cdef class Splitter(BaseSplitter): cdef const int8_t[:] monotonic_cst cdef bint with_monotonic_cst + cdef SplitCondition min_samples_leaf_condition + cdef SplitCondition min_weight_leaf_condition + cdef SplitCondition monotonic_constraint_condition + + cdef vector[SplitConditionTuple] presplit_conditions + cdef vector[SplitConditionTuple] postsplit_conditions + cdef int init( self, object X, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index d3c8fa1f98e83..ff707817d3d60 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -19,7 +19,8 @@ from cython cimport final from libc.math cimport isnan -from libc.stdlib cimport qsort +from libc.stdint cimport uintptr_t +from libc.stdlib cimport qsort, free from libc.string cimport memcpy from ._criterion cimport Criterion @@ -42,6 +43,155 @@ cdef float32_t FEATURE_THRESHOLD = 1e-7 # in SparsePartitioner cdef float32_t EXTRACT_NNZ_SWITCH = 0.1 + +cdef bint min_sample_leaf_condition( + Splitter splitter, + SplitRecord* current_split, + intp_t n_missing, + bint missing_go_to_left, + float64_t lower_bound, + float64_t upper_bound, + SplitConditionParameters split_condition_parameters +) noexcept nogil: + cdef intp_t min_samples_leaf = splitter.min_samples_leaf + cdef intp_t end_non_missing = splitter.end - n_missing + cdef intp_t n_left, n_right + + if missing_go_to_left: + n_left = current_split.pos - splitter.start + n_missing + n_right = end_non_missing - current_split.pos + else: + n_left = current_split.pos - splitter.start + n_right = end_non_missing - current_split.pos + n_missing + + # Reject if min_samples_leaf is not guaranteed + if n_left < min_samples_leaf or n_right < min_samples_leaf: + return False + + return True + +cdef class MinSamplesLeafCondition(SplitCondition): + def __cinit__(self): + self.t.f = min_sample_leaf_condition + self.t.p = NULL # min_samples is stored in splitter, which is already passed to f + +cdef bint min_weight_leaf_condition( + Splitter splitter, + SplitRecord* current_split, + intp_t n_missing, + bint missing_go_to_left, + float64_t lower_bound, + float64_t upper_bound, + SplitConditionParameters split_condition_parameters +) noexcept nogil: + cdef float64_t min_weight_leaf = splitter.min_weight_leaf + + # Reject if min_weight_leaf is not satisfied + if ((splitter.criterion.weighted_n_left < min_weight_leaf) or + (splitter.criterion.weighted_n_right < min_weight_leaf)): + return False + + return True + +cdef class MinWeightLeafCondition(SplitCondition): + def __cinit__(self): + self.t.f = min_weight_leaf_condition + self.t.p = NULL # min_weight_leaf is stored in splitter, which is already passed to f + +cdef bint monotonic_constraint_condition( + Splitter splitter, + SplitRecord* current_split, + intp_t n_missing, + bint missing_go_to_left, + float64_t lower_bound, + float64_t upper_bound, + SplitConditionParameters split_condition_parameters +) noexcept nogil: + if ( + splitter.with_monotonic_cst and + splitter.monotonic_cst[current_split.feature] != 0 and + not splitter.criterion.check_monotonicity( + splitter.monotonic_cst[current_split.feature], + lower_bound, + upper_bound, + ) + ): + return False + + return True + +cdef class MonotonicConstraintCondition(SplitCondition): + def __cinit__(self): + self.t.f = monotonic_constraint_condition + self.t.p = NULL + +# cdef struct HasDataParameters: +# int min_samples + +# cdef bint has_data_condition( +# Splitter splitter, +# SplitRecord* current_split, +# intp_t n_missing, +# bint missing_go_to_left, +# float64_t lower_bound, +# float64_t upper_bound, +# SplitConditionParameters split_condition_parameters +# ) noexcept nogil: +# cdef HasDataParameters* p = split_condition_parameters +# return splitter.n_samples >= p.min_samples + +# cdef class HasDataCondition(SplitCondition): +# def __cinit__(self, int min_samples): +# self.t.f = has_data_condition +# self.t.p = malloc(sizeof(HasDataParameters)) +# (self.t.p).min_samples = min_samples + +# def __dealloc__(self): +# if self.t.p is not NULL: +# free(self.t.p) + +# super.__dealloc__(self) + +# cdef struct AlphaRegularityParameters: +# float64_t alpha + +# cdef bint alpha_regularity_condition( +# Splitter splitter, +# SplitRecord* current_split, +# intp_t n_missing, +# bint missing_go_to_left, +# float64_t lower_bound, +# float64_t upper_bound, +# SplitConditionParameters split_condition_parameters +# ) noexcept nogil: +# cdef AlphaRegularityParameters* p = split_condition_parameters + +# return True + +# cdef class AlphaRegularityCondition(SplitCondition): +# def __cinit__(self, float64_t alpha): +# self.t.f = alpha_regularity_condition +# self.t.p = malloc(sizeof(AlphaRegularityParameters)) +# (self.t.p).alpha = alpha + +# def __dealloc__(self): +# if self.t.p is not NULL: +# free(self.t.p) + +# super.__dealloc__(self) + + +# from ._tree cimport Tree +# cdef class FooTree(Tree): +# cdef Splitter splitter + +# def __init__(self): +# self.splitter = Splitter( +# presplit_conditions = [HasDataCondition(10)], +# postsplit_conditions = [AlphaRegularityCondition(0.1)], +# ) + + cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil: self.impurity_left = INFINITY self.impurity_right = INFINITY @@ -148,6 +298,8 @@ cdef class Splitter(BaseSplitter): float64_t min_weight_leaf, object random_state, const int8_t[:] monotonic_cst, + SplitCondition[:] presplit_conditions = None, + SplitCondition[:] postsplit_conditions = None, *argv ): """ @@ -188,6 +340,38 @@ cdef class Splitter(BaseSplitter): self.monotonic_cst = monotonic_cst self.with_monotonic_cst = monotonic_cst is not None + self.min_samples_leaf_condition = MinSamplesLeafCondition() + self.min_weight_leaf_condition = MinWeightLeafCondition() + + self.presplit_conditions.resize( + (len(presplit_conditions) if presplit_conditions is not None else 0) + + (2 if self.with_monotonic_cst else 1) + ) + self.postsplit_conditions.resize( + (len(postsplit_conditions) if postsplit_conditions is not None else 0) + + (2 if self.with_monotonic_cst else 1) + ) + + offset = 0 + self.presplit_conditions[offset] = self.min_samples_leaf_condition.t + self.postsplit_conditions[offset] = self.min_weight_leaf_condition.t + offset += 1 + + if(self.with_monotonic_cst): + self.monotonic_constraint_condition = MonotonicConstraintCondition() + self.presplit_conditions[offset] = self.monotonic_constraint_condition.t + self.postsplit_conditions[offset] = self.monotonic_constraint_condition.t + offset += 1 + + if presplit_conditions is not None: + for i in range(len(presplit_conditions)): + self.presplit_conditions[i + offset] = presplit_conditions[i].t + + if postsplit_conditions is not None: + for i in range(len(postsplit_conditions)): + self.postsplit_conditions[i + offset] = postsplit_conditions[i].t + + def __reduce__(self): return (type(self), (self.criterion, self.max_features, @@ -485,6 +669,8 @@ cdef inline intp_t node_split_best( # n_total_constants = n_known_constants + n_found_constants cdef intp_t n_total_constants = n_known_constants + cdef bint conditions_hold = True + _init_split(&best_split, end) partitioner.init_node_split(start, end) @@ -579,46 +765,71 @@ cdef inline intp_t node_split_best( current_split.pos = p - # Reject if monotonicity constraints are not satisfied - if ( - with_monotonic_cst and - monotonic_cst[current_split.feature] != 0 and - not criterion.check_monotonicity( - monotonic_cst[current_split.feature], - lower_bound, - upper_bound, - ) - ): - continue - - # Reject if min_samples_leaf is not guaranteed - if missing_go_to_left: - n_left = current_split.pos - splitter.start + n_missing - n_right = end_non_missing - current_split.pos - else: - n_left = current_split.pos - splitter.start - n_right = end_non_missing - current_split.pos + n_missing - if splitter.check_presplit_conditions(¤t_split, n_missing, missing_go_to_left) == 1: + # # Reject if monotonicity constraints are not satisfied + # if ( + # with_monotonic_cst and + # monotonic_cst[current_split.feature] != 0 and + # not criterion.check_monotonicity( + # monotonic_cst[current_split.feature], + # lower_bound, + # upper_bound, + # ) + # ): + # continue + + # # Reject if min_samples_leaf is not guaranteed + # if missing_go_to_left: + # n_left = current_split.pos - splitter.start + n_missing + # n_right = end_non_missing - current_split.pos + # else: + # n_left = current_split.pos - splitter.start + # n_right = end_non_missing - current_split.pos + n_missing + + conditions_hold = True + for condition in splitter.presplit_conditions: + if not condition.f( + splitter, ¤t_split, n_missing, missing_go_to_left, + lower_bound, upper_bound, condition.p + ): + conditions_hold = False + break + + if not conditions_hold: continue + # if splitter.check_presplit_conditions(¤t_split, n_missing, missing_go_to_left) == 1: + # continue + criterion.update(current_split.pos) - # Reject if monotonicity constraints are not satisfied - if ( - with_monotonic_cst and - monotonic_cst[current_split.feature] != 0 and - not criterion.check_monotonicity( - monotonic_cst[current_split.feature], - lower_bound, - upper_bound, - ) - ): - continue - - # Reject if min_weight_leaf is not satisfied - if splitter.check_postsplit_conditions() == 1: + # # Reject if monotonicity constraints are not satisfied + # if ( + # with_monotonic_cst and + # monotonic_cst[current_split.feature] != 0 and + # not criterion.check_monotonicity( + # monotonic_cst[current_split.feature], + # lower_bound, + # upper_bound, + # ) + # ): + # continue + + conditions_hold = True + for condition in splitter.postsplit_conditions: + if not condition.f( + splitter, ¤t_split, n_missing, missing_go_to_left, + lower_bound, upper_bound, condition.p + ): + conditions_hold = False + break + + if not conditions_hold: continue - + + # # Reject if min_weight_leaf is not satisfied + # if splitter.check_postsplit_conditions() == 1: + # continue + current_proxy_improvement = criterion.proxy_impurity_improvement() if current_proxy_improvement > best_proxy_improvement: