Apply cnn to array
This script demonstrates how to apply the convolutional neural network for artefact detectoin to (neonatal) EEG data, with EEG in array format.
References: T. Hermans et al., ?A multi-task and multi-channel convolutional neural network for semi-supervised neonatal artefact detection,? Journal of Neural Engineering, vol. 20, no. 2, p. 26013, Mar. 2023, doi: 10.1088/1741-2552/acbc4b. https://pubmed.ncbi.nlm.nih.gov/36791462/
Author: Tim Hermans (tim-hermans@hotmail.com).
Link to script: artefacts/apply_cnn_to_array.py
import os
import numpy as np
import matplotlib.pyplot as plt
from nnsa import EegDataset, CleanDetectorCnn
Simulate random data (looks nothing like EEG, so will all be classified as artefact).
fs = 250
channel_labels = ['EEG Fp1', 'EEG Fp2', 'EEG C3', 'EEG C4', 'EEG T3', 'EEG T4', 'EEG O1', 'EEG O2']
eeg = (np.random.rand(fs*40, len(channel_labels)) - 0.5)*300
Now we initiate the object that will load and apply the model. Here we can specify some options, such as loading the multi-channel model or not (i.e., load the single-channel model).
# Note that the multi channel model requires 8 channel EEG, referenced to Cz (see CleanDetectorCnn.data_requirements).
# On the other hand, the single channel model (multi_channel=False) works on any montage,
# but still assumes the data is referenced to Cz.
# Initiate the clean detector class, specify wheter to use the multi-channel model or not (i.e., single-channel).
cd = CleanDetectorCnn(multi_channel=False)
Now we can apply the model to the EEG data. The input here is important: eeg (np.ndarray): multichannel EEG referenced to Cz. Array with shape (n_time, n_channels). If using the multi-channel model, the order of the channels should be: [‘Fp1’, ‘Fp2’, ‘C3’, ‘C4’, ‘T3’, ‘T4’, ‘O1’, ‘O2’]. The first returned array has the same shape as eeg and has 1s where the data is CLEAN and 0s at artefacts. The second array contains probabilities that the data is clean, but has a fixed resolution of 1 per second (whereas the first array was upsampled to match the sampling frequency of the input.
clean_mask, probs = cd.predict(
eeg, fs=fs,
preprocess=True, # Preprocessing consists of filtering and resampling. This can be set to False if the data is already properly preprocessed (see CleanDetectorCnn.preprocess_eeg()).
)
If needed, we can convert the clean mask to an af mask like so:
af_mask = clean_mask == 0
Finally, we can plot the results.
# We use the EegDataset for easier plotting.
eeg_ds = EegDataset.from_array(eeg=eeg, fs=fs, channel_labels=channel_labels)
# Plot EEG and highlight predicted artefacts in red (since we are not using real data, the result is meaningless).
fig, ax = plt.subplots(1, 1)
eeg_ds.plot(ax=ax, color='k', alpha=0.8)
eeg_ds.plot_mask(af_mask, color='r', ax=ax)
ax.set_title('Raw input and detected artefacts')