Skip to content

Commit

Permalink
treeple-compatibility tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelCarliles3 committed Sep 8, 2024
1 parent c565d65 commit 2316e4c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
18 changes: 15 additions & 3 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2481,9 +2481,15 @@ class labels (multi-output problem).
dict,
list,
None,
],
]
}
_parameter_constraints.pop("splitter")
_parameter_constraints.pop("max_samples")
_parameter_constraints["max_samples"] = [
None,
Interval(RealNotInt, 0.0, None, closed="right"),
Interval(Integral, 1, None, closed="left"),
]

def __init__(
self,
Expand All @@ -2509,7 +2515,9 @@ def __init__(
max_samples=None,
max_bins=None,
store_leaf_values=False,
monotonic_cst=None
monotonic_cst=None,
stratify=False,
honest_prior="ignore"
):
self.target_tree_kwargs = {
"criterion": criterion,
Expand All @@ -2528,7 +2536,9 @@ def __init__(
super().__init__(
estimator=HonestDecisionTree(
target_tree_class=target_tree_class,
target_tree_kwargs=self.target_tree_kwargs
target_tree_kwargs=self.target_tree_kwargs,
stratify=stratify,
honest_prior=honest_prior
),
n_estimators=n_estimators,
estimator_params=(
Expand Down Expand Up @@ -2572,6 +2582,8 @@ def __init__(
self.monotonic_cst = monotonic_cst
self.ccp_alpha = ccp_alpha
self.target_tree_class = target_tree_class
self.stratify = stratify
self.honest_prior = honest_prior


class RandomForestRegressor(ForestRegressor):
Expand Down
14 changes: 7 additions & 7 deletions sklearn/tree/_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
class HonestDecisionTree(BaseDecisionTree):
_parameter_constraints: dict = {
**BaseDecisionTree._parameter_constraints,
"target_tree_class": [BaseDecisionTree],
"target_tree_class": "no_validation",
"target_tree_kwargs": [dict],
"honest_fraction": [Interval(RealNotInt, 0.0, 1.0, closed="neither")],
"honest_fraction": [Interval(RealNotInt, 0.0, 1.0, closed="both")],
"honest_prior": [StrOptions({"empirical", "uniform", "ignore"})],
"stratify": ["boolean"],
}
Expand Down Expand Up @@ -221,7 +221,7 @@ def fit(

# fingers crossed sklearn.utils.validation.check_is_fitted doesn't
# change its behavior
print(f"n_classes = {target_bta.n_classes}")
#print(f"n_classes = {target_bta.n_classes}")
self.tree_ = HonestTree(
self.target_tree.n_features_in_,
target_bta.n_classes,
Expand All @@ -231,8 +231,8 @@ def fit(
self.honesty.resize_tree(self.tree_, self.honesty.get_node_count())
self.tree_.node_count = self.honesty.get_node_count()

print(f"dishonest node count = {self.target_tree.tree_.node_count}")
print(f"honest node count = {self.tree_.node_count}")
#print(f"dishonest node count = {self.target_tree.tree_.node_count}")
#print(f"honest node count = {self.tree_.node_count}")

criterion = BaseDecisionTree._create_criterion(
self.target_tree,
Expand All @@ -250,8 +250,8 @@ def fit(

for i in range(self.honesty.get_node_count()):
start, end = self.honesty.get_node_range(i)
print(f"setting sample range for node {i}: ({start}, {end})")
print(f"node {i} is leaf: {self.honesty.is_leaf(i)}")
#print(f"setting sample range for node {i}: ({start}, {end})")
#print(f"node {i} is leaf: {self.honesty.is_leaf(i)}")
self.honesty.set_sample_pointers(criterion, start, end)

if missing_values_in_feature_mask is not None:
Expand Down

0 comments on commit 2316e4c

Please sign in to comment.