Skip to content

Commit

Permalink
Add curated schema for Birch
Browse files Browse the repository at this point in the history
  • Loading branch information
shinnar committed Apr 11, 2021
1 parent d4f2937 commit 76b5d85
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lale/lib/sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
* lale.lib.sklearn. `BaggingClassifier`_
* lale.lib.sklearn. `BernoulliNB`_
* lale.lib.sklearn. `BernoulliRBM`_
* lale.lib.sklearn. `Birch`_
* lale.lib.sklearn. `DecisionTreeClassifier`_
* lale.lib.sklearn. `DummyClassifier`_
* lale.lib.sklearn. `ExtraTreesClassifier`_
Expand Down Expand Up @@ -96,6 +97,7 @@
.. _`BernoulliNB`: lale.lib.sklearn.bernoulli_nb.html
.. _`BernoulliRBM`: lale.lib.sklearn.bernoulli_rbm.html
.. _`Binarizer`: lale.lib.sklearn.binarizer.html
.. _`Birch`: lale.lib.sklearn.birch.html
.. _`ColumnTransformer`: lale.lib.sklearn.column_transformer.html
.. _`DecisionTreeClassifier`: lale.lib.sklearn.decision_tree_classifier.html
.. _`DecisionTreeRegressor`: lale.lib.sklearn.decision_tree_regressor.html
Expand Down Expand Up @@ -156,6 +158,7 @@
from .bernoulli_nb import BernoulliNB
from .bernoulli_rbm import BernoulliRBM
from .binarizer import Binarizer
from .birch import Birch
from .column_transformer import ColumnTransformer
from .decision_tree_classifier import DecisionTreeClassifier
from .decision_tree_regressor import DecisionTreeRegressor
Expand Down
148 changes: 148 additions & 0 deletions lale/lib/sklearn/birch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from sklearn.cluster import Birch as Op

from lale.docstrings import set_docstrings
from lale.operators import make_operator

_hyperparams_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Implements the Birch clustering algorithm.",
"allOf": [
{
"type": "object",
"required": [
"threshold",
"branching_factor",
"n_clusters",
"compute_labels",
"copy",
],
"relevantToOptimizer": [
"branching_factor",
"n_clusters",
"compute_labels",
],
"additionalProperties": False,
"properties": {
"threshold": {
"type": "number",
"default": 0.5,
"description": "The radius of the subcluster obtained by merging a new sample and the closest subcluster should be lesser than the threshold",
},
"branching_factor": {
"type": "integer",
"minimumForOptimizer": 50,
"maximumForOptimizer": 51,
"distribution": "uniform",
"default": 50,
"description": "Maximum number of CF subclusters in each node",
},
"n_clusters": {
"description": "Number of clusters after the final clustering step, which treats the subclusters from the leaves as new samples",
"default": 3,
"anyOf": [
{
"description": "The model fit is AgglomerativeClustering with n_clusters set to be equal to the int.",
"type": "integer",
"minimumForOptimizer": 2,
"maximumForOptimizer": 8,
"distribution": "uniform",
},
{
"forOptimizer": False,
"description": "sklearn.cluster Estimator: If a model is provided, the model is fit treating the subclusters as new samples and the initial data is mapped to the label of the closest subcluster.",
"laleType": "operator",
},
{
"enum": [None],
"description": "The final clustering step is not performed and the subclusters are returned as they are.",
},
],
},
"compute_labels": {
"type": "boolean",
"default": True,
"description": "Whether or not to compute labels for each fit.",
},
"copy": {
"type": "boolean",
"default": True,
"description": "Whether or not to make a copy of the given data",
},
},
}
],
}
_input_fit_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Build a CF Tree for the input data.",
"type": "object",
"required": ["X", "y"],
"properties": {
"X": {
"type": "array",
"items": {"type": "array", "items": {"type": "number"}},
"description": "Input data.",
},
"y": {},
},
}
_input_transform_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Transform X into subcluster centroids dimension.",
"type": "object",
"required": ["X"],
"properties": {
"X": {
"type": "array",
"items": {"type": "array", "items": {"type": "number"}},
"description": "Input data.",
}
},
}
_output_transform_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Transformed data.",
"type": "array",
"items": {"type": "array", "items": {"type": "number"}},
}
_input_predict_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Predict data using the ``centroids_`` of subclusters.",
"type": "object",
"required": ["X"],
"properties": {
"X": {
"type": "array",
"items": {"type": "array", "items": {"type": "number"}},
"description": "Input data.",
}
},
}
_output_predict_schema = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "Labelled data.",
"laleType": "Any",
"XXX TODO XXX": "ndarray, shape(n_samples)",
}
_combined_schemas = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": """`Birch`_ clustering algorithm.
.. _`Birch`: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch
""",
"documentation_url": "https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.birch.html",
"import_from": "sklearn.cluster",
"type": "object",
"tags": {"pre": [], "op": ["transformer", "estimator"], "post": []},
"properties": {
"hyperparams": _hyperparams_schema,
"input_fit": _input_fit_schema,
"input_transform": _input_transform_schema,
"output_transform": _output_transform_schema,
"input_predict": _input_predict_schema,
"output_predict": _output_predict_schema,
},
}
Birch = make_operator(Op, _combined_schemas)

set_docstrings(Birch)
1 change: 1 addition & 0 deletions test/test_core_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_classifier(self):
classifiers = [
"lale.lib.sklearn.BernoulliNB",
"lale.lib.sklearn.BernoulliRBM",
"lale.lib.sklearn.Birch",
"lale.lib.sklearn.DummyClassifier",
"lale.lib.sklearn.RandomForestClassifier",
"lale.lib.sklearn.DecisionTreeClassifier",
Expand Down

0 comments on commit 76b5d85

Please sign in to comment.