Skip to content

Commit

Permalink
Add curated schema for CalibratedClassifierCV
Browse files Browse the repository at this point in the history
  • Loading branch information
shinnar committed Apr 12, 2021
1 parent 76b5d85 commit 06aea7c
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lale/lib/sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
* lale.lib.sklearn. `BernoulliNB`_
* lale.lib.sklearn. `BernoulliRBM`_
* lale.lib.sklearn. `Birch`_
* lale.lib.sklearn. `CalibratedClassifierCV`_
* lale.lib.sklearn. `DecisionTreeClassifier`_
* lale.lib.sklearn. `DummyClassifier`_
* lale.lib.sklearn. `ExtraTreesClassifier`_
Expand Down Expand Up @@ -98,6 +99,7 @@
.. _`BernoulliRBM`: lale.lib.sklearn.bernoulli_rbm.html
.. _`Binarizer`: lale.lib.sklearn.binarizer.html
.. _`Birch`: lale.lib.sklearn.birch.html
.. _`CalibratedClassifierCV`: lale.lib.sklearn.calibrated_classifier_cv.html
.. _`ColumnTransformer`: lale.lib.sklearn.column_transformer.html
.. _`DecisionTreeClassifier`: lale.lib.sklearn.decision_tree_classifier.html
.. _`DecisionTreeRegressor`: lale.lib.sklearn.decision_tree_regressor.html
Expand Down Expand Up @@ -159,6 +161,7 @@
from .bernoulli_rbm import BernoulliRBM
from .binarizer import Binarizer
from .birch import Birch
from .calibrated_classifier_cv import CalibratedClassifierCV
from .column_transformer import ColumnTransformer
from .decision_tree_classifier import DecisionTreeClassifier
from .decision_tree_regressor import DecisionTreeRegressor
Expand Down
163 changes: 163 additions & 0 deletions lale/lib/sklearn/calibrated_classifier_cv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import typing

import sklearn
from sklearn.calibration import CalibratedClassifierCV as Op

import lale.operators
from lale.docstrings import set_docstrings
from lale.operators import make_operator
from lale.schemas import AnyOf, Bool, Int, Null

_hyperparams_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Probability calibration with isotonic regression or sigmoid.",
"allOf": [
{
"type": "object",
"required": ["base_estimator", "method", "cv"],
"relevantToOptimizer": ["method", "cv"],
"additionalProperties": False,
"properties": {
"base_estimator": {
"description": "The classifier whose output decision function needs to be calibrated to offer more accurate predict_proba outputs",
"default": None,
"anyOf": [
{
"description": "None uses the default classifier, LinearSVC.",
"enum": [None],
},
{"laleType": "operator"},
],
},
"method": {
"description": "The method to use for calibration. Can be ‘sigmoid’ which corresponds to Platt’s method (i.e. a logistic regression model) or ‘isotonic’ which is a non-parametric approach. It is not advised to use isotonic calibration with too few calibration samples (<<1000) since it tends to overfit.",
"enum": ["sigmoid", "isotonic"],
"default": "sigmoid",
},
"cv": {
"description": "Determines the cross-validation splitting strategy",
"default": None,
"anyOf": [
{
"description": "use the default 5-fold cross-validation",
"enum": [None],
},
{
"type": "integer",
"minimumForOptimizer": 3,
"maximumForOptimizer": 4,
"distribution": "uniform",
},
{
"laleType": "Any",
"description": "CV splitter or an iterable yielding (train, test) splits as arrays of indices.",
},
],
},
},
}
],
}
_input_fit_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Fit the calibrated model",
"type": "object",
"required": ["X", "y"],
"properties": {
"X": {
"type": "array",
"items": {"type": "array", "items": {"type": "number"}},
"description": "Training data.",
},
"y": {
"type": "array",
"items": {"type": "number"},
"description": "Target values.",
},
"sample_weight": {
"anyOf": [{"type": "array", "items": {"type": "number"}}, {"enum": [None]}],
"description": "Sample weights",
},
},
}
_input_predict_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Predict the target of new samples. Can be different from the",
"type": "object",
"required": ["X"],
"properties": {
"X": {
"type": "array",
"items": {"type": "array", "items": {"type": "number"}},
"description": "The samples.",
}
},
}
_output_predict_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "The predicted class.",
"type": "array",
"items": {"type": "number"},
}
_input_predict_proba_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Posterior probabilities of classification",
"type": "object",
"required": ["X"],
"properties": {
"X": {
"type": "array",
"items": {"type": "array", "items": {"type": "number"}},
"description": "The samples.",
}
},
}
_output_predict_proba_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "The predicted probas.",
"type": "array",
"items": {"type": "array", "items": {"type": "number"}},
}
_combined_schemas = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": """`CalibratedClassifierCV`_ : Probability calibration with isotonic regression or sigmoid.
.. _`CalibratedClassifierCV`: https://scikit-learn.org/stable/modules/generated/sklearn.calibration.CalibratedClassifierCV
""",
"documentation_url": "https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.calibrated_classifier_cv.html",
"import_from": "sklearn.calibration",
"type": "object",
"tags": {"pre": [], "op": ["estimator"], "post": []},
"properties": {
"hyperparams": _hyperparams_schema,
"input_fit": _input_fit_schema,
"input_predict": _input_predict_schema,
"output_predict": _output_predict_schema,
"input_predict_proba": _input_predict_proba_schema,
"output_predict_proba": _output_predict_proba_schema,
},
}
CalibratedClassifierCV = make_operator(Op, _combined_schemas)

if sklearn.__version__ >= "0.24":
# old: https://scikit-learn.org/0.23/modules/generated/sklearn.calibration.CalibratedClassifierCV.html
# new: https://scikit-learn.org/0.24/modules/generated/sklearn.calibration.CalibratedClassifierCV.html
CalibratedClassifierCV = typing.cast(
lale.operators.PlannedIndividualOp,
CalibratedClassifierCV.customize_schema(
n_jobs=AnyOf(
types=[
Int(minimum=1),
Int(minimum=-1, maximum=-1, desc="Use all the processors"),
Null(),
],
desc="Number of jobs to run in parallel. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors.",
default=None,
),
ensemble=Bool(
default=True, desc="Determines how the calibrator is fitted."
),
),
)

set_docstrings(CalibratedClassifierCV)
1 change: 1 addition & 0 deletions test/test_core_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def test_classifier(self):
"lale.lib.sklearn.BernoulliNB",
"lale.lib.sklearn.BernoulliRBM",
"lale.lib.sklearn.Birch",
"lale.lib.sklearn.CalibratedClassifierCV",
"lale.lib.sklearn.DummyClassifier",
"lale.lib.sklearn.RandomForestClassifier",
"lale.lib.sklearn.DecisionTreeClassifier",
Expand Down

0 comments on commit 06aea7c

Please sign in to comment.