diff --git a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py index c5b177220..7fdea9a9c 100644 --- a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py +++ b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py @@ -171,11 +171,11 @@ def one_subject_decoding( clf = make_pipeline( *preproc_steps, csp, - LogReg( + LinearModel(LogReg( solver="liblinear", # much faster than the default random_state=cfg.random_state, n_jobs=1, - ), + )), ) cv = StratifiedKFold( n_splits=cfg.decoding_n_splits, @@ -244,9 +244,6 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non # Get the data for all time points X = epochs_filt.get_data() - - - cv_scores = cross_val_score( estimator=clf, X=X, @@ -258,14 +255,14 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non freq_decoding_table.loc[idx, "mean_crossval_score"] = cv_scores.mean() freq_decoding_table.at[idx, "scores"] = cv_scores - # PATTERNS - csp.fit_transform(X, y) - sensor_pattern_csp = csp.patterns_ - # COEFS clf.fit(X, y) weights_csp = mne.decoding.get_coef(clf, 'patterns_', inverse_transform=True) + # PATTERNS + csp.fit_transform(X, y) + sensor_pattern_csp = csp.patterns_ + # save scores # XXX right now this saves in working directory csp_fname = cond1 + cond2 + str(fmin) + str(fmax) @@ -345,14 +342,14 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non msg += f": {cfg.decoding_metric}={score:0.3f}" logger.info(**gen_log_kwargs(msg)) - # PATTERNS - csp.fit_transform(X, y) - sensor_pattern_csp = csp.patterns_ - # COEFS clf.fit(X, y) weights_csp = mne.decoding.get_coef(clf, 'patterns_', inverse_transform=True) + # PATTERNS + csp.fit_transform(X, y) + sensor_pattern_csp = csp.patterns_ + # save scores # XXX right now this saves in working directory csp_fname = cond1 + cond2 + str(fmin) + str(fmax) + str(tmin) + str(tmax)