Source code for nnsa.feature_extraction.statistics

"""
Module containing statistics-related functions and classes.
"""
import sys

import h5py
import numpy as np
import pyprind
import scipy.stats

from nnsa.feature_extraction.result import ResultBase
from nnsa.parameters.parameters import ClassWithParameters, Parameters
from nnsa.feature_extraction.common import check_multichannel_data_matrix, preprocess_segment
from nnsa.utils.segmentation import compute_n_segments, segment_generator
from nnsa.utils.config import HORIZONTAL_RULE
from nnsa.utils.other import enumerate_label


__all__ = [
    'SignalStats',
    'SignalStatsResult',
]


[docs]class SignalStats(ClassWithParameters): """ Class for computing a set of statistical parameters of a signal. Main method: signal_stats(). Args: see nnsa.ClassWithParameters Examples: >>> fs = 256 >>> np.random.seed(43) >>> x = np.random.normal(loc=5, scale=10, size=(8, fs*300)) >>> ss = SignalStats(artefact_criteria=None) >>> print(type(ss.parameters).__name__) Parameters >>> result = ss.signal_stats(x, fs=fs, verbose=0) >>> print(type(result).__name__) SignalStatsResult # Some stats of 3nd channel (index 2), 5th segment (index 4). >>> print(result.stats['std'][2, 4]) 9.96795466696633 >>> print(result.stats['skewness'][2, 4]) 0.02042382947932932 >>> print(result.stats['kurtosis'][2, 4]) 0.028582535331379777 # Average std across all segments and all channels. >>> print(result.stats['std'].mean()) 10.00292371308106 """
[docs] @staticmethod def default_parameters(): """ Return the default parameters as a dictionary. Returns: (nnsa.Parameters): a default set of parameters for the object. """ # Parameters for segmentation of EEG recording into small time segments/epochs. segmentation = Parameters(**{ # Segment length in seconds: 'segment_length': 30, # Overlap in segments in seconds: 'overlap': 0, }) # Parameters for artefact detection/exclusion, is applied to each segment, see # nnsa.artefacts.artefact_detection.detect_artefact_signals(). # If None, no artefact detection is done. artefact_criteria = None # Specify whether to demean (subtract mean) each segment or not. demean_segment = False # Specify a filter for filtering the segments, see nnsa.preprocessing.filter.filter_signal(). # Specify None to not filter the segments. segment_filter_specification = None pars = { 'segmentation': segmentation, 'artefact_criteria': artefact_criteria, 'demean_segment': demean_segment, 'segment_filter_specification': segment_filter_specification, } return Parameters(**pars)
[docs] def signal_stats(self, data_matrix, fs, channel_labels=None, verbose=1): """ Compute signal stats on multichannel data. Analysis pipeline ported from MATLAB code designed by Ofelie De Wel and Mario Lavanga. Pipeline constsist of 4 steps: 1) Segmentation 2) Optional filtering 3) Artefact exclusion 4) Stats computation Args: data_matrix (np.ndarray): see check_multichannel_data_matrix(). fs (float): sample frequency of the signals. channel_labels (list of str, optional): see check_multichannel_data_matrix(). verbose (int, optional): verbose level. Defaults to 1. Returns: (nnsa.SignalStatsResult): object containing statistics per segment per channel. """ # Check input. data_matrix, channel_labels = check_multichannel_data_matrix(data_matrix, channel_labels) if verbose > 0: print(HORIZONTAL_RULE) print('Running signal_stats with parameters:') print(self.parameters) # Extract some parameters. seg_pars = self.parameters['segmentation'] demean_segment = self.parameters['demean_segment'] filter_specification = self.parameters['segment_filter_specification'] artefact_criteria = self.parameters['artefact_criteria'] n_channels = data_matrix.shape[0] dtype = data_matrix.dtype # Segment the data in the time axis (create a generator). seg_generator = segment_generator(data_matrix, segment_length=seg_pars['segment_length'], overlap=seg_pars['overlap'], fs=fs, axis=-1) n_segments = compute_n_segments(data_matrix, segment_length=seg_pars['segment_length'], overlap=seg_pars['overlap'], fs=fs, axis=-1) # Initialize progress bar. bar = pyprind.ProgBar(n_segments, stream=sys.stdout) # Loop over segments. all_stats = np.zeros((8, n_channels, n_segments), dtype=dtype) for i_seg, seg in enumerate(seg_generator): # Preprocess segment (demean, optionally filter, find channels to exclude). seg, exclude_mask = preprocess_segment(seg, fs, demean=demean_segment, filter_specification=filter_specification, artefact_criteria=artefact_criteria) # Loop over channels. for j_channel, excl, signal in zip(range(n_channels), exclude_mask, seg): # If channels is to be excluded, use NaN to indicate artefacted channel. if excl: all_stats[:, j_channel, i_seg] = np.nan else: # Compute stats. x_diff = np.diff(signal) * fs # Per second. stats_i = [ np.std(signal, ddof=1), # Use ddof=1 for MATLAB compatibility. np.mean(np.abs(signal)), np.dot(signal, signal) / len(signal), scipy.stats.skew(signal), scipy.stats.kurtosis(signal), np.max(np.abs(signal)), np.max(np.abs(x_diff)), np.mean(np.abs(x_diff)) ] all_stats[:, j_channel, i_seg] = stats_i # Update progress bar. if verbose > 0: bar.update() # Save the stats in a dictionary. stats = { 'std': all_stats[0], 'mean_abs': all_stats[1], 'mean_squared': all_stats[2], 'skewness': all_stats[3], 'kurtosis': all_stats[4], 'max_amp': all_stats[5], 'max_der': all_stats[6], 'line_length': all_stats[7], } # Return as a SignalStatsResult object. return SignalStatsResult(stats, channel_labels=channel_labels, algorithm_parameters=self.parameters)
[docs]class SignalStatsResult(ResultBase): """ High-level interface for processing signal statistics as computed by nnsa.SignalStats(). Args: stats (dict): dict where each entry is some statistic computed per channel per segment. The entries are matrices with size (n_channels, n_segments). algorithm_parameters (nnsa.Parameters): see ResultBase. channel_labels (list of str, optional): labels of the channels corresponding to the rows of the values in mse. If None, default labels will be created. Default is None. data_info (str, optional): see ResultBase. segment_start_times (np.ndarray, optional): see ResultBase. segment_end_times (np.ndarray, optional): see ResultBase. fs (flaot, optional): see ResultBase. """ def __init__(self, stats, algorithm_parameters, channel_labels=None, data_info=None, segment_start_times=None, segment_end_times=None, fs=None): super().__init__(algorithm_parameters=algorithm_parameters, data_info=data_info, segment_start_times=segment_start_times, segment_end_times=segment_end_times, fs=fs) # Input check. data_shape = next(iter(stats.items()))[1].shape if channel_labels is None: channel_labels = enumerate_label(data_shape[1], label='Channel') elif len(channel_labels) != data_shape[0]: raise ValueError('Length of channel_labels ({}) does not correspond to the shape of the data {}.' .format(len(channel_labels), data_shape)) # Store variables that are not already stored by the parent class (ResultBase). self.stats = stats self.channel_labels = channel_labels @property def num_segments(self): """ Return the number of segments. Returns: (int): number of segments. """ return next(iter(self.stats.items()))[1].shape[-1] def _merge(self, other): """ See ResultBase. """ # Check if the channel labels of self and other are the same. if self.channel_labels != other.channel_labels: raise ValueError('Cannot merge results with different channel labels.') for key, val in self.stats.items(): val_other = other.stats[key] self.stats[key] = np.concatenate((val, val_other), axis=-1) @staticmethod def _read_from_hdf5(filepath): """ Read result from hdf5 file into a SignalStatsResult class. Args: filepath (str): see ResultBase._read_from_hdf5(). Returns: result (nnsa.SignalStatsResult): instance of SignalStatsResult containing the SignalStats result. """ # Read standard hdf5 header (use the ResultBase method). algorithm_parameters, data_info, segment_start_times, segment_end_times, fs, time_offset = \ ResultBase._read_hdf5_header(filepath)[1:] # Re-open the file and read the rest of the file. with h5py.File(filepath, 'r') as f: # Read array data. stats = dict() s_ds = f['stats'] for key in s_ds.keys(): stats[key] = s_ds[key][:] # Read non-array data. channel_labels = [label.decode() for label in s_ds.attrs['channel_labels']] # Create a result object. result = SignalStatsResult(stats=stats, algorithm_parameters=algorithm_parameters, channel_labels=channel_labels, data_info=data_info, segment_start_times=segment_start_times, segment_end_times=segment_end_times, fs=fs) return result def _write_to_hdf5(self, filepath): """ Write the contents of the object to an hdf5 file. Args: filepath (str): see ResultBase._write_to_hdf5(). """ # Write standard hdf5 header (use the ResultBase method). self._write_hdf5_header(filepath) # Append attributes to the hdf5 file. with h5py.File(filepath, 'a') as f: # Write array data. for key, data in self.stats.items(): f.create_dataset('stats/{}'.format(key), data=data) # Write non-array data as attributes to the 'stats' group. # Convert strings to np.string_ type as recommended for compatibility. f['stats'].attrs['channel_labels'] = [np.string_(label) for label in self.channel_labels]