Skip to content

Commit

Permalink
MRG: Update repository structure (#21)
Browse files Browse the repository at this point in the history
* more iclabel to a 'icabel' submodule and add '_' to all private function/class

* add entrypoints to iclabel and to mne_icalabel

* black

* fix missed

* rename to label_components

* add tests

* fix resample_poly that requires up and down as integers

* sort imports with isort

* run sort and black

* fix for resampling

* run black

* simpler
  • Loading branch information
mscheltienne authored Apr 29, 2022
1 parent 7bf7956 commit 32edd7d
Show file tree
Hide file tree
Showing 62 changed files with 287 additions and 117 deletions.
2 changes: 2 additions & 0 deletions mne_icalabel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
# License: BSD (3-clause)

__version__ = "0.1dev0"

from .label_components import label_components # noqa: F401
8 changes: 8 additions & 0 deletions mne_icalabel/iclabel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""ICLabel - An automated electroencephalographic independent component
classifier, dataset, and website.
This is a python implementation of the EEGLAB plugin 'ICLabel'."""

from .features import get_features # noqa: F401
from .label_components import label_components # noqa: F401
from .network import ICLabelNet, run_iclabel # noqa: F401
File renamed without changes.
59 changes: 37 additions & 22 deletions mne_icalabel/features.py → mne_icalabel/iclabel/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from numpy.typing import NDArray
from scipy.signal import resample_poly

from .utils import _next_power_of_2, gdatav4, mne_to_eeglab_locs, pol2cart
from .utils import _gdatav4, _mne_to_eeglab_locs, _next_power_of_2, _pol2cart


def get_features(inst: Union[BaseRaw, BaseEpochs], ica: ICA):
Expand All @@ -21,23 +21,23 @@ def get_features(inst: Union[BaseRaw, BaseEpochs], ica: ICA):
ica : ICA
MNE ICA decomposition.
"""
icawinv, _ = retrieve_eeglab_icawinv(ica)
icaact = compute_ica_activations(inst, ica)
icawinv, _ = _retrieve_eeglab_icawinv(ica)
icaact = _compute_ica_activations(inst, ica)

# compute topographic feature (float32)
topo = eeg_topoplot(inst, icawinv)
topo = _eeg_topoplot(inst, icawinv)

# compute psd feature (float32)
psd = eeg_rpsd(inst, ica, icaact)
psd = _eeg_rpsd(inst, ica, icaact)

# compute autocorr feature (float32)
if isinstance(inst, BaseRaw):
if 5 < inst.times.size / inst.info["sfreq"]:
autocorr = eeg_autocorr_welch(inst, ica, icaact)
autocorr = _eeg_autocorr_welch(inst, ica, icaact)
else:
autocorr = eeg_autocorr(inst, ica, icaact)
autocorr = _eeg_autocorr(inst, ica, icaact)
else:
autocorr = eeg_autocorr_fftw(inst, ica, icaact)
autocorr = _eeg_autocorr_fftw(inst, ica, icaact)

# scale by 0.99
topo *= 0.99
Expand All @@ -47,7 +47,7 @@ def get_features(inst: Union[BaseRaw, BaseEpochs], ica: ICA):
return topo, psd, autocorr


def retrieve_eeglab_icawinv(
def _retrieve_eeglab_icawinv(
ica: ICA,
) -> Tuple[NDArray[float], NDArray[float]]:
"""
Expand Down Expand Up @@ -75,7 +75,7 @@ def retrieve_eeglab_icawinv(
return icawinv, weights


def compute_ica_activations(inst: Union[BaseRaw, BaseEpochs], ica: ICA) -> NDArray[float]:
def _compute_ica_activations(inst: Union[BaseRaw, BaseEpochs], ica: ICA) -> NDArray[float]:
"""Compute the ICA activations 'icaact' variable from an MNE ICA instance.
Parameters
Expand Down Expand Up @@ -105,7 +105,7 @@ def compute_ica_activations(inst: Union[BaseRaw, BaseEpochs], ica: ICA) -> NDArr
assumed that 'common', 'average' and 'averef' are all denoting a common
average reference.
"""
icawinv, weights = retrieve_eeglab_icawinv(ica)
icawinv, weights = _retrieve_eeglab_icawinv(ica)
icasphere = np.eye(icawinv.shape[0])
data = inst.get_data(picks=ica.ch_names) * 1e6
icaact = (weights[0 : ica.n_components_, :] @ icasphere) @ data
Expand All @@ -117,12 +117,12 @@ def compute_ica_activations(inst: Union[BaseRaw, BaseEpochs], ica: ICA) -> NDArr


# ----------------------------------------------------------------------------
def eeg_topoplot(inst: Union[BaseRaw, BaseEpochs], icawinv: NDArray[float]) -> NDArray[float]:
def _eeg_topoplot(inst: Union[BaseRaw, BaseEpochs], icawinv: NDArray[float]) -> NDArray[float]:
"""Topoplot feature."""
# TODO: Selection of channels is missing.
ncomp = icawinv.shape[-1]
topo = np.zeros((32, 32, 1, ncomp))
rd, th = mne_to_eeglab_locs(inst)
rd, th = _mne_to_eeglab_locs(inst)
th = np.pi / 180 * th # convert degrees to radians
for it in range(ncomp):
temp_topo = _topoplotFast(icawinv[:, it], rd, th)
Expand All @@ -139,7 +139,7 @@ def _topoplotFast(values: NDArray[float], rd: NDArray[float], th: NDArray[float]
rmax = 0.5 # actual head radius

# convert electrode locations from polar to cartesian coordinates
x, y = pol2cart(th, rd)
x, y = _pol2cart(th, rd)

# prepare coordinates
# Comments in MATLAB (L750:753) are:
Expand Down Expand Up @@ -176,7 +176,7 @@ def _topoplotFast(values: NDArray[float], rd: NDArray[float], th: NDArray[float]
yi = np.linspace(ymin, ymax, GRID_SCALE).astype(np.float64).reshape((1, -1))
# additional step for gdatav4 compared to MATLAB: linspace to meshgrid
XQ, YQ = np.meshgrid(xi, yi)
Xi, Yi, Zi = gdatav4(x, y, values.reshape((-1, 1)), XQ, YQ)
Xi, Yi, Zi = _gdatav4(x, y, values.reshape((-1, 1)), XQ, YQ)
# additional step for gdatav4 compared to MATLAB: transpose
Zi = Zi.T

Expand All @@ -188,7 +188,7 @@ def _topoplotFast(values: NDArray[float], rd: NDArray[float], th: NDArray[float]


# ----------------------------------------------------------------------------
def eeg_rpsd(inst: Union[BaseRaw, BaseEpochs], ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
def _eeg_rpsd(inst: Union[BaseRaw, BaseEpochs], ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
"""PSD feature."""
assert isinstance(inst, (BaseRaw, BaseEpochs)) # sanity-check
constants = _eeg_rpsd_constants(inst, ica)
Expand Down Expand Up @@ -316,7 +316,7 @@ def _eeg_rpsd_format(
return psd[:, :, np.newaxis, np.newaxis].transpose([2, 1, 3, 0]).astype(np.float32)


def eeg_autocorr_welch(raw: BaseRaw, ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
def _eeg_autocorr_welch(raw: BaseRaw, ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
"""Autocorrelation feature applied on raw object with at least 5 * fs
samples (5 seconds).
MATLAB: 'eeg_autocorr_welch.m'."""
Expand Down Expand Up @@ -387,12 +387,17 @@ def eeg_autocorr_welch(raw: BaseRaw, ica: ICA, icaact: NDArray[float]) -> NDArra
ac = np.divide(ac, den)

# resample to 1 second at 100 samples/sec
resamp = resample_poly(ac.T, 100, raw.info["sfreq"]).T
# i.e. the resampling must output an array of shape (components, 101), thus
# respecting '100 < ac.T.shape[0] * 100 / down <= 101'.
down = int(raw.info["sfreq"])
if 101 < ac.shape[1] * 100 / down:
down += 1
resamp = resample_poly(ac.T, 100, down).T
resamp = resamp[:, 1:, np.newaxis, np.newaxis].transpose([2, 1, 3, 0])
return np.real(resamp).astype(np.float32)


def eeg_autocorr(raw: BaseRaw, ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
def _eeg_autocorr(raw: BaseRaw, ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
"""Autocorrelation feature applied on raw object that do not have enough
sampes for eeg_autocorr_welch.
MATLAB: 'eeg_autocorr.m'."""
Expand Down Expand Up @@ -421,12 +426,17 @@ def eeg_autocorr(raw: BaseRaw, ica: ICA, icaact: NDArray[float]) -> NDArray[floa
ac = np.divide(ac.T, ac[:, 0]).T

# resample to 1 second at 100 samples/sec
resamp = resample_poly(ac.T, 100, raw.info["sfreq"]).T
# i.e. the resampling must output an array of shape (components, 101), thus
# respecting '100 < ac.T.shape[0] * 100 / down <= 101'.
down = int(raw.info["sfreq"])
if 101 < ac.shape[1] * 100 / down:
down += 1
resamp = resample_poly(ac.T, 100, down).T
resamp = resamp[:, 1:, np.newaxis, np.newaxis].transpose([2, 1, 3, 0])
return resamp.astype(np.float32)


def eeg_autocorr_fftw(epochs: BaseEpochs, ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
def _eeg_autocorr_fftw(epochs: BaseEpochs, ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
"""Autocorrelation feature applied on epoch object.
MATLAB: 'eeg_autocorr_fftw.m'."""
assert isinstance(epochs, BaseEpochs) # sanity-check
Expand All @@ -452,6 +462,11 @@ def eeg_autocorr_fftw(epochs: BaseEpochs, ica: ICA, icaact: NDArray[float]) -> N
ac = np.divide(ac.T, ac[:, 0]).T

# resample to 1 second at 100 samples/sec
resamp = resample_poly(ac.T, 100, epochs.info["sfreq"]).T
# i.e. the resampling must output an array of shape (components, 101), thus
# respecting '100 < ac.T.shape[0] * 100 / down <= 101'.
down = int(epochs.info["sfreq"])
if 101 < ac.shape[1] * 100 / down:
down += 1
resamp = resample_poly(ac.T, 100, down).T
resamp = resamp[:, 1:, np.newaxis, np.newaxis].transpose([2, 1, 3, 0])
return np.real(resamp).astype(np.float32)
32 changes: 32 additions & 0 deletions mne_icalabel/iclabel/label_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Union

from mne import BaseEpochs
from mne.io import BaseRaw
from mne.preprocessing import ICA

from .features import get_features
from .network import run_iclabel


def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA):
"""
Label the provided ICA components with the ICLabel neural network. This
network uses 3 features:
- Topographic maps, based on the ICA decomposition.
- Power Spectral Density (PSD), based on the ICA decomposition and the
provided instance.
- Autocorrelation, based on the ICA decomposition and the provided
instance.
Parameters
----------
inst : Raw | Epochs
Instance used to fit the ICA decomposition. The instance should be
referenced to a common average and bandpass filtered between 1 and
100 Hz.
ica : ICA
ICA decomposition of the provided instance.
"""
features = get_features(inst, ica)
labels = run_iclabel(*features)
return labels
20 changes: 10 additions & 10 deletions mne_icalabel/network.py → mne_icalabel/iclabel/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numpy.typing import ArrayLike


class ICLabelNetImg(nn.Module):
class _ICLabelNetImg(nn.Module):
def __init__(self):
super().__init__()

Expand Down Expand Up @@ -45,7 +45,7 @@ def forward(self, x):
return self.sequential(x)


class ICLabelNetPSDS(nn.Module):
class _ICLabelNetPSDS(nn.Module):
def __init__(self):
super().__init__()

Expand Down Expand Up @@ -81,7 +81,7 @@ def forward(self, x):
return self.sequential(x)


class ICLabelNetAutocorr(nn.Module):
class _ICLabelNetAutocorr(nn.Module):
def __init__(self):
super().__init__()

Expand Down Expand Up @@ -121,9 +121,9 @@ class ICLabelNet(nn.Module):
def __init__(self):
super().__init__()

self.img_conv = ICLabelNetImg()
self.psds_conv = ICLabelNetPSDS()
self.autocorr_conv = ICLabelNetAutocorr()
self.img_conv = _ICLabelNetImg()
self.psds_conv = _ICLabelNetPSDS()
self.autocorr_conv = _ICLabelNetAutocorr()

self.conv = nn.Conv2d(
in_channels=712,
Expand Down Expand Up @@ -175,7 +175,7 @@ def forward(
return labels


def format_input(topo: ArrayLike, psd: ArrayLike, autocorr: ArrayLike):
def _format_input(topo: ArrayLike, psd: ArrayLike, autocorr: ArrayLike):
"""Replicate the input formatting in EEGLAB -ICLabel.
.. code-block:: matlab
Expand All @@ -194,7 +194,7 @@ def format_input(topo: ArrayLike, psd: ArrayLike, autocorr: ArrayLike):
return formatted_topo, formatted_psd, formatted_autocorr


def format_input_for_torch(topo: ArrayLike, psd: ArrayLike, autocorr: ArrayLike):
def _format_input_for_torch(topo: ArrayLike, psd: ArrayLike, autocorr: ArrayLike):
"""Format the features to the correct shape and type for pytorch."""
topo = np.transpose(topo, (3, 2, 0, 1))
psd = np.transpose(psd, (3, 2, 0, 1))
Expand All @@ -210,12 +210,12 @@ def format_input_for_torch(topo: ArrayLike, psd: ArrayLike, autocorr: ArrayLike)
def run_iclabel(images: ArrayLike, psds: ArrayLike, autocorr: ArrayLike) -> ArrayLike:
"""Run the ICLabel network on the provided set of features. The features
are un-formatted and are as-returned by ``get_features``."""
ica_network_file = files("mne_icalabel").joinpath("assets/iclabelNet.pt")
ica_network_file = files("mne_icalabel.iclabel").joinpath("assets/iclabelNet.pt")

# Get network and load weights
iclabel_net = ICLabelNet()
iclabel_net.load_state_dict(torch.load(ica_network_file))

# Format input and get labels
labels = iclabel_net(*format_input_for_torch(*format_input(images, psds, autocorr)))
labels = iclabel_net(*_format_input_for_torch(*_format_input(images, psds, autocorr)))
return labels.detach().numpy()
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 32edd7d

Please sign in to comment.