Apply cnn to eegdataset

This script demonstrates how to apply the convolutional neural network for artefact detectoin to (neonatal) EEG data, with EE in nnsa.EegDataset format.

References: T. Hermans et al., ?A multi-task and multi-channel convolutional neural network for semi-supervised neonatal artefact detection,? Journal of Neural Engineering, vol. 20, no. 2, p. 26013, Mar. 2023, doi: 10.1088/1741-2552/acbc4b. https://pubmed.ncbi.nlm.nih.gov/36791462/

Author: Tim Hermans (tim-hermans@hotmail.com).

Link to script: artefacts/apply_cnn_to_eegdataset.py

import os

import numpy as np
import matplotlib.pyplot as plt

from nnsa import EdfReader, Butterworth, EegDataset

Specify the path to an EDF file to process.

# Filepath to .EDF file with EEG.
fp_edf = r'C:/data_temp/test.edf'

Read EEG from the EDF file (if the file exists, otherwise create some dummy data).

if os.path.exists(fp_edf):
    # Load the EEG from the EDF.
    with EdfReader(fp_edf) as r:
        # Returns a nnsa.EegDataset
        eeg_ds = r.read_eeg_dataset()
else:
    print(f'File {fp_edf} not found. Creating dummy data... '
          f'Note that the artefact detection model will recognize that its fake and predict all artefacts.')
    fs = 250
    channel_labels = ['EEG Fp1', 'EEG Fp2', 'EEG C3', 'EEG C4', 'EEG Cz', 'EEG T3', 'EEG T4', 'EEG O1', 'EEG O2']
    eeg = (np.random.rand(len(channel_labels), fs*40) - 0.5)*300
    eeg_ds = EegDataset.from_array(eeg=eeg, fs=fs, channel_labels=channel_labels)

Apply the artefact detection model. Using the EegDataset class we can call detect_artefacts_cnn() to find the artefacts. The returned object is a new EegDataset containing True at corresponding locations in eeg_ds where there are artefacts and False where the data is considered clean. Note that the multi channel model requires 8 channel EEG referenced to Cz (see CleanDetectorCnn.data_requirements). This will automatically be done (if all channels are available), but means that the output can contain different channels than the input. On the other hand, the single channel model (multi_channel=False) works on any montage, but still requires referencing to Cz. Therefore, if Cz is in eeg_ds, it will be used to reference all other channels and then removed. If no Cz is found in eeg_ds, no rereferencing is done and it is assumed that the data is already referenced to Cz.

# We can apply the model with a single line, making use of the method built-in for EegDataset().
af_ds = eeg_ds.detect_artefacts_cnn(multi_channel=False)

# Channels in EEG:
print('EEG channels:', eeg_ds.channel_labels)
print('Artefacts detected in:', af_ds.channel_labels)

Finally, we can plot the results.

# Get only the EEG channels for which we have the artefact output.
eeg_ds_plot = eeg_ds.extract_channels(af_ds.channel_labels)

# Plot raw EEG and a mask indicating the artefacts.
fig, ax = plt.subplots(1, 1)
eeg_ds_plot.plot(ax=ax, color='k')
eeg_ds_plot.plot_mask(mask=af_ds, color='r')
ax.set_title('Raw input and detected artefacts')

# Plot EEG and highlight predicted artefacts in red.
fig, ax = plt.subplots(1, 1)
eeg_ds_plot.notch_filt(f0=50).filter(Butterworth(fn=[0.27, 30], order=1)).plot(ax=ax, color='k', alpha=0.8)
eeg_ds_plot.plot_mask(mask=af_ds, color='r')
ax.set_title('Filtered input and detected artefacts')