Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG: Update repository structure #21

Merged
merged 13 commits into from
Apr 29, 2022
2 changes: 2 additions & 0 deletions mne_icalabel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
# License: BSD (3-clause)

__version__ = "0.1dev0"

from .label_components import label_components # noqa: F401
8 changes: 8 additions & 0 deletions mne_icalabel/iclabel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""ICLabel - An automated electroencephalographic independent component
classifier, dataset, and website.
This is a python implementation of the EEGLAB plugin 'ICLabel'."""

from .features import get_features # noqa: F401
from .label_components import label_components # noqa: F401
from .network import ICLabelNet, run_iclabel # noqa: F401
59 changes: 37 additions & 22 deletions mne_icalabel/features.py → mne_icalabel/iclabel/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from numpy.typing import NDArray
from scipy.signal import resample_poly

from .utils import _next_power_of_2, gdatav4, mne_to_eeglab_locs, pol2cart
from .utils import _gdatav4, _mne_to_eeglab_locs, _next_power_of_2, _pol2cart


def get_features(inst: Union[BaseRaw, BaseEpochs], ica: ICA):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking maybe we even rename this to get_iclabel_features, since we'll presumably have other models which also have feature engineering which would have get_<model_X>_features.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, good idea.

Expand All @@ -21,23 +21,23 @@ def get_features(inst: Union[BaseRaw, BaseEpochs], ica: ICA):
ica : ICA
MNE ICA decomposition.
"""
icawinv, _ = retrieve_eeglab_icawinv(ica)
icaact = compute_ica_activations(inst, ica)
icawinv, _ = _retrieve_eeglab_icawinv(ica)
icaact = _compute_ica_activations(inst, ica)

# compute topographic feature (float32)
topo = eeg_topoplot(inst, icawinv)
topo = _eeg_topoplot(inst, icawinv)

# compute psd feature (float32)
psd = eeg_rpsd(inst, ica, icaact)
psd = _eeg_rpsd(inst, ica, icaact)

# compute autocorr feature (float32)
if isinstance(inst, BaseRaw):
if 5 < inst.times.size / inst.info["sfreq"]:
autocorr = eeg_autocorr_welch(inst, ica, icaact)
autocorr = _eeg_autocorr_welch(inst, ica, icaact)
else:
autocorr = eeg_autocorr(inst, ica, icaact)
autocorr = _eeg_autocorr(inst, ica, icaact)
else:
autocorr = eeg_autocorr_fftw(inst, ica, icaact)
autocorr = _eeg_autocorr_fftw(inst, ica, icaact)

# scale by 0.99
topo *= 0.99
Expand All @@ -47,7 +47,7 @@ def get_features(inst: Union[BaseRaw, BaseEpochs], ica: ICA):
return topo, psd, autocorr


