Source code for nnsa.artefacts.clean_detector_cnn

import os
import sys
import warnings

import numpy as np
import pandas as pd
import pyprind

from nnsa.parameters.parameters import ClassWithParameters, Parameters
from nnsa.config import MODEL_DATA_DIR
from nnsa.preprocessing.filter import NotchIIR, Butterworth
from nnsa.preprocessing.resample import resample_by_interpolation
from nnsa.utils.arrays import moving_std, moving_max, check_eeg_is_long


[docs]class CleanDetectorCnn(ClassWithParameters): """ Interface for applying clean EEG detection using a trained Convolutional Neural Network. References: T. Hermans et al., “A multi-task and multi-channel convolutional neural network for semi-supervised neonatal artefact detection,” Journal of Neural Engineering, vol. 20, no. 2, p. 26013, Mar. 2023, doi: 10.1088/1741-2552/acbc4b. https://pubmed.ncbi.nlm.nih.gov/36791462/ Args: multi_channel (bool): if True, use the multi-channel model, which requires a specific montage and order of channels (see self.data_requirements). If False, use the single-channel model (no specific montage required). chunk_length (int, float): number of seconds of EEG to process at once (EEG will be divided in chunks with maximal length chunk_size, predicted and then merged together). It is to prevent memory problems. If the RAM is large enough, you can set chunk_length to None to not process into chunks. Examples: # TODO Load some raw eeg data, referenced to Cz. # Initiate artefact detector object. cd = CleanDetectorCnn() # Use the predict method to predict the clean mask for the eeg array. clean_mask, clean_prob = cd.predict(eeg, fs=fs)) """ def __init__(self, multi_channel=True, chunk_length=3600, **kwargs): super().__init__(multi_channel=multi_channel, chunk_length=chunk_length, **kwargs) # Select model file. if multi_channel: model_name = 'cnn_multi_chan' else: model_name = 'cnn_single_chan' model_filepath = os.path.join(MODEL_DATA_DIR, 'clean_det', model_name+'.h5' if '.' not in model_name else model_name) # Check if model file exist. if not os.path.exists(model_filepath): raise FileNotFoundError('Cannot find the HDF5 file with the saved model. ' 'File "{}" does not exist.'.format(model_filepath)) self.model_filepath = model_filepath self._model = None self._normpars = None self._model_settings = None
[docs] @staticmethod def default_parameters(): """ Return the default parameters as a dictionary. Returns: (dict or Parameters): a default set of parameters for the object. """ pars = dict( # Use multi_channel (True) or single channel (False) model. multi_channel=True, # Length of chunks in which data is fed to the model (seconds). chunk_length=3600, # Ignore a number of seconds near the boundaries. boundary_length=5, ) return Parameters(**pars)
@property def data_requirements(self): """ Return some data requirements. """ multi_channel = self.parameters['multi_channel'] if multi_channel: channel_order = ['Fp1', 'Fp2', 'C3', 'C4', 'T3', 'T4', 'O1', 'O2'] else: channel_order = None data_requirements = dict( # Reference channel. reference_channel='Cz', # Order of channels in input. channel_order=channel_order, # Sampling frequency of input (Hz). fs=128, ) return data_requirements @property def fs_output(self): """ Return sampling frequency of output. """ model_name = os.path.splitext(os.path.basename(self.model_filepath))[0] if model_name in ['cnn_single_chan', 'cnn_multi_chan']: fs_output = 1 else: raise NotImplementedError('Not implemented for model_name = "{}".'.format(model_name)) return fs_output @property def model(self): """ Return keras model. """ if self._model is None: self._model = self._load_model() return self._model @property def model_settings(self): """ Returns dict with settings. """ if self._model_settings is None: self._model_settings = self._load_model_settings() return self._model_settings @property def normpars(self): """ Returns tuple with mean, sd and channel labels. """ if self._normpars is None: self._normpars = self._load_normpars() return self._normpars @property def is_clean_detector(self): """ Returns: (bool): True if the model is a clean EEG detector (instead of an artefact detector). """ return True
[docs] def predict(self, eeg, fs, preprocess=None, detect_flats=False, detect_peaks=False, verbose=1, **predict_kwargs): """ Detect clean parts in Cz-referenced EEG using the CNN model. Args: eeg (np.ndarray): multichannel EEG referenced to Cz. Array with shape (n_time, n_channels). If using the multi-channel model, the order of the channels should be: ['Fp1', 'Fp2', 'C3', 'C4', 'T3', 'T4', 'O1', 'O2']. fs (flaot): sampling frequency of `eeg` in Hz. preprocess (bool): specify whether the EEG needs to be preprocessed (filtered, resampled). Set to True if `eeg` is raw data (but note that it should still be referenced to Cz). If not specified, preprocessing will be done if `fs` is not 128, otherwise not. detect_flats (bool): if True, computes moving std in short windows and if its below a threshold, the sample is marked as artefact (since the CNN might not catch this). detect_peaks (bool): if True, computes moving max abs amplitude in short windows and if its above a threshold, the sample is marked as artefact (since the CNN might not catch this). verbose (int): verbosity level. **predict_kwargs : kwargs for self.model.predict(). Returns: mask (np.ndarray): array with shape (n_time, n_channels) containing 1 at locations of clean EEG and 0 at artefacts. prob (np.ndarray): array with shape (n_samples, n_channels, n_classes) containing probabilies for artefact and clean. Note that n_samples depends on the output frequency (it is not the same as n_time of `eeg`). """ # Check input shape (n_time, n_channels), transposes if needed. eeg = check_eeg_is_long(eeg=eeg, mode='transpose') # Default inputs. if preprocess is None: # Preprocessed data should have specific sampling frequency -> preprocess_eeg if fs is different from this. preprocess = True if fs != self.data_requirements['fs'] else False # Dimensions. original_len, n_chan = eeg.shape original_fs = fs*1 if preprocess: # Do the preprocessing (filtering, resampling). eeg, fs = self.preprocess_eeg(eeg, fs, axis=0, verbose=verbose > 1) # Prepare data (normalize, chunking). eeg_chunks = self._prepare_data(eeg, fs) # Predict. if verbose > 1: print('Detecting clean samples and artefacts...') y_chunks = self._predict(eeg_chunks, verbose=verbose, **predict_kwargs) # Postprocess. mask, prob = self._postprocess(y_chunks=y_chunks, original_fs=original_fs, original_len=original_len) # Optionally check for low amplitudes (flats). if detect_flats: # Get locations of flats. flat_mask = self._predict_flat_mask(eeg=eeg, fs=fs, original_fs=original_fs, original_len=original_len) # Update clean mask. mask[flat_mask == 1] = 0 # Optionally check for high amplitudes (peaks). if detect_peaks: # Get locations of peaks. peak_mask = self._predict_peak_mask(eeg=eeg, fs=fs, original_fs=original_fs, original_len=original_len) # Update clean mask. mask[peak_mask == 1] = 0 # To integers. mask = np.round(mask).astype(int) return mask, prob
[docs] @staticmethod def preprocess_eeg(eeg, fs, axis=0, verbose=1): """ Preprocess EEG data. Args: eeg (np.ndarray): EEG data. fs (float, int): sampling frequency of `eeg` in Hz. axis (int): time axis. of `eeg` array. verbose (int): verbose level. Returns: eeg_out (np.ndarray): preprocessed EEG data. fs_out (int): sampling frequency of preprocessed EEG. """ if verbose > 1: print('Preprocessing for artefact detection...') # Notch for power line interference (needed, because the Butterworth filter below does # not remove this power line noise enough). filt_notch = NotchIIR(f0=50, fs=fs) # Katrien used a 0.27-30 Hz bandpass filter when scoring (BrainRT software). # I contacted the OSG firm of BrainRT to ask for the filter specification, which is a 1st order # Butterworth. filt = Butterworth(fn=[0.27, 30], order=1, fs=fs) # Resampling frequency. fs_res = 128 # Filter. eeg_out = filt_notch.filtfilt(eeg, axis=axis) eeg_out -= np.nanmean(eeg_out, axis=axis) eeg_out = filt.filter(eeg_out, axis=axis) eeg_out -= np.nanmean(eeg_out, axis=axis) # Construct time array corresponding to current fs. time = np.arange(eeg.shape[axis]) / fs # Resample the signal. eeg_out = resample_by_interpolation(x=eeg_out, t=time, fs_new=fs_res, axis=axis) # Update sampling frequency. fs_out = fs_res return eeg_out, fs_out
def _load_model(self): """ Return keras model. """ from tensorflow.python.keras.saving.save import load_model model = load_model(self.model_filepath, compile=False) return model def _load_model_settings(self): """ Returns dict with settings. """ # Find file with model settings. filepath_settings = self.model_filepath.replace('.h5', '_settings.xlsx') if not os.path.exists(filepath_settings): # Try removing the first element from the name. filename = '_'.join(os.path.basename(filepath_settings).split('_')[1:]) filepath_settings = os.path.join(os.path.dirname(filepath_settings), filename) if not os.path.exists(filepath_settings): raise FileNotFoundError('Cannot find settings.csv file. File "{}" does not exist.' .format(filepath_settings)) # Load settings. df = pd.read_excel(filepath_settings) # All settings are saved in first row. Convert to dict. settings = df.iloc[0, :].to_dict() return settings def _load_normpars(self): """ Returns tuple with mean, sd and channel labels. """ # Find file with normalization parameters. filepath_normpars = self.model_filepath.replace('.h5', '_normpars.csv') if not os.path.exists(filepath_normpars): raise FileNotFoundError('Cannot find normpars.csv file. File "{}" does not exist.' .format(filepath_normpars)) # Load parameters. df = pd.read_csv(filepath_normpars) mean = df['mean'] sd = df['sd'] channel_labels = df['channel_label'] if len(df) == 1: if self.model_settings.get('normalize_per_channel', False): raise AssertionError('In settings it says normalized per channel, ' 'but the normalization parameters contain just one channel.') elif len(df) > 1: if not self.model_settings['normalize_per_channel']: raise AssertionError('In settings it says not normalized per channel, ' 'but the normalization parameters contain multiple channels.') if len(df) == 1: # Extract first row. mean = mean.iloc[0] sd = sd.iloc[0] channel_labels = channel_labels.iloc[0] else: mean = mean.values sd = sd.values return mean, sd, channel_labels def _normalize(self, eeg): """ Normalize EEG with loaded normalization parameters specific to model. """ mean, sd, channel_labels = self.normpars eeg = (eeg - mean)/sd return eeg def _postprocess(self, y_chunks, original_fs, original_len): """ Args: y_chunks (list): list with predictions for each data chunk (see self._predict()). original_fs (float): sampling frequency of original data in Hz. original_len (int): length of original EEG data (number of time samples). Returns: mask (np.ndarray): array with shape (n_time, n_channels) containing 1 at locations of clean EEG and 0 at artefacts. Here n_time is in the fs of original EEG, so it can be used as a boolean mask for the EEG data. prob (np.ndarray): array with shape (n_samples, n_channels, n_classes) containing probabilies for artefact and clean. Here the n_samples is in the fs of the model output, so its probably a smaller array than the mask. """ # Cut off boundary size for each chunk. boundary_length = self.parameters['boundary_length'] fs_y = self.fs_output skip = int(boundary_length*fs_y) y_chunks_cut = [yi[0, skip:(yi.shape[1]-skip)] for yi in y_chunks] # Paste chunks together (n_time, n_channels, n_classes). prob = np.vstack(y_chunks_cut) n_time, n_channels, n_classes = prob.shape # Sanity check (should have as many seconds in output as in input). assert (original_len/original_fs - n_time/fs_y) < 1/fs_y # Create clean mask from probabilities. mask = (np.argmax(prob, axis=-1) == 1).astype(float) # Upsample mask to original sample rate. t_prob = np.arange(len(mask)) / fs_y + 1/fs_y/2 t_eeg = np.arange(original_len) / original_fs # Loop over channels. new_masks = [] for i in range(n_channels): # Use linear interpolation for resampling. mask_i = np.round(np.interp(t_eeg, t_prob, mask[:, i], left=mask[0, i], right=mask[-1, i])) new_masks.append(mask_i) mask = np.vstack(new_masks).T return mask, prob def _predict(self, eeg_chunks, verbose=1, **kwargs): """ Predict. Args: eeg_chunks (list): list with EEG chunks (see self.prepare_data). Returns: y (list): predictions, i.e. probabilites for each chunk, e.g. list of arrays with shape (1, n_seconds, n_channels, n_classes). """ bar = pyprind.ProgBar(len(eeg_chunks), stream=sys.stdout) y_chunks = [] for eeg_seg in eeg_chunks: # Shape (n_samples, n_seconds, n_channels, n_classes). y = self.model.predict(eeg_seg, verbose=verbose > 1, **kwargs) # Check shape. assert y.ndim == 4 assert y.shape[2] == eeg_seg.shape[2] # Channel dimension. y_chunks.append(y) if verbose > 0: bar.update() return y_chunks @staticmethod def _predict_flat_mask(eeg, fs, original_fs=None, original_len=None): """ Predict flat mask. """ # Check input shape (n_time, n_channels). eeg = check_eeg_is_long(eeg=eeg, mode='error') # Compute moving std. window = 2 min_std = 10 ** (-0.5) # From looking at distribution in resilience dataset. eeg_std = moving_std(x=eeg, n=int(fs * window), axis=0) flat_mask = eeg_std < min_std if original_fs is not None and fs != original_fs: # Upsample mask to original sample rate. if original_len is None: raise ValueError('`original_len` needs to be specified.') # Time arrays. t_flat = np.arange(len(flat_mask)) / fs + 1 / fs / 2 t_eeg = np.arange(original_len) / original_fs # Loop over channels. new_flat_masks = [] for mask_i in flat_mask.T: # Use linear interpolation for resampling. mask_i_new = np.round(np.interp(t_eeg, t_flat, mask_i, left=mask_i[0], right=mask_i[-1])) new_flat_masks.append(mask_i_new) flat_mask = np.vstack(new_flat_masks).T return flat_mask @staticmethod def _predict_peak_mask(eeg, fs, original_fs=None, original_len=None): """ Predict peak mask. """ # Check input shape (n_time, n_channels). eeg = check_eeg_is_long(eeg=eeg, mode='error') # Compute moving max of abs values. window = 1 # Seconds. max_amp = 1000 # uV. eeg_max = moving_max(x=np.abs(eeg), n=round(fs * window), axis=0) peak_mask = eeg_max > max_amp if original_fs is not None and fs != original_fs: # Upsample mask to original sample rate. if original_len is None: raise ValueError('`original_len` needs to be specified.') # Time arrays. t_peak = np.arange(len(peak_mask)) / fs + 1 / fs / 2 t_eeg = np.arange(original_len) / original_fs # Loop over channels. new_peak_masks = [] for mask_i in peak_mask.T: # Use linear interpolation for resampling. mask_i_new = np.round(np.interp(t_eeg, t_peak, mask_i, left=mask_i[0], right=mask_i[-1])) new_peak_masks.append(mask_i_new) peak_mask = np.vstack(new_peak_masks).T return peak_mask def _prepare_data(self, eeg, fs=None, verbose=1): """ Helper function to prepare the EEG recording for input to CNN (normalization, chunking). Args: eeg (np.ndarray): preprocessed EEG data with shape (n_time, n_channels). fs (float): sampling frequency of `eeg` in Hz. verbose (int): verbosity level. Returns: eeg_chunks (list): normalized and chunked EEG with length number of chunks and each element (chunk) is an EEG array with shape (1, n_time, n_channels). """ if fs != self.data_requirements['fs']: raise ValueError(f'fs={fs} does not match the data_requirements: fs={self.data_requirements["fs"]}. ' f'Make sure the input has the required sampling frequency.') n_time, n_channels = eeg.shape # Normalize. if self.model_settings.get('normalize', True): if verbose > 1: print('Normalizing...') eeg = self._normalize(eeg) # Nan to num. eeg = np.nan_to_num(eeg) # Add zeros to start since we will cut off some small part for boundary effects. boundary_length = self.parameters['boundary_length'] n_pad = int(boundary_length*fs) pad = np.zeros((n_pad, n_channels)) eeg = np.concatenate([pad, eeg, pad], axis=0) # To shape (1, n_time, n_channels). n_tot = len(eeg) eeg = np.expand_dims(eeg, axis=0) # Divide EEG into chunks (if EEG is shorter than chunk size, will just be one chunk). chunk_length = self.parameters['chunk_length'] if verbose > 1: print('Dividing data into chunks...') if chunk_length is not None: chunk_size = int(chunk_length * fs) overlap = int(boundary_length * fs)*2 stepsize = chunk_size - overlap chunk_start_indices = np.arange(0, n_tot, stepsize) else: chunk_start_indices = [0] chunk_size = n_tot # Add chunks to list (the last chunk may not have the same size as the rest, so use a list instead of array). eeg_chunks = [] for start_idx in chunk_start_indices: stop_idx = min(start_idx + chunk_size, n_tot) eeg_chunk = eeg[:, start_idx:stop_idx, :] eeg_chunks.append(eeg_chunk) return eeg_chunks