Source code for nnsa.artefacts.artefact_detection

"""
Module for exclusion of artefacts.
"""
import os
import warnings

import numpy as np
import scipy.signal

from nnsa.utils.arrays import moving_std, moving_mean, do_for_axis, moving_mad, moving_median, check_eeg_is_wide

__all__ = [
    'check_amplitude',
    'check_amplitude_mad',
    'check_max_diff',
    'default_eeg_signal_quality_criteria',
    'default_eeg_sample_quality_criteria',
    'default_oxygen_sample_quality_criteria',
    'detect_anomalous_channels',
    'detect_artefact_samples',
    'detect_artefact_signals',
    'detect_flatlines',
    'detect_high_amplitudes',
    'remove_flatlines',
    'signal_quality',
]

from nnsa.utils.event_detections import get_onsets_offsets

from nnsa.utils.segmentation import get_all_segments


[docs]def check_amplitude(x, n=-1, std_factor=3, axis=-1): """ Return a mask with True at samples where the amplitude falls within a moving average +- 3*moving std. Args: x (np.ndarray): time series array. n (int): window size of moving average/std. If None or -1, the global average/std is taken. std_factor (int, optional): the number of stds that the amplitude can differ from the mean. Defaults to 3. axis (int, optional): time axis. Defaults to -1. Returns: amp_ok (np.ndarray): boolean mask of same size as `x`, where True means that the amplitude is ok. """ if n is None or n == -1: # Compute global average and std. avg = np.nanmean(x, axis=axis) std = np.nanstd(x, axis=axis) else: # Compute moving average and std (roll to only use past samples). x_rolled = np.roll(x, int(np.ceil(n/2))) avg = moving_mean(x_rolled, axis=axis, n=n) std = moving_std(x_rolled, axis=axis, n=n) # Determine if the amplitude is ok, based on moving stats. amp_ok = np.logical_and(avg - std_factor*std <= x, x <= avg + std_factor*std) return amp_ok
[docs]def check_amplitude_mad(x, threshold=3, n=-1, axis=-1): """ Check amplitude based on median absolute deviation. See https://www.influxdata.com/blog/anomaly-detection-with-median-absolute-deviation/ """ # Compute running median and MAD. med = moving_median(x, n=n, axis=axis) mad = moving_mad(x, n=n, axis=axis) # Normalized anomaly value (if higher than a threshold (e.g. 3), classify as anomaly). anomaly_score = np.abs(x - med)/mad amp_ok = anomaly_score <= threshold return amp_ok
[docs]def check_max_diff(x, n, fs=1, std_factor=3, which='std', min_max_diff=None, axis=-1): """ Return a mask with True at samples where the diff with prev and next sample is < 3*moving std. Args: x (np.ndarray): time series array. n (int): window size of moving std. If None or -1, the global std is taken. fs (float): sample frequency of `x`. std_factor (int, optional): the number of stds that the diff can be. Defaults to 3. which (str, optional): which metric to use for std. Choose from: - 'std' or 'SD': for standard deviation. - 'mad' or 'MAD': for median absolute deviation (alternative to std robust to outliers). Only recommended for stationary series. min_max_diff (float): lower limit on the maximum diff. E.g. to prevent diff to go towards zero during a flatline. axis (int, optional): time axis. Defaults to -1. Returns: diff_ok (np.ndarray): boolean mask of same size as `x`, where True means that the diff is ok. """ # Compute difference with previous and next sample. diff = np.abs(np.diff(x, axis=axis))*fs zeros_shape = np.array(x.shape) zeros_shape[axis] = 1 zeros = np.zeros(zeros_shape) diff_prev = np.concatenate((zeros, diff), axis=axis) # Diff with previous sample. diff_next = np.concatenate((diff, zeros), axis=axis) # Diff with next sample. # Replace nan by 0. Otherwise, nans will result in not ok at the end. diff_prev = np.nan_to_num(diff_prev) diff_next = np.nan_to_num(diff_next) # Compute SD. which = which.lower() if which in ['std', 'sd']: std = moving_std(x, n=n, axis=axis) elif which in ['mad']: std = moving_mad(x, n=n, axis=axis) else: raise ValueError('Invalid option "{}" for `which`. Choose from {}.' .format(which, ['SD', 'MAD'])) # Threshold difference. max_diff = std_factor * std # Clip threshold. if min_max_diff is not None: max_diff = np.clip(max_diff, min_max_diff, np.inf) diff_ok = np.logical_and(diff_prev <= max_diff, diff_next <= max_diff) return diff_ok
def check_non_nan_length(x, min_length, axis=-1): """ Return a mask with True at samples that are part of a non-nan segment with length >= min_length. Args: x (np.ndarray): time series array. min_length (int): minimum length of non-nan segment (in samples). axis (int, optional): time axis. Defaults to -1. Returns: length_ok (np.ndarray): boolean mask of same size as `x`, where True means that the length of the non-nan segment is ok. Examples: >>> x = np.array([np.nan, 0., 0., np.nan, np.nan, 0., np.nan, 0., 0., 0.]) >>> check_non_nan_length(x, min_length=3) array([ True, False, False, True, True, False, True, True, True, True]) >>> check_non_nan_length(x, min_length=2) array([ True, True, True, True, True, False, True, True, True, True]) >>> x = np.array([0., 0., 0., np.nan, 0., np.nan, np.nan, 0., 0., 0.]) >>> check_non_nan_length(x, min_length=3) array([ True, True, True, True, False, True, True, True, True, True]) """ def _check_non_nan_length(a): """ Operates on 1D arrays only. """ if a.ndim != 1: raise ValueError('`x` should be 1-dimensional.') # Get start and stop indices of non-nan segments. nan_mask = np.isnan(a).astype(int) diff_mask = np.diff(nan_mask) start_idx = np.where(diff_mask == -1)[0] + 1 stop_idx = np.where(diff_mask == 1)[0] + 1 if nan_mask[0] == 0: start_idx = np.concatenate(([0], start_idx)) if nan_mask[-1] == 0: stop_idx = np.concatenate((stop_idx, [len(nan_mask)])) # Loop over start and stop indices and determine segment lengths. seg_lengths = np.full(len(nan_mask), fill_value=min_length) for start, stop in zip(start_idx, stop_idx): seg_lengths[start: stop] = stop - start # Return mask where the lengths are ok. length_ok = seg_lengths >= min_length return length_ok # Operate along specified axis. length_ok = do_for_axis(x, fun=_check_non_nan_length, axis=axis) return length_ok
[docs]def default_eeg_signal_quality_criteria(fs=250): """ Return dictionary with default criteria for EEG signal quality. Returns: criteria (dict): dictionary with default criteria for signal quality. """ criteria = { # Minimum accepted standard deviation in signal: 'min_std': 0.001, # Maximum accepted standard deviation in signal: 'max_std': 50, # Maximum accepted amplitude in signal: 'max_amp': 200, # Maximum accepted absolute difference between consecutive samples in signal: 'max_diff': 50*250/fs, # Maximum fraction of nan samples (in time dimension): 'max_nan_frac': 1e-12, # Maximum accepted fraction of channels/signals in data array that may be artefacted # (if the number of artefacted channels is higher, all channels are classified as artefacts): 'max_fraction_of_artefact_channels': 0.5, } return criteria
def default_bp_sample_quality_criteria(): """ Return dictionary with default criteria for blood pressure signal quality. Returns: criteria (dict): dictionary with default criteria for sample quality. """ criteria = { # Minimum accepted amplitude. 'min_amp': 5, # Maximum accepted amplitude. 'max_amp': 125, # Number of samples around the current sample that need to be artefact free to consider the current sample # artefact free (the number includes the current sample itself). 'num_artefact_free_neighbouring_samples': 3, # Custom functions (yield true for samples that are ok). 'custom_functions': [ # Max diff global. lambda x, axis: check_max_diff(x, n=-1, std_factor=5, axis=axis), ] } return criteria
[docs]def default_eeg_sample_quality_criteria(fs=250): """ Return dictionary with default criteria for EEG sample quality. Args: fs (float, optional): sample frequency of EEG in Hz. Defaults to 250. Returns: criteria (dict): dictionary with default criteria for sample quality. """ n = int(180*fs) # Window in samples for moving averages and sds. criteria = { # Minimum accepted amplitude. 'min_amp': -200, # Maximum accepted amplitude. 'max_amp': 200, # Maximum difference between consecutive samples. 'max_diff': 50 * 250 / fs, # Number of samples around the current sample that need to be artefact free to consider the current sample # artefact free (the number includes the current sample itself). 'num_artefact_free_neighbouring_samples': 3, # Custom functions (yield true for samples that are ok). 'custom_functions': [ # lambda x, axis: ~detect_flatlines(x, n_flatline='auto', axis=axis, tol=1e-14), # Min length of non artefact part. # lambda x, axis: check_non_nan_length(x, min_length=1 * fs, axis=axis), # Min and max amplitude. # lambda x, axis: check_amplitude(x, n=n, axis=axis), # Min and max amplitude 2. # lambda x, axis: check_amplitude(x, n=-1, std_factor=5, axis=axis), # Max diff. # lambda x, axis: check_max_diff(x, n=-1, axis=axis), # Min length of non artefact part. # lambda x, axis: check_non_nan_length(x, min_length=int(1*fs), axis=axis), ] } return criteria
[docs]def default_oxygen_sample_quality_criteria(fs=1): """ Return dictionary with default criteria for oxygen signal quality. Args: fs (float, optional): sample frequency of the oxygen signal in Hz. Default to 1. Returns: criteria (dict): dictionary with default criteria for sample quality. """ criteria = { # Minimum accepted amplitude. 'min_amp': 15, # Maximum accepted amplitude. 'max_amp': 100.01, # Maximum accepted absolute difference with neighbouring sample. 'max_diff': 25/fs, # Number of samples around the current sample that need to be artefact free to consider the current sample # artefact free (the number includes the current sample itself). 'num_artefact_free_neighbouring_samples': 3, # Custom functions (yield true for samples that are ok). 'custom_functions': [ lambda x, axis: ~detect_flatlines(x, n_flatline=int(3600*fs), flatline_range=[-np.inf, 70], axis=axis), # Min and max amplitude. # lambda x, axis: check_amplitude(x, n=5*60*fs, std_factor=4, axis=axis), # Max diff global. lambda x, axis: check_max_diff(x, n=-1, std_factor=4, min_max_diff=5/fs, which='std', axis=axis), # Max diff local. lambda x, axis: check_max_diff(x, n=600 * fs, std_factor=4, min_max_diff=5 / fs, which='MAD', axis=axis), # Min length of non artefact part. # lambda x, axis: check_non_nan_length(x, min_length=int(30*fs), axis=axis), ] } return criteria
[docs]def detect_anomalous_channels(x, window, fs=1, std_factor=8, p_trim=0.25, shape_mode='error'): """ Simple method to detect anomalous high-frequency/amplitude channels by setting a threshold on the line length. The log line length of the (1-`p_strip`)*n_channels with lowest line lengts is used to determine mean and std and if any channel exceeds mean + `std_factor`*std, then it is flagged as artefact. This is done per window. Args: x (np.ndarray): multichannel signal with dimensions (n_channels, n_samples). window (float): the window length (in seconds if `fs` is given, else in samples) in which to compute line length. fs (float): sampling frequency of `x`. std_factor (float): the factor for std determining the threshold (mean+`std_factor*std). p_trim (float): the relative amount of channels to strip before computing the mean and std. The channels with the highest line lengths are stripped/excluded. shape_mode (str): what to do with unexpected input shape (see check_eeg_is_wide()). Returns: mask_anomamly (np.ndarray): boolean array with same shape as `x` containing True at locations of anomalies. Examples: >>> # Create dummy signal with 8 channels and an anomaly somewhere in the third channel. >>> rng = np.random.RandomState(43) >>> x = rng.random((8, 1000)) >>> x[2, 301:600] *= 10 >>> # Detect anomaly >>> mask = detect_anomalous_channels(x, window=20) >>> print(np.where(np.any(mask, axis=1))[0]) [2] >>> onsets, offsets = get_onsets_offsets(mask[2]) >>> print(onsets) [301] >>> print(offsets) [600] """ # Check input shape. original_shape = np.asarray(x).shape x = check_eeg_is_wide(eeg=x, mode=shape_mode) n_channels, n_samples = x.shape # Compute line length in 1 second segments. x_seg = get_all_segments( x, segment_length=window, overlap=0, axis=-1, fs=fs) # Shape (n_segments, n_channels). line_length = np.log(np.mean(np.abs(np.diff(x_seg, axis=-1)), axis=-1)) assert line_length.shape[1] == n_channels # Sort to get the 6 cleanest channels (lowest LL). if p_trim: if p_trim >= 1: raise ValueError(f'`p_trim` should be lower than 1. Got p_strip={p_trim}.') n_trim = int(np.round(n_channels * p_trim)) if (n_channels - n_trim) < 2: raise ValueError(f'Not enough channels left with `p_trim`={p_trim}. Decrease the value.') line_length_stripped = np.sort(line_length, axis=1)[:, :-n_trim] else: # No sorting and stripping needed. line_length_stripped = line_length # Mean and std on channels. mean_ = np.nanmean(line_length_stripped, axis=1, keepdims=True) std_ = np.nanstd(line_length_stripped, axis=1, keepdims=True) # Set a minimum to the std (sometimes, by chance the std can become very small). min_std = np.mean(std_) std_ = np.clip(std_, a_min=min_std, a_max=np.inf) # Anomaly mask (per window/segment). Shape (n_segments, n_channels). mask_anomaly_seg = line_length > mean_ + std_factor * std_ # Upsample to original size and save in the af_ds. mask_anomaly = [] t_orig = np.arange(n_samples) / fs t_mask = np.arange(mask_anomaly_seg.shape[0]) * window + window / 2 for m in mask_anomaly_seg.T: mask = np.interp(x=t_orig, xp=t_mask, fp=m.astype(float), left=m[0], right=m[-1]) > 0.5 mask_anomaly.append(mask) # To shape (n_channels, n_samples). mask_anomaly = np.vstack(mask_anomaly) if original_shape != x.shape: # Assume x was transposed, transpose the mask too. mask_anomaly = mask_anomaly.T # Make sure the dimensions are right. if mask_anomaly.shape != original_shape: raise AssertionError(f'Unexpected result: the mask does not have the same shape ({mask_anomaly.shape}) ' f'as the original input ({original_shape}). Check your input dimensions.') return mask_anomaly
[docs]def detect_artefact_samples(x, min_amp=-np.inf, max_amp=np.inf, max_diff=np.inf, num_artefact_free_neighbouring_samples=1, custom_functions=None, axis=-1, demean=False): """ Detect samples in x along the specified axis that do not meet the sample quality criteria. Args: x (np.ndarray): array containing signal(s) along the specified axis. min_amp (float, optional): minimum accepted amplitude. max_amp (float, optional): maximum accepted amplitude. max_diff (float, optional): maximum accepted absolute difference with neighbouring sample. num_artefact_free_neighbouring_samples (int, optional): number of samples around the current sample that need to be artefact-free to consider the current sample artefact free. The number includes the current samples, so if it is 1, no neighbours are checked. If its 3, the previous and next sample are checked, etc. E.g. if this parameter is set to 5, the 2 samples before the current sample, the current sample itself and the 2 samples after the current sample need all to be artefact-free to consider the current sample to be artefact free. custom_functions (list, optional): list with functions that yield true for samples that are ok (i.e. non-artefact). Is applied after nans have been inserted based on min_amp, max_amp, max_diff. axis (int, optional): the axis corresponding to the time dimension of the signal(s) in x. demean (bool, optional): if True, subtracts the mean of the signal(s) from the signal(s) before assessing the sample quality. If False, does not subtract the mean. Defaults to False. Returns: artefact_mask (bool or np.ndarray): boolean array where True entries correspond to samples that are artefacts, False entries correspond to non artefact sample. The shape of the array is the same as the shape of x. """ x_mean = np.nanmean(x, axis=axis, keepdims=True) if demean: # Demeaning. x = x - x_mean # Threshold on amplitude. amp_ok = np.logical_and(x >= min_amp, x <= max_amp) # Compute difference with previous and next sample. diff = np.abs(np.diff(x, axis=axis)) zeros_shape = np.array(x.shape) zeros_shape[axis] = 1 zeros = np.zeros(zeros_shape) diff_prev = np.concatenate((zeros, diff), axis=axis) # Diff with previous sample. diff_next = np.concatenate((diff, zeros), axis=axis) # Diff with next sample. # Threshold on diff. diff_ok = np.logical_and(diff_prev <= max_diff, diff_next <= max_diff) # Sample is ok only if all criteria are met. ok_mask = np.logical_and(amp_ok, diff_ok) # Replace artefacts by nan before evaluating custom functions. x[~ok_mask] = np.nan # Evaluate custom functions. if custom_functions is not None: for fun in custom_functions: ok_fun = fun(x, axis) ok_mask = np.logical_and(ok_mask, ok_fun) x[~ok_mask] = np.nan # Artefact sample if sample is not ok. artefact_mask = np.logical_not(ok_mask) # Artefact free neighbouring samples. if num_artefact_free_neighbouring_samples > 1: artefact_mask = detect_neighborhood_artefacts( artefact_mask, size=num_artefact_free_neighbouring_samples, axis=axis) return artefact_mask
[docs]def detect_artefact_signals(x, min_std=0, max_std=np.inf, min_amp=-np.inf, max_amp=np.inf, max_diff=np.inf, max_nan_frac=1, max_fraction_of_artefact_channels=1, axis=-1, channel_axis=0, demean=False, keepdims=False): """ Detect signal(s) in x along the specified axis that do not meet the signal quality criteria. Args: x (np.ndarray): array containing signal(s) along the specified axis. For each signal in x a bool will be computed specifying whether the signal is an artefact or not. min_std (float, optional): minimum accepted standard deviation in signal. max_std (flaot, optional): maximum accepted standard deviation in signal. min_amp (float, optional): minimum accepted amplitude in signal, max_amp (float, optional): maximum accepted amplitude in signal. max_diff (float, optional): maximum accepted absolute difference between consecutive samples in signal. max_nan_frac (float, optional): maximum accepted fraction of nan samples (in time dimension). max_fraction_of_artefact_channels (float, optional): maximum accepted fraction of channels/signals in data array that may be artefacted (if the number of artefacted channels is higher, all channels are classified as artefacts). axis (int, optional): the axis corresponding to the time dimension of the signal(s) in x. channel_axis (int, optional): the axis corresponding to channels. demean (bool, optional): if True, subtracts the mean of the signal(s) from the signal(s) before assessing the signal quality. If False, does not subtract the mean. Defaults to True. keepdims (bool, optional): if True, the axes which are reduced are left in the output as dimensions with size 1. If False, the dimension corresponding to the specified axis is removed in the output. Defaults to False. Returns: artefact_mask (bool or np.ndarray): boolean array where True entries correspond to signals that are artefacts, False entries correspond to non artefact signals. The shape of the array is the same as the shape of x, except along the specified axis (this axis is removed or reduced to 1, depending on keepdims). """ if axis == channel_axis and max_fraction_of_artefact_channels < 1: raise ValueError('`channel_axis` cannot be equal to `axis`.') # Compute signal quality indices. sq = signal_quality(x, axis=axis, demean=demean, keepdims=keepdims) # Threshold on std. std_ok = np.multiply(min_std <= sq['std'], sq['std'] <= max_std) # Threshold on amplitude. amp_ok = np.multiply(min_amp <= sq['min_amp'], sq['max_amp'] <= max_amp) # Threshold on diff. if type(max_diff) == type(lambda x: x): max_diff = max_diff(x, axis) diff_ok = sq['max_diff'] <= max_diff # Threshold on fraction of nans. nan_ok = sq['nan_frac'] <= max_nan_frac # Signal is ok only if all criteria are met. ok_mask = np.multiply(np.multiply(np.multiply(std_ok, amp_ok), diff_ok), nan_ok) # Artefact signal if signal is not ok. artefact_mask = np.logical_not(ok_mask) # If too many channels are artefacts, classify all channels as artefacts. if max_fraction_of_artefact_channels < 1: n_channels = x.shape[channel_axis] max_channels_exclude = round(max_fraction_of_artefact_channels * n_channels) af_mask_channels = (np.sum(artefact_mask, axis=channel_axis, keepdims=True) > max_channels_exclude) *\ np.full_like(artefact_mask, fill_value=True) artefact_mask = np.logical_or(artefact_mask, af_mask_channels) return artefact_mask
[docs]def detect_flatlines(x, n_flatline='auto', flatline_range=None, fs=None, axis=-1, tol=1e-14): """ Detect flatlines, i.e. signal segments where the signal does not change for at least n_flatline consecutive samples. Args: x (np.array): array with signal(s). n_flatline (int or str, optional): minimal number of consecutive non-changing samples to be considered a flatline segment. If fs is given, this is in unit of seconds. If 'auto', a sensible number of samples is automatically determined per signal. flatline_range (list, optional): [min, max] range of signal values that can be considered a potential flatline. Only flatlines with values inside this range are considered real flatlines. If None, all values can be a flatline. Defaults to None. fs (float, optional): sampling frequency in Hz. If given, n_flatline is in seconds. If fs is not specified, n_flatline is in samples. axis (int, optional): axis of the time dimension in `x`. Defaults to -1. tol (float, optional): tolerance for finding non changing samples. Defaults to 1e-14. Returns: mask (np.array): boolean mask for input `x` where True values correspond to samples that belong to a flatline. Examples: >>> x = [0, 0, 0, 1, 2, 3, 3, 2, 10, 10, 10, 10, 2, 2, 2, 2, 2, 1, 0] >>> detect_flatlines(x, n_flatline=3, flatline_range=None) array([ True, True, True, False, False, False, False, False, True, True, True, True, True, True, True, True, True, False, False]) >>> detect_flatlines(x, n_flatline=3, flatline_range=[0, 5]) array([ True, True, True, False, False, False, False, False, False, False, False, False, True, True, True, True, True, False, False]) """ # Default input. if flatline_range is None: flatline_range = [-np.inf, np.inf] elif len(flatline_range) != 2: raise ValueError('Expected input `flatline_range` to be a list of length 2. Got length {}.' .format(len(flatline_range))) if fs is None: # Number of samples euquals number of seconds if fs=1 Hz. fs = 1 # Swap axis and reshape, so that we have a 2D array with the time dimension as the last dimension. x = np.asarray(x) xn = x.swapaxes(axis, -1) interm_shape = xn.shape xn = xn.reshape(-1, interm_shape[-1]) # Initialize mask. mask = np.full(xn.shape, fill_value=False) for i, xi in enumerate(xn): if n_flatline == 'auto': # Automatically determine n_flatline. idx_transitions = np.where(np.abs(np.diff(xi)) > tol)[0] line_lengths = np.diff(idx_transitions) inclusion_mask = np.abs(line_lengths - np.mean(line_lengths)) < 2.5*np.std(line_lengths) line_lengths = line_lengths[inclusion_mask] n = int(np.ceil(np.mean(line_lengths) + 10*np.std(line_lengths))) if n > len(xi) or n == 1: raise ValueError('Cannot determine n_flatline automatically, set manually.') else: n = int(n_flatline*fs) # Find onsets and offsets of non-changing parts. d1 = (np.abs(np.diff(xi)) < tol).astype(int) d2 = np.diff(d1) idx_onsets = np.where(d2 == 1)[0] + 1 idx_offsets = np.where(d2 == -1)[0] + 2 if d1[0] == 1: idx_onsets = np.append([0], idx_onsets) if d1[-1] == 1: idx_offsets = np.append(idx_offsets, [len(d2)]) assert len(idx_onsets) == len(idx_offsets) for onset, offset in zip(idx_onsets, idx_offsets): # If length of flatline is large enough, indicate flatlines by True. flat_val = xi[onset] if offset - onset >= n and flatline_range[0] <= flat_val <= flatline_range[1]: mask[i, onset: offset] = True # Reshape and swap back. mask = mask.reshape(interm_shape).swapaxes(axis, -1) return mask
def detect_neighborhood_artefacts(af_mask, size, max_af_frac=0.5, fs=1, axis=-1): """ Check the neighborhood for artefacts. If enough artefacts are found in the neighborhood of sample i, sample i is labeled as artefact. Args: af_mask (np.ndarray): boolean array with True values at locations of artefacts, and False values elsewhere. size (int): size of the neighborhood (including the current sample). In samples if `fs` is not specified, or in seconds if `fs` is specified. E.g., if the neighborhood is 3, sample i is an artefact if there is an artefact in sample i-1, i, or i+1. max_af_frac (float): maximum fraction of samples in the neighborhood that are allowed. If there are more artefacts than this fraction in the neighborhood of sample i, sample i is considered an artefact. fs (float): sampling frequency in Hz. If not given `size` is in samples. If given, `size` is in seconds. axis (int): neighborhood axis (e.g. the time axis). Returns: af_mask_new (np.ndarray): boolean array with same shape as `af_mask` with the updated artefact mask. """ size = int(size * fs) if size < 2: msg = '\nsize={} is not odd. No neighborhood defined.'.format(size) warnings.warn(msg) return af_mask # Create a convolution kernel that counts the number of artefact samples in the neighbourhood. # Kernel should have same number of dimensions as x, with only the axis dimension have length > 1. kernel_size = np.ones(len(af_mask.shape), dtype=int) kernel_size[axis] = size kernel = np.ones(kernel_size, dtype=int) # Use integers to prevent floating point effects! # Count using convolution. num_ans = scipy.signal.convolve(af_mask, kernel, mode='same') # A sample is an artefact if there are any artefacts in the neighborhood. af_mask_new = np.round(num_ans) > np.round(size*max_af_frac) # Round to prevent floating point effects. # Update artefact mask. af_mask_new = np.logical_or(af_mask, af_mask_new) return af_mask_new def detect_outliers(x, feature='amplitude', feature_window=None, normalize_window=None, min_range=-np.inf, max_range=np.inf, std_factor=4, fs=1, axis=-1): """ Detect outliers in a signal based on a feature of the signal using adaptive data-driven thresholds. Computes the `feature`, optionally smooths it in `feature_window`, finds a range for normal values using median and MAD computed in windows of `normalize_window`, and then thresholds the features value, giving values outside the estimated normal range an anomaly label. Args: x (np.ndarray): input array. feature (str): feature to compute. E.g. 'x', 'abs_dxdt' (see code for options). feature_window (float or None): number of samples/seconds in which to average the feature. normalize_window (float or None): number of samples/seconds in which to normalize the averaged feature. min_range (float): minimal range for feature value. max_range (float): maximal range for feature value. std_factor (float): range around the median in terms of (MAD-based) standard devaitions, e.g. 3 or 4. fs (float): sampling frequency of `x` in Hz. If not specified, the windows are in samples. axis (int): time axis of `x`. Returns: af_mask (np.ndarray): boolean mask, with True values where anomalies were found and False values elsewhere. """ # Compute feature. feature = feature.lower() if feature in ['x', 'amp', 'amplitude']: # Just the signal values. y = x.copy() elif feature in ['abs_x']: # The absolute signal values. y = np.abs(x) elif feature in ['diff', 'dxdt']: # Compute dxdt. y = np.diff(x, prepend=np.nan, axis=axis) * fs elif feature in ['ll', 'line_length', 'abs_dxdt']: # Compute abs(dxdt). y = np.abs(np.diff(x, prepend=np.nan, axis=axis) * fs) else: raise NotImplementedError('Not implemented for feature="{}".'.format(feature)) # Smooth feature. if feature_window is not None: y = moving_mean(y, n=int(feature_window * fs), axis=axis) # Find anomalies using MAD and median. MAD is a robust surrogate for std. if normalize_window is None: n_moving = None else: n_moving = int(normalize_window * fs) med = moving_median(y, n=n_moving, axis=axis) std = moving_mad(y, n=n_moving, axis=axis) # Prevent MAD getting too small or too large (maintain some expected range). std = np.clip(std, min_range, max_range) # Detect outliers. af_mask = np.abs(y - med) > std_factor * std return af_mask def detect_values(x, min_x=None, max_x=None, max_abs_dxdt=None, fs=1, axis=-1): """ Detect values in a signal based on absolute thresholds. Args: x (np.ndarray): input array. min_x (float): minimal value for x. Lower values will be detected. max_x (float): maximal value for x. Larger values will be detected. max_abs_dxdt (float): maximal value for absolute first time derivative of x. fs (float): sampling frequency of `x` in Hz. If not specified, assumes 1 Hz. axis (int): time axis of `x`. Returns: af_mask (np.ndarray): boolean mask, with True values where anomalies were found and False values elsewhere. """ # Start with a clean mask. af_mask = np.full(x.shape, fill_value=False) # Apply thresholds. if min_x is not None: af_mask = np.logical_or(af_mask, x < min_x) if max_x is not None: af_mask = np.logical_or(af_mask, x > max_x) if max_abs_dxdt is not None: abs_dxdt_prev = np.abs(np.diff(x, prepend=np.nan, axis=axis) * fs) # Diff with prev. abs_dxdt_next = np.abs(np.diff(x, append=np.nan, axis=axis) * fs) # Diff with next. abs_dxdt = np.maximum(np.nan_to_num(abs_dxdt_prev), np.nan_to_num(abs_dxdt_next)) af_mask = np.logical_or(af_mask, abs_dxdt > max_abs_dxdt) return af_mask
[docs]def remove_flatlines(x, inplace=False, **kwargs): """ Detect flatlines and replace by np.nan. Args: x (np.array): array with signal(s). inplace (bool): whether to change `x` inplace or not. **kwargs (optional): keyword arguments for detect_flatlines(). Returns: xn (np.array): if not inplace: a copy of the original array `x`, but with np.nans at flatline segments. """ # Detect flatlines. mask = detect_flatlines(x, **kwargs) # Make a copy of x if not inplace. if not inplace: xn = x.copy() else: xn = x # Set the flatlines to np.nan xn[mask] = np.nan # Return if not inplace. if not inplace: return xn
def remove_neighborhood_artefacts(x, *args, inplace=False, **kwargs): """ Detect neighborhood artefacts and replace by np.nan. Args: x (np.ndarray): array with signal(s). *args: arguments for remove_neighborhood_artefacts(). inplace (bool): whether to change `x` inplace or not. **kwargs (optional): keyword arguments for remove_neighborhood_artefacts(). Returns: xn (np.array): if not inplace: a copy of the original array `x`, but with np.nans at samples with artefacts in their neighborhood. """ # Detect samples with artefacts in their neighborhood. mask = detect_neighborhood_artefacts(np.isnan(x), *args, **kwargs) # Make a copy of x if not inplace. if not inplace: xn = x.copy() else: xn = x # Set to np.nan xn[mask] = np.nan # Return if not inplace. if not inplace: return xn def remove_outliers(x, *args, inplace=False, **kwargs): """ Detect outliers and replace by np.nan. Args: x (np.ndarray): array with signal(s). *args: arguments for detect_outliers(). inplace (bool): whether to change `x` inplace or not. **kwargs (optional): keyword arguments for detect_outliers(). Returns: xn (np.array): if not inplace: a copy of the original array `x`, but with np.nans at outlier samples. """ # Detect outliers. mask = detect_outliers(x, *args, **kwargs) # Make a copy of x if not inplace. if not inplace: xn = x.copy() else: xn = x # Set to np.nan xn[mask] = np.nan # Return if not inplace. if not inplace: return xn def remove_values(x, *args, inplace=False, **kwargs): """ Detect values and replace by np.nan. Args: x (np.ndarray): array with signal(s). *args: arguments for detect_outliers(). inplace (bool): whether to change `x` inplace or not. **kwargs (optional): keyword arguments for detect_values(). Returns: xn (np.array): if not inplace: a copy of the original array `x`, but with np.nans at specific samples. """ # Detect outliers. mask = detect_values(x, *args, **kwargs) # Make a copy of x if not inplace. if not inplace: xn = x.copy() else: xn = x # Set to np.nan xn[mask] = np.nan # Return if not inplace. if not inplace: return xn
[docs]def signal_quality(x, axis=-1, demean=True, keepdims=False): """ Return quality indices of the signal(s) in x along the specified axis. Args: x (np.ndarray): array containing signal(s) along the specified axis. For each signal in x the quality indices will be computed. axis (int, optional): the axis corresponding to the time dimension of the signal(s) in x. demean (bool, optional): if True, subtracts the mean of the signal(s) from the signal(s) before computing the signal quality. If False, does not subtract the mean. Defaults to True. keepdims (bool, optional): if True, the axes which are reduced are left in the output as dimensions with size 1. If False, the dimension corresponding to the specified axis is removed in the output. Defaults to False. Returns: quality (dict): dictionary with quality indices as keys and arrays containing the values of those indices for all signals in x as values. """ if demean: # Demeaning. x = x - np.nanmean(x, axis=axis, keepdims=True) # Initialize output dict. quality = dict() # STD. quality['std'] = np.nanstd(x, axis=axis, keepdims=keepdims) # Min amplitude. quality['min_amp'] = np.nanmin(abs(x), axis=axis, keepdims=keepdims) # Max amplitude. quality['max_amp'] = np.nanmax(abs(x), axis=axis, keepdims=keepdims) # Max absolute difference between successive values. quality['max_diff'] = np.nanmax(abs(np.diff(x, axis=axis)), axis=axis, keepdims=keepdims) # Fraction of nans. quality['nan_frac'] = np.mean(np.isnan(x), axis=axis, keepdims=keepdims) return quality
[docs]def detect_high_amplitudes(amp, af_mask=None): """ Detect high amplitude segments. Args: amp (np.ndarray): (n_segments, n_channels) af_mask (np.ndarray): (n_segments, n_channels) Returns: is_high_amp (np.ndarray): (n_segments, n_channels) """ os.environ["OMP_NUM_THREADS"] = '1' from sklearn.cluster import KMeans if af_mask is None: af_mask = np.full(amp.shape, fill_value=False) log_amp = np.nan_to_num(np.log(amp), neginf=0, posinf=0) # Loop over channels and determine which channels are high amplitude. is_high_amp = [] for log_amp_i, is_af_i in zip(log_amp.T, af_mask.T): # Find mean log amplitudes of clean and artefacts. log_amp_clean = log_amp_i[~is_af_i] if len(log_amp_clean) == 0: mean_log_amp_clean = 1 else: mean_log_amp_clean = np.nanmean(log_amp_clean) mean_log_amp_af = np.log(2 * np.exp(mean_log_amp_clean)) # Do kmeans. kmeans = KMeans( n_clusters=2, random_state=43, init=np.array([mean_log_amp_clean, mean_log_amp_af]).reshape(-1, 1), n_init=1, algorithm='elkan').fit(log_amp_i.reshape(-1, 1)) cluster_idx = kmeans.labels_ idx_high = np.argmax(kmeans.cluster_centers_) is_high_i = cluster_idx == idx_high is_high_amp.append(is_high_i) # Compute number of high amp channels in a segment. is_high_amp = np.asarray(is_high_amp).T # (segment, channels) return is_high_amp