def retrieve_eeglab_icawinv(
def _retrieve_eeglab_icawinv(
ica: ICA,
) -> Tuple[NDArray[float], NDArray[float]]:
"""
Expand Down Expand Up @@ -75,7 +75,7 @@ def retrieve_eeglab_icawinv(
return icawinv, weights


def compute_ica_activations(inst: Union[BaseRaw, BaseEpochs], ica: ICA) -> NDArray[float]:
def _compute_ica_activations(inst: Union[BaseRaw, BaseEpochs], ica: ICA) -> NDArray[float]:
"""Compute the ICA activations 'icaact' variable from an MNE ICA instance.
Parameters
Expand Down Expand Up @@ -105,7 +105,7 @@ def compute_ica_activations(inst: Union[BaseRaw, BaseEpochs], ica: ICA) -> NDArr
assumed that 'common', 'average' and 'averef' are all denoting a common
average reference.
"""
icawinv, weights = retrieve_eeglab_icawinv(ica)
icawinv, weights = _retrieve_eeglab_icawinv(ica)
icasphere = np.eye(icawinv.shape[0])
data = inst.get_data(picks=ica.ch_names) * 1e6
icaact = (weights[0 : ica.n_components_, :] @ icasphere) @ data
Expand All @@ -117,12 +117,12 @@ def compute_ica_activations(inst: Union[BaseRaw, BaseEpochs], ica: ICA) -> NDArr


# ----------------------------------------------------------------------------
def eeg_topoplot(inst: Union[BaseRaw, BaseEpochs], icawinv: NDArray[float]) -> NDArray[float]:
def _eeg_topoplot(inst: Union[BaseRaw, BaseEpochs], icawinv: NDArray[float]) -> NDArray[float]:
"""Topoplot feature."""
# TODO: Selection of channels is missing.
ncomp = icawinv.shape[-1]
topo = np.zeros((32, 32, 1, ncomp))
rd, th = mne_to_eeglab_locs(inst)
rd, th = _mne_to_eeglab_locs(inst)
th = np.pi / 180 * th # convert degrees to radians
for it in range(ncomp):
temp_topo = _topoplotFast(icawinv[:, it], rd, th)
Expand All @@ -139,7 +139,7 @@ def _topoplotFast(values: NDArray[float], rd: NDArray[float], th: NDArray[float]
rmax = 0.5 # actual head radius

# convert electrode locations from polar to cartesian coordinates
x, y = pol2cart(th, rd)
x, y = _pol2cart(th, rd)

# prepare coordinates
# Comments in MATLAB (L750:753) are:
Expand Down Expand Up @@ -176,7 +176,7 @@ def _topoplotFast(values: NDArray[float], rd: NDArray[float], th: NDArray[float]
yi = np.linspace(ymin, ymax, GRID_SCALE).astype(np.float64).reshape((1, -1))
# additional step for gdatav4 compared to MATLAB: linspace to meshgrid
XQ, YQ = np.meshgrid(xi, yi)
Xi, Yi, Zi = gdatav4(x, y, values.reshape((-1, 1)), XQ, YQ)
Xi, Yi, Zi = _gdatav4(x, y, values.reshape((-1, 1)), XQ, YQ)
# additional step for gdatav4 compared to MATLAB: transpose
Zi = Zi.T

Expand All @@ -188,7 +188,7 @@ def _topoplotFast(values: NDArray[float], rd: NDArray[float], th: NDArray[float]


# ----------------------------------------------------------------------------
def eeg_rpsd(inst: Union[BaseRaw, BaseEpochs], ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
def _eeg_rpsd(inst: Union[BaseRaw, BaseEpochs], ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
"""PSD feature."""
assert isinstance(inst, (BaseRaw, BaseEpochs)) # sanity-check
constants = _eeg_rpsd_constants(inst, ica)
Expand Down Expand Up @@ -316,7 +316,7 @@ def _eeg_rpsd_format(
return psd[:, :, np.newaxis, np.newaxis].transpose([2, 1, 3, 0]).astype(np.float32)


def eeg_autocorr_welch(raw: BaseRaw, ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
def _eeg_autocorr_welch(raw: BaseRaw, ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
"""Autocorrelation feature applied on raw object with at least 5 * fs
samples (5 seconds).
MATLAB: 'eeg_autocorr_welch.m'."""
Expand Down Expand Up @@ -387,12 +387,17 @@ def eeg_autocorr_welch(raw: BaseRaw, ica: ICA, icaact: NDArray[float]) -> NDArra
ac = np.divide(ac, den)

# resample to 1 second at 100 samples/sec
resamp = resample_poly(ac.T, 100, raw.info["sfreq"]).T
# i.e. the resampling must output an array of shape (components, 101), thus
# respecting '100 < ac.T.shape[0] * 100 / down <= 101'.
down = int(raw.info["sfreq"])
if 101 < ac.shape[1] * 100 / down:
down += 1
resamp = resample_poly(ac.T, 100, down).T
resamp = resamp[:, 1:, np.newaxis, np.newaxis].transpose([2, 1, 3, 0])
return np.real(resamp).astype(np.float32)


def eeg_autocorr(raw: BaseRaw, ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
def _eeg_autocorr(raw: BaseRaw, ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
"""Autocorrelation feature applied on raw object that do not have enough
sampes for eeg_autocorr_welch.
MATLAB: 'eeg_autocorr.m'."""
Expand Down Expand Up @@ -421,12 +426,17 @@ def eeg_autocorr(raw: BaseRaw, ica: ICA, icaact: NDArray[float]) -> NDArray[floa
ac = np.divide(ac.T, ac[:, 0]).T

# resample to 1 second at 100 samples/sec
resamp = resample_poly(ac.T, 100, raw.info["sfreq"]).T
# i.e. the resampling must output an array of shape (components, 101), thus
# respecting '100 < ac.T.shape[0] * 100 / down <= 101'.
down = int(raw.info["sfreq"])
if 101 < ac.shape[1] * 100 / down:
down += 1
resamp = resample_poly(ac.T, 100, down).T
resamp = resamp[:, 1:, np.newaxis, np.newaxis].transpose([2, 1, 3, 0])
return resamp.astype(np.float32)


def eeg_autocorr_fftw(epochs: BaseEpochs, ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
def _eeg_autocorr_fftw(epochs: BaseEpochs, ica: ICA, icaact: NDArray[float]) -> NDArray[float]:
"""Autocorrelation feature applied on epoch object.
MATLAB: 'eeg_autocorr_fftw.m'."""
assert isinstance(epochs, BaseEpochs) # sanity-check
Expand All @@ -452,6 +462,11 @@ def eeg_autocorr_fftw(epochs: BaseEpochs, ica: ICA, icaact: NDArray[float]) -> N
ac = np.divide(ac.T, ac[:, 0]).T

# resample to 1 second at 100 samples/sec
resamp = resample_poly(ac.T, 100, epochs.info["sfreq"]).T
# i.e. the resampling must output an array of shape (components, 101), thus
# respecting '100 < ac.T.shape[0] * 100 / down <= 101'.
down = int(epochs.info["sfreq"])
if 101 < ac.shape[1] * 100 / down:
down += 1
resamp = resample_poly(ac.T, 100, down).T
resamp = resamp[:, 1:, np.newaxis, np.newaxis].transpose([2, 1, 3, 0])
return np.real(resamp).astype(np.float32)
32 changes: 32 additions & 0 deletions mne_icalabel/iclabel/label_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Union

from mne import BaseEpochs
from mne.io import BaseRaw
from mne.preprocessing import ICA

from .features import get_features
from .network import run_iclabel


def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA):
"""
Label the provided ICA components with the ICLabel neural network. This
network uses 3 features:
- Topographic maps, based on the ICA decomposition.
- Power Spectral Density (PSD), based on the ICA decomposition and the
provided instance.
- Autocorrelation, based on the ICA decomposition and the provided
instance.
Parameters
----------
inst : Raw | Epochs
Instance used to fit the ICA decomposition. The instance should be
referenced to a common average and bandpass filtered between 1 and
100 Hz.
ica : ICA
ICA decomposition of the provided instance.
"""
features = get_features(inst, ica)
labels = run_iclabel(*features)
return labels
20 changes: 10 additions & 10 deletions mne_icalabel/network.py → mne_icalabel/iclabel/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numpy.typing import ArrayLike


class ICLabelNetImg(nn.Module):
class _ICLabelNetImg(nn.Module):
def __init__(self):
super().__init__()

Expand Down Expand Up @@ -45,7 +45,7 @@ def forward(self, x):
return self.sequential(x)


class ICLabelNetPSDS(nn.Module):
class _ICLabelNetPSDS(nn.Module):
def __init__(self):
super().__init__()

Expand Down Expand Up @@ -81,7 +81,7 @@ def forward(self, x):
return self.sequential(x)


class ICLabelNetAutocorr(nn.Module):
class _ICLabelNetAutocorr(nn.Module):
def __init__(self):
super().__init__()

Expand Down Expand Up @@ -121,9 +121,9 @@ class ICLabelNet(nn.Module):
def __init__(self):
super().__init__()

self.img_conv = ICLabelNetImg()
self.psds_conv = ICLabelNetPSDS()
self.autocorr_conv = ICLabelNetAutocorr()
self.img_conv = _ICLabelNetImg()
self.psds_conv = _ICLabelNetPSDS()
self.autocorr_conv = _ICLabelNetAutocorr()

self.conv = nn.Conv2d(
in_channels=712,
Expand Down Expand Up @@ -175,7 +175,7 @@ def forward(
return labels


def format_input(topo: ArrayLike, psd: ArrayLike, autocorr: ArrayLike):
def _format_input(topo: ArrayLike, psd: ArrayLike, autocorr: ArrayLike):
"""Replicate the input formatting in EEGLAB -ICLabel.
.. code-block:: matlab
Expand All @@ -194,7 +194,7 @@ def format_input(topo: ArrayLike, psd: ArrayLike, autocorr: ArrayLike):
return formatted_topo, formatted_psd, formatted_autocorr


def format_input_for_torch(topo: ArrayLike, psd: ArrayLike, autocorr: ArrayLike):
def _format_input_for_torch(topo: ArrayLike, psd: ArrayLike, autocorr: ArrayLike):
"""Format the features to the correct shape and type for pytorch."""
topo = np.transpose(topo, (3, 2, 0, 1))
psd = np.transpose(psd, (3, 2, 0, 1))
Expand All @@ -210,12 +210,12 @@ def format_input_for_torch(topo: ArrayLike, psd: ArrayLike, autocorr: ArrayLike)
def run_iclabel(images: ArrayLike, psds: ArrayLike, autocorr: ArrayLike) -> ArrayLike:
"""Run the ICLabel network on the provided set of features. The features
are un-formatted and are as-returned by ``get_features``."""
ica_network_file = files("mne_icalabel").joinpath("assets/iclabelNet.pt")
ica_network_file = files("mne_icalabel.iclabel").joinpath("assets/iclabelNet.pt")

# Get network and load weights
iclabel_net = ICLabelNet()
iclabel_net.load_state_dict(torch.load(ica_network_file))

# Format input and get labels
labels = iclabel_net(*format_input_for_torch(*format_input(images, psds, autocorr)))
labels = iclabel_net(*_format_input_for_torch(*_format_input(images, psds, autocorr)))
return labels.detach().numpy()
Empty file.
Loading