Conversation
…se improve on spline channel interpolation and handling of events. Please add dataset augmentation and load balancing if necessary
gcattan
left a comment
There was a problem hiding this comment.
Thank you for your contribution!
We have some way to do transfer learning accross dataset in MOABB, using the compound dataset.
I see a different use case however, as:
- The
compound_datasetconsiders all subjects identically (so one of the split can contain a mix of subjects from different datasets). So you cannot run one evaluation with only subjects from one dataset for training and another dataset for testing. - And at the inverse, the new cross dataset evaluation is agnostic of the number of subjects/sessions/runs.
I guess the main point is rather how to align with the new splitter API.
examples/cross_dataset.py
Outdated
| logging.getLogger("mne").setLevel(logging.ERROR) | ||
|
|
||
|
|
||
| def get_common_channels(datasets: List[Any]) -> List[str]: |
There was a problem hiding this comment.
There is a match_all method in base paradigm:
Line 429 in 357cd12
| logging.basicConfig(level=logging.WARNING) | ||
|
|
||
|
|
||
| def get_common_channels(train_dataset, test_dataset): |
There was a problem hiding this comment.
Same here (match_all method)
| return event_id | ||
|
|
||
|
|
||
| def interpolate_missing_channels( |
| return len(dataset.subject_list) > 1 | ||
|
|
||
|
|
||
| class CrossDatasetEvaluation(BaseEvaluation): |
There was a problem hiding this comment.
Hm. I think there is a plan to refactor the existing evaluation.
The recommended way to go will be to use the new Splitter API (see: #612 (comment)).
@bruAristimunha can probably advise you better than me what refactoring is necessary in this case.
There was a problem hiding this comment.
Hey @gcattan, thanks for all your feedback!
@bruAristimunha - if you could comment on the best way to move forward :)
There was a problem hiding this comment.
We can implement this one and we migrate later
| train_dataset : Dataset or list of Dataset | ||
| Dataset(s) to use for training | ||
| test_dataset : Dataset or list of Dataset | ||
| Dataset(s) to use for testing |
There was a problem hiding this comment.
Probably you want to have a cross-evaluation.
So provide a list of datasets, and then, keep one for training and the other for testing. and then rotate.
There was a problem hiding this comment.
@ali-sehar @EazyAl Please implement this suggestion too.
- Pass a list of datasets
- And implement cross-validation
| model = clone(pipeline).fit(train_X[0], train_y) | ||
| score = model.score(test_X, test_y) |
There was a problem hiding this comment.
Ok, so you train on the whole subjects/sessions/runs, and then test on the whole subjects/sessions/run of the second dataset?
| # Get the list of channels from each dataset before matching | ||
| print("\nChannels before matching:") | ||
| for ds_name, ds in datasets_dict.items(): | ||
| try: | ||
| # Load data for first subject to get channel information | ||
| data = ds.get_data([ds.subject_list[0]]) # Get data for first subject | ||
| first_subject = list(data.keys())[0] | ||
| first_session = list(data[first_subject].keys())[0] | ||
| first_run = list(data[first_subject][first_session].keys())[0] | ||
| run_data = data[first_subject][first_session][first_run] | ||
|
|
||
| if isinstance(run_data, (RawArray, RawCNT)): | ||
| channels = run_data.info["ch_names"] | ||
| else: | ||
| # Assuming the channels are stored in the dataset class after loading | ||
| channels = ds.channels | ||
| print(f"{ds_name}: {channels}") | ||
| except Exception as e: | ||
| print(f"Error getting channels for {ds_name}: {str(e)}") |
| # Get channels from all datasets after matching to ensure we have the correct intersection | ||
| all_channels_after_matching = [] | ||
| print("\nChannels after matching:") | ||
| for i, (ds_name, _) in enumerate(datasets_dict.items()): | ||
| ds = all_datasets[i] # Get the matched dataset | ||
| try: | ||
| data = ds.get_data([ds.subject_list[0]]) | ||
| subject = list(data.keys())[0] | ||
| session = list(data[subject].keys())[0] | ||
| run = list(data[subject][session].keys())[0] | ||
| run_data = data[subject][session][run] | ||
|
|
||
| if isinstance(run_data, (RawArray, RawCNT)): | ||
| channels = run_data.info["ch_names"] | ||
| else: | ||
| channels = ds.channels | ||
| all_channels_after_matching.append(set(channels)) | ||
| print(f"{ds_name}: {channels}") | ||
| except Exception as e: | ||
| print(f"Error getting channels for {ds_name} after matching: {str(e)}") | ||
|
|
||
| # Get the intersection of all channel sets | ||
| common_channels = sorted(list(set.intersection(*all_channels_after_matching))) | ||
| print(f"\nCommon channels after matching: {common_channels}") | ||
| print(f"Number of common channels: {len(common_channels)}") | ||
|
|
||
| # Update the datasets_dict with the matched datasets | ||
| for i, (name, _) in enumerate(datasets_dict.items()): | ||
| datasets_dict[name] = all_datasets[i] | ||
|
|
||
| train_dataset = datasets_dict["train_dataset"] | ||
| test_dataset = datasets_dict["test_dataset"] | ||
|
|
||
| # Initialize the paradigm with common channels | ||
| paradigm = MotorImagery(channels=common_channels, n_classes=2, fmin=8, fmax=32) |
There was a problem hiding this comment.
Remove this.
match_all don't change the number of channels in the dataset,
it just automatically set the filter in the paradigm.
| @@ -0,0 +1,691 @@ | |||
| """ | |||
There was a problem hiding this comment.
same comments about match_all here. Please apply.
| train_dataset : Dataset or list of Dataset | ||
| Dataset(s) to use for training | ||
| test_dataset : Dataset or list of Dataset | ||
| Dataset(s) to use for testing |
There was a problem hiding this comment.
@ali-sehar @EazyAl Please implement this suggestion too.
- Pass a list of datasets
- And implement cross-validation
This adds a new type of evaluation to be able to validate models across several datasets. This particularly relevant for deep learning models as it allows MOABB to be used for benchmarking transfer learning.
Some examples are also added, one of which uses braindecode.