diff --git a/.circleci/config.yml b/.circleci/config.yml index 3384ba89..4619a212 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -43,7 +43,7 @@ jobs: name: Set BASH_ENV command: | set -e - ./scripts/setup_circleci.sh + ./scripts/setup_xvfb.sh sudo apt install -qq graphviz optipng python3.8-venv python3-venv libxft2 ffmpeg python3.8 -m venv ~/python_env echo "set -e" >> $BASH_ENV @@ -92,7 +92,11 @@ jobs: pip install --progress-bar off . pip install --upgrade --progress-bar off -r requirements_testing.txt pip install --upgrade --progress-bar off -r requirements_doc.txt - + pip install --upgrade --progress-bar off PyQt5 + python -m pip uninstall -yq sphinx-gallery mne-qt-browser + # TODO: Revert to upstream/main once https://github.com/mne-tools/mne-qt-browser/pull/105 is merged + python -m pip install --upgrade --progress-bar off https://github.com/mne-tools/mne-qt-browser/zipball/main https://github.com/sphinx-gallery/sphinx-gallery/zipball/master + - save_cache: key: pip-cache paths: @@ -104,6 +108,10 @@ jobs: - ~/.local/lib/python3.8/site-packages - ~/.local/bin + # - run: + # name: Check Qt + # command: LD_DEBUG=libs python -c "from PySide6.QtWidgets import QApplication, QWidget; app = QApplication([])" + # Look at what we have and fail early if there is some library conflict - run: name: Check installation diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 01c5f2a8..bb74e1f3 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -52,6 +52,7 @@ jobs: - name: Install dependencies run: | pip install --upgrade --progress-bar off pip setuptools wheel + pip install $STD_ARGS --only-binary ":all:" PyQt6 PyQt6-sip PyQt6-Qt6 # build with sdist directly - uses: actions/checkout@v3 @@ -118,6 +119,10 @@ jobs: - uses: actions/checkout@v3 + - name: 'Setup xvfb' + if: "matrix.os == 'ubuntu-latest'" + run: ./scripts/setup_xvfb.sh + - name: Install mne-icalabel run: | pip install --upgrade --progress-bar off pip setuptools wheel @@ -149,10 +154,15 @@ jobs: shell: bash -el {0} run: mne sys_info + - shell: bash -el {0} + run: | + QT_QPA_PLATFORM=xcb LIBGL_DEBUG=verbose LD_DEBUG=libs + name: 'Check Qt GL' + - name: Run pytest shell: bash run: | - python -m pytest ./mne_icalabel --cov=mne_icalabel --cov-report=xml --cov-config=setup.cfg --verbose --ignore mne-python + python -m pytest ./mne_icalabel --cov=mne_icalabel --cov-report=xml --cov-config=setup.cfg --verbose --ignore mne-python -vv mne_icalabel/gui - name: Upload coverage stats to codecov if: "matrix.os == 'ubuntu-latest'" diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 81af8ea8..a0d640eb 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -24,7 +24,7 @@ Version 0.3 (Unreleased) Enhancements ~~~~~~~~~~~~ -- +- Adding a GUI to facilitate the labeling of ICA components, by `Adam Li`_ and `Mathieu Scheltienne`_ (:gh:`66`) Bug ~~~ @@ -41,6 +41,7 @@ Authors * `Mathieu Scheltienne`_ * `Anand Saini`_ +* `Adam Li`_ :doc:`Find out what was new in previous releases ` diff --git a/examples/label_components.py b/examples/label_components.py new file mode 100644 index 00000000..a158177f --- /dev/null +++ b/examples/label_components.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +""" +.. _tut-label-ica-components: + +Labeling ICA components with a GUI +================================== + +This tutorial covers how to label ICA components with a GUI. +""" + +# %% + +import os + +import mne +from mne.preprocessing import ICA + +from mne_icalabel.gui import label_ica_components + +# %% +# Load in some sample data + +sample_data_folder = mne.datasets.sample.data_path() +sample_data_raw_file = os.path.join( + sample_data_folder, "MEG", "sample", "sample_audvis_filt-0-40_raw.fif" +) +raw = mne.io.read_raw_fif(sample_data_raw_file) + +# Here we'll crop to 60 seconds and drop gradiometer channels for speed +raw.crop(tmax=60.0).pick_types(meg="mag", eeg=True, stim=True, eog=True) +raw.load_data() + +# %% +# Preprocess and run ICA on the data +# ---------------------------------- +# Before labeling components with the GUI, one needs to filter the data +# and then fit the ICA instance. Afterwards, one can run the GUI using the +# ``Raw`` data object and the fitted ``ICA`` instance. The GUI will modify +# the ICA instance in place, and add the labels of each component to +# the ``labels_`` attribute. + +# high-pass filter the data and then perform ICA +filt_raw = raw.copy().filter(l_freq=1.0, h_freq=None) +ica = ICA(n_components=15, max_iter="auto", random_state=97) +ica.fit(filt_raw) + +# now label the components using a GUI +mne.set_log_level("DEBUG") +gui = label_ica_components(raw, ica) + +# The `ica` object is modified to contain the component labels +# after closing the GUI and can now be saved +# gui.close() # typically you close when done + +# Now, we can take a look at the components, which can be +# saved into the BIDs directory. +print(ica.labels_) + +# %% +# Save the labeled components +# --------------------------- +# After the GUI labels, save the components using the ``write_components_tsv`` +# function. This will save the ICA annotations to disc in BIDS-Derivative for +# EEG data format. +# +# Note: BIDS-EEG-Derivatives is not fully specified, so this functionality +# may change in the future without notice. + +# fname = '' +# write_components_tsv(ica, fname) diff --git a/mne_icalabel/__init__.py b/mne_icalabel/__init__.py index 471e7f1d..9f4eef5c 100644 --- a/mne_icalabel/__init__.py +++ b/mne_icalabel/__init__.py @@ -7,4 +7,5 @@ __version__ = "0.3.dev0" +from . import gui from .label_components import label_components # noqa: F401 diff --git a/mne_icalabel/annotation/bids.py b/mne_icalabel/annotation/bids.py index d1f44759..d820442f 100644 --- a/mne_icalabel/annotation/bids.py +++ b/mne_icalabel/annotation/bids.py @@ -4,6 +4,7 @@ from mne.preprocessing import ICA from mne.utils import _check_pandas_installed +from ..config import ICLABEL_LABELS_TO_MNE from ..iclabel.config import ICLABEL_STRING_TO_NUMERICAL @@ -41,17 +42,31 @@ def write_components_tsv(ica: ICA, fname): if not isinstance(fname, BIDSPath): fname = get_bids_path_from_fname(fname) + # initialize status, description and IC type + status = ["good"] * ica.n_components_ + status_description = ["n/a"] * ica.n_components_ + ic_type = ["n/a"] * ica.n_components_ + + # extract the component labels if they are present in the ICA instance + if ica.labels_: + for label, comps in ica.labels_.items(): + this_status = "good" if label == "brain" else "bad" + if label in ICLABEL_LABELS_TO_MNE.values(): + for comp in comps: + status[comp] = this_status + ic_type[comp] = label + # Create TSV. tsv_data = pd.DataFrame( dict( component=list(range(ica.n_components_)), type=["ica"] * ica.n_components_, description=["Independent Component"] * ica.n_components_, - status=["good"] * ica.n_components_, - status_description=["n/a"] * ica.n_components_, + status=status, + status_description=status_description, annotate_method=["n/a"] * ica.n_components_, annotate_author=["n/a"] * ica.n_components_, - ic_type=["n/a"] * ica.n_components_, + ic_type=ic_type, ) ) # make sure parent directories exist diff --git a/mne_icalabel/annotation/tests/test_bids.py b/mne_icalabel/annotation/tests/test_bids.py index 6807d0ee..a480e763 100644 --- a/mne_icalabel/annotation/tests/test_bids.py +++ b/mne_icalabel/annotation/tests/test_bids.py @@ -52,6 +52,9 @@ def test_write_channels_tsv(_ica, _tmp_bids_path): suffix="channels", extension=".tsv", ) + _ica = _ica.copy() + _ica.labels_["ecg"] = [0] + write_components_tsv(_ica, deriv_fname) assert deriv_fname.fpath.exists() @@ -59,7 +62,9 @@ def test_write_channels_tsv(_ica, _tmp_bids_path): assert expected_json.fpath.exists() ch_tsv = pd.read_csv(deriv_fname, sep="\t") - assert all(status == "good" for status in ch_tsv["status"]) + assert all(status == "good" for status in ch_tsv["status"][1:]) + assert ch_tsv["status"][0] == "bad" + assert ch_tsv["ic_type"].values[0] == "ecg" def test_mark_components(_ica, _tmp_bids_path): diff --git a/mne_icalabel/commands/__init__.py b/mne_icalabel/commands/__init__.py new file mode 100644 index 00000000..be522b8f --- /dev/null +++ b/mne_icalabel/commands/__init__.py @@ -0,0 +1 @@ +"""Entry-points for mne-icalabel commands.""" diff --git a/mne_icalabel/commands/mne_gui_ic_annotation.py b/mne_icalabel/commands/mne_gui_ic_annotation.py new file mode 100644 index 00000000..133276d9 --- /dev/null +++ b/mne_icalabel/commands/mne_gui_ic_annotation.py @@ -0,0 +1,36 @@ +import argparse + +from qtpy.QtWidgets import QApplication + +from mne_icalabel.gui._label_components import ICAComponentLabeler + + +def main(): + """Entry point for mne_gui_ic_annotation.""" + parser = argparse.ArgumentParser(prog="mne-icalabel", description="IC annotation GUI") + parser.add_argument("--dev", help="loads a sample dataset.", action="store_true") + args = parser.parse_args() + + if not args.dev: + raise NotImplementedError + else: + from mne.datasets import sample + from mne.io import read_raw + from mne.preprocessing import ICA + + directory = sample.data_path() / "MEG" / "sample" + raw = read_raw(directory / "sample_audvis_raw.fif", preload=False) + raw.crop(0, 10).pick_types(eeg=True, exclude="bads") + raw.load_data() + # preprocess + raw.filter(l_freq=1.0, h_freq=100.0) + raw.set_eeg_reference("average") + + n_components = 15 + ica = ICA(n_components=n_components, method="picard") + ica.fit(raw) + + app = QApplication([]) + window = ICAComponentLabeler(inst=raw, ica=ica) + window.show() + app.exec() diff --git a/mne_icalabel/config.py b/mne_icalabel/config.py index fb4e801d..3b662248 100644 --- a/mne_icalabel/config.py +++ b/mne_icalabel/config.py @@ -4,3 +4,14 @@ "iclabel": iclabel_label_components, "manual": None, } + +# map ICLabel labels to MNE str format +ICLABEL_LABELS_TO_MNE = { + "Brain": "brain", + "Muscle": "muscle", + "Eye": "eog", + "Heart": "ecg", + "Line Noise": "line_noise", + "Channel Noise": "ch_noise", + "Other": "other", +} diff --git a/mne_icalabel/conftest.py b/mne_icalabel/conftest.py index eac9d8c4..0e752a82 100644 --- a/mne_icalabel/conftest.py +++ b/mne_icalabel/conftest.py @@ -2,7 +2,6 @@ # Author: Eric Larson # # License: BSD-3-Clause - import warnings import pytest diff --git a/mne_icalabel/gui/__init__.py b/mne_icalabel/gui/__init__.py new file mode 100644 index 00000000..7bc75eb0 --- /dev/null +++ b/mne_icalabel/gui/__init__.py @@ -0,0 +1,32 @@ +from mne.preprocessing import ICA + + +def label_ica_components(inst, ica: ICA, show: bool = True, block: bool = False): + """Launch the IC labelling GUI. + + Parameters + ---------- + inst : : Raw | Epochs + `~mne.io.Raw` or `~mne.Epochs` instance used to fit the `~mne.preprocessing.ICA` decomposition. + ica : ICA + The ICA object fitted on `inst`. + show : bool + Show the GUI if True. + block : bool + Whether to halt program execution until the figure is closed. + + Returns + ------- + gui : instance of ICAComponentLabeler + The graphical user interface (GUI) window. + """ + from mne.viz.backends._utils import _init_mne_qtapp, _qt_app_exec + + from ._label_components import ICAComponentLabeler + + # get application + app = _init_mne_qtapp() + gui = ICAComponentLabeler(inst=inst, ica=ica, show=show) + if block: + _qt_app_exec(app) + return gui diff --git a/mne_icalabel/gui/_label_components.py b/mne_icalabel/gui/_label_components.py new file mode 100644 index 00000000..630226b4 --- /dev/null +++ b/mne_icalabel/gui/_label_components.py @@ -0,0 +1,274 @@ +from typing import Dict, List, Union + +from matplotlib import pyplot as plt +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from mne import BaseEpochs +from mne.io import BaseRaw +from mne.preprocessing import ICA +from mne.utils import _validate_type +from mne.viz import set_browser_backend +from qtpy.QtCore import Qt, Slot +from qtpy.QtWidgets import ( + QAbstractItemView, + QButtonGroup, + QGridLayout, + QListWidget, + QMainWindow, + QPushButton, + QVBoxLayout, + QWidget, +) + +from mne_icalabel.config import ICLABEL_LABELS_TO_MNE + + +class ICAComponentLabeler(QMainWindow): + """Qt GUI to annotate components. + + Parameters + ---------- + inst : Raw | Epochs + ica : ICA + """ + + def __init__(self, inst: Union[BaseRaw, BaseEpochs], ica: ICA, show: bool = True) -> None: + ICAComponentLabeler._check_inst_ica(inst, ica) + super().__init__() # initialize the QMainwindow + set_browser_backend("qt") # force MNE to use the QT Browser + + # keep an internal pointer to the instance and to the ICA + self._inst = inst + self._ica = ica + # define valid labels + self._labels = list(ICLABEL_LABELS_TO_MNE.keys()) + # prepare the GUI + self._load_ui() + + # dictionary to remember selected labels, with the key as the 'indice' + # of the component and the value as the 'label'. + self.selected_labels: Dict[int, str] = dict() + + # connect signal to slots + self._connect_signals_to_slots() + + # select first IC + self._selected_component = 0 + self._components_listWidget.setCurrentRow(0) # emit signal + + if show: + self.show() + + def _save_labels(self) -> None: + """Save the selected labels to the ICA instance.""" + # convert the dict[int, str] to dict[str, List[int]] with the key as + # 'label' and value as a list of component indices. + labels2save: Dict[str, List[int]] = {key: [] for key in self.labels} + for component, label in self.selected_labels.items(): + labels2save[label].append(component) + # sanity-check: uniqueness + assert all(len(elt) == len(set(elt)) for elt in labels2save.values()) + + for label, comp_list in labels2save.items(): + mne_label = ICLABEL_LABELS_TO_MNE[label] + if mne_label not in self._ica.labels_: + self._ica.labels_[mne_label] = comp_list + continue + for comp in comp_list: + if comp not in self._ica.labels_[mne_label]: + self._ica.labels_[mne_label].append(comp) + + # - UI -------------------------------------------------------------------- + def _load_ui(self) -> None: + """Prepare the GUI. + + Widgets + ------- + self._components_listWidget + self._labels_buttonGroup + self._mpl_widgets (dict) + - topomap + - psd + self._timeSeries_widget + + Matplotlib figures + ------------------ + self._mpl_figures (dict) + - topomap + - psd + """ + self.setWindowTitle("ICA Component Labeler") + self.setContextMenuPolicy(Qt.NoContextMenu) + + # create central widget and main layout + self._central_widget = QWidget(self) + self._central_widget.setObjectName("central_widget") + grid_layout = QGridLayout() + self._central_widget.setLayout(grid_layout) + self.setCentralWidget(self._central_widget) + + # QListWidget with the components' names. + self._components_listWidget = QListWidget(self._central_widget) + self._components_listWidget.setSelectionMode(QAbstractItemView.SingleSelection) + self._components_listWidget.addItems( + [f"ICA{str(k).zfill(3)}" for k in range(self.n_components_)] + ) + grid_layout.addWidget(self._components_listWidget, 0, 0, 2, 1) + + # buttons to select labels + self._labels_buttonGroup = QButtonGroup(self._central_widget) + buttonGroup_layout = QVBoxLayout() + self._labels_buttonGroup.setExclusive(True) + for k, label in enumerate(self.labels + ["Reset"]): + pushButton = QPushButton(self._central_widget) + pushButton.setObjectName(f"pushButton_{label.lower().replace(' ', '_')}") + pushButton.setText(label) + pushButton.setCheckable(True) + pushButton.setChecked(False) + pushButton.setEnabled(False) + # buttons are ordered in the same order as labels + self._labels_buttonGroup.addButton(pushButton, k) + buttonGroup_layout.addWidget(pushButton) + grid_layout.addLayout(buttonGroup_layout, 0, 1, 2, 1) + + # matplotlib figures + self._mpl_figures = dict() + self._mpl_widgets = dict() + + # topographic map + fig, _ = plt.subplots(1, 1, figsize=(4, 4), dpi=100) + fig.subplots_adjust(bottom=0, left=0, right=1, top=1, wspace=0, hspace=0) + self._mpl_figures["topomap"] = fig + self._mpl_widgets["topomap"] = FigureCanvasQTAgg(fig) + grid_layout.addWidget(self._mpl_widgets["topomap"], 0, 2) + + # PSD + fig, _ = plt.subplots(1, 1, figsize=(4, 4), dpi=100) + fig.subplots_adjust(bottom=0, left=0, right=1, top=1, wspace=0, hspace=0) + self._mpl_figures["psd"] = fig + self._mpl_widgets["psd"] = FigureCanvasQTAgg(fig) + grid_layout.addWidget(self._mpl_widgets["psd"], 0, 3) + + # time-series, initialized with an empty widget. + # TODO: When the browser supports changing the instance displayed, this + # should be initialized to a browser with the first IC. + self._timeSeries_widget = QWidget() + grid_layout.addWidget(self._timeSeries_widget, 1, 2, 1, 2) + + # - Checkers -------------------------------------------------------------- + @staticmethod + def _check_inst_ica(inst: Union[BaseRaw, BaseEpochs], ica: ICA) -> None: + """Check if the ICA was fitted.""" + _validate_type(inst, (BaseRaw, BaseEpochs), "inst", "raw or epochs") + _validate_type(ica, ICA, "ica", "ICA") + if ica.current_fit == "unfitted": + raise ValueError( + "ICA instance should be fit on the raw data before " + "running the ICA labeling GUI. Run `ica.fit(inst)`." + ) + + # - Properties ------------------------------------------------------------ + @property + def inst(self) -> Union[BaseRaw, BaseEpochs]: + """Instance on which the ICA has been fitted.""" + return self._inst + + @property + def ica(self) -> ICA: + """Fitted ICA decomposition.""" + return self._ica + + @property + def n_components_(self) -> int: + """The number of fitted components.""" + return self._ica.n_components_ + + @property + def labels(self) -> List[str]: + """List of valid labels.""" + return self._labels + + @property + def selected_component(self) -> int: + """IC selected and displayed.""" + return self._selected_component + + # - Slots ----------------------------------------------------------------- + def _connect_signals_to_slots(self) -> None: + """Connect all the signals and slots of the GUI.""" + self._components_listWidget.currentRowChanged.connect(self._components_listWidget_clicked) + self._labels_buttonGroup.buttons()[-1].clicked.connect(self._reset) + + @Slot() + def _components_listWidget_clicked(self) -> None: + """Update the plots and the saved labels accordingly.""" + self._update_selected_labels() + self._reset_buttons() + + # update selected IC + self._selected_component = self._components_listWidget.currentRow() + + # reset matplotlib figures + for fig in self._mpl_figures.values(): + fig.axes[0].clear() + # create dummy figure and axes to hold the unused plots from plot_properties + dummy_fig, dummy_axes = plt.subplots(3) + # create axes argument provided to plot_properties + axes = [ + self._mpl_figures["topomap"].axes[0], + dummy_axes[0], + dummy_axes[1], + self._mpl_figures["psd"].axes[0], + dummy_axes[2], + ] + # update matplotlib plots with plot_properties + self.ica.plot_properties(self.inst, axes=axes, picks=self.selected_component, show=False) + del dummy_fig + # remove title from topomap axes + self._mpl_figures["topomap"].axes[0].set_title("") + # update the matplotlib canvas + for fig in self._mpl_figures.values(): + fig.tight_layout() + fig.canvas.draw() + fig.canvas.flush_events() + + # swap timeSeries widget + timeSeries_widget = self.ica.plot_sources(self.inst, picks=[self.selected_component]) + self._central_widget.layout().replaceWidget(self._timeSeries_widget, timeSeries_widget) + self._timeSeries_widget.setParent(None) + self._timeSeries_widget = timeSeries_widget + + # select buttons that were previously selected for this IC + if self.selected_component in self.selected_labels: + idx = self.labels.index(self.selected_labels[self.selected_component]) + self._labels_buttonGroup.button(idx).setChecked(True) + + def _update_selected_labels(self) -> None: + """Update the labels saved.""" + selected = self._labels_buttonGroup.checkedButton() + if selected is not None: + self.selected_labels[self.selected_component] = selected.text() + self._save_labels() # updates the ICA instance every time + + @Slot() + def _reset(self) -> None: # noqa: D401 + """Action of the reset button.""" + self._reset_buttons() + if self.selected_component in self.selected_labels: + del self.selected_labels[self.selected_component] + + def _reset_buttons(self) -> None: + """Reset all buttons.""" + self._labels_buttonGroup.setExclusive(False) + for button in self._labels_buttonGroup.buttons(): + button.setEnabled(True) + button.setChecked(False) + self._labels_buttonGroup.setExclusive(True) + + def closeEvent(self, event) -> None: + """Clean up upon closing the window. + + Update the labels since the user might have selected one for the + currently being displayed IC. + """ + self._update_selected_labels() + event.accept() diff --git a/mne_icalabel/gui/tests/test_label_components.py b/mne_icalabel/gui/tests/test_label_components.py new file mode 100644 index 00000000..7bae8730 --- /dev/null +++ b/mne_icalabel/gui/tests/test_label_components.py @@ -0,0 +1,84 @@ +import os.path as op + +import matplotlib.pyplot as plt +import pytest +from mne.datasets import testing +from mne.io import read_raw_edf +from mne.preprocessing import ICA +from mne.utils import requires_version + +import mne_icalabel + + +@pytest.fixture +def _label_ica_components(): + # Use a fixture to create these classes so we can ensure that they + # are closed at the end of the test + guis = list() + + def fun(*args, **kwargs): + guis.append(mne_icalabel.gui.label_ica_components(*args, **kwargs)) + return guis[-1] + + yield fun + + for gui in guis: + try: + gui.close() + except Exception: + pass + + +@pytest.fixture(scope="module") +def load_raw_and_fit_ica(): + data_path = op.join(testing.data_path(), "EDF") + raw_fname = op.join(data_path, "test_reduced.edf") + raw = read_raw_edf(raw_fname, preload=True) + + # high-pass filter + raw.filter(l_freq=1, h_freq=100) + + # compute ICA + ica = ICA(n_components=15, random_state=12345) + ica.fit(raw) + return raw, ica + + +@pytest.fixture(scope="function") +def _fitted_ica(load_raw_and_fit_ica): + raw, ica = load_raw_and_fit_ica + return raw, ica.copy() + + +@requires_version("mne", "1.1dev0") +@testing.requires_testing_data +def test_label_components_gui_io(_fitted_ica, _label_ica_components): + """Test the input/output of the labeling ICA components GUI.""" + # get the Raw and fitted ICA instance + raw, ica = _fitted_ica + ica_copy = ica.copy() + + with pytest.raises(ValueError, match="ICA instance should be fit on"): + ica_copy.current_fit = "unfitted" + _label_ica_components(raw, ica_copy) + + +@requires_version("mne", "1.1dev0") +@testing.requires_testing_data +def test_label_components_gui_display(_fitted_ica, _label_ica_components): + raw, ica = _fitted_ica + + # test functions + gui = _label_ica_components(raw, ica) + + # test setting the label + assert gui.inst == raw + assert gui.ica == ica + assert gui.n_components_ == ica.n_components_ + + # the initial component should be 0 + assert gui.selected_component == 0 + + # there should be three figures inside the QT window + figs = list(map(plt.figure, plt.get_fignums())) + assert len(figs) == 3 diff --git a/mne_icalabel/iclabel/label_components.py b/mne_icalabel/iclabel/label_components.py index de9a094e..e948cc14 100644 --- a/mne_icalabel/iclabel/label_components.py +++ b/mne_icalabel/iclabel/label_components.py @@ -1,5 +1,6 @@ from typing import Union +import numpy as np from mne import BaseEpochs from mne.io import BaseRaw from mne.preprocessing import ICA @@ -8,7 +9,7 @@ from .network import run_iclabel -def iclabel_label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA): +def iclabel_label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA, inplace: bool = True): """Label the provided ICA components with the ICLabel neural network. This network uses 3 features: @@ -24,19 +25,22 @@ def iclabel_label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA): Parameters ---------- inst : Raw | Epochs - Instance used to fit the ICA decomposition. The instance should be - referenced to a common average and bandpass filtered between 1 and - 100 Hz. + Instance used to fit the ICA decomposition. The instance should be + referenced to a common average and bandpass filtered between 1 and + 100 Hz. ica : ICA - ICA decomposition of the provided instance. + ICA decomposition of the provided instance. + inplace : bool + Whether to modify the ``ica`` instance in place by adding the automatic + annotations to the ``labels_`` property. By default True. Returns ------- labels_pred_proba : numpy.ndarray of shape (n_components, n_classes) - The estimated corresponding predicted probabilities of output classes - for each independent component. Columns are ordered with 'brain', - 'muscle artifact', 'eye blink', 'heart beat', 'line noise', - 'channel noise', 'other'. + The estimated corresponding predicted probabilities of output classes + for each independent component. Columns are ordered with 'brain', + 'muscle artifact', 'eye blink', 'heart beat', 'line noise', + 'channel noise', 'other'. References ---------- @@ -44,4 +48,21 @@ def iclabel_label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA): """ features = get_iclabel_features(inst, ica) labels_pred_proba = run_iclabel(*features) + + if inplace: + from mne_icalabel.config import ICLABEL_LABELS_TO_MNE + + ica.labels_scores_ = labels_pred_proba + argmax_labels = np.argmax(labels_pred_proba, axis=1) + + # add labels to the ICA instance + for idx, (_, mne_label) in enumerate(ICLABEL_LABELS_TO_MNE.items()): + auto_labels = np.argwhere(argmax_labels == idx) + if mne_label not in ica.labels_: + ica.labels_[mne_label] = auto_labels + continue + for comp in auto_labels: + if comp not in ica.labels_[mne_label]: + ica.labels_[mne_label].append(comp) + return labels_pred_proba diff --git a/mne_icalabel/utils/__init__.py b/mne_icalabel/utils/__init__.py index e69de29b..341cd59a 100644 --- a/mne_icalabel/utils/__init__.py +++ b/mne_icalabel/utils/__init__.py @@ -0,0 +1,2 @@ +from ._checks import _check_qt_version, _validate_inst_and_ica +from ._docs import fill_doc diff --git a/mne_icalabel/utils/_checks.py b/mne_icalabel/utils/_checks.py index 42cba946..0bdf35d3 100644 --- a/mne_icalabel/utils/_checks.py +++ b/mne_icalabel/utils/_checks.py @@ -1,9 +1,11 @@ +import sys from typing import Union from mne import BaseEpochs +from mne.fixes import _compare_version from mne.io import BaseRaw from mne.preprocessing import ICA -from mne.utils import _validate_type +from mne.utils import _validate_type, warn def _validate_inst_and_ica(inst: Union[BaseRaw, BaseEpochs], ica: ICA): @@ -16,3 +18,28 @@ def _validate_inst_and_ica(inst: Union[BaseRaw, BaseEpochs], ica: ICA): "The provided ICA instance was not fitted. Please use the '.fit()' method to " "determine the independent components before trying to label them." ) + + +def _check_qt_version(*, return_api=False): + """Check if Qt is installed.""" + try: + from qtpy import API_NAME as api + from qtpy import QtCore + except Exception: + api = version = None + else: + try: # pyside + version = QtCore.__version__ + except AttributeError: + version = QtCore.QT_VERSION_STR + if sys.platform == "darwin" and api in ("PyQt5", "PySide2"): + if not _compare_version(version, ">=", "5.10"): + warn( + f"macOS users should use {api} >= 5.10 for GUIs, " + f"got {version}. Please upgrade e.g. with:\n\n" + f' pip install "{api}>=5.10"\n' + ) + if return_api: + return version, api + else: + return version diff --git a/requirements_testing.txt b/requirements_testing.txt index 0948c3eb..98b516cd 100644 --- a/requirements_testing.txt +++ b/requirements_testing.txt @@ -23,3 +23,6 @@ python-picard joblib scikit-learn pandas +qtpy +PyQt5 +mne-qt-browser \ No newline at end of file diff --git a/scripts/setup_circleci.sh b/scripts/setup_xvfb.sh similarity index 89% rename from scripts/setup_circleci.sh rename to scripts/setup_xvfb.sh index 67d3dc83..040541ea 100755 --- a/scripts/setup_circleci.sh +++ b/scripts/setup_xvfb.sh @@ -11,5 +11,5 @@ done # This also includes the libraries necessary for PyQt5/PyQt6 sudo apt update -sudo apt install -yqq xvfb libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-xfixes0 libopengl0 libegl1 +sudo apt install -yqq xvfb libxkbcommon-x11-0 libxcb-icccm4 libxcb-image0 libxcb-keysyms1 libxcb-randr0 libxcb-render-util0 libxcb-xinerama0 libxcb-xfixes0 libopengl0 libegl1 libosmesa6 mesa-utils libxcb-shape0 /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset diff --git a/setup.cfg b/setup.cfg index adaa5667..9651006a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,10 @@ install_requires = packages = find: include_package_data = True +[options.entry_points] +console_scripts = + mne_gui_ic_annotation = mne_icalabel.commands.mne_gui_ic_annotation:main + # Building package [bdist_wheel] universal = true