import os
import sys
import warnings
import numpy as np
import pandas as pd
import pyprind
from nnsa.parameters.parameters import ClassWithParameters, Parameters
from nnsa.config import MODEL_DATA_DIR
from nnsa.preprocessing.filter import NotchIIR, Butterworth
from nnsa.preprocessing.resample import resample_by_interpolation
from nnsa.utils.arrays import moving_std, moving_max, check_eeg_is_long
[docs]class CleanDetectorCnn(ClassWithParameters):
"""
Interface for applying clean EEG detection using a trained Convolutional Neural Network.
References:
T. Hermans et al.,
“A multi-task and multi-channel convolutional neural network for semi-supervised neonatal artefact detection,”
Journal of Neural Engineering, vol. 20, no. 2, p. 26013, Mar. 2023, doi: 10.1088/1741-2552/acbc4b.
https://pubmed.ncbi.nlm.nih.gov/36791462/
Args:
multi_channel (bool): if True, use the multi-channel model,
which requires a specific montage and order of channels (see self.data_requirements).
If False, use the single-channel model (no specific montage required).
chunk_length (int, float): number of seconds of EEG to process at once (EEG will be divided in chunks
with maximal length chunk_size, predicted and then merged together). It is to prevent memory problems.
If the RAM is large enough, you can set chunk_length to None to not process into chunks.
Examples:
# TODO Load some raw eeg data, referenced to Cz.
# Initiate artefact detector object.
cd = CleanDetectorCnn()
# Use the predict method to predict the clean mask for the eeg array.
clean_mask, clean_prob = cd.predict(eeg, fs=fs))
"""
def __init__(self, multi_channel=True, chunk_length=3600, **kwargs):
super().__init__(multi_channel=multi_channel, chunk_length=chunk_length, **kwargs)
# Select model file.
if multi_channel:
model_name = 'cnn_multi_chan'
else:
model_name = 'cnn_single_chan'
model_filepath = os.path.join(MODEL_DATA_DIR, 'clean_det', model_name+'.h5' if '.' not in model_name else model_name)
# Check if model file exist.
if not os.path.exists(model_filepath):
raise FileNotFoundError('Cannot find the HDF5 file with the saved model. '
'File "{}" does not exist.'.format(model_filepath))
self.model_filepath = model_filepath
self._model = None
self._normpars = None
self._model_settings = None
[docs] @staticmethod
def default_parameters():
"""
Return the default parameters as a dictionary.
Returns:
(dict or Parameters): a default set of parameters for the object.
"""
pars = dict(
# Use multi_channel (True) or single channel (False) model.
multi_channel=True,
# Length of chunks in which data is fed to the model (seconds).
chunk_length=3600,
# Ignore a number of seconds near the boundaries.
boundary_length=5,
)
return Parameters(**pars)
@property
def data_requirements(self):
"""
Return some data requirements.
"""
multi_channel = self.parameters['multi_channel']
if multi_channel:
channel_order = ['Fp1', 'Fp2', 'C3', 'C4', 'T3', 'T4', 'O1', 'O2']
else:
channel_order = None
data_requirements = dict(
# Reference channel.
reference_channel='Cz',
# Order of channels in input.
channel_order=channel_order,
# Sampling frequency of input (Hz).
fs=128,
)
return data_requirements
@property
def fs_output(self):
"""
Return sampling frequency of output.
"""
model_name = os.path.splitext(os.path.basename(self.model_filepath))[0]
if model_name in ['cnn_single_chan', 'cnn_multi_chan']:
fs_output = 1
else:
raise NotImplementedError('Not implemented for model_name = "{}".'.format(model_name))
return fs_output
@property
def model(self):
"""
Return keras model.
"""
if self._model is None:
self._model = self._load_model()
return self._model
@property
def model_settings(self):
"""
Returns dict with settings.
"""
if self._model_settings is None:
self._model_settings = self._load_model_settings()
return self._model_settings
@property
def normpars(self):
"""
Returns tuple with mean, sd and channel labels.
"""
if self._normpars is None:
self._normpars = self._load_normpars()
return self._normpars
@property
def is_clean_detector(self):
"""
Returns:
(bool): True if the model is a clean EEG detector (instead of an artefact detector).
"""
return True
[docs] def predict(self, eeg, fs, preprocess=None,
detect_flats=False, detect_peaks=False,
verbose=1, **predict_kwargs):
"""
Detect clean parts in Cz-referenced EEG using the CNN model.
Args:
eeg (np.ndarray): multichannel EEG referenced to Cz. Array with shape (n_time, n_channels).
If using the multi-channel model, the order of the channels should be:
['Fp1', 'Fp2', 'C3', 'C4', 'T3', 'T4', 'O1', 'O2'].
fs (flaot): sampling frequency of `eeg` in Hz.
preprocess (bool): specify whether the EEG needs to be preprocessed (filtered, resampled).
Set to True if `eeg` is raw data (but note that it should still be referenced to Cz).
If not specified, preprocessing will be done if `fs` is not 128, otherwise not.
detect_flats (bool): if True, computes moving std in short windows
and if its below a threshold, the sample is marked as artefact (since the CNN might not catch this).
detect_peaks (bool): if True, computes moving max abs amplitude in short windows
and if its above a threshold, the sample is marked as artefact (since the CNN might not catch this).
verbose (int): verbosity level.
**predict_kwargs : kwargs for self.model.predict().
Returns:
mask (np.ndarray): array with shape (n_time, n_channels) containing 1 at locations
of clean EEG and 0 at artefacts.
prob (np.ndarray): array with shape (n_samples, n_channels, n_classes) containing probabilies
for artefact and clean. Note that n_samples depends on the output frequency
(it is not the same as n_time of `eeg`).
"""
# Check input shape (n_time, n_channels), transposes if needed.
eeg = check_eeg_is_long(eeg=eeg, mode='transpose')
# Default inputs.
if preprocess is None:
# Preprocessed data should have specific sampling frequency -> preprocess_eeg if fs is different from this.
preprocess = True if fs != self.data_requirements['fs'] else False
# Dimensions.
original_len, n_chan = eeg.shape
original_fs = fs*1
if preprocess:
# Do the preprocessing (filtering, resampling).
eeg, fs = self.preprocess_eeg(eeg, fs, axis=0, verbose=verbose > 1)
# Prepare data (normalize, chunking).
eeg_chunks = self._prepare_data(eeg, fs)
# Predict.
if verbose > 1:
print('Detecting clean samples and artefacts...')
y_chunks = self._predict(eeg_chunks, verbose=verbose, **predict_kwargs)
# Postprocess.
mask, prob = self._postprocess(y_chunks=y_chunks, original_fs=original_fs, original_len=original_len)
# Optionally check for low amplitudes (flats).
if detect_flats:
# Get locations of flats.
flat_mask = self._predict_flat_mask(eeg=eeg, fs=fs, original_fs=original_fs,
original_len=original_len)
# Update clean mask.
mask[flat_mask == 1] = 0
# Optionally check for high amplitudes (peaks).
if detect_peaks:
# Get locations of peaks.
peak_mask = self._predict_peak_mask(eeg=eeg, fs=fs, original_fs=original_fs,
original_len=original_len)
# Update clean mask.
mask[peak_mask == 1] = 0
# To integers.
mask = np.round(mask).astype(int)
return mask, prob
[docs] @staticmethod
def preprocess_eeg(eeg, fs, axis=0, verbose=1):
"""
Preprocess EEG data.
Args:
eeg (np.ndarray): EEG data.
fs (float, int): sampling frequency of `eeg` in Hz.
axis (int): time axis. of `eeg` array.
verbose (int): verbose level.
Returns:
eeg_out (np.ndarray): preprocessed EEG data.
fs_out (int): sampling frequency of preprocessed EEG.
"""
if verbose > 1:
print('Preprocessing for artefact detection...')
# Notch for power line interference (needed, because the Butterworth filter below does
# not remove this power line noise enough).
filt_notch = NotchIIR(f0=50, fs=fs)
# Katrien used a 0.27-30 Hz bandpass filter when scoring (BrainRT software).
# I contacted the OSG firm of BrainRT to ask for the filter specification, which is a 1st order
# Butterworth.
filt = Butterworth(fn=[0.27, 30], order=1, fs=fs)
# Resampling frequency.
fs_res = 128
# Filter.
eeg_out = filt_notch.filtfilt(eeg, axis=axis)
eeg_out -= np.nanmean(eeg_out, axis=axis)
eeg_out = filt.filter(eeg_out, axis=axis)
eeg_out -= np.nanmean(eeg_out, axis=axis)
# Construct time array corresponding to current fs.
time = np.arange(eeg.shape[axis]) / fs
# Resample the signal.
eeg_out = resample_by_interpolation(x=eeg_out, t=time, fs_new=fs_res, axis=axis)
# Update sampling frequency.
fs_out = fs_res
return eeg_out, fs_out
def _load_model(self):
"""
Return keras model.
"""
from tensorflow.python.keras.saving.save import load_model
model = load_model(self.model_filepath, compile=False)
return model
def _load_model_settings(self):
"""
Returns dict with settings.
"""
# Find file with model settings.
filepath_settings = self.model_filepath.replace('.h5', '_settings.xlsx')
if not os.path.exists(filepath_settings):
# Try removing the first element from the name.
filename = '_'.join(os.path.basename(filepath_settings).split('_')[1:])
filepath_settings = os.path.join(os.path.dirname(filepath_settings), filename)
if not os.path.exists(filepath_settings):
raise FileNotFoundError('Cannot find settings.csv file. File "{}" does not exist.'
.format(filepath_settings))
# Load settings.
df = pd.read_excel(filepath_settings)
# All settings are saved in first row. Convert to dict.
settings = df.iloc[0, :].to_dict()
return settings
def _load_normpars(self):
"""
Returns tuple with mean, sd and channel labels.
"""
# Find file with normalization parameters.
filepath_normpars = self.model_filepath.replace('.h5', '_normpars.csv')
if not os.path.exists(filepath_normpars):
raise FileNotFoundError('Cannot find normpars.csv file. File "{}" does not exist.'
.format(filepath_normpars))
# Load parameters.
df = pd.read_csv(filepath_normpars)
mean = df['mean']
sd = df['sd']
channel_labels = df['channel_label']
if len(df) == 1:
if self.model_settings.get('normalize_per_channel', False):
raise AssertionError('In settings it says normalized per channel, '
'but the normalization parameters contain just one channel.')
elif len(df) > 1:
if not self.model_settings['normalize_per_channel']:
raise AssertionError('In settings it says not normalized per channel, '
'but the normalization parameters contain multiple channels.')
if len(df) == 1:
# Extract first row.
mean = mean.iloc[0]
sd = sd.iloc[0]
channel_labels = channel_labels.iloc[0]
else:
mean = mean.values
sd = sd.values
return mean, sd, channel_labels
def _normalize(self, eeg):
"""
Normalize EEG with loaded normalization parameters specific to model.
"""
mean, sd, channel_labels = self.normpars
eeg = (eeg - mean)/sd
return eeg
def _postprocess(self, y_chunks, original_fs, original_len):
"""
Args:
y_chunks (list): list with predictions for each data chunk (see self._predict()).
original_fs (float): sampling frequency of original data in Hz.
original_len (int): length of original EEG data (number of time samples).
Returns:
mask (np.ndarray): array with shape (n_time, n_channels) containing 1 at locations
of clean EEG and 0 at artefacts. Here n_time is in the fs of original EEG, so it can be used
as a boolean mask for the EEG data.
prob (np.ndarray): array with shape (n_samples, n_channels, n_classes) containing probabilies
for artefact and clean. Here the n_samples is in the fs of the model output, so its probably a
smaller array than the mask.
"""
# Cut off boundary size for each chunk.
boundary_length = self.parameters['boundary_length']
fs_y = self.fs_output
skip = int(boundary_length*fs_y)
y_chunks_cut = [yi[0, skip:(yi.shape[1]-skip)] for yi in y_chunks]
# Paste chunks together (n_time, n_channels, n_classes).
prob = np.vstack(y_chunks_cut)
n_time, n_channels, n_classes = prob.shape
# Sanity check (should have as many seconds in output as in input).
assert (original_len/original_fs - n_time/fs_y) < 1/fs_y
# Create clean mask from probabilities.
mask = (np.argmax(prob, axis=-1) == 1).astype(float)
# Upsample mask to original sample rate.
t_prob = np.arange(len(mask)) / fs_y + 1/fs_y/2
t_eeg = np.arange(original_len) / original_fs
# Loop over channels.
new_masks = []
for i in range(n_channels):
# Use linear interpolation for resampling.
mask_i = np.round(np.interp(t_eeg, t_prob, mask[:, i], left=mask[0, i], right=mask[-1, i]))
new_masks.append(mask_i)
mask = np.vstack(new_masks).T
return mask, prob
def _predict(self, eeg_chunks, verbose=1, **kwargs):
"""
Predict.
Args:
eeg_chunks (list): list with EEG chunks (see self.prepare_data).
Returns:
y (list): predictions, i.e. probabilites for each chunk,
e.g. list of arrays with shape (1, n_seconds, n_channels, n_classes).
"""
bar = pyprind.ProgBar(len(eeg_chunks), stream=sys.stdout)
y_chunks = []
for eeg_seg in eeg_chunks:
# Shape (n_samples, n_seconds, n_channels, n_classes).
y = self.model.predict(eeg_seg, verbose=verbose > 1, **kwargs)
# Check shape.
assert y.ndim == 4
assert y.shape[2] == eeg_seg.shape[2] # Channel dimension.
y_chunks.append(y)
if verbose > 0:
bar.update()
return y_chunks
@staticmethod
def _predict_flat_mask(eeg, fs, original_fs=None, original_len=None):
"""
Predict flat mask.
"""
# Check input shape (n_time, n_channels).
eeg = check_eeg_is_long(eeg=eeg, mode='error')
# Compute moving std.
window = 2
min_std = 10 ** (-0.5) # From looking at distribution in resilience dataset.
eeg_std = moving_std(x=eeg, n=int(fs * window), axis=0)
flat_mask = eeg_std < min_std
if original_fs is not None and fs != original_fs:
# Upsample mask to original sample rate.
if original_len is None:
raise ValueError('`original_len` needs to be specified.')
# Time arrays.
t_flat = np.arange(len(flat_mask)) / fs + 1 / fs / 2
t_eeg = np.arange(original_len) / original_fs
# Loop over channels.
new_flat_masks = []
for mask_i in flat_mask.T:
# Use linear interpolation for resampling.
mask_i_new = np.round(np.interp(t_eeg, t_flat, mask_i, left=mask_i[0], right=mask_i[-1]))
new_flat_masks.append(mask_i_new)
flat_mask = np.vstack(new_flat_masks).T
return flat_mask
@staticmethod
def _predict_peak_mask(eeg, fs, original_fs=None, original_len=None):
"""
Predict peak mask.
"""
# Check input shape (n_time, n_channels).
eeg = check_eeg_is_long(eeg=eeg, mode='error')
# Compute moving max of abs values.
window = 1 # Seconds.
max_amp = 1000 # uV.
eeg_max = moving_max(x=np.abs(eeg), n=round(fs * window), axis=0)
peak_mask = eeg_max > max_amp
if original_fs is not None and fs != original_fs:
# Upsample mask to original sample rate.
if original_len is None:
raise ValueError('`original_len` needs to be specified.')
# Time arrays.
t_peak = np.arange(len(peak_mask)) / fs + 1 / fs / 2
t_eeg = np.arange(original_len) / original_fs
# Loop over channels.
new_peak_masks = []
for mask_i in peak_mask.T:
# Use linear interpolation for resampling.
mask_i_new = np.round(np.interp(t_eeg, t_peak, mask_i, left=mask_i[0], right=mask_i[-1]))
new_peak_masks.append(mask_i_new)
peak_mask = np.vstack(new_peak_masks).T
return peak_mask
def _prepare_data(self, eeg, fs=None, verbose=1):
"""
Helper function to prepare the EEG recording for input to CNN (normalization, chunking).
Args:
eeg (np.ndarray): preprocessed EEG data with shape (n_time, n_channels).
fs (float): sampling frequency of `eeg` in Hz.
verbose (int): verbosity level.
Returns:
eeg_chunks (list): normalized and chunked EEG with length number of chunks and each element (chunk)
is an EEG array with shape (1, n_time, n_channels).
"""
if fs != self.data_requirements['fs']:
raise ValueError(f'fs={fs} does not match the data_requirements: fs={self.data_requirements["fs"]}. '
f'Make sure the input has the required sampling frequency.')
n_time, n_channels = eeg.shape
# Normalize.
if self.model_settings.get('normalize', True):
if verbose > 1:
print('Normalizing...')
eeg = self._normalize(eeg)
# Nan to num.
eeg = np.nan_to_num(eeg)
# Add zeros to start since we will cut off some small part for boundary effects.
boundary_length = self.parameters['boundary_length']
n_pad = int(boundary_length*fs)
pad = np.zeros((n_pad, n_channels))
eeg = np.concatenate([pad, eeg, pad], axis=0)
# To shape (1, n_time, n_channels).
n_tot = len(eeg)
eeg = np.expand_dims(eeg, axis=0)
# Divide EEG into chunks (if EEG is shorter than chunk size, will just be one chunk).
chunk_length = self.parameters['chunk_length']
if verbose > 1:
print('Dividing data into chunks...')
if chunk_length is not None:
chunk_size = int(chunk_length * fs)
overlap = int(boundary_length * fs)*2
stepsize = chunk_size - overlap
chunk_start_indices = np.arange(0, n_tot, stepsize)
else:
chunk_start_indices = [0]
chunk_size = n_tot
# Add chunks to list (the last chunk may not have the same size as the rest, so use a list instead of array).
eeg_chunks = []
for start_idx in chunk_start_indices:
stop_idx = min(start_idx + chunk_size, n_tot)
eeg_chunk = eeg[:, start_idx:stop_idx, :]
eeg_chunks.append(eeg_chunk)
return eeg_chunks