Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
c26ae41
push test for channels interpolation
gcattan Aug 31, 2023
b89c1b4
use mne interpolation
gcattan Sep 6, 2023
4e797ee
fix typing
gcattan Sep 6, 2023
5d7b7e5
fix get_paradigm
gcattan Sep 6, 2023
10bc03f
typo
gcattan Sep 6, 2023
07520bc
constructor parameter `paradigm` can now be removed
gcattan Sep 6, 2023
241c19e
fix filtered dataset
gcattan Sep 6, 2023
d4c4a22
set interpolate_missing_channels to False by default
gcattan Sep 6, 2023
bafaff2
fix assert paradigm not passing
gcattan Sep 6, 2023
3f33cd6
parameter incorreclty named
gcattan Sep 6, 2023
12f77f4
fix paradigm name
gcattan Sep 6, 2023
c4c71cc
exclude stim channel
gcattan Sep 6, 2023
3e44212
debug trace
gcattan Sep 6, 2023
892b3e4
A-B vs B-A set difference pb
gcattan Sep 6, 2023
60d88eb
fix event_list
gcattan Sep 7, 2023
bad2e3b
fix montage error
gcattan Sep 7, 2023
f0f05dc
fix disabling montage
gcattan Sep 7, 2023
edf2582
pop should be on info not raw object
gcattan Sep 7, 2023
378391d
typo
gcattan Sep 7, 2023
2be4e17
workaround ValueError: lowpass frequency 32.0 must be less than Nyqui…
gcattan Sep 7, 2023
a99de44
do not forget to pick channels
gcattan Sep 7, 2023
6777099
remove fmin/fmax
gcattan Sep 7, 2023
41ea412
fix Nyquist error
gcattan Sep 7, 2023
21a4527
remove finnally block
gcattan Sep 7, 2023
678badd
log montage
gcattan Sep 7, 2023
46c309b
use default montage 1005 is not available
gcattan Sep 7, 2023
99a62f4
fix type
gcattan Sep 7, 2023
91bd43f
add origin
gcattan Sep 7, 2023
07f9345
fix reference error
gcattan Sep 7, 2023
0035b94
debug
gcattan Sep 7, 2023
daa67dc
add new check on epochs length
gcattan Sep 7, 2023
6729952
dataset missing in get_data
gcattan Sep 7, 2023
33533d1
invalid subject given
gcattan Sep 7, 2023
87d3136
debug string
gcattan Sep 7, 2023
c0a3395
missing dereferencement
gcattan Sep 7, 2023
c7c1173
fix destructuring
gcattan Sep 7, 2023
0227f72
some debug trace
gcattan Sep 7, 2023
6828025
debug string
gcattan Sep 7, 2023
abaa23e
debug string
gcattan Sep 7, 2023
ca85089
additiona testing
gcattan Sep 7, 2023
37db330
warn epochs
gcattan Sep 7, 2023
456ad3f
debug info
gcattan Sep 7, 2023
5007bd0
fix syntax
gcattan Sep 7, 2023
c99b5b7
inverse test
gcattan Sep 7, 2023
58a3c63
raw.copy missing
gcattan Sep 7, 2023
a839101
resample not accurate. Debug.
gcattan Sep 8, 2023
4123f76
fix interpolate missing channel True when should be False
gcattan Sep 8, 2023
d97b80c
get_data directly from datasets
gcattan Sep 10, 2023
db45b4d
lint
gcattan Sep 10, 2023
f301d3c
Update bi_illiteracy.py
gcattan Oct 21, 2023
24bff71
Merge branch 'feat/interpolate_channels' of https://github.com/gcatta…
gcattan Oct 21, 2023
3386bee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2023
bb660e9
missing modification
gcattan Oct 21, 2023
a95ddea
Merge branch 'feat/interpolate_channels' of https://github.com/gcatta…
gcattan Oct 21, 2023
062d24d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2023
64a82f7
Merge branch 'develop' into feat/interpolate_channels
gcattan Nov 6, 2023
74e1654
Merge branch 'develop' into feat/interpolate_channels
bruAristimunha Nov 10, 2023
b3f3b7f
Update whats_new.rst
gcattan Nov 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Enhancements
~~~~~~~~~~~~

- Adding cache option to the evaluation (:gh:`517` by `Bruno Aristimunha`_)
- Option to interpolate channel in paradigms' `match_all` method (:gh:`480` by `Gregoire Cattan`_)

