Source code for nnsa.preprocessing.data_cleaning

"""
Author: Tim Hermans (tim-hermans@hotmail.com).
"""
import numpy as np
from scipy.signal import convolve2d

from nnsa.utils.arrays import moving_mean
from nnsa.utils.mathematics import closest_multiple


[docs]def substitute_bad_channels( x, af_mask, fs=1, fs_grid=1, axis=0, aci_window=3, aci_threshold=50, max_channels=3, min_duration=0, transition_window=1, verbose=1): """ Identify transient bad EEG channels and substitute them by the mean of the good channels. The steps of the algorithm are: 1. Compute the artefact contamination index (percentage of time where af_mask is True) in a moving time window (`aci_window`). 2. Find samples with a moving ACI greater than the threshold (`aci_threshold`) and mark as artefact. 3. Find locations in time the amount of channels with an artefact does not exced `max_channels`. 4. Find segments during which the above holds and the same channels are marked as artefact. 5. For all such segments that last at least `min_duration`, replace the artefact channels with the mean of the non-artefact channels. To ensure a smooth transition, use a `transition_window` during which the original and the replaced signals linearly transition into each other. To do this, each segment is elongated left and right so that the transition happens outside the original segment. Args: x (np.ndarray): 2D array containg multi-channel data with shape (n_time, n_channels). af_mask (np.ndarray): boolean mask with shape (n_time, n_channels) containing True at locations of artefacts. fs (float): sampling frequency, scales the windows. fs_grid (float): sampling frequency of the grid onto which to snap the adaptive segments. If not None, the segments borders are snapped onto a grid with grid size 1/fs_grid. axis (int): time axis. If 1 or -1, the input shape is expected to be (n_channels, n_time). aci_window (float): time window in seconds in which to compute the moving ACI. aci_threshold (float): threshold in % for channel-ACI to consider the channel bad. max_channels (int): maximum number of channels that can be bad at the same time. If more channels in a segment are bad, no substitution is done. min_duration (float): minimum duration of artefact segment to subsitute. By default, half of `aci_window` is used. transition_window (float): time window in seconds used to transition from the original channel data to the substitute/clean data to ensure smooth transitions. This window falls completely within the to-be-substituted artefact. verbose (int): verbosity level. Returns: x (np.ndarray): new data. af_mask (np.ndarray): new af_mask. """ if af_mask.ndim != 2: raise ValueError('`af_mask` should be 2D. Got array with shape {}.'.format(af_mask.shape)) if x is None: x = np.zeros_like(af_mask) x_shape = x.shape if x.shape != af_mask.shape: raise ValueError('`x` and `af_mask` should have the same shape. Got shapes {} and {}.' .format(x.shape, af_mask.shape)) if min_duration is None: min_duration = max([aci_window / 2, transition_window * 2]) raise DeprecationWarning('min_duration=None is deprecated.') if axis in [1, -1]: # Transpose. x = x.T af_mask = af_mask.T # Pre-compute the transition weights. transition = np.linspace(0, 1, int(fs * transition_window)) if max_channels > 0: # Compute ACI (%) per channel in moving window. moving_aci = moving_mean(af_mask.astype(float), n=int(aci_window * fs), axis=0) * 100 # Determine if sample is artefact or clean based on moving ACI. is_af = moving_aci > aci_threshold # Determine if substitution is possible: max max_channels channels at the same time artefact. subs_possible = np.sum(is_af, axis=1, keepdims=True) <= max_channels is_af = is_af & subs_possible # Adaptively segment: segment when any of the channels change from artefact to clean or vice versa. abs_diff = np.sum(np.abs(np.diff(is_af.astype(int), axis=0)), axis=-1) idx_change = np.where(abs_diff > 0.5)[0] + 1 idx_change = np.concatenate([[0], idx_change, [len(is_af)]]) # Optionally snap segment border to grid. if fs_grid is not None: # Compute grid size in samples. grid_size = int(np.round(fs / fs_grid)) # Snap segment borders to grid. idx_change = np.unique(closest_multiple(idx_change, grid_size)) # Compute means of channels. x_means = np.nanmean(x, axis=0) # Subtract means of channels. x_demean = x - x_means # Make copies of the data in which we will insert the substitutions. x_cleaned = x_demean.copy() af_cleaned = af_mask.copy() # Loop over segments. samples_subsituted = 0 for idx_start, idx_stop in zip(idx_change[:-1], idx_change[1:]): # Skip if too short. duration = (idx_stop - idx_start) / fs if duration < min_duration: continue # Skip if no subs possible. is_af_seg = is_af[idx_start: idx_stop] if not np.any(is_af_seg): continue if fs_grid is not None: # Channel is artefact if more then 50% of the time is artefact. is_af_seg = np.nanmean(is_af_seg, axis=0, keepdims=True) >= 0.5 if np.sum(is_af_seg) > max_channels: # Skip. continue # Assert that artefact channels are constant throughout segment. if not np.all(np.diff(is_af_seg.astype(float), axis=0) == 0): raise AssertionError("If code reaches here, there's a bug.") # Enlarge the window. idx_start = int(max([0, idx_start - len(transition)])) idx_stop = int(min([len(x_cleaned), idx_stop + len(transition)])) # Add rows to the artefact mask. n_pad = (idx_stop - idx_start) - len(is_af_seg) af_pad = np.tile(is_af_seg[:1], (n_pad, 1)) is_af_seg = np.concatenate([is_af_seg, af_pad], axis=0) # Compute mean of healthy channels. x_seg = x_cleaned[idx_start: idx_stop] x_healthy = x_seg.copy() x_healthy[is_af_seg] = np.nan x_subs = np.nanmean(x_healthy, axis=-1, keepdims=True) # Do the same with artefact mask. af_seg = af_mask[idx_start: idx_stop] af_healthy = af_seg.copy() af_healthy[is_af_seg] = np.nan af_subs = np.nanmean(af_healthy, axis=-1, keepdims=True) > 0.5 assert len(x_subs) == (idx_stop - idx_start) # Make smooth transition in and out of artefact. weights = np.ones([len(x_subs), 1]) weights[:len(transition), 0] = transition weights[-len(transition):, 0] = transition[::-1] # Substitute signal. x_subs = x_subs * weights + (1 - weights) * x_seg x_new = x_seg.copy() x_new[is_af_seg] = x_subs[is_af_seg] x_cleaned[idx_start:idx_stop] = x_new # Substitute af mask. af_subs = af_subs * weights + (1 - weights) * af_seg af_new = af_seg.copy() af_new[is_af_seg] = af_subs[is_af_seg] af_cleaned[idx_start:idx_stop] = af_new samples_subsituted += np.sum(is_af_seg) # Add mean back. x_cleaned += x_means if verbose: print(f'{(samples_subsituted/x.size)*100:.2f} % samples substituted.') else: x_cleaned = x.copy() af_cleaned = af_mask.copy() if axis in [1, -1]: # Transpose back. x_cleaned = x_cleaned.T af_cleaned = af_cleaned.T assert x_cleaned.shape == x_shape return x_cleaned, af_cleaned