"""
Module containing statistics-related functions and classes.
"""
import sys
import h5py
import numpy as np
import pyprind
import scipy.stats
from nnsa.feature_extraction.result import ResultBase
from nnsa.parameters.parameters import ClassWithParameters, Parameters
from nnsa.feature_extraction.common import check_multichannel_data_matrix, preprocess_segment
from nnsa.utils.segmentation import compute_n_segments, segment_generator
from nnsa.utils.config import HORIZONTAL_RULE
from nnsa.utils.other import enumerate_label
__all__ = [
'SignalStats',
'SignalStatsResult',
]
[docs]class SignalStats(ClassWithParameters):
"""
Class for computing a set of statistical parameters of a signal.
Main method: signal_stats().
Args:
see nnsa.ClassWithParameters
Examples:
>>> fs = 256
>>> np.random.seed(43)
>>> x = np.random.normal(loc=5, scale=10, size=(8, fs*300))
>>> ss = SignalStats(artefact_criteria=None)
>>> print(type(ss.parameters).__name__)
Parameters
>>> result = ss.signal_stats(x, fs=fs, verbose=0)
>>> print(type(result).__name__)
SignalStatsResult
# Some stats of 3nd channel (index 2), 5th segment (index 4).
>>> print(result.stats['std'][2, 4])
9.96795466696633
>>> print(result.stats['skewness'][2, 4])
0.02042382947932932
>>> print(result.stats['kurtosis'][2, 4])
0.028582535331379777
# Average std across all segments and all channels.
>>> print(result.stats['std'].mean())
10.00292371308106
"""
[docs] @staticmethod
def default_parameters():
"""
Return the default parameters as a dictionary.
Returns:
(nnsa.Parameters): a default set of parameters for the object.
"""
# Parameters for segmentation of EEG recording into small time segments/epochs.
segmentation = Parameters(**{
# Segment length in seconds:
'segment_length': 30,
# Overlap in segments in seconds:
'overlap': 0,
})
# Parameters for artefact detection/exclusion, is applied to each segment, see
# nnsa.artefacts.artefact_detection.detect_artefact_signals().
# If None, no artefact detection is done.
artefact_criteria = None
# Specify whether to demean (subtract mean) each segment or not.
demean_segment = False
# Specify a filter for filtering the segments, see nnsa.preprocessing.filter.filter_signal().
# Specify None to not filter the segments.
segment_filter_specification = None
pars = {
'segmentation': segmentation,
'artefact_criteria': artefact_criteria,
'demean_segment': demean_segment,
'segment_filter_specification': segment_filter_specification,
}
return Parameters(**pars)
[docs] def signal_stats(self, data_matrix, fs, channel_labels=None, verbose=1):
"""
Compute signal stats on multichannel data.
Analysis pipeline ported from MATLAB code designed by Ofelie De Wel and Mario Lavanga.
Pipeline constsist of 4 steps:
1) Segmentation
2) Optional filtering
3) Artefact exclusion
4) Stats computation
Args:
data_matrix (np.ndarray): see check_multichannel_data_matrix().
fs (float): sample frequency of the signals.
channel_labels (list of str, optional): see check_multichannel_data_matrix().
verbose (int, optional): verbose level.
Defaults to 1.
Returns:
(nnsa.SignalStatsResult): object containing statistics per segment per channel.
"""
# Check input.
data_matrix, channel_labels = check_multichannel_data_matrix(data_matrix, channel_labels)
if verbose > 0:
print(HORIZONTAL_RULE)
print('Running signal_stats with parameters:')
print(self.parameters)
# Extract some parameters.
seg_pars = self.parameters['segmentation']
demean_segment = self.parameters['demean_segment']
filter_specification = self.parameters['segment_filter_specification']
artefact_criteria = self.parameters['artefact_criteria']
n_channels = data_matrix.shape[0]
dtype = data_matrix.dtype
# Segment the data in the time axis (create a generator).
seg_generator = segment_generator(data_matrix, segment_length=seg_pars['segment_length'],
overlap=seg_pars['overlap'], fs=fs, axis=-1)
n_segments = compute_n_segments(data_matrix, segment_length=seg_pars['segment_length'],
overlap=seg_pars['overlap'], fs=fs, axis=-1)
# Initialize progress bar.
bar = pyprind.ProgBar(n_segments, stream=sys.stdout)
# Loop over segments.
all_stats = np.zeros((8, n_channels, n_segments), dtype=dtype)
for i_seg, seg in enumerate(seg_generator):
# Preprocess segment (demean, optionally filter, find channels to exclude).
seg, exclude_mask = preprocess_segment(seg, fs,
demean=demean_segment,
filter_specification=filter_specification,
artefact_criteria=artefact_criteria)
# Loop over channels.
for j_channel, excl, signal in zip(range(n_channels), exclude_mask, seg):
# If channels is to be excluded, use NaN to indicate artefacted channel.
if excl:
all_stats[:, j_channel, i_seg] = np.nan
else:
# Compute stats.
x_diff = np.diff(signal) * fs # Per second.
stats_i = [
np.std(signal, ddof=1), # Use ddof=1 for MATLAB compatibility.
np.mean(np.abs(signal)),
np.dot(signal, signal) / len(signal),
scipy.stats.skew(signal),
scipy.stats.kurtosis(signal),
np.max(np.abs(signal)),
np.max(np.abs(x_diff)),
np.mean(np.abs(x_diff))
]
all_stats[:, j_channel, i_seg] = stats_i
# Update progress bar.
if verbose > 0:
bar.update()
# Save the stats in a dictionary.
stats = {
'std': all_stats[0],
'mean_abs': all_stats[1],
'mean_squared': all_stats[2],
'skewness': all_stats[3],
'kurtosis': all_stats[4],
'max_amp': all_stats[5],
'max_der': all_stats[6],
'line_length': all_stats[7],
}
# Return as a SignalStatsResult object.
return SignalStatsResult(stats, channel_labels=channel_labels, algorithm_parameters=self.parameters)
[docs]class SignalStatsResult(ResultBase):
"""
High-level interface for processing signal statistics as computed by nnsa.SignalStats().
Args:
stats (dict): dict where each entry is some statistic computed per channel per segment.
The entries are matrices with size (n_channels, n_segments).
algorithm_parameters (nnsa.Parameters): see ResultBase.
channel_labels (list of str, optional): labels of the channels corresponding to the rows of the values in mse.
If None, default labels will be created.
Default is None.
data_info (str, optional): see ResultBase.
segment_start_times (np.ndarray, optional): see ResultBase.
segment_end_times (np.ndarray, optional): see ResultBase.
fs (flaot, optional): see ResultBase.
"""
def __init__(self, stats, algorithm_parameters, channel_labels=None, data_info=None,
segment_start_times=None, segment_end_times=None, fs=None):
super().__init__(algorithm_parameters=algorithm_parameters, data_info=data_info,
segment_start_times=segment_start_times, segment_end_times=segment_end_times, fs=fs)
# Input check.
data_shape = next(iter(stats.items()))[1].shape
if channel_labels is None:
channel_labels = enumerate_label(data_shape[1], label='Channel')
elif len(channel_labels) != data_shape[0]:
raise ValueError('Length of channel_labels ({}) does not correspond to the shape of the data {}.'
.format(len(channel_labels), data_shape))
# Store variables that are not already stored by the parent class (ResultBase).
self.stats = stats
self.channel_labels = channel_labels
@property
def num_segments(self):
"""
Return the number of segments.
Returns:
(int): number of segments.
"""
return next(iter(self.stats.items()))[1].shape[-1]
def _merge(self, other):
"""
See ResultBase.
"""
# Check if the channel labels of self and other are the same.
if self.channel_labels != other.channel_labels:
raise ValueError('Cannot merge results with different channel labels.')
for key, val in self.stats.items():
val_other = other.stats[key]
self.stats[key] = np.concatenate((val, val_other), axis=-1)
@staticmethod
def _read_from_hdf5(filepath):
"""
Read result from hdf5 file into a SignalStatsResult class.
Args:
filepath (str): see ResultBase._read_from_hdf5().
Returns:
result (nnsa.SignalStatsResult): instance of SignalStatsResult containing the SignalStats result.
"""
# Read standard hdf5 header (use the ResultBase method).
algorithm_parameters, data_info, segment_start_times, segment_end_times, fs, time_offset = \
ResultBase._read_hdf5_header(filepath)[1:]
# Re-open the file and read the rest of the file.
with h5py.File(filepath, 'r') as f:
# Read array data.
stats = dict()
s_ds = f['stats']
for key in s_ds.keys():
stats[key] = s_ds[key][:]
# Read non-array data.
channel_labels = [label.decode() for label in s_ds.attrs['channel_labels']]
# Create a result object.
result = SignalStatsResult(stats=stats, algorithm_parameters=algorithm_parameters,
channel_labels=channel_labels, data_info=data_info,
segment_start_times=segment_start_times,
segment_end_times=segment_end_times,
fs=fs)
return result
def _write_to_hdf5(self, filepath):
"""
Write the contents of the object to an hdf5 file.
Args:
filepath (str): see ResultBase._write_to_hdf5().
"""
# Write standard hdf5 header (use the ResultBase method).
self._write_hdf5_header(filepath)
# Append attributes to the hdf5 file.
with h5py.File(filepath, 'a') as f:
# Write array data.
for key, data in self.stats.items():
f.create_dataset('stats/{}'.format(key), data=data)
# Write non-array data as attributes to the 'stats' group.
# Convert strings to np.string_ type as recommended for compatibility.
f['stats'].attrs['channel_labels'] = [np.string_(label) for label in self.channel_labels]