Source code for nnsa.feature_extraction.common

"""
Common functions for the feature extraction algorithms.
"""
import heapq

import numpy as np

from nnsa.artefacts.artefact_detection import detect_artefact_signals
from nnsa.preprocessing.filter import filter_signal

__all__ = [
    'aggregate_channel_features',
    'aggregate_channel_events',
    'aggregate_channel_segment_features',
    'aggregate_segment_features',
    'baseline_correction_min',
    'get_channel_index',
    'check_multichannel_data_matrix',
    'local_to_global_features',
    'prepare_postfix',
    'preprocess_segment',
    'print_init_message',
]

from nnsa.utils.config import HORIZONTAL_RULE


[docs]def aggregate_channel_features(data, axis=0): """ Aggregate features per channel to one global feature. This is a shortcut, called by most feature extraction classes, so changing this function, affects most features. Args: data (np.ndarray): data array where one axis corresponds to channels. axis (int, optional): the axis in `data` corresponding to the channels. Defaults to 0. Returns: (np.ndarray): array containing the aggregated data with the same shape as `data`, expect for `axis`. The `axis` dimension has been removed. """ return np.nanmedian(data, axis=axis)
[docs]def aggregate_channel_events(detected, min_channels=None, min_channels_elong=None): """ Aggregate the event masks of all channels into one global event mask. Handles np.nan values by assigning np.nan to indices where all channels in `detected` are np.nan. Args: detected (np.ndarray): array with dimensions (channels, time) containing 1s at samples with detected events and 0s at samples without the event (e.g. burst mask). min_channels (int, optional): minimum number of channels to detected the event in order to consider the event detected glabally (after aggregation). If None, the value is set to the total number of channels available in the input array. Defaults to None. min_channels_elong (int, optional): minimum number of channels that must detect the event when elongating the globally detected events. If None, the detected global events are not elongated. Defaults to None. Returns: detected_agg (np.ndarray): array with shape (detected.shape[1]) containing 1s at samples with detected global events (across multiple channels) and 0s at samples without the event (e.g. burst mask). The type is float, since it may contain np.nan values (if present in `detected` input). Examples: >>> detected = np.array([[1, 1, 0, 0, 0, 1, 1, 1, 0, 0], [1, 1, 0, 1, 0, 0, 1, 0, 1, 0]]) >>> print(detected) [[1 1 0 0 0 1 1 1 0 0] [1 1 0 1 0 0 1 0 1 0]] >>> aggregate_channel_events(detected, min_channels=2) array([1., 1., 0., 0., 0., 0., 1., 0., 0., 0.]) >>> aggregate_channel_events(detected, min_channels=2, min_channels_elong=1) array([1., 1., 0., 0., 0., 1., 1., 1., 1., 0.]) """ if min_channels is None: # By default, event must be detected in all channels. min_channels = detected.shape[0] elif min_channels > detected.shape[0]: raise ValueError('Minimum number of channels (min_channels={}) exceeds the number of available channels ({}).' .format(min_channels, detected.shape[0])) # Threshold on number of channels with detections. detected_num_channels = np.nansum(detected, axis=0) detected_agg = (detected_num_channels >= min_channels).astype(float) detected_agg[np.all(np.isnan(detected), axis=0)] = np.nan if min_channels_elong is not None and min_channels_elong < min_channels: if min_channels_elong <= 0: raise ValueError('min_channels_elong must be > 0. Otherwise an array of ones is created. ' 'Got min_channels_elong={}.'.format(min_channels_elong)) # Elongate detected periods while number of channels with detections >= min_channels_elong. # Find onsets and offsets. d = np.diff(detected_agg) idx_onsets = np.where(d == 1)[0] + 1 # First 1. idx_offsets = np.where(d == -1)[0] # Last 1. # Elongate before onset. for idx in idx_onsets: while idx - 1 >= 0 and detected_agg[idx - 1] == 0 and \ detected_num_channels[idx - 1] >= min_channels_elong: detected_agg[idx - 1] = 1 idx -= 1 # Elongate after offset. for idx in idx_offsets: while idx + 1 < len(detected_agg) and detected_agg[idx + 1] == 0 and \ detected_num_channels[idx + 1] >= min_channels_elong: detected_agg[idx + 1] = 1 idx += 1 return detected_agg
[docs]def aggregate_segment_features(data, axis=-1): """ Aggregate features per segment to one global feature. This is a shortcut, called by most feature extraction classes, so changing this function, affects most features. Args: data (np.ndarray): data array where one axis corresponds to segments. axis (int, optional): the axis in `data` corresponding to the segments. Defaults to -1. Returns: (np.ndarray): array containing the aggregated data with the same shape as `data`, expect for `axis`. The `axis` dimension has been removed. """ return np.nanmedian(data, axis=axis)
[docs]def aggregate_channel_segment_features(data, feature_name, aggregate_segments=aggregate_segment_features, aggregate_channels=aggregate_channel_features, channel_labels=None, postfix=None): """ Helper function to aggregate features consisting of (channels, segments). Args: data (np.ndarray): array with dimensions corresponding to (channels, segments). feature_name (str): name of the feature. aggregate_segments (function, optional): function that takes an array of segment features as input and returns one aggregate value. Defaults to aggregate_segment_features. aggregate_channels (function or None, optional): function that takes an array of channel features and returns one aggregate values. If None, the channels are not aggregated, i.e. the feature values each channel are returned per channel. Defaults to aggregate_channel_features. channel_labels (list or None, optional): list of strings representing labels corresponding to the first dimension of data. Can only be None if aggregate_channels is not None. Defaults to None. postfix (str, optional): postfix for the feature name. If None, no postfix will be added. Defaults to None. Returns: features (dict): dictionary with aggregated features. The keys include the `feature_name` and optionally the channel label (if `aggregate_channels` is None) and `postfix` if not None. """ # Combine segments. channel_features = np.array([aggregate_segments(ch) for ch in data]) # Prepare postfix. postfix = prepare_postfix(postfix) features = dict() # Combine all channels if requested. if aggregate_channels is None: # Check channel labels. if isinstance(channel_labels, str): channel_labels = [channel_labels] if len(channel_labels) != len(data): raise ValueError('Length of `channel_labels` ({}) is not equal to length of `data` ({}).' .format(len(channel_labels), len(data))) # Save the feature values of each channel. for channel_label, value in zip(channel_labels, channel_features): features['{}_{}{}'.format(feature_name, channel_label.lstrip('EEG '), postfix)] = value else: # Combine the feature values across channels. value = aggregate_channels(channel_features) features['{}{}'.format(feature_name, postfix)] = value return features
[docs]def baseline_correction_min(x, window_length): """ Do a baseline correction on x, by subtracting the minimum value in x in a window preceding x. x_corrected[i] = x[i] - min{x[j] | j in{i - window_length, ..., i}}. For the first window_length samples, the minimum value in this segment is used as baseline. Args: x (np.ndarray): 1D array with signal. window_length (int): length of the window to look for the minimum. Returns: x_corrected (np.ndarray): 1D array of the same size as x. Examples: >>> x = np.array([5, 4, 6, 2, 4, 3, 5, 6, 1, 0, 8, 9, 7, 9]) >>> print(x) [5 4 6 2 4 3 5 6 1 0 8 9 7 9] >>> x_corrected = baseline_correction_min(x, window_length=4) >>> baseline = x - x_corrected >>> print(baseline) [2 2 2 2 2 2 2 3 1 0 0 0 0 7] >>> print(x_corrected) [3 2 4 0 2 1 3 3 0 0 8 9 7 2] """ window_length = int(window_length) if window_length > len(x): raise ValueError('window_length ({}) must be lower than the length of x ({}).'.format(window_length, len(x))) baseline = np.zeros_like(x) # Use a heapq. As tuples. Where the first element is the value g_avg and the second element is the index. # Initialize heap. heap = [] heapq.heappush(heap, (np.inf, -1)) # Loop over values in x. for i, x_i in enumerate(x): # Compare the current value to the minimum value in the heap (at the top of the heap). if x_i < heap[0][0]: # If current value is lower than all past values in the heap, we can reset the heap. heap = [] else: # Else, we first kick the minimum values out of the heap if they fall out of the window. while i - heap[0][1] >= window_length: heapq.heappop(heap) # Then again, if the current value is lower than all values in the heap, we can reset the heap. if x_i < heap[0][0]: heap = [] # Get out of the while loop. break # Add the current value to the heap. heapq.heappush(heap, (x_i, i)) # Use the minimum value stored at the top of the heap as baseline. baseline[i] = heap[0][0] # Set the baseline values with index lower than the window length, to the baseline value in the first full window. baseline[: window_length - 1] = baseline[window_length - 1] # Subtract the baseline. x_corrected = x - baseline return x_corrected
[docs]def get_channel_index(channel_labels, channels): """ Return the row index/indices of the specified channel(s) in channel_labels. If channel is a list, this function is called recursively on each element. Args: channel_labels (list): list of channel labels that define the index. channels (str or int or list): (list of) label(s) of the channel(s) to get the index of. Returns: idx (int or list): (list of) index/indices corresponding to the specified channel(s). """ from nnsa.edfreadpy.io.utils import standardize_and_check_eeg_label if type(channels) is str: # Channel is the channel label. if channels not in channel_labels: # Try standardizing it. channels = standardize_and_check_eeg_label(channels)[0] if channels not in channel_labels: raise ValueError('Invalid channel "{}". Choose from {}.'.format(channels, channel_labels)) # Index of the channel in the arrays. idx = channel_labels.index(channels) elif type(channels) is int: # Channel is the channel index. if channels < len(channel_labels): idx = channels else: raise ValueError('Channel index {} too large. Only got {} channels.' .format(channels, len(channel_labels))) elif type(channels) is list: # Create a list with indices of the channels in the list. idx = [] for label in channels: idx.append(get_channel_index(channel_labels, label)) elif channels is None: # Return all indices. idx = list(range(len(channel_labels))) else: raise TypeError('Invalid input type "{}" for channel.'.format(type(channels))) return idx
[docs]def check_multichannel_data_matrix(data_matrix, channel_labels=None): """ Check data_matrix, which is the input of most feature_extraction methods. Args: data_matrix (np.ndarray): 2D array with dimensions (channels, time) containing multichannel data. If data_matrix has only one dimension, it is automatically converted to shape (1, len(data_matrix)). channel_labels (list of str, optional): optional list of labels for the channels in data_matrix. Defaults to None. Returns: data_matrix (np.ndarray): data_matrix, possibly reshaped. channel_labels (list, optional): channel_labels, possibly converted from string to list. Raises: ValueError: - if matrix.ndim is not 1 or 2. - if the size of the first axis exceeds the size of the second axis in data_matrix (assume we have more time points than channels). - if the number of channel labels specified does not correspond with the size of the first axis of data_matrix. """ # Check data_matrix dimension and shape. if data_matrix.ndim == 1: # Reshape to (1, -1) array. data_matrix = data_matrix.reshape(1, -1) if data_matrix.ndim != 2: raise ValueError('Invalid array shape. Array of 2 dimensions expected. ' 'Got array with shape: {}.'.format(data_matrix.shape)) if np.diff(data_matrix.shape)[0] < 0: raise ValueError('Unexpected shape {}. Verify the dimension correspond to (channels, time).' .format(data_matrix.shape)) # Check channel labels. if channel_labels is not None: if type(channel_labels) is str: # Convert to list. channel_labels = [channel_labels] if len(channel_labels) != data_matrix.shape[0]: raise ValueError('Length of channel_labels ({}) is not equal to number of channels in data_matrix ({}).' .format(len(channel_labels), data_matrix.shape[0])) return data_matrix, channel_labels
[docs]def local_to_global_features(features, segment_start_times, segment_end_times, sleep_stages=None, aggregate_segments=aggregate_segment_features, aggregate_channels=aggregate_channel_features, channel_labels=None): """ Helper function to go from local features (per channel, per segment) to global features. Optionally returns the aggregate feature values per sleep stage if `sleep_stages` is given. Args: features (dict): dict with local features. Keys are the names of the features. Values are arrays with feature values with dimension corresponding to (channels, segments). segment_start_times (np.ndarray): start times of the segments corresponding to the arrays in `features`. segment_end_times (np.ndarray): end times of the segments corresponding to the arrays in `features`. sleep_stages (nnsa.SleepStagesResult, optional): object containing sleep stages result. Used to report global feature values per sleep stage. If None, all data is used (no distinction is made for sleep stage). Defaults to None. aggregate_segments (function, optional): function that takes an array of segment features as input and returns one aggregate value. Defaults to nnsa.aggregate_segment_features. aggregate_channels (function or None, optional): function that takes an array of channel features and returns one aggregate values. If None, the channels are not aggregated, i.e. the feature values each channel are returned per channel. Defaults to nnsa.aggregate_channel_features. channel_labels (list or None, optional): list of strings representing labels corresponding to the first dimension of the arrays in `features`. Can only be None if aggregate_channels is not None. Defaults to None. Returns: global_features (dict): dictionary containing the feature name and value pairs. """ global_features = dict() if sleep_stages is None: # Combine all segments to get global features (do not distinguish sleep stages). # Just use all data. sleep_label = 'ALL' for feature_name, data in features.items(): new_features = aggregate_channel_segment_features(data, feature_name, aggregate_segments=aggregate_segments, aggregate_channels=aggregate_channels, channel_labels=channel_labels, postfix=sleep_label) global_features.update(new_features) else: # Combine all segments to get global features (distinguish sleep stages). # Get the sleep labels (class numbers) of the segments. segment_labels = sleep_stages.segment_labels(segment_start_times, segment_end_times) # Use all sleep stages, except no_label to capture the entire recording. sleep_label = 'ALL' sleep_mask = ~np.isnan(segment_labels) for feature_name, data in features.items(): data = data[:, sleep_mask] new_features = aggregate_channel_segment_features(data, feature_name, aggregate_segments=aggregate_segments, aggregate_channels=aggregate_channels, channel_labels=channel_labels, postfix=sleep_label) global_features.update(new_features) # Loop over sleep stages. for sleep_label, label_number in sleep_stages.class_mapping.items(): if np.isnan(label_number): # Skip NaNs (no_label segments). continue # Combine the segments per sleep stage. sleep_mask = segment_labels == label_number for feature_name, data in features.items(): data = data[:, sleep_mask] new_features = aggregate_channel_segment_features(data, feature_name, aggregate_segments=aggregate_segments, aggregate_channels=aggregate_channels, channel_labels=channel_labels, postfix=sleep_label) global_features.update(new_features) return global_features
[docs]def prepare_postfix(postfix): """ Helper function to prepare a postfix. Add an underscore if the postfix does not start with it. Returns an empty string if postfix is None. Args: postfix (str or None): postfix. Returns: postfix (str): postfix strating with underscore or an empty string if input is None. """ if postfix is not None: if not postfix.startswith('_'): postfix = '_{}'.format(postfix) else: postfix = '' return postfix
[docs]def preprocess_segment(seg, fs, filter_specification=None, artefact_criteria=None, demean=True, preprocess_fun=None): """ Preprocess an EEG segment. Args: seg (np.ndarray): EEG signal with dimensions (channels, time). fs (float): sample frequency. filter_specification (str, FilterBase or None): if a string, it is interpreted as a filter_name of a saved filter (see nnsa.filter_saved_filter()). If a FilterBase object, the objects filtfilt() function is used to do the filtering. If None, no filtering is applied to the segment. artefact_criteria (dict): parameters for artefact detection, see detect_artefact_signals(). If None, no artefact detection is performed (exclude mask contains only False). demean (bool, optional): if True, the mean of the segment is subtracted (per channel). If False, the mean is not subtracted. Defaults to True. preprocess_fun (function): function evaluated on seg and fs and outputting processed seg. Returns: seg (np.ndarray): processed EEG signal with dimensions (channels, time). exclude_mask (np.ndarray): boolean array of shape (n_channels, ) where True values correspond to channels to exclude. """ # Filter the segment signal if requested. if filter_specification is not None: seg = filter_signal(seg, filter_specification, fs=fs, axis=-1) # Demean if requested. if demean: seg = seg - np.nanmean(seg, axis=-1, keepdims=True) # Channel exclusion. if artefact_criteria is not None: exclude_mask = detect_artefact_signals(seg, axis=-1, demean=False, keepdims=False, **artefact_criteria) else: exclude_mask = np.full(len(seg), fill_value=False) # Preprocess fun. if preprocess_fun is not None: seg = preprocess_fun(seg, fs) return seg, exclude_mask
def select_segments(aci=None, is_novelty=None, novelty_score=None, aci_thres=None, novelty_score_thres=None, verbose=0): """ Routine for selecting segments, by keeping segments in selectors with lower than relative cutoff values (see code). Returns: keep_mask (np.ndarray): boolean array indicating selected segments or None if nothing has been selected. """ keep_mask = None if aci_thres is not None and aci is not None: keep_i = aci <= aci_thres if verbose: print('{:.2f}% segments with too high ACI.'.format(100 - np.mean(keep_i)*100)) if keep_mask is None: keep_mask = keep_i else: keep_mask = keep_mask & keep_i if novelty_score_thres is not None: if novelty_score_thres == 'auto': if is_novelty is None: raise ValueError('Pass `is_novelty` if novelty_score_thres is "auto".') keep_i = ~is_novelty else: keep_i = novelty_score <= novelty_score_thres if verbose: print('{:.2f}% segments with too high novelty score.'.format(100 - np.mean(keep_i) * 100)) if keep_mask is None: keep_mask = keep_i else: keep_mask = keep_mask & keep_i return keep_mask