diff --git a/docs/source/settings/preprocessing/ssp_ica.md b/docs/source/settings/preprocessing/ssp_ica.md index b132ef4bf..23f510796 100644 --- a/docs/source/settings/preprocessing/ssp_ica.md +++ b/docs/source/settings/preprocessing/ssp_ica.md @@ -24,6 +24,7 @@ tags: - ssp_ecg_channel - ica_reject - ica_algorithm + - ica_use_icalabel - ica_l_freq - ica_max_iterations - ica_n_components diff --git a/docs/source/v1.5.md.inc b/docs/source/v1.5.md.inc index 5522271b1..40a86d747 100644 --- a/docs/source/v1.5.md.inc +++ b/docs/source/v1.5.md.inc @@ -3,6 +3,10 @@ This release contains a number of very important bug fixes that address problems related to decoding, time-frequency analysis, and inverse modeling. All users are encouraged to update. +We also improved logging during parallel processing, added support for finding and repairing bad epochs via +[`autoreject`](https://autoreject.github.io), and included support for automatic labeling of ICA artifacts +via [MNE-ICALabel][https://mne.tools/mne-icalabel]. + ### :new: New features & enhancements - Added `deriv_root` argument to CLI (#773 by @vferat) @@ -22,6 +26,7 @@ All users are encouraged to update. - Added support for "local" [`autoreject`](https://autoreject.github.io) to find (and repair) bad channels on a per-epoch basis before submitting them to ICA fitting. This can be enabled by setting [`ica_reject`][mne_bids_pipeline._config.ica_reject] to `"autoreject_local"`. (#810 by @hoechenberger) +- Added support for automated labeling of ICA components via [MNE-ICALabel][https://mne.tools/mne-icalabel] (#812 by @hoechenberger) - Website documentation tables can now be sorted (e.g., to find examples that use a specific feature) (#808 by @larsoner) [//]: # (### :warning: Behavior changes) diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index 32e0c9735..9d1782585 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -1237,7 +1237,7 @@ """ Peak-to-peak amplitude limits to exclude epochs from ICA fitting. This allows you to remove strong transient artifacts from the epochs used for fitting ICA, which could -negatively affect ICA performance. +negatively affect ICA performance. The parameter values are the same as for [`reject`][mne_bids_pipeline._config.reject], but `"autoreject_global"` is not supported. @@ -1262,7 +1262,7 @@ to **not** specify rejection thresholds for EOG and ECG channels here – otherwise, ICA won't be able to "see" these artifacts. -???+ info +???+ info This setting is applied only to the epochs that are used for **fitting** ICA. The goal is to make it easier for ICA to produce a good decomposition. After fitting, ICA is applied to the epochs to be analyzed, usually with one or more components @@ -1367,6 +1367,20 @@ false-alarm rate increases dramatically. """ +ica_use_icalabel: bool = False +""" +Whether to use MNE-ICALabel to automatically label ICA components. Only available for +EEG data. + +!!! info + Using MNE-ICALabel mandates that you also set: + ```python + eeg_reference = "average" + ica_l_freq = 1 + h_freq = 100 + ``` +""" + # Rejection based on peak-to-peak amplitude # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1384,7 +1398,7 @@ If `None` (default), do not apply artifact rejection. -If a dictionary, manually specify rejection thresholds (see examples). +If a dictionary, manually specify rejection thresholds (see examples). The thresholds provided here must be at least as stringent as those in [`ica_reject`][mne_bids_pipeline._config.ica_reject] if using ICA. In case of `'autoreject_global'`, thresholds for any channel that do not meet this @@ -1443,7 +1457,8 @@ !!! info This setting only takes effect if [`reject`][mne_bids_pipeline._config.reject] has - been set to `"autoreject_local"`. + been set to `"autoreject_local"`. It is not applied when using + `"autoreject_global"`. !!! info Channels marked as globally bad in the BIDS dataset (in `*_channels.tsv)`) will not diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index 14a55df2e..81a492d41 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -345,6 +345,20 @@ def _check_config(config: SimpleNamespace, config_path: Optional[PathLike]) -> N f"but got shape {destination.shape}" ) + # MNE-ICALabel + if config.ica_use_icalabel: + if config.ica_l_freq != 1.0 or config.h_freq != 100.0: + raise ValueError( + f"When using MNE-ICALabel, you must set ica_l_freq=1 and h_freq=100, " + f"but got: ica_l_freq={config.ica_l_freq} and h_freq={config.h_freq}" + ) + + if config.eeg_reference != "average": + raise ValueError( + f'When using MNE-ICALabel, you must set eeg_reference="average", but ' + f"got: eeg_reference={config.eeg_reference}" + ) + def _default_factory(key, val): # convert a default to a default factory if needed, having an explicit diff --git a/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py b/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py old mode 100644 new mode 100755 index efd8bec84..dd6355c05 --- a/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py @@ -17,6 +17,7 @@ import pandas as pd import numpy as np import autoreject +from mne_icalabel import label_components import mne from mne.report import Report @@ -135,7 +136,7 @@ def make_ecg_epochs( del raw # Free memory if len(ecg_epochs) == 0: - msg = "No ECG events could be found. Not running ECG artifact " "detection." + msg = "No ECG events could be found. Not running ECG artifact detection." logger.info(**gen_log_kwargs(message=msg)) ecg_epochs = None else: @@ -173,7 +174,7 @@ def make_eog_epochs( eog_epochs = create_eog_epochs(raw, ch_name=ch_names, baseline=(None, -0.2)) if len(eog_epochs) == 0: - msg = "No EOG events could be found. Not running EOG artifact " "detection." + msg = "No EOG events could be found. Not running EOG artifact detection." logger.warning(**gen_log_kwargs(message=msg)) eog_epochs = None else: @@ -184,7 +185,7 @@ def make_eog_epochs( return eog_epochs -def detect_bad_components( +def detect_bad_components_mne( *, cfg, which: Literal["eog", "ecg"], @@ -195,7 +196,7 @@ def detect_bad_components( session: Optional[str], ) -> Tuple[List[int], np.ndarray]: artifact = which.upper() - msg = f"Performing automated {artifact} artifact detection …" + msg = f"Performing automated {artifact} artifact detection (MNE) …" logger.info(**gen_log_kwargs(message=msg)) if which == "eog": @@ -224,7 +225,7 @@ def detect_bad_components( logger.warning(**gen_log_kwargs(message=warn)) else: msg = ( - f"Detected {len(inds)} {artifact}-related ICs in " + f"Detected {len(inds)} {artifact}-related independent component(s) in " f"{len(epochs)} {artifact} epochs." ) logger.info(**gen_log_kwargs(message=msg)) @@ -271,6 +272,14 @@ def run_ica( in_files: dict, ) -> dict: """Run ICA.""" + if cfg.ica_use_icalabel: + # The ICALabel network was trained on extended-Infomax ICA decompositions fit + # on data flltered between 1 and 100 Hz. + assert cfg.ica_algorithm in ["picard-extended_infomax", "extended_infomax"] + assert cfg.ica_l_freq == 1.0 + assert cfg.h_freq == 100.0 + assert cfg.eeg_reference == "average" + raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] bids_basename = raw_fnames[0].copy().update(processing=None, split=None, run=None) out_files = dict() @@ -395,7 +404,18 @@ def run_ica( # Set an EEG reference if "eeg" in cfg.ch_types: - projection = True if cfg.eeg_reference == "average" else False + if cfg.ica_use_icalabel: + assert cfg.eeg_reference == "average" + projection = False # Avg. ref. needs to be applied for MNE-ICALabel + elif cfg.eeg_reference == "average": + projection = True + else: + projection = False + + if not projection: + msg = "Applying average reference to EEG epochs used for ICA fitting." + logger.info(**gen_log_kwargs(message=msg)) + epochs.set_eeg_reference(cfg.eeg_reference, projection=projection) if cfg.ica_reject == "autoreject_local": @@ -446,9 +466,9 @@ def run_ica( if cfg.task is not None: title += f", task-{cfg.task}" - # ECG and EOG component detection + # Run MNE's built-in ECG and EOG component detection if epochs_ecg: - ecg_ics, ecg_scores = detect_bad_components( + ecg_ics, ecg_scores = detect_bad_components_mne( cfg=cfg, which="ecg", epochs=epochs_ecg, @@ -461,7 +481,7 @@ def run_ica( ecg_ics = ecg_scores = [] if epochs_eog: - eog_ics, eog_scores = detect_bad_components( + eog_ics, eog_scores = detect_bad_components_mne( cfg=cfg, which="eog", epochs=epochs_eog, @@ -473,11 +493,34 @@ def run_ica( else: eog_ics = eog_scores = [] + # Run MNE-ICALabel if requested. + if cfg.ica_use_icalabel: + icalabel_ics = [] + icalabel_labels = [] + + msg = "Performing automated artifact detection (MNE-ICALabel) …" + logger.info(**gen_log_kwargs(message=msg)) + + label_results = label_components(inst=epochs, ica=ica, method="iclabel") + for idx, label in enumerate(label_results["labels"]): + if label not in ["brain", "other"]: + icalabel_ics.append(idx) + icalabel_labels.append(label) + + msg = ( + f"Detected {len(icalabel_ics)} artifact-related independent component(s) " + f"in {len(epochs)} epochs." + ) + logger.info(**gen_log_kwargs(message=msg)) + else: + icalabel_ics = [] + + ica.exclude = sorted(set(ecg_ics + eog_ics + icalabel_ics)) + # Save ICA to disk. # We also store the automatically identified ECG- and EOG-related ICs. msg = "Saving ICA solution and detected artifacts to disk." logger.info(**gen_log_kwargs(message=msg)) - ica.exclude = sorted(set(ecg_ics + eog_ics)) ica.save(out_files["ica"], overwrite=True) _update_for_splits(out_files, "ica") @@ -492,15 +535,28 @@ def run_ica( ) ) - for component in ecg_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected ECG artifact" - - for component in eog_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected EOG artifact" + if cfg.ica_use_icalabel: + assert len(icalabel_ics) == len(icalabel_labels) + for component, label in zip(icalabel_ics, icalabel_labels): + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[ + row_idx, "status_description" + ] = f"Auto-detected {label} (MNE-ICALabel)" + else: + for component in ecg_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[ + row_idx, "status_description" + ] = "Auto-detected ECG artifact (MNE)" + + for component in eog_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[ + row_idx, "status_description" + ] = "Auto-detected EOG artifact (MNE)" tsv_data.to_csv(out_files["components"], sep="\t", index=False) @@ -510,10 +566,16 @@ def run_ica( logger.info(**gen_log_kwargs(message=msg)) report = Report(info_fname=epochs, title=title, verbose=False) + ecg_evoked = None if epochs_ecg is None else epochs_ecg.average() eog_evoked = None if epochs_eog is None else epochs_eog.average() - ecg_scores = None if len(ecg_scores) == 0 else ecg_scores - eog_scores = None if len(eog_scores) == 0 else eog_scores + + if cfg.ica_use_icalabel: + # We didn't run MNE's scoring + ecg_scores = eog_scores = None + else: + ecg_scores = None if len(ecg_scores) == 0 else ecg_scores + eog_scores = None if len(eog_scores) == 0 else eog_scores with _agg_backend(): if cfg.ica_reject == "autoreject_local": @@ -588,10 +650,12 @@ def get_config( ica_reject=config.ica_reject, ica_eog_threshold=config.ica_eog_threshold, ica_ctps_ecg_threshold=config.ica_ctps_ecg_threshold, + ica_use_icalabel=config.ica_use_icalabel, autoreject_n_interpolate=config.autoreject_n_interpolate, random_state=config.random_state, ch_types=config.ch_types, l_freq=config.l_freq, + h_freq=config.h_freq, epochs_decim=config.epochs_decim, raw_resample_sfreq=config.raw_resample_sfreq, event_repeated=config.event_repeated, diff --git a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py index 47fcb5846..d645adda5 100644 --- a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py +++ b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py @@ -71,15 +71,20 @@ t_break_annot_start_after_previous_event = 3.0 t_break_annot_stop_before_next_event = 1.5 +# Settings for autoreject and ICA if task == "N400": # test autoreject local without ICA spatial_filter = None reject = "autoreject_local" autoreject_n_interpolate = [2, 4] -elif task == "N170": # test autoreject local before ICA +elif task == "N170": # test autoreject local before ICA, and MNE-ICALabel spatial_filter = "ica" + ica_algorithm = "picard-extended_infomax" + ica_use_icalabel = True + ica_l_freq = 1 + h_freq = 100 ica_reject = "autoreject_local" reject = "autoreject_global" - autoreject_n_interpolate = [2, 4] + autoreject_n_interpolate = [12] # Only for testing! else: spatial_filter = "ica" ica_reject = dict(eeg=350e-6, eog=500e-6) @@ -249,6 +254,7 @@ baseline = (None, 0) conditions = ["stimulus/face/normal", "stimulus/car/normal"] contrasts = [("stimulus/face/normal", "stimulus/car/normal")] + cluster_forming_t_threshold = 1.25 # Only for testing! elif task == "P3": rename_events = { "response/201": "response/correct", diff --git a/mne_bids_pipeline/tests/test_documented.py b/mne_bids_pipeline/tests/test_documented.py index 097fc1032..7bb622dc4 100644 --- a/mne_bids_pipeline/tests/test_documented.py +++ b/mne_bids_pipeline/tests/test_documented.py @@ -18,15 +18,19 @@ def test_options_documented(): with open(root_path / "_config.py", "r") as fid: contents = fid.read() contents = ast.parse(contents) - in_config = [ + unannotated = [ + item.targets[0].id for item in contents.body if isinstance(item, ast.Assign) + ] + assert unannotated == [] + _config_py = [ item.target.id for item in contents.body if isinstance(item, ast.AnnAssign) ] - assert len(set(in_config)) == len(in_config) - in_config = set(in_config) + assert len(set(_config_py)) == len(_config_py) + _config_py = set(_config_py) # ensure we clean our namespace correctly config = _get_default_config() - config_names = set(d for d in dir(config) if not d.startswith("_")) - assert in_config == config_names + _get_default_config_names = set(d for d in dir(config) if not d.startswith("_")) + assert _config_py == _get_default_config_names settings_path = root_path.parent / "docs" / "source" / "settings" assert settings_path.is_dir() in_doc = set() @@ -51,8 +55,8 @@ def test_options_documented(): assert val not in in_doc, "Duplicate documentation" in_doc.add(val) what = "docs/source/settings doc" - assert in_doc.difference(in_config) == set(), f"Extra values in {what}" - assert in_config.difference(in_doc) == set(), f"Values missing from {what}" + assert in_doc.difference(_config_py) == set(), f"Extra values in {what}" + assert _config_py.difference(in_doc) == set(), f"Values missing from {what}" def test_datasets_in_doc(): diff --git a/pyproject.toml b/pyproject.toml index 5d217e36d..344f0dc8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,8 @@ dependencies = [ "autoreject", "mne[hdf5] >=1.2", "mne-bids[full]", + "mne-icalabel", + "onnxruntime", # for mne-icalabel "filelock", "setuptools >=65", ]