Bugs
~~~~
Expand Down
2 changes: 1 addition & 1 deletion moabb/datasets/compound_dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
BI_Il,
Cattan2019_VR_Il,
)
from .utils import _init_compound_dataset_list
from .utils import _init_compound_dataset_list, compound # noqa: F401


_init_compound_dataset_list()
Expand Down
26 changes: 23 additions & 3 deletions moabb/datasets/compound_dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,12 @@ class CompoundDataset(BaseDataset):
interval: list with 2 entries
See `BaseDataset`.

paradigm: ['p300','imagery', 'ssvep', 'rstate']
Defines what sort of dataset this is
"""

def __init__(self, subjects_list: list, code: str, interval: list, paradigm: str):
def __init__(self, subjects_list: list, code: str, interval: list):
self._set_subjects_list(subjects_list)
dataset, _, _, _ = self.subjects_list[0]
paradigm = self._get_paradigm()
super().__init__(
subjects=list(range(1, self.count + 1)),
sessions_per_subject=self._get_sessions_per_subject(),
Expand All @@ -52,6 +51,17 @@ def __init__(self, subjects_list: list, code: str, interval: list, paradigm: str
paradigm=paradigm,
)

@property
def datasets(self):
all_datasets = [entry[0] for entry in self.subjects_list]
found_flags = set()
filtered_dataset = []
for dataset in all_datasets:
if dataset.code not in found_flags:
filtered_dataset.append(dataset)
found_flags.add(dataset.code)
return filtered_dataset

@property
def count(self):
return len(self.subjects_list)
Expand All @@ -78,6 +88,16 @@ def _set_subjects_list(self, subjects_list: list):
for compoundDataset in subjects_list:
self.subjects_list.extend(compoundDataset.subjects_list)

def _get_paradigm(self):
dataset, _, _, _ = self.subjects_list[0]
paradigm = dataset.paradigm
# Check all of the datasets have the same paradigm
for i in range(1, len(self.subjects_list)):
entry = self.subjects_list[i]
dataset = entry[0]
assert dataset.paradigm == paradigm
return paradigm

def _with_data_origin(self, data: dict, shopped_subject):
data_origin = self.subjects_list[shopped_subject - 1]

Expand Down
1 change: 0 additions & 1 deletion moabb/datasets/compound_dataset/bi_illiteracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def __init__(self, subjects_list, dataset=None, code=None):
subjects_list=subjects_list,
code=code,
interval=[0, 1.0],
paradigm="p300",
)


Expand Down
15 changes: 15 additions & 0 deletions moabb/datasets/compound_dataset/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import inspect
from typing import List

import moabb.datasets.compound_dataset as db
from moabb.datasets.base import BaseDataset
from moabb.datasets.compound_dataset.base import CompoundDataset


Expand All @@ -11,3 +13,16 @@ def _init_compound_dataset_list():
for ds in inspect.getmembers(db, inspect.isclass):
if issubclass(ds[1], CompoundDataset) and not ds[0] == "CompoundDataset":
compound_dataset_list.append(ds[1])


def compound(*datasets: List[BaseDataset], interval=[0, 1.0]):
subjects_list = [
(d, subject, None, None) for d in datasets for subject in d.subject_list
]
code = "".join([d.code for d in datasets])
ret = CompoundDataset(
subjects_list=subjects_list,
code=code,
interval=interval,
)
return ret
37 changes: 36 additions & 1 deletion moabb/datasets/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
from operator import methodcaller
from typing import Dict, List, Tuple, Union
from warnings import warn

import mne
import numpy as np
Expand Down Expand Up @@ -199,13 +200,15 @@ def __init__(
tmax: float,
baseline: Tuple[float, float],
channels: List[str] = None,
interpolate_missing_channels: bool = False,
):
assert isinstance(event_id, dict) # not None
self.event_id = event_id
self.tmin = tmin
self.tmax = tmax
self.baseline = baseline
self.channels = channels
self.interpolate_missing_channels = interpolate_missing_channels

def transform(self, X, y=None):
raw = X["raw"]
Expand All @@ -218,9 +221,40 @@ def transform(self, X, y=None):
if self.channels is None:
picks = mne.pick_types(raw.info, eeg=True, stim=False)
else:
available_channels = raw.info["ch_names"]
if self.interpolate_missing_channels:
missing_channels = list(set(self.channels).difference(available_channels))

# add missing channels (contains only zeros by default)
try:
raw.add_reference_channels(missing_channels)
except IndexError:
# Index error can occurs if the channels we add are not part of this epoch montage
# Then log a warning
montage = raw.info["dig"]
warn(
f"Montage disabled as one of these channels, {missing_channels}, is not part of the montage {montage}"
)
# and disable the montage
raw.info.pop("dig")
# run again with montage disabled
raw.add_reference_channels(missing_channels)

# Trick: mark these channels as bad
raw.info["bads"].extend(missing_channels)
# ...and use mne bad channel interpolation to generate the value of the missing channels
try:
raw.interpolate_bads(origin="auto")
except ValueError:
# use default origin if montage info not available
raw.interpolate_bads(origin=(0, 0, 0.04))
# update the name of the available channels
available_channels = self.channels

picks = mne.pick_channels(
raw.info["ch_names"], include=self.channels, ordered=True
available_channels, include=self.channels, ordered=True
)
assert len(picks) == len(self.channels)

epochs = mne.Epochs(
raw,
Expand All @@ -236,6 +270,7 @@ def transform(self, X, y=None):
event_repeated="drop",
on_missing="ignore",
)
warn(f"warnEpochs {epochs}")
return epochs


Expand Down
49 changes: 38 additions & 11 deletions moabb/paradigms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
self.resample = resample
self.tmin = tmin
self.tmax = tmax
self.interpolate_missing_channels = False

@property
@abc.abstractmethod
Expand Down Expand Up @@ -399,6 +400,7 @@ def _get_epochs_pipeline(self, return_epochs, return_raws, dataset):
tmax=bmax,
baseline=baseline,
channels=self.channels,
interpolate_missing_channels=self.interpolate_missing_channels,
),
),
)
Expand Down Expand Up @@ -429,7 +431,13 @@ def _get_array_pipeline(
return None
return Pipeline(steps)

def match_all(self, datasets: List[BaseDataset], shift=-0.5):
def match_all(
self,
datasets: List[BaseDataset],
shift=-0.5,
channel_merge_strategy: str = "intersect",
ignore=["stim"],
):
"""
Initialize this paradigm to match all datasets in parameter:
- `self.resample` is set to match the minimum frequency in all datasets, minus `shift`.
Expand All @@ -442,29 +450,48 @@ def match_all(self, datasets: List[BaseDataset], shift=-0.5):
----------
datasets: List[BaseDataset]
A dataset instance.
shift: List[BaseDataset]
Shift the sampling frequency by this value
E.g.: if sampling=128 and shift=-0.5, then it returns 127.5 Hz
channel_merge_strategy: str (default: 'intersect')
Accepts two values:
- 'intersect': keep only channels common to all datasets
- 'union': keep all channels from all datasets, removing duplicate
ignore: List[string]
A list of channels to ignore

..versionadded:: 0.6.0
"""
resample = None
channels = None
channels: set = None
for dataset in datasets:
X, _, _ = self.get_data(
dataset, subjects=[dataset.subject_list[0]], return_epochs=True
)
first_subject = dataset.subject_list[0]
data = dataset.get_data(subjects=[first_subject])[first_subject]
first_session = list(data.keys())[0]
session = data[first_session]
first_run = list(session.keys())[0]
X = session[first_run]
info = X.info
sfreq = info["sfreq"]
ch_names = info["ch_names"]
# get the minimum sampling frequency between all datasets
resample = sfreq if resample is None else min(resample, sfreq)
# get the channels common to all datasets
channels = (
set(ch_names)
if channels is None
else set(channels).intersection(ch_names)
)
if channels is None:
channels = set(ch_names)
elif channel_merge_strategy == "intersect":
channels = channels.intersection(ch_names)
self.interpolate_missing_channels = False
else:
channels = channels.union(ch_names)
self.interpolate_missing_channels = True
# If resample=128 for example, then MNE can returns 128 or 129 samples
# depending on the dataset, even if the length of the epochs is 1s
# `shift=-0.5` solves this particular issue.
self.resample = resample + shift
self.channels = list(channels)

# exclude ignored channels
self.channels = list(channels.difference(ignore))

@abc.abstractmethod
def _get_events_pipeline(self, dataset):
Expand Down
15 changes: 8 additions & 7 deletions moabb/tests/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,6 @@ def test_fake_dataset(self):
subjects_list,
code="CompoundDataset-test",
interval=[0, 1],
paradigm=self.paradigm,
)

data = compound_data.get_data()
Expand Down Expand Up @@ -385,7 +384,6 @@ def test_compound_dataset_composition(self):
subjects_list,
code="CompoundDataset-test",
interval=[0, 1],
paradigm=self.paradigm,
)

# Add it two time to a subjects_list
Expand All @@ -394,9 +392,11 @@ def test_compound_dataset_composition(self):
subjects_list,
code="CompoundDataset-test",
interval=[0, 1],
paradigm=self.paradigm,
)

# Assert there is only one source dataset in the compound dataset
self.assertEqual(len(compound_data.datasets), 1)

# Assert that the coumpouned dataset has two times more subject than the original one.
data = compound_data.get_data()
self.assertEqual(len(data), 2)
Expand All @@ -408,7 +408,7 @@ def test_get_sessions_per_subject(self):
n_runs=self.n_runs,
n_subjects=self.n_subjects,
event_list=["Target", "NonTarget"],
paradigm=self.paradigm,
paradigm=self.ds.paradigm,
)

# Add the two datasets to a CompoundDataset
Expand All @@ -417,9 +417,11 @@ def test_get_sessions_per_subject(self):
subjects_list,
code="CompoundDataset",
interval=[0, 1],
paradigm=self.paradigm,
)

# Assert there are two source datasets (ds and ds2) in the compound dataset
self.assertEqual(len(compound_dataset.datasets), 2)

# Test private method _get_sessions_per_subject returns the minimum number of sessions per subjects
self.assertEqual(compound_dataset._get_sessions_per_subject(), self.n_sessions)

Expand All @@ -430,7 +432,7 @@ def test_event_id_correctly_updated(self):
n_runs=self.n_runs,
n_subjects=self.n_subjects,
event_list=["Target2", "NonTarget2"],
paradigm=self.paradigm,
paradigm=self.ds.paradigm,
)

# Add the two datasets to a CompoundDataset
Expand All @@ -440,7 +442,6 @@ def test_event_id_correctly_updated(self):
subjects_list,
code="CompoundDataset",
interval=[0, 1],
paradigm=self.paradigm,
)

# Check that the event_id of the compound_dataset is the same has the first dataset
Expand Down
36 changes: 36 additions & 0 deletions moabb/tests/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.pipeline import FunctionTransformer, Pipeline, make_pipeline

from moabb.analysis.results import get_string_rep
from moabb.datasets.compound_dataset import compound
from moabb.datasets.fake import FakeDataset
from moabb.evaluations import evaluations as ev
from moabb.evaluations.utils import create_save_path, save_model_cv, save_model_list
Expand Down Expand Up @@ -82,6 +83,41 @@ def test_eval_results(self):
# We should have 9 columns in the results data frame
self.assertEqual(len(results[0].keys()), 9 if _carbonfootprint else 8)

def test_compound_dataset(self):
ch1 = ["C3", "Cz", "Fz"]
dataset1 = FakeDataset(
paradigm="imagery",
event_list=["left_hand", "right_hand"],
channels=ch1,
sfreq=128,
)
ch2 = ["C3", "C4", "Cz"]
dataset2 = FakeDataset(
paradigm="imagery",
event_list=["left_hand", "right_hand"],
channels=ch2,
sfreq=256,
)
merged_dataset = compound(dataset1, dataset2)

# We want to interpolate channels that are not in common between the two datasets
self.eval.paradigm.match_all(
merged_dataset.datasets, channel_merge_strategy="union"
)

process_pipeline = self.eval.paradigm.make_process_pipelines(dataset)[0]
results = [
r
for r in self.eval.evaluate(
dataset, pipelines, param_grid=None, process_pipeline=process_pipeline
)
]

# We should get 4 results, 2 sessions 2 subjects
self.assertEqual(len(results), 4)
# We should have 9 columns in the results data frame
self.assertEqual(len(results[0].keys()), 9 if _carbonfootprint else 8)

def test_eval_grid_search(self):
# Test grid search
param_grid = {"C": {"csp__metric": ["euclid", "riemann"]}}
Expand Down
Loading