From 4b9354349f6717d9a9134d272da51f4d8fb624be Mon Sep 17 00:00:00 2001 From: Adam Li Date: Sun, 1 May 2022 19:15:35 -0400 Subject: [PATCH 1/8] Adding details to docs --- doc/conf.py | 2 +- mne_icalabel/iclabel/label_components.py | 9 +++++---- mne_icalabel/iclabel/network.py | 7 ++++++- mne_icalabel/label_components.py | 16 ++++++++++++++-- 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index b8b1089a..d25e839f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -88,7 +88,7 @@ 'n_node_names', 'n_tapers', 'n_signals', 'n_step', 'n_freqs', 'epochs', 'freqs', 'times', 'arrays', 'lists', 'func', 'n_nodes', 'n_estimated_nodes', 'n_samples', 'n_channels', 'Renderer', - 'n_ytimes', 'n_ychannels', 'n_events', 'n_components', + 'n_ytimes', 'n_ychannels', 'n_events', 'n_components', 'n_classes', } numpydoc_xref_aliases = { # Python diff --git a/mne_icalabel/iclabel/label_components.py b/mne_icalabel/iclabel/label_components.py index 2857d502..c930a307 100644 --- a/mne_icalabel/iclabel/label_components.py +++ b/mne_icalabel/iclabel/label_components.py @@ -19,7 +19,7 @@ def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA): - Autocorrelation, based on the ICA decomposition and the provided instance. - For more information, see :footcite:`iclabel2019` + For more information, see :footcite:`iclabel2019`. Parameters ---------- @@ -32,9 +32,10 @@ def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA): Returns ------- - labels : numpy.ndarray of shape (n_components,) - The estimated corresponding numerical labels for each independent - component. + labels : numpy.ndarray of shape (n_components, n_classes) + The estimated corresponding predicted probabilities of output classes + for each independent component. Columns are ordered with 'Brain', + 'Muscle', 'Eye', 'Heart', 'Line Noise', 'Channel Noise', and 'Other'. References ---------- diff --git a/mne_icalabel/iclabel/network.py b/mne_icalabel/iclabel/network.py index 3b70d45e..1934ab2d 100644 --- a/mne_icalabel/iclabel/network.py +++ b/mne_icalabel/iclabel/network.py @@ -225,8 +225,10 @@ def run_iclabel(images: ArrayLike, psds: ArrayLike, autocorr: ArrayLike): Returns ------- - labels : np.ndarray of shape (n_components) + labels : np.ndarray of shape (n_components, n_classes) The predicted numerical probability values for all labels in ICLabel output. + Columns are ordered with 'Brain', 'Muscle', 'Eye', 'Heart', + 'Line Noise', 'Channel Noise', and 'Other'. """ ica_network_file = files("mne_icalabel.iclabel").joinpath("assets/iclabelNet.pt") @@ -237,4 +239,7 @@ def run_iclabel(images: ArrayLike, psds: ArrayLike, autocorr: ArrayLike): # Format input and get labels labels = iclabel_net(*_format_input_for_torch(*_format_input(images, psds, autocorr))) labels = labels.detach().numpy() + + # outputs are: + # ordered as in https://github.com/sccn/ICLabel/blob/e8abc99e0c371ff49eff115cf7955fafc7f7969a/iclabel.m#L60-L62 return labels diff --git a/mne_icalabel/label_components.py b/mne_icalabel/label_components.py index e959b71e..88d9afac 100644 --- a/mne_icalabel/label_components.py +++ b/mne_icalabel/label_components.py @@ -30,8 +30,20 @@ def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA, method: str): Returns ------- - labels : np.ndarray of shape (n_components,) or (n_components, n_class) - The estimated numerical labels of each ICA component. + labels : np.ndarray of shape (n_components,) or (n_components, n_classes) + The estimated corresponding predicted probabilities of output classes + for each independent component. + + Notes + ----- + For ICLabel model, the output classes are ordered: + - 'Brain' + - 'Muscle' + - 'Eye' + - 'Heart' + - 'Line Noise' + - 'Channel Noise' + - 'Other' """ _validate_type(method, str, "method") _check_option("method", method, methods) From 8964a5e50435d22653b426cb45b2bba6a031c7cd Mon Sep 17 00:00:00 2001 From: Adam Li Date: Sun, 1 May 2022 19:24:22 -0400 Subject: [PATCH 2/8] Exp w/ transformer --- mne_icalabel/label_components.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mne_icalabel/label_components.py b/mne_icalabel/label_components.py index 88d9afac..c685ee62 100644 --- a/mne_icalabel/label_components.py +++ b/mne_icalabel/label_components.py @@ -1,10 +1,12 @@ from typing import Union +import numpy as np from mne import BaseEpochs from mne.io import BaseRaw from mne.preprocessing import ICA from mne.utils import _validate_type from mne.utils.check import _check_option +from sklearn.base import TransformerMixin, BaseEstimator from .iclabel import label_components as label_components_iclabel from .utils import _validate_inst_and_ica @@ -13,6 +15,26 @@ "iclabel": label_components_iclabel, } +class AutoLabelICA(TransformerMixin): + def __init__(self, method:str ='iclabel') -> None: + self.method = method + + def fit(self, X, y): + pass + + def transform(self, raw, ica): + ic_labels = label_components(raw, ica, method=self.method) + + # Afterwards, we can hard threshold the probability values to assign + # each component to be kept or not (i.e. it is part of brain signal). + # The first component was visually an artifact, which was captured + # for certain. + not_brain_index = np.argmax(ic_labels, axis=1) != 0 + exclude_idx = np.argwhere(not_brain_index).squeeze() + + ica.apply(raw, exclude=exclude_idx) + return raw + def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA, method: str): """ From 218556cd14a05cf3958db0d5b4e9a29199f41f62 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Mon, 2 May 2022 12:17:45 -0400 Subject: [PATCH 3/8] Fixing API --- ...label_automatic_artifact_correction_ica.py | 14 +++--- mne_icalabel/iclabel/__init__.py | 2 +- mne_icalabel/iclabel/config.py | 9 ++++ mne_icalabel/iclabel/label_components.py | 8 ++-- mne_icalabel/label_components.py | 48 +++++++++---------- 5 files changed, 42 insertions(+), 39 deletions(-) create mode 100644 mne_icalabel/iclabel/config.py diff --git a/examples/iclabel_automatic_artifact_correction_ica.py b/examples/iclabel_automatic_artifact_correction_ica.py index a5af1317..cbf42294 100644 --- a/examples/iclabel_automatic_artifact_correction_ica.py +++ b/examples/iclabel_automatic_artifact_correction_ica.py @@ -203,14 +203,12 @@ # See :footcite:`iclabel2019` for full details. ic_labels = label_components(raw, ica) -print(np.round(ic_labels, 2)) - -# Afterwards, we can hard threshold the probability values to assign -# each component to be kept or not (i.e. it is part of brain signal). -# The first component was visually an artifact, which was captured -# for certain. -not_brain_index = np.argmax(ic_labels, axis=1) != 0 -exclude_idx = np.argwhere(not_brain_index).squeeze() +print(ic_labels) + +# We can extract the labels of each component and exclude +# non-brain classified components. +labels = ic_labels["labels"] +exclude_idx = np.argwhere(labels != "brain").squeeze() print(f"Excluding these ICA components: {exclude_idx}") # %% diff --git a/mne_icalabel/iclabel/__init__.py b/mne_icalabel/iclabel/__init__.py index 51affc3b..50e282ec 100644 --- a/mne_icalabel/iclabel/__init__.py +++ b/mne_icalabel/iclabel/__init__.py @@ -4,5 +4,5 @@ This is a python implementation of the EEGLAB plugin 'ICLabel'.""" from .features import get_iclabel_features # noqa: F401 -from .label_components import label_components # noqa: F401 +from .label_components import iclabel_label_components # noqa: F401 from .network import ICLabelNet, run_iclabel # noqa: F401 diff --git a/mne_icalabel/iclabel/config.py b/mne_icalabel/iclabel/config.py new file mode 100644 index 00000000..6ded6e3f --- /dev/null +++ b/mne_icalabel/iclabel/config.py @@ -0,0 +1,9 @@ +ICLABEL_NUMERICAL_TO_STRING = { + 0: "brain", + 1: "muscle artifact", + 2: "eye blink", + 3: "heart beat", + 4: "line noise", + 5: "channel noise", + 6: "other", +} diff --git a/mne_icalabel/iclabel/label_components.py b/mne_icalabel/iclabel/label_components.py index c930a307..c87dc172 100644 --- a/mne_icalabel/iclabel/label_components.py +++ b/mne_icalabel/iclabel/label_components.py @@ -8,7 +8,7 @@ from .network import run_iclabel -def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA): +def iclabel_label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA): """Label the provided ICA components with the ICLabel neural network. This network uses 3 features: @@ -32,7 +32,7 @@ def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA): Returns ------- - labels : numpy.ndarray of shape (n_components, n_classes) + labels_pred_proba : numpy.ndarray of shape (n_components, n_classes) The estimated corresponding predicted probabilities of output classes for each independent component. Columns are ordered with 'Brain', 'Muscle', 'Eye', 'Heart', 'Line Noise', 'Channel Noise', and 'Other'. @@ -42,5 +42,5 @@ def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA): .. footbibliography:: """ features = get_iclabel_features(inst, ica) - labels = run_iclabel(*features) - return labels + labels_pred_proba = run_iclabel(*features) + return labels_pred_proba diff --git a/mne_icalabel/label_components.py b/mne_icalabel/label_components.py index c685ee62..152995c6 100644 --- a/mne_icalabel/label_components.py +++ b/mne_icalabel/label_components.py @@ -6,35 +6,15 @@ from mne.preprocessing import ICA from mne.utils import _validate_type from mne.utils.check import _check_option -from sklearn.base import TransformerMixin, BaseEstimator -from .iclabel import label_components as label_components_iclabel +from .iclabel import iclabel_label_components +from .iclabel.config import ICLABEL_NUMERICAL_TO_STRING from .utils import _validate_inst_and_ica methods = { - "iclabel": label_components_iclabel, + "iclabel": iclabel_label_components, } -class AutoLabelICA(TransformerMixin): - def __init__(self, method:str ='iclabel') -> None: - self.method = method - - def fit(self, X, y): - pass - - def transform(self, raw, ica): - ic_labels = label_components(raw, ica, method=self.method) - - # Afterwards, we can hard threshold the probability values to assign - # each component to be kept or not (i.e. it is part of brain signal). - # The first component was visually an artifact, which was captured - # for certain. - not_brain_index = np.argmax(ic_labels, axis=1) != 0 - exclude_idx = np.argwhere(not_brain_index).squeeze() - - ica.apply(raw, exclude=exclude_idx) - return raw - def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA, method: str): """ @@ -52,9 +32,16 @@ def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA, method: str): Returns ------- - labels : np.ndarray of shape (n_components,) or (n_components, n_classes) - The estimated corresponding predicted probabilities of output classes + component_dict : dict + A dictionary with the following output: + - 'y_pred_proba' : np.ndarray of shape (n_components, n_classes) + Estimated corresponding predicted probabilities of output classes for each independent component. + - 'y_pred' : list of shape (n_components,) + The corresponding numerical label of the class with the highest + predicted probability. + - 'labels': list of shape (n_components,) + The corresponding string label of each class in 'y_pred'. Notes ----- @@ -70,4 +57,13 @@ def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA, method: str): _validate_type(method, str, "method") _check_option("method", method, methods) _validate_inst_and_ica(inst, ica) - return methods[method](inst, ica) + labels_pred_proba = methods[method](inst, ica) + labels_pred = np.argmax(labels_pred_proba, axis=1) + labels = [ICLABEL_NUMERICAL_TO_STRING[label] for label in labels_pred] + + component_dict = { + "y_pred_proba": labels_pred_proba, + "y_pred": labels_pred, + "labels": labels, + } + return component_dict From 7c798e0afa0c222247d93dd65a0d8a0e53ea055d Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Mon, 2 May 2022 19:48:04 +0200 Subject: [PATCH 4/8] fix import --- mne_icalabel/iclabel/tests/test_label_components.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne_icalabel/iclabel/tests/test_label_components.py b/mne_icalabel/iclabel/tests/test_label_components.py index b75825e7..2a4445aa 100644 --- a/mne_icalabel/iclabel/tests/test_label_components.py +++ b/mne_icalabel/iclabel/tests/test_label_components.py @@ -3,7 +3,7 @@ from mne.io import read_raw from mne.preprocessing import ICA -from mne_icalabel.iclabel import label_components +from mne_icalabel.iclabel import iclabel_label_components directory = sample.data_path() / "MEG" / "sample" raw = read_raw(directory / "sample_audvis_raw.fif", preload=False) @@ -20,5 +20,5 @@ @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_label_components(): """Simple test to check that label_components runs without raising.""" - labels = label_components(raw, ica) + labels = iclabel_label_components(raw, ica) assert labels is not None From 01f11284e5a947176b364122afde8215be1058db Mon Sep 17 00:00:00 2001 From: Adam Li Date: Mon, 2 May 2022 15:10:51 -0400 Subject: [PATCH 5/8] Fix docs --- examples/iclabel_automatic_artifact_correction_ica.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/iclabel_automatic_artifact_correction_ica.py b/examples/iclabel_automatic_artifact_correction_ica.py index cbf42294..c9e7dbc4 100644 --- a/examples/iclabel_automatic_artifact_correction_ica.py +++ b/examples/iclabel_automatic_artifact_correction_ica.py @@ -27,7 +27,7 @@ import numpy as np from mne.preprocessing import ICA -from mne_icalabel.iclabel import label_components +from mne_icalabel import label_components sample_data_folder = mne.datasets.sample.data_path() sample_data_raw_file = os.path.join( From 26a2123cbca2b3f580e3d456e9e3f502709f3efc Mon Sep 17 00:00:00 2001 From: Adam Li Date: Mon, 2 May 2022 15:49:57 -0400 Subject: [PATCH 6/8] Fix example --- mne_icalabel/label_components.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mne_icalabel/label_components.py b/mne_icalabel/label_components.py index 152995c6..6fa86cc2 100644 --- a/mne_icalabel/label_components.py +++ b/mne_icalabel/label_components.py @@ -17,8 +17,7 @@ def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA, method: str): - """ - Automatically label the ICA components with the selected method. + """Automatically label the ICA components with the selected method. Parameters ---------- From 144c6ad8f8647ac2c1398a3811c754a70caaad18 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Mon, 2 May 2022 17:14:32 -0400 Subject: [PATCH 7/8] Fix example --- examples/iclabel_automatic_artifact_correction_ica.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/iclabel_automatic_artifact_correction_ica.py b/examples/iclabel_automatic_artifact_correction_ica.py index c9e7dbc4..fd9dda11 100644 --- a/examples/iclabel_automatic_artifact_correction_ica.py +++ b/examples/iclabel_automatic_artifact_correction_ica.py @@ -202,7 +202,7 @@ # into a 3-head neural network that has been pretrained. # See :footcite:`iclabel2019` for full details. -ic_labels = label_components(raw, ica) +ic_labels = label_components(raw, ica, method='iclabel') print(ic_labels) # We can extract the labels of each component and exclude From f83f0db29359e3189acb1ff9686fb09a99ded600 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Mon, 2 May 2022 18:00:23 -0400 Subject: [PATCH 8/8] Fixing example and style --- examples/iclabel_automatic_artifact_correction_ica.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/iclabel_automatic_artifact_correction_ica.py b/examples/iclabel_automatic_artifact_correction_ica.py index fd9dda11..b7922fea 100644 --- a/examples/iclabel_automatic_artifact_correction_ica.py +++ b/examples/iclabel_automatic_artifact_correction_ica.py @@ -24,7 +24,6 @@ import os import mne -import numpy as np from mne.preprocessing import ICA from mne_icalabel import label_components @@ -202,13 +201,15 @@ # into a 3-head neural network that has been pretrained. # See :footcite:`iclabel2019` for full details. -ic_labels = label_components(raw, ica, method='iclabel') +ic_labels = label_components(raw, ica, method="iclabel") print(ic_labels) # We can extract the labels of each component and exclude -# non-brain classified components. +# non-brain classified components, keeping 'brain' and 'other'. +# "Other" is a catch-all that for non-classifiable components. +# We will ere on the side of caution and assume we cannot blindly remove these. labels = ic_labels["labels"] -exclude_idx = np.argwhere(labels != "brain").squeeze() +exclude_idx = [idx for idx, label in enumerate(labels) if label not in ["brain", "other"]] print(f"Excluding these ICA components: {exclude_idx}") # %%