"""
Module for creating an object that holds the data of multiple time series, annotations and features.
"""
import copy
import datetime
import os
import warnings
import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.io
from matplotlib.patches import Rectangle
from matplotlib.ticker import FuncFormatter
from nnsa.artefacts.clean_detector_cnn import CleanDetectorCnn
from nnsa.feature_extraction.brain_age_cnn import BrainAgeSinc
from nnsa.feature_extraction.feature_sets.eeg import EegFeatures
from nnsa.preprocessing.data_cleaning import substitute_bad_channels
from nnsa.utils.arrays import get_range_mask, moving_max, moving_mean
from nnsa.utils.event_detections import get_onsets_offsets
from nnsa.utils.pkl import pickle_save, pickle_load
from nnsa.utils.testing import assert_equal
from nnsa.artefacts.artefact_detection import detect_artefact_signals, default_eeg_signal_quality_criteria, \
default_oxygen_sample_quality_criteria, detect_anomalous_channels
from nnsa.containers.time_series import TimeSeries
from nnsa.edfreadpy.io.utils import standardize_and_check_eeg_label
from nnsa.feature_extraction.brain_age import BrainAge
from nnsa.feature_extraction.discontinuity import BurstDetection
from nnsa.feature_extraction.connectivity import CoherenceGraph
from nnsa.feature_extraction.entropy import MultiScaleEntropy
from nnsa.feature_extraction.envelope import Envelope
from nnsa.feature_extraction.frequency_analysis import PowerAnalysis
from nnsa.feature_extraction.fractality import MultifractalAnalysis, LineLength
from nnsa.feature_extraction.sleep_stages import SleepStagesCnn, SleepStagesRobust
from nnsa.feature_extraction.statistics import SignalStats
from nnsa.io.hdf5 import write_dict_to_hdf5
from nnsa.preprocessing.combine_channels import combine_channels
from nnsa.preprocessing.filter import NotchIIR, Butterworth
from nnsa.preprocessing.saved_filters import get_eeg_fir_filter_a, get_eeg_fir_filter_b, get_filter
from nnsa.utils import check_directory_exists
from nnsa.utils.segmentation import segment_generator, get_segment_times, get_all_segments
from nnsa.utils.plotting import compute_linewidth, subplot_rows_columns
from nnsa.utils.conversions import convert_time_scale
from nnsa.utils.scalebars import add_scalebar
__all__ = [
'DEFAULT_EEG_CHANNELS',
'BaseDataset',
'EegDataset',
'OxygenDataset',
]
DEFAULT_EEG_CHANNELS = ['Fp1', 'Fp2', 'C3', 'C4', 'Cz', 'T3', 'T4', 'O1', 'O2']
[docs]class BaseDataset(object):
"""
Base class with common functionalities for Datasets.
Args:
*args (optional): optional positional arguments may be TimeSeries objects containing signals. Those will
be added to the Dataset upon initialization. When a list is given, this list is expected to be a list of
TimeSeries objects.
label (str): optional label for dataset.
"""
def __init__(self, *args, label=None):
self._time_series = dict() # holds TimeSeries objects of signals, label is the key.
# Append the data objects passed as optional positional arguments.
for data in args:
self.append(data)
# Set label.
if label is None:
label = self.__class__.__name__
self.label = label
def __add__(self, other):
"""
Combine two datasets into a new dataset containing the signals in both sets.
Args:
other (BaseDataset): other dataset to add.
Returns:
ds_out (BaseDataset): new dataset with all signals.
Raises:
ValueError if there are overlapping keys/signal labels.
"""
# Other should be a BaseDataset.
if not isinstance(other, BaseDataset):
raise TypeError('Cannot add {} to {}.'.format(type(other), type(self)))
# Create a copy.
ds_out = copy.deepcopy(self)
time_series_labels_self = list(self._time_series.keys())
time_series_labels_other = list(other._time_series.keys())
if any(lab in time_series_labels_self for lab in time_series_labels_other):
raise ValueError('Cannot add {} to {} with overlapping signal labels.'.format(type(other), type(self)))
ds_out._time_series.update(other._time_series)
return ds_out
def __array__(self):
"""
We can convert to a numpy array using numpy.array or numpy.asarray, which will call this __array__ method.
Returns:
a copy of the 2D signal array.
"""
return self.asarray()
def __contains__(self, item):
"""
Returns True if item in self. False otherwise.
Returns:
found (bool): True if item in self. False otherwise.
"""
# Try getting the item. If succeeds, return True, else return False.
try:
_ = self[item]
found = True
except KeyError:
found = False
return found
def __repr__(self):
"""
Return a comprehensive info string about this object.
Returns:
(str): a comprehensive info string about this object.
"""
info_string_parts = ['{} with label {} containing:'.format(self.__class__.__name__, self.label),
'time series:\n\t{}'.format('\n\t'.join(('{}: {}'.format(k, v) for k, v in
self._time_series.items())))]
return '\n'.join(info_string_parts)
def __iter__(self):
"""
Iterate over time series objects.
Returns:
(iterator): iterator over time series objects.
"""
return iter(self.time_series.values())
def __len__(self):
"""
Return the number of time series objects.
Returns:
(int): length.
"""
return len(self.time_series)
def __getitem__(self, item):
"""
Evaluation of self[item].
"""
return self.time_series[item]
def __mul__(self, other):
ds_out = copy.deepcopy(self)
for key in ds_out.time_series:
if isinstance(other, BaseDataset):
x = other[key]
else:
x = other
ds_out.time_series[key] = ds_out[key] * x
return ds_out
def __truediv__(self, other):
ds_out = copy.deepcopy(self)
for key in ds_out.time_series:
if isinstance(other, BaseDataset):
x = other[key]
else:
x = other
ds_out.time_series[key] = ds_out[key] / x
return ds_out
@property
def channel_labels(self):
"""
Return a list of channel labels (labels of the time series objects).
Returns:
(list): list with channel labels.
"""
return list(self.time_series.keys())
@property
def dtype(self):
"""
Return the dtype if all TimeSeries have the same dtype.
Raises an error otherwise.
Returns:
dtype: dtype common to all signals.
"""
dtype_all = [ts.dtype for ts in self]
if len(set(dtype_all)) == 1:
return dtype_all[0]
else:
raise AttributeError('{} with time series with different dtypes does not have a `dtype` attribute.'
.format(self.__class__.__name__))
@property
def fs(self):
"""
Return the sample frequency if all TimeSeries have the same sample frequency.
Raises an error otherwise.
Returns:
fs (flaot): sample frequency common to all signals.
"""
fs_all = [ts.fs for ts in self]
if len(set(fs_all)) == 1:
return fs_all[0]
else:
raise AttributeError('{} with time series of different sample frequencies does not have an `fs` attribute.'
.format(self.__class__.__name__))
@property
def shape(self):
"""
Return the shape (n_channels, n_samples).
Raises an error if not all signals have the same lengths.
"""
len_all = [len(ts) for ts in self]
if len(set(len_all)) == 1:
n_samples = len_all[0]
else:
raise AttributeError('{} with time series of different lengths does not have a `shape` attribute.'
.format(self.__class__.__name__))
n_channels = len(self)
return n_channels, n_samples
@property
def unit(self):
"""
Return the unit if all TimeSeries have the same unit.
Raises an error otherwise.
Returns:
unit (str): unit common to all signals.
"""
unit_all = [ts.unit for ts in self]
if len(set(unit_all)) == 1:
return unit_all[0]
else:
raise AttributeError('{} with time series with different units does not have a `unit` attribute.'
.format(self.__class__.__name__))
@property
def info(self):
"""
Return the info dict of first signal in dataset.
Returns:
info (dict): info dict of first signal in dataset.
"""
for ts in self:
return ts.info
@property
def time(self):
"""
Return the time array if all TimeSeries have the same length, sample frequency and time_offset.
Raises an error otherwise.
Returns:
time (np.ndaaray): time array common to all signals.
"""
if len(self) == 0:
return []
fs_all = [ts.fs for ts in self]
time_offset_all = [ts.time_offset for ts in self]
len_all = [len(ts) for ts in self]
if len(set(fs_all)) == 1 and len(set(time_offset_all)) == 1 and len(set(len_all)) == 1:
return np.arange(len_all[0])/fs_all[0] + time_offset_all[0]
else:
raise AttributeError('{} with time series of different sample frequencies, length or time offset'
' does not have a `time` attribute.'
.format(self.__class__.__name__))
@property
def time_offset(self):
"""
Return the time offset if all TimeSeries have the same time offset.
Raises an error otherwise.
Returns:
time_offset (flaot): time offset in seconds common to all signals.
"""
to_all = [ts.time_offset for ts in self]
if len(set(to_all)) == 1:
return to_all[0]
else:
raise AttributeError('{} with time series of different time offsets does not have a `time_offset` attribute.'
.format(self.__class__.__name__))
@time_offset.setter
def time_offset(self, time_offset):
"""
Set time offset of all TimeSeries in the DataSet.
Note that the TimeSeries are affected inplace.
If you don not want this, use self.apply_time_offset(time_offset, inplace=False)
"""
self.set_time_offset(time_offset, inplace=True)
@property
def time_series(self):
"""
Return the TimeSeries objects of the Dataset object.
Returns:
(dict): the TimeSeries objects in the Dataset object, where the keys are the TimeSeries labels and the
values the corresponding TimeSeries objects.
"""
return self._time_series
[docs] def append(self, data_object):
"""
Append a data object to the Dataset (in place).
Valid data objects:
TimeSeries
Args:
data_object (various): an nnsa data object. If data_object is a list, this function will recursively be
called on each item in the list such that data_object may also be a list of nnsa data objects.
"""
# Check type and append appropriately.
object_type = type(data_object)
if object_type is TimeSeries:
# Extract label.
label = data_object.label
# Check if TimeSeries with this label is already in the Dataset.
if label in self._time_series:
# Give a warning that an old signal is being replaced.
warnings.warn('\n{} with label {} replaced in {}.'.format(data_object.__class__.__name__, label,
self.label))
# Add the TimeSeries object to the Dataset.
self._time_series[label] = data_object
elif object_type is list:
# Input of append may be a list of data objects.
for i in data_object:
self.append(i)
else:
raise NotImplementedError('{} cannot append object of type {}.'.format(self.__class__.__name__,
object_type))
[docs] def set_time_offset(self, time_offset, inplace=False):
"""
Set the same time offset to all time series.
Args:
time_offset (float): time offset to set.
inplace (bool, optional): whether to set in place (True) or create a new object (False).
Defaults to False.
Returns:
ds_out (BaseDataset): new dataset with time offset set to all time series (if inplace is True).
"""
if len(self) == 0:
# No timeseries, we cannot set the time_offset.
msg = 'No TimeSeries in {}.\n\tDid not set time_offset.'.format(self)
warnings.warn(msg)
return self
if inplace:
# We will apply the changes to the current Dataset object.
ds = self
else:
# We will apply the changes to a copy of the current Dataset object and return this changed copy.
ds = copy.deepcopy(self)
# Loop over time series.
for ts in ds:
ts.time_offset = time_offset
if not inplace:
return ds
[docs] def asarray(self, channels=None, return_channel_labels=False, channels_last=False):
"""
Return an array containing (a copy of) multi-channel data with shape (num_channels, num_samples) if channels_last is False.
Args:
channels (list, optional): list with labels of the signals that are put into the output array. The
order of the labels determines the order of the signals in the output array. I.e. the
signal with label labels[i] is put in array(i, :).
If None, all signals in the EegDataset are put in the output array.
Defaults to None.
return_channel_labels (bool, optional): if True, additionally outputs the labels that correspond
with the rows of the created array.
If False, only outputs the array.
Default is False.
channels_last (bool): if False, output shape is (num_channels, num_samples),
if True, output shape is (num_samples, num_channels).
Returns:
array (np.ndarray): array with shape (num_channels, num_samples) if channels_last == False,
containing (copies of) the multi-channel data.
channels (list, optional): the labels corresponding to the rows in the output array (only
if output_channel_labels is True).
"""
# Default channels are all channels in the order that .keys() returns.
if channels is None:
channels = self.channel_labels
else:
# Check if all channels are available in the Dataset.
channels = [self._check_label(c) for c in channels]
# Check if the signals are synchronized.
if not self.is_synchronized(channels=channels):
raise ValueError('Cannot convert {} to array because signals are not synchronized.'
.format(self.label))
# Initialize output array of shape (num_channels, num_samples).
example_signal = self.time_series[channels[0]].signal
num_channels = len(channels)
num_samples = example_signal.size
array = np.zeros((num_channels, num_samples), dtype=example_signal.dtype)
# Collect signals in array.
for i, label in enumerate(channels):
# Extract TimeSeries in array.
array[i, :] = self.time_series[label].signal
if channels_last:
# Transpose to (n_samples, n_channels).
array = array.T
if return_channel_labels:
# Return array and row labels corresponding to the channels.
return array, channels
else:
# Only return array.
return array
[docs] def astype(self, *args, inplace=False, **kwargs):
if inplace:
# We will apply the changes to the current Dataset object.
ds = self
else:
# We will apply the changes to a copy of the current Dataset object and return this changed copy.
ds = copy.deepcopy(self)
for ts in ds:
ts.astype(*args, **kwargs, inplace=True)
if not inplace:
return ds
[docs] def average_signals(self, channels=None, **kwargs):
"""
Average the signals and return as a new TimeSeries.
Raises an error if the signals are not synchronized.
Args:
channels (list, optional): list with labels of the signals to be average.
If None, all signals in the Dataset are averaged.
Defaults to None.
**kwargs (optional): keyword arguments for the creation of a the new TimeSeries object.
Returns:
(nnsa.TimeSeries): new TimeSeries object containing the average signal.
"""
# By default use all channels.
if channels is None:
channels = list(self.time_series.keys())
# Test whether signals to average are synchronized.
if not self.is_synchronized(channels=channels):
raise ValueError('Signals {} are not synchronized. Cannot compute average signal.'
.format(channels))
n_channels = len(channels)
if n_channels == 1:
# No need to average.
return self.time_series[channels[0]]
all_units = []
all_info = []
# First signal forms basis.
ts0 = self.time_series[channels[0]]
sum_signals = ts0.signal.copy()
fs = ts0.fs
all_units.append(ts0.unit)
all_info.append(str(ts0))
time_offset = ts0.time_offset
pars = ts0.parameters
# Loop over other signals.
for label in channels[1:]:
ts = self.time_series[label]
sum_signals += ts.signal
all_units.append(ts.unit)
all_info.append(str(ts))
# Compute average signal.
signal_avg = sum_signals/n_channels
# Create info string for new TimeSeries.
info = {'source': 'Average of {} TimeSeries: {}'.format(n_channels, '; '.join(all_info))}
# Default arguments for new TimeSeries.
ts_kwargs = {
'label': 'average {}'.format(' + '.join(channels)),
'unit': all_units[0] if len(np.unique(all_units)) == 1 else 'a.u.',
'check_label': False,
'check_unit': True,
}
# Add parameters to the keyword arguments.
ts_kwargs.update(pars)
# Override defaults with user-specified arguments.
ts_kwargs.update(kwargs)
# Create new TimeSeries.
ts_avg = TimeSeries(signal_avg, fs, info=info, time_offset=time_offset, **ts_kwargs)
return ts_avg
[docs] def clear(self):
"""
Empty the dataset (inplace).
"""
self.time_series.clear()
[docs] def combine_channels(self, channels=None, label=None, **kwargs):
"""
Combine the channels into one signal by cutting into segments and pasting the best channels together.
See nnsa.combine_channels().
Args:
channels (list, optional): list with labels of the signals to be combined.
If None, all signals in the Dataset are combined.
Defaults to None.
label (str, optional): a label for the combined time series.
**kwargs (optional): keyword arguments for nnsa.combine_channels().
Returns:
ts (nnsa.TimeSeries): TimeSeries object containing the combined signal.
"""
# By default use all channels.
if channels is None:
channels = list(self.time_series.keys())
n_channels = len(channels)
# Create a copy of one of the time series.
ts = copy.deepcopy(self.time_series[channels[0]])
if n_channels == 1:
# If only one signal, do not combine, just return the one signal.
return ts
# Get the data matrix and combine the channels.
data_matrix = self.asarray(channels=channels)
signal = combine_channels(data_matrix, **kwargs)
# Create info string for new TimeSeries.
all_info = []
for label in channels:
ts = self.time_series[label]
all_info.append(str(ts))
info = {'source': 'Combination of {} TimeSeries: {}'.format(n_channels, '; '.join(all_info))}
# Update some variables.
if label is None:
label = 'Combined {}'.format(' + '.join(channels))
ts._signal = signal
ts._label = label
ts.info = info
return ts
[docs] def compute_power_cwt(self, freq_low=None, freq_high=None, channels=None, inplace=False, **kwargs):
"""
Compute power in specified band for each signal in the dataset.
Args:
freq_low (float, optional): lower frequency cutoff for bandpower. If None, takes lowest possible.
freq_high (float, optional): upper frequency cutoff for bandpower. If None, takes highest possible.
channels (list, optional): list with labels of the signals to apply to.
If None, all signals in the Dataset are processed.
Defaults to None.
inplace (bool, optional): if True, replaces the signals in place by their power.
If False, returns a new object.
**kwargs (optional): optional keyword arguments for compute_power_cwt().
Returns:
ds (BaseDataset): dataset with power of the signals (if inplace is False).
"""
# By default use all channels.
if channels is None:
channels = self.time_series.keys()
if inplace:
# We will apply the changes to the current Dataset object.
ds = self
else:
# We will apply the changes to a copy of the current Dataset object and return this changed copy.
ds = copy.deepcopy(self)
for label in channels:
ds.time_series[label].compute_power_cwt(freq_low=freq_low, freq_high=freq_high,
inplace=True, **kwargs)
if not inplace:
return ds
[docs] def demean(self, channels=None, verbose=1, inplace=False, **kwargs):
"""
Demean each time series in the dataset.
"""
if verbose > 0:
print('Demeaning {}...'.format(self.label))
# By default use all channels.
if channels is None:
channels = self.time_series.keys()
if inplace:
# We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries.
ds_filtered = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered
# TimeSeries.
# Create a copy of the Dataset object.
ds_filtered = copy.deepcopy(self)
# Filter the signals.
for label in channels:
ts_filtered = self.time_series[label].demean(**kwargs)
ds_filtered.time_series[label] = ts_filtered
# Only return if not in place filtering.
if not inplace:
return ds_filtered
[docs] def envelope(self, verbose=1, **kwargs):
"""
Compute the envelope of the signals.
This is a wrapper that prepares the input for Envelope.envelope() and returns
the result.
Args:
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): optional keyword arguments for the Envelope class.
Returns:
result (nnsa.EnvelopeResult): the result of the envelope computation.
Examples:
To get the envelope data as a EegDataset:
>>> signal = 4 * np.sin(np.arange(10000))
>>> ts = TimeSeries(signal=signal, fs=100, label='EEG Cz', unit='uV', info={'source': 'random'})
>>> ds = BaseDataset([ts])
>>> envelope_result = ds.envelope(verbose=0)
>>> ds_envelope = envelope_result.to_eeg_dataset()
>>> print(type(ds_envelope).__name__)
EegDataset
>>> print(ds_envelope.time_series['EEG Cz'].signal.mean())
3.9998155220627787
"""
if verbose > 0:
print('Computing envelope of {} with {} channels: {}'
.format(self.label,
len(self.time_series),
[ts.label for ts in self.time_series.values()]))
# Initialize Envelope object (pass user specified keyword arguments).
envelope = Envelope(**kwargs)
# Prepare input matrix for envelope.
data_matrix, channel_labels = self.asarray(return_channel_labels=True)
# Sample frequency of all signals is the same if self.asarray() did not raise an error.
fs = next((ts.fs for ts in self.time_series.values()))
# Run envelope.
result = envelope.envelope(data_matrix, fs=fs,
channel_labels=channel_labels, verbose=verbose)
# Add info string about data to result.
self._postprocess_result(result)
return result
[docs] def fill_nan(self, value=0, inplace=False, **kwargs):
"""
Fill NaN values in the signals.
Args:
value (float): value to fill nans with.
inplace (bool): whether to fill nans in place or create a new object.
**kwargs: for TimeSeries.fill_nan().
Returns:
ds (BaseDataset): new Dataset object containing the new signals (only if inplace is False).
"""
if inplace:
# We will apply the changes to the current Dataset object.
ds = self
else:
# We will apply the changes to a copy of the current Dataset object and return this changed copy.
ds = copy.deepcopy(self)
# Loop over TimeSeries in Dataset.
for ts in ds:
ts.fill_nan(value=value, inplace=True, **kwargs)
if not inplace:
return ds
[docs] def filter(self, filter_obj, channels=None, inplace=False, verbose=1, **kwargs):
"""
Filter each of the signals in the Dataset by calling the filter() method on each TimeSeries objects.
Args:
filter_obj (FilterBase-derived): a child class from the nnsa.FilterBase class. The fs property does not
need to be set, will be done automatically.
channels (list, optional): list with labels of the signals that should be filtered.
If None, all signals in the Dataset are filtered.
Defaults to None.
inplace (bool, optional): if True, the filtered signals will replace the original signals in this
Dataset. If False (default), the function returns a new Dataset object containing the filtered
signals, leaving the data in the current Dataset unchanged. Note that if only certain channels are
filtered, the newly created Dataset object will only contain those filtered channels.
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): keyword arguments for TimeSeries.filter().
Returns:
(BaseDataset): new Dataset object containing the filtered signals (only if inplace is False).
"""
if verbose > 0:
print('One-way filtering {} with {}...'.format(self.label,
filter_obj))
# By default use all channels.
if channels is None:
channels = self.time_series.keys()
if inplace:
# We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries.
ds_filtered = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered
# TimeSeries.
# Create a copy of the Dataset object.
ds_filtered = copy.deepcopy(self)
# Filter the signals.
for label in channels:
ts_filtered = self.time_series[label].filter(filter_obj, verbose=1 if verbose > 1 else 0,
**kwargs)
ds_filtered.time_series[label] = ts_filtered
# Only return if not in place filtering.
if not inplace:
return ds_filtered
[docs] def filtfilt(self, filter_obj, channels=None, inplace=False, verbose=1, **kwargs):
"""
Filter each of the signals in the Dataset by calling the filtfilt() method on each TimeSeries objects.
Args:
filter_obj (FilterBase-derived): a child class from the nnsa.FilterBase class. The fs property does not
need to be set, will be done automatically.
channels (list, optional): list with labels of the signals that should be filtered.
If None, all signals in the Dataset are filtered.
Defaults to None.
inplace (bool, optional): if True, the filtered signals will replace the original signals in this
Dataset. If False (default), the function returns a new Dataset object containing the filtered
signals, leaving the data in the current Dataset unchanged. Note that if only certain channels are
filtered, the newly created Dataset object will only contain those filtered channels.
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): keyword arguments for TimeSeries.filtfilt().
Returns:
(BaseDataset): new Dataset object containing the filtered signals (only if inplace is False).
"""
if verbose > 0:
print('Zero-phase filtering {} with {}...'.format(self.label,
filter_obj))
# By default use all channels.
if channels is None:
channels = self.time_series.keys()
if inplace:
# We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries.
ds_filtered = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered
# TimeSeries.
# Create a copy of the Dataset object.
ds_filtered = copy.deepcopy(self)
# Filter the signals.
for label in channels:
ts_filtered = self.time_series[label].filtfilt(filter_obj, verbose=1 if verbose > 1 else 0,
**kwargs)
ds_filtered.time_series[label] = ts_filtered
# Only return if not in place filtering.
if not inplace:
return ds_filtered
[docs] def filter_saved_filter(self, filter_name, channels=None, inplace=False, verbose=1,
**kwargs):
"""
Filter each of the signals in the dataset by calling the filter_saved_filter() method on each TimeSeries
object.
Args:
filter_name (str): see nnsa.preprocessing.filter_saved_filter().
channels (list, optional): list with labels of the signals that should be filtered.
If None, all signals in the Dataset are filtered.
Defaults to None.
inplace (bool, optional): if True, the filtered signals will replace the original signals in this
Dataset. If False (default), the function returns a new Dataset object containing the filtered
signals, leaving the data in the current Dataset unchanged. Note that if only certain channels are
filtered, the newly created Dataset object will only contain those filtered channels.
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): keyword arguments for TimeSeries.filter_saved_filter().
Returns:
(BaseDataset): new Dataset object containing the filtered signals (only if inplace is False).
"""
if verbose > 0:
print('Filtering {} with saved filter "{}"...'.format(self.label,
filter_name))
# By default use all channels.
if channels is None:
channels = self.time_series.keys()
if inplace:
# We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries.
ds_filtered = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered
# TimeSeries.
# Create a copy of the EegDataset object.
ds_filtered = copy.deepcopy(self)
# Filter all EEG signals.
for label in channels:
ts_filtered = self.time_series[label].filter_saved_filter(filter_name=filter_name,
verbose=1 if verbose > 1 else 0,
**kwargs)
ds_filtered.time_series[label] = ts_filtered
# Only return if not in place filtering.
if not inplace:
return ds_filtered
[docs] def from_data_dict(self, data_dict, **kwargs):
"""
Create a dataset based on data in a dict.
Args:
data_dict = dict(
data=data, # 2D array with shape (n_channels, n_samples)
fs=fs, # sampling frequency (Hz)
label=label, # Optional list of labels for each channel in data
filepath=filepath, # Optional filepath where the data comes from
t0=t0, # Optional time offset (time at t0) in seconds.
**kwargs: for self.__init__().
)
"""
# Check input dimensions.
data = np.asarray(data_dict['data'])
if data.ndim == 1:
# Assume one channel, change to shape (n_time, n_channels).
data = data.reshape(-1, 1)
if data.ndim != 2:
raise ValueError('`data` must have 2 dimensions. Got {} dimensions.'
.format(data.ndim))
channel_labels = data_dict.get('label', None)
if channel_labels is None:
# Create default labels.
channel_labels = ['Ch{}'.format(i+1) for i in range(min(data.shape))]
if len(data) != len(channel_labels):
# Try transpose.
data = np.transpose(data)
if len(data) != len(channel_labels):
raise ValueError('`len(data) ({}) does not equal len(channel_labels) ({})'
.format(len(data), len(channel_labels)))
# Extract more information.
fs = data_dict['fs']
time_offset = data_dict.get('t0', 0)
source = data_dict.get('filepath', 'unknown')
info = {'source': source}
# Create dataset (collection of TimeSeries).
opts = dict({'label': self.label}, **kwargs)
ds = self.__class__(**opts)
for signal, label in zip(data, channel_labels):
ts = TimeSeries(signal=signal, fs=fs, label=label, time_offset=time_offset, info=info)
ds.append(ts)
return ds
[docs] def get_non_artefact_masks(self, segment_length, overlap=0, artefact_criteria=None):
"""
For each signal get a mask indicating non artefact segments.
Args:
segment_length (float): segment length.
overlap (float, optional): overlap between segments.
Defaults to 0.
artefact_criteria (dict): dict with artefact criteria for determining whether each segment is an artefact.
Returns:
ds_out (BaseDataset): new dataset with non artefact masks for each signal.
"""
if artefact_criteria is None:
artefact_criteria = dict()
# Segment the datasets.
ds_out = BaseDataset()
for ts in self:
seg = np.asarray(list(segment_generator(ts.signal, segment_length=segment_length,
overlap=0, fs=ts.fs, axis=0)))
mask = ~detect_artefact_signals(seg, axis=-1, demean=True,
keepdims=False, **artefact_criteria)
ds_out.append(TimeSeries(signal=mask, fs=1 / (segment_length - overlap),
label=ts.label, check_label=False,
time_offset=ts.time_offset, info=ts.info))
return ds_out
[docs] def interp(self, t, channels=None, **kwargs):
"""
Linearly interpolate values.
Args:
t (np.ndarray): time instances at which to evaluate the signals.
channels (list): determines which and the order of the channels that are interpolated.
If None, uses all available channels.
**kwargs: for TimeSeries.interp().
Returns:
data (np.ndarray): 2D array with shape (n_channels, len(t)) containing the
interpolated data.
"""
if channels is None:
# All channels.
channels = list(self.time_series.keys())
# Loop over TimeSeries in Dataset.
data = []
for ch in channels:
yi = self[ch].interp(t=t, **kwargs)
data.append(yi)
data = np.array(data)
return data
[docs] def interp_nan(self, *args, **kwargs):
return self.interpolate_nan(*args, **kwargs)
[docs] def interpolate_nan(self, inplace=False, max_nan_length=None):
"""
Linearly interpolate nan values.
Args:
inplace (bool, optional): interpolate inplace (True) or return a copy (False).
Defaults to False.
max_nan_length (int or str, optional): number of maximum allowable consecutive nan
samples that we interpolate (in seconds). If 'auto', the average duration of a constant value
in the signal is taken. If None, no limit is applied, i.e. all nans are interpolated.
Defaults to None.
Returns:
ds (BaseDataset): dataset with nans interpolated (if inplace is False).
"""
if inplace:
# We will apply the changes to the current Dataset object.
ds = self
else:
# We will apply the changes to a copy of the current Dataset object and return this changed copy.
ds = copy.deepcopy(self)
# Loop over TimeSeries in Dataset.
for ts in ds:
ts.interpolate_nan(inplace=True, max_nan_length=max_nan_length)
if not inplace:
return ds
[docs] def isempty(self):
"""
Check if dataset is empty.
Returns:
(bool): True if the Dataset is empty, False if not.
"""
return len(self.time_series) == 0
[docs] def is_synchronized(self, channels=None):
"""
Check if all the TimeSeries in the Dataset are synchronized.
Args:
channels (list, optional): list with labels of signals to check.
If None, all signals are checked.
Defaults to None.
Returns:
(bool): True if all TimeSeries are in sync, False if not.
"""
if channels is None:
# All channels.
channels = list(self.time_series.keys())
# Synchronized if all TimeSeries have same signal length, fs and time_offset.
lengths_all = []
fs_all = []
time_offset_all = []
for label in channels:
lengths_all.append(len(self.time_series[label].signal))
fs_all.append(self.time_series[label].fs)
time_offset_all.append(self.time_series[label].time_offset)
return len(set(lengths_all)) == 1 and len(set(fs_all)) == 1 and len(set(time_offset_all)) == 1
[docs] def line_length(self, preprocess=False, verbose=1, **kwargs):
"""
Compute line length feature.
This is a wrapper that prepares the input for LineLength.line_length() and returns
the result.
Args:
preprocess (bool, optional): apply a default preprocessing routine on the data before feature computation
(True) or not (False).
Defaults to False.
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): optional keyword arguments for the LineLength class.
Returns:
result (nnsa.LineLengthResult): the result of the line length computation.
"""
# Preprocess.
if preprocess:
# Default processing routine.
fir_filt = get_eeg_fir_filter_b()
notch_filt = NotchIIR()
ds = self.filtfilt(fir_filt).filtfilt(notch_filt)#.resample(
#fs_new=250, method='polyphase_filtering')
else:
# Use data as it is.
ds = self
if verbose > 0:
print('Computing Line Length of {} with {} channels: {}'
.format(ds.__class__.__name__,
len(ds.time_series),
[ts.label for ts in ds.time_series.values()]))
# Initialize LineLength object (pass user specified keyword arguments).
line_length = LineLength(**kwargs)
# Prepare input matrix for line_length.
data_matrix, channel_labels = ds.asarray(return_channel_labels=True)
# Sample frequency of all signals is the same if self.asarray() did not raise an error.
fs = next((ts.fs for ts in ds.time_series.values()))
# Run line_length.
result = line_length.line_length(data_matrix, fs=fs,
channel_labels=channel_labels, verbose=verbose)
# Add info string about data to result.
ds._postprocess_result(result)
return result
[docs] def logical_and(self, other):
ds_out = copy.deepcopy(self)
for key in ds_out.time_series:
if isinstance(other, BaseDataset):
x = other[key]
else:
x = other
ds_out.time_series[key].signal = np.logical_and(ds_out[key], x)
return ds_out
[docs] def logical_or(self, other):
ds_out = copy.deepcopy(self)
for key in ds_out.time_series:
if isinstance(other, BaseDataset):
x = other[key]
else:
x = other
ds_out.time_series[key].signal = np.logical_or(ds_out[key], x)
return ds_out
[docs] def merge(self, other, inplace=False):
"""
Merge other BaseDataset object(s) into one object.
Args:
other (BaseDataset or list): list with (compatible) BaseDataset objects.
Must contain the same TimeSeries (with same labels and fs) as self.
inplace (bool, optional): if True, merges the data inplace by adding it to the current object.
If False, a new object is returned, leaving the original ones unchanged.
Defaults to False.
Returns:
out (BaseDataset): new BaseDataset object containing the merged data (if inplace is False).
"""
if not isinstance(other, (list, tuple)):
other = [other]
if inplace:
out = self
else:
out = copy.deepcopy(self)
# Loop over Datasets in other.
for item in other:
# Check class type.
if not isinstance(item, out.__class__):
raise ValueError('Object of type "{}" cannot be merged with object of type "{}".'
.format(type(item), type(out)))
# Merge the timeseries in the datasets.
for key, ts in out.time_series.items():
if key not in item.time_series:
raise ValueError('Time series {} not in other dataset with time series {}.'.format(key, item.time_series.keys()))
out.time_series[key] = out[key].merge(item[key], inplace=False)
if not inplace:
return out
[docs] def multi_scale_entropy(self, preprocess=False, verbose=1, **kwargs):
"""
Perform a multi-scale entropy (mse) analysis.
This is a wrapper that prepares the input for MultiScaleEntropy.multi_scale_entropy() and returns the result.
Args:
preprocess (bool, optional): apply a default preprocessing routine on the data before feature computation
(True) or not (False).
Defaults to False.
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): optional keyword arguments to overrule default parameters of the MultiScaleEntropy
class.
Returns:
result (nnsa.MultiScaleEntropyResult): the result of the mse analysis.
"""
# Preprocess.
if preprocess:
# Default processing routine.
fir_filt = get_eeg_fir_filter_b()
notch_filt = NotchIIR()
ds = self.filtfilt(fir_filt).filtfilt(notch_filt).resample(
fs_new=125, method='polyphase_filtering')
else:
# Use data as it is.
ds = self
if verbose > 0:
print('Multi-scale entropy analysis of {} with {} channels: {}'
.format(ds.__class__.__name__,
len(ds.time_series),
[ts.label for ts in ds.time_series.values()]))
# Initialize MultiScaleEntropy object (updates default parameters with user specified keyword arguments).
mse = MultiScaleEntropy(**kwargs)
# Prepare input matrix for multi_scale_entropy.
data_matrix, channel_labels = ds.asarray(return_channel_labels=True)
# Sample frequency of all signals is the same if self.asarray() did not raise an error.
fs = next((ts.fs for ts in ds.time_series.values()))
# Run multi-scale entropy analysis.
result = mse.multi_scale_entropy(data_matrix, fs=fs,
channel_labels=channel_labels, verbose=verbose)
# Add info string about data to result.
ds._postprocess_result(result)
return result
[docs] def multifractal_analysis(self, verbose=1, **kwargs):
"""
Perform a multifractal analysis (mfa).
This is a wrapper that prepares the input for MultifractalAnalysis.multifractal_analysis() and returns
the result.
Args:
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): optional keyword arguments for the MultifractalAnalysis class.
Returns:
result (nnsa.MultifractalAnalysisResult): the result of the mfa analysis.
"""
if verbose > 0:
print('Multifractal analysis of {} with {} channels: {}'
.format(self.label,
len(self.time_series),
[ts.label for ts in self.time_series.values()]))
# Initialize MultifractalAnalysis object (pass user specified keyword arguments).
mfa = MultifractalAnalysis(**kwargs)
# Prepare input matrix for multifractal_analysis.
data_matrix, channel_labels = self.asarray(return_channel_labels=True)
# Sample frequency of all signals is the same if self.asarray() did not raise an error.
fs = next((ts.fs for ts in self.time_series.values()))
# Run multifractal_analysis.
result = mfa.multifractal_analysis(data_matrix, fs=fs,
channel_labels=channel_labels, verbose=verbose)
# Add info string about data to result.
self._postprocess_result(result)
return result
[docs] def normalize(self, how='zscore', channels=None, inplace=False, verbose=0, **kwargs):
"""
Normalize the data.
Args:
how (str): how to normalize. Choose from:
- "zscore"
channels (list, optional): which channels to normalize.
If None, all channels are normalized.
inplace (bool, optional): whether to normalize inplace (True) or not.
verbose (int, optional): verbosity level.
**kwargs (dict): for TimeSeries.normalize().
Returns:
(BaseDataset): new Dataset object (only returned if inplace is False).
"""
if verbose > 0:
print('Normalizing {} ({})...'.format(self.label,
how))
# By default use all channels.
if channels is None:
channels = self.time_series.keys()
if inplace:
# We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries.
ds_out = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered
# TimeSeries.
# Create a copy of the Dataset object.
ds_out = copy.deepcopy(self)
# Normalize the signals.
for label in channels:
ts_out = self.time_series[label].normalize(how=how, **kwargs)
ds_out.time_series[label] = ts_out
# Only return if not inplace.
if not inplace:
return ds_out
[docs] def notch_filt(self, f0=50, channels=None, inplace=False, verbose=1, **kwargs):
if verbose > 0:
print('Notch filtering {} at {} Hz...'.format(self.label,
f0))
# By default use all channels.
if channels is None:
channels = self.time_series.keys()
if inplace:
# We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries.
ds_filtered = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered
# TimeSeries.
# Create a copy of the Dataset object.
ds_filtered = copy.deepcopy(self)
# Filter the signals.
for label in channels:
ts_filtered = self.time_series[label].notch_filt(f0=f0, verbose=1 if verbose > 1 else 0,
**kwargs)
ds_filtered.time_series[label] = ts_filtered
# Only return if not in place filtering.
if not inplace:
return ds_filtered
[docs] def plot(self, *args, begin=None, end=None, channels=None, scale=None, relative_time=False,
time_scale='seconds', add_offsets=False, ticklabels=True, subplots=False, color=None,
legend=True, ax=None, **kwargs):
"""
Plot the data of specified channels for a specifiec time frame.
Plots in the current axis.
Args:
*args (optional): optional arguments for the plt.plot() function.
begin (float, optional): begin time in seconds (incl. offset). If None, plots from the first sample.
Defaults to None.
end (float, optional): end time in seconds or None to specify the end of the entire signal.
Relative to the beginning of the loaded signal ignores any offset.
Defaults to None.
channels (list, optional): list of labels specifying the signals to plot.
If None, all signals in the Dataset are plotted.
Defaults to None.
relative_time (bool, optional): if True, the time axis is relative to the start of the segment to plot.
If False, the time axis will correspond to the time in the recording.
time_scale (str, optional): the time scale to use. Choose from 'seconds', 'minutes', 'hours'.
Defaults to 'seconds'.
add_offsets (bool, optional): if True, adds y offsets, based on signal SD, to each of the channels for easier
visualization. If False, does not apply any offsets. If a float, adds this offset to each channel.
ticklabels (bool, optional): if add_offsets is True, set ticklabels to True to put the signal
labels as yticks. Set to False to keep numeric y-labels.
subplots (bool, optional): if True, plots each signal in a separate subplot. If False, plots all
signals in `ax``.
Defaults to False.
color (list or dict, optional): list or dictionary woth colors for each TimeSeries in the Dataset.
legend (bool, optional): include legend (True) or not (False) if subplots is False.
Defaults to True.
ax (plt.Axes, optional): axes object to plot in. If None, plots in the current axes.
If subplots is True, always plots in a new figure.
Defaults to None.
**kwargs (optional): optional keyword arguments for the plt.plot() function.
Returns:
ax (plt.axes): axes instance.
"""
# By default, plot all channels.
if channels is None:
channels = list(self.time_series.keys())
# Loop over channels and plot.
if subplots:
nrows, ncols = subplot_rows_columns(len(channels))
fig, axes = plt.subplots(nrows, ncols, tight_layout=True, sharex='all')
axes = np.reshape([axes], -1)
elif ax is None:
# Current axes.
ax = plt.gca()
axes = [ax]
else:
axes = [ax]
y_offset = 0
yticks = []
yticklabels = []
for i, label in enumerate(channels):
if subplots:
ax = axes[i]
else:
ax = axes[0]
if isinstance(color, list):
col = color[i]
elif isinstance(color, dict):
col = color[label]
else:
col = color
ts_i = self.time_series[label]
if add_offsets is not False:
ts_i = ts_i + y_offset
# Collect ticks.
yticks.append(y_offset)
yticklabels.append(ts_i.label)
# Add offset for next channel.
if isinstance(add_offsets, (int, float)):
y_offset_add = add_offsets
else:
# Determine offset from SD.
y_offset_add = np.std(ts_i)*3
y_offset -= y_offset_add
ts_i.plot(*args, begin=begin, end=end, time_scale=time_scale,
ax=ax, color=col, **kwargs)
# Figure makeup.
if not subplots:
plt.xlabel('Time ({})'.format(time_scale))
plt.title('{}'.format(self.label))
if add_offsets and ticklabels:
ax.set_yticks(yticks)
ax.set_yticklabels(yticklabels)
else:
plt.ylabel('Signal (a.u.)')
if legend:
plt.legend(loc='upper right')
return ax
[docs] def power_analysis(self, verbose=1, **kwargs):
"""
Perform a Fourier based power analysis.
This is a wrapper that prepares the input for PowerAnalysis.power_analysis() and returns the result.
Args:
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): optional keyword arguments to overrule default parameters of the PowerAnalysis class.
Returns:
result (nnsa.PowerAnalysisResult): the result of the power analysis.
"""
if verbose > 0:
print('Power analysis of {} with {} channels: {}'
.format(self.label,
len(self.time_series),
[ts.label for ts in self.time_series.values()]))
# Initialize PowerAnalysis object (updates default parameters with user specified keyword arguments).
power_analysis = PowerAnalysis(**kwargs)
# Prepare input matrix.
data_matrix, channel_labels = self.asarray(return_channel_labels=True)
# Sample frequency of all signals is the same if self.asarray() did not raise an error.
fs = next((ts.fs for ts in self.time_series.values()))
# Assume the units are the same for all signals.
unit = next((ts.unit for ts in self.time_series.values()))
# Run power_analysis.
result = power_analysis.power_analysis(data_matrix, fs=fs,
channel_labels=channel_labels, unit=unit,
verbose=verbose)
# Add info string about data to result.
self._postprocess_result(result)
return result
[docs] def nan_to_num(self, channels=None, inplace=False, **kwargs):
"""
Apply numpy's nan_to_num function to each timeseries (e.g. to replace nans by zeros).
Args:
channels (list): see self.transform.
inplace (bool): see self.transform.
**kwargs (optional): kwargs for self.transform.
"""
return self.transform(fun=lambda x: np.nan_to_num(x, **kwargs),
channels=channels, inplace=inplace)
[docs] def remove(self, channel, inplace=False, verbose=1):
"""
Remove a signal from the dataset.
Args:
channel (str): label of the channel to remove.
inplace (bool, optional): If True, the channel is removed directly from the EegDataset object itself.
If False, the function returns a copy of the original EegDataset object in which the specified channel
is removed, leaving the data in the current EegDataset unchanged.
Default is False.
verbose (int, optional): verbose level.
Defaults to 1.
Returns:
(BaseDataset): new Dataset object containing the same signals, except for the removed channel (only if
inplace is False).
"""
# Check channel.
label = self._check_label(channel)
if inplace:
# We will remove the channel from the current Dataset object.
ds = self
else:
# We will remove the channel from a copy of the current Dataset object and return this copy.
# Create a copy of the Dataset object.
ds = copy.deepcopy(self)
del ds.time_series[label]
if verbose > 0:
print('Removed "{}" from {}.'.format(label, self.label))
# Only return if not in place removing.
if not inplace:
return ds
[docs] def remove_artefact_channels(self, inplace=False, **kwargs):
"""
Remove channels that do not meet the quality criteria.
Args:
inplace (bool, optional): If True, the channel is removed directly from the Dataset object itself.
If False, the function returns a copy of the original Dataset object in which the artefacted channels
are removed, leaving the data in the current Dataset unchanged.
Default is False.
**kwargs (optional): optional keyword arguments for overruling default signal quality criteria (see
nnsa.artefacts.artefact_detection.detect_artefact_signals()).
Returns:
(EegDataset): new Dataset object containing the same signals, except for the removed channels (only if
inplace is False).
"""
if inplace:
# We will remove the to be excluded channels from the current Dataset object.
ds = self
else:
# We will remove the to be excluded channels from a copy of the current Dataset object and return this
# copy.
ds = copy.deepcopy(self)
# Loop over channels.
channels = list(ds.time_series.keys())
for label in channels:
# Extract channel.
ts = ds.time_series[label]
# Check if the signal quality meets the signal criteria.
exclude = detect_artefact_signals(ts.signal, **kwargs)
if exclude:
# Remove the channel.
ds.remove(label, inplace=True)
# Only return if not in place removing.
if not inplace:
return ds
[docs] def remove_artefacts(self, inplace=False, **kwargs):
"""
Replace samples that are artefacts by np.nan
Args:
inplace (bool, optional): If True, the samples are replaced directly from the Dataset object itself.
If False, the function returns a copy of the original Dataset object in which the artefacted samples
are replaced, leaving the data in the current Dataset unchanged.
Default is False.
**kwargs (optional): optional keyword arguments for overruling default sample quality criteria (see
nnsa.artefacts.artefact_detection.detect_artefact_samples()).
Returns:
(BaseDataset): new Dataset object containing the same signals, but with artefacted samples changed to np.nan
(only returned if inplace is False).
"""
if inplace:
# We will apply the changes to the current Dataset object.
ds = self
else:
# We will apply the changes to a copy of the current Dataset object and return this changed copy.
ds = copy.deepcopy(self)
# Loop over channels.
channels = list(ds.time_series.keys())
for label in channels:
# Extract channel.
ts = ds.time_series[label]
# Remove artefacts from the ts inplace (copy has been made already if requested).
ts.remove_artefacts(inplace=True, **kwargs)
# Only return if not in place removing.
if not inplace:
return ds
[docs] def remove_flatlines(self, inplace=False, verbose=1, **kwargs):
"""
Replace flatline samples by np.nan
Args:
inplace (bool, optional): If True, the samples are replaced directly from the Dataset object itself.
If False, the function returns a copy of the original Dataset object in which the flatline samples
are replaced, leaving the data in the current Dataset unchanged.
Default is False.
verbose (int): verbosity level.
**kwargs (optional): optional keyword arguments for TimeSeries.remove_flatlines().
Returns:
(BaseDataset): new Dataset object containing the same signals, but with flatline samples changed to np.nan
(only returned if inplace is False).
"""
if verbose > 0:
print('Removing flatlines in {}...'.format(self.label))
if inplace:
# We will replace the original TimeSeries objects in the current EegDataset by the new TimeSeries.
ds_out = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current EegDataset by the new
# TimeSeries.
# Create a copy of the EegDataset object.
ds_out = copy.deepcopy(self)
for ts in ds_out:
ts.remove_flatlines(inplace=True, **kwargs)
# Only return if not in place referencing.
if not inplace:
return ds_out
[docs] def remove_neighborhood_artefacts(self, *args, inplace=False, verbose=1, **kwargs):
"""
Replace values by np.nan if samples in their neighborhood are nan.
Args:
*args: arguments for TimeSeries.remove_neighborhood_artefacts().
inplace (bool, optional): If True, the samples are replaced directly from the Dataset object itself.
If False, the function returns a copy of the original Dataset object in which the samples
are replaced, leaving the data in the current Dataset unchanged.
Default is False.
verbose (int): verbosity level.
**kwargs (optional): optional keyword arguments for TimeSeries.remove_neighborhood_artefacts().
Returns:
(BaseDataset): new Dataset object containing the same signals, but with specific samples changed to np.nan
(only returned if inplace is False).
"""
if verbose > 0:
print('Removing neighborhood artefacts in {}...'.format(self.label))
if inplace:
# We will replace the original TimeSeries objects in the current EegDataset by the new TimeSeries.
ds_out = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current EegDataset by the new
# TimeSeries.
# Create a copy of the EegDataset object.
ds_out = copy.deepcopy(self)
for ts in ds_out:
ts.remove_neighborhood_artefacts(*args, inplace=True, **kwargs)
# Only return if not in place referencing.
if not inplace:
return ds_out
[docs] def remove_outliers(self, *args, inplace=False, verbose=1, **kwargs):
"""
Replace outlier samples by np.nan
Args:
*args: arguments for TimeSeries.remove_outliers().
inplace (bool, optional): If True, the samples are replaced directly from the Dataset object itself.
If False, the function returns a copy of the original Dataset object in which the samples
are replaced, leaving the data in the current Dataset unchanged.
Default is False.
verbose (int): verbosity level.
**kwargs (optional): optional keyword arguments for TimeSeries.remove_outliers().
Returns:
(BaseDataset): new Dataset object containing the same signals, but with anomaly samples changed to np.nan
(only returned if inplace is False).
"""
if verbose > 0:
print('Removing outliers in {}...'.format(self.label))
if inplace:
# We will replace the original TimeSeries objects in the current EegDataset by the new TimeSeries.
ds_out = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current EegDataset by the new
# TimeSeries.
# Create a copy of the EegDataset object.
ds_out = copy.deepcopy(self)
for ts in ds_out:
ts.remove_outliers(*args, inplace=True, **kwargs)
# Only return if not in place referencing.
if not inplace:
return ds_out
[docs] def remove_values(self, *args, inplace=False, verbose=1, **kwargs):
"""
Replace values by np.nan
Args:
*args: arguments for TimeSeries.remove_values().
inplace (bool, optional): If True, the samples are replaced directly from the Dataset object itself.
If False, the function returns a copy of the original Dataset object in which the samples
are replaced, leaving the data in the current Dataset unchanged.
Default is False.
verbose (int): verbosity level.
**kwargs (optional): optional keyword arguments for TimeSeries.remove_values().
Returns:
(BaseDataset): new Dataset object containing the same signals, but with specific samples changed to np.nan
(only returned if inplace is False).
"""
if verbose > 0:
print('Removing values in {}...'.format(self.label))
if inplace:
# We will replace the original TimeSeries objects in the current EegDataset by the new TimeSeries.
ds_out = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current EegDataset by the new
# TimeSeries.
# Create a copy of the EegDataset object.
ds_out = copy.deepcopy(self)
for ts in ds_out:
ts.remove_values(*args, inplace=True, **kwargs)
# Only return if not in place referencing.
if not inplace:
return ds_out
[docs] def resample(self, fs_new, method='polyphase_filtering', channels=None, inplace=False, verbose=1, **kwargs):
"""
Resample the signals in this Dataset to new sampling frequency fs_new.
Args:
fs_new (float): new sampling frequency to resample to.
method (str, optional): see nnsa.TimeSeries.resample()
channels (list, optional): list with labels of the signals to be resampled.
If None, all signals in the Dataset are resampled.
Defaults to None.
inplace (bool, optional): if True, the resampled signals will replace the original signals in this
Dataset. If False (default), the function returns a new Dataset object containing the
resampled signals, leaving the data in the current Dataset unchanged. Note that if only certain
channels are resampled, the newly created Dataset object will only contain those resampled channels.
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): see nnsa.TimeSeries.resample()
Returns:
(BaseDataset): new Dataset object containing the resampled signals (only if inplace is False).
"""
if verbose > 0:
print('Resampling {} using {}...'.format(self.label, method))
# By default use all channels.
if channels is None:
channels = self.time_series.keys()
if inplace:
# We will replace the original TimeSeries objects in the current Dataset by the resampled TimeSeries.
ds_resampled = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current Dataset by the resampled
# TimeSeries.
# Create a copy of the Dataset object.
ds_resampled = copy.deepcopy(self)
# Resample the signals.
for label in channels:
ts_resampled = self.time_series[label].resample(
fs_new, method=method, verbose=1 if verbose > 1 else 0, **kwargs)
ds_resampled.time_series[label] = ts_resampled
# Only return if not in place resampling.
if not inplace:
return ds_resampled
[docs] @staticmethod
def read_hdf5(filepath, ds=None, begin=None, end=None, **kwargs):
"""
Read the Dataset object to a .hdf5 file.
Args:
filepath (str): filepath to read.
begin (float, optional): start second (wrt time offset).
end (float, optional): end second (wrt time offset).
ds (optional): Dataset object to append the data to. If None, creates a BaseDataset object.
**kwargs (optional): keyword arguments for TimeSeries().
Returns:
ds (nnsa.BaseDataset): Dataset object with data read from the file.
Examples:
>>> signal = np.random.rand(8, 1000)
>>> ds = BaseDataset()
>>> ts_all = [TimeSeries(signal=signal[i], fs=10, label=i) for i in range(len(signal))]
>>> ds.append(ts_all)
>>> ds.save_hdf5('testfile.hdf5')
>>> ds_read = BaseDataset.read_hdf5('testfile.hdf5')
>>> ds_read_epoch = BaseDataset.read_hdf5('testfile.hdf5', begin=10, end=20)
>>> os.remove('testfile.hdf5')
>>> _ = [assert_equal(ts.signal, ts_read.signal) for ts, ts_read in zip (ds, ds_read)]
>>> _ = [assert_equal(ts.signal, ts_read.signal) for ts, ts_read in zip (ds.extract_epoch(begin=10, end=20), ds_read_epoch)]
"""
if ds is None:
ds = BaseDataset()
with h5py.File(filepath, 'r') as f:
# Loop over datasets in HDF5 and if it is a TimeSeries, append to the Dataset.
for lab in f:
is_time_series = False
sig = f[lab]
if 'type' in sig.attrs:
if sig.attrs['type'].decode() == 'TimeSeries':
# Found time series.
is_time_series = True
elif 'fs' in sig.attrs:
# Found time series.
is_time_series = True
if is_time_series:
# Read file.
ts = TimeSeries.read_hdf5(f, label=lab, begin=begin, end=end)
ds.append(ts)
return ds
[docs] def save_hdf5(self, filepath, overwrite=False):
"""
Save the Dataset object to a .hdf5 file.
Args:
filepath (str): filepath to save to.
overwrite (bool, optional): if True, overwrites existing files. If False, raises
an error when `filepath` already exists and mode is 'w'.
Defaults to False.
"""
if not overwrite:
# Check if filepath already exists.
if os.path.exists(filepath):
raise ValueError('File "{}" already exists. Overwriting can be enabled by setting overwrite=True.')
check_directory_exists(filepath=filepath)
# Create an empty HDF5 file.
with h5py.File(filepath, mode='w') as f:
pass
# Save each timeseries to the hdf5.
for ts in self:
ts.save_hdf5(filepath, mode='a', overwrite=True)
[docs] def segment(self, segment_length, overlap=0):
"""
Return a segment generator that segments the signals in the dataset (into smaller time segments).
Only works is all signals have the same sample frequency.
Args:
segment_length (float): length of a segment in seconds.
overlap (float, optional): overlap between succesive segments in seconds.
Defaults to 0.
Yields:
(BaseDataset): Dataset object holding the signals of the next segment.
"""
# Get frequencies.
fs_all = [ts.fs for ts in self.time_series.values()]
# Verify all sample frequencies are the same.
if len(set(fs_all)) != 1:
raise ValueError('Cannot segment {} if it contains signals of different frequencies: {}.'
.format(self.__class__.__name__, fs_all))
# Create segment generators for all channels in the Dataset.
seg_generators = [ts.segment(segment_length, overlap=overlap) for ts in self.time_series.values()]
# Start infinite loop.
while True:
# Initialize Dataset.
ds = self.__class__()
# Loop over channels.
for gen in seg_generators:
# Put the next segment in the Dataset for the segment.
ds.append(next(gen))
# Yield Dataset.
yield ds
[docs] def signal_stats(self, verbose=1, **kwargs):
"""
Compute signal statistics.
This is a wrapper that prepares the input for SignalStats.signal_stats() and returns the result.
Args:
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): optional keyword arguments to overrule default parameters of the SignalStats class.
Returns:
result (nnsa.SignalStatsResult): object containing the signal stats.
"""
if verbose > 0:
print('Computing signal statistics of {} with {} channels: {}'
.format(self.label,
len(self.time_series),
[ts.label for ts in self.time_series.values()]))
# Initialize feature extraction object (updates default parameters with user specified keyword arguments).
ss = SignalStats(**kwargs)
# Prepare input matrix for signal stats.
data_matrix, channel_labels = self.asarray(return_channel_labels=True)
# Sample frequency of all signals is the same if self.asarray() did not raise an error.
fs = next((ts.fs for ts in self.time_series.values()))
# Run signal stats.
result = ss.signal_stats(data_matrix, fs=fs,
channel_labels=channel_labels, verbose=verbose)
# Add info string about data to result and update the time vectors if needed.
self._postprocess_result(result)
return result
[docs] def squeeze(self):
"""
Returns the TimeSeries if only one TimeSeries in the Dataset. Otherwise returns the Dataset itself.
"""
if len(self) == 1:
return list(self.time_series.values())[0]
else:
return self
[docs] def stepwise_moving_average(self, *args, avg_fun=np.nanmean, **kwargs):
return self.stepwise_reduce(*args, reduce_fun=avg_fun, **kwargs)
[docs] def stepwise_reduce(self, *args, channels=None, inplace=False, **kwargs):
"""
Compute stepwise values for the signals in this Dataset.
Args:
*args: inputs for nnsa.TimeSeries.stepwise_reduce()
channels (list, optional): list with labels of the signals to apply to.
If None, all signals in the Dataset are processed.
Defaults to None.
inplace (bool, optional): if True, the new signals will replace the original signals in this
Dataset. If False (default), the function returns a new Dataset object containing the
new signals, leaving the data in the current Dataset unchanged. Note that if only certain
channels are processed, the newly created Dataset object will only contain those processed channels.
**kwargs (optional): keyword arguments for nnsa.TimeSeries.stepwise_reduce()
Returns:
(BaseDataset): new Dataset object containing the new signals (only if inplace is False).
"""
# By default use all channels.
if channels is None:
channels = self.time_series.keys()
if inplace:
# We will replace the original TimeSeries objects in the current Dataset by the new TimeSeries.
ds_new = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current Dataset by the new
# TimeSeries.
# Create a copy of the Dataset object.
ds_new = copy.deepcopy(self)
# Compute the signals.
for label in channels:
ds_new.time_series[label].stepwise_reduce(
*args, inplace=True, **kwargs)
# Only return if not in place.
if not inplace:
return ds_new
[docs] def to_time_series(self, make_copy=False):
"""
If the Dataset contains only one TimeSeries, return this. Otherwise raise a ValueError.
Args:
make_copy (bool, optional): if True, the returned TimeSeries is a copy of
the TimeSeries in self. If False, no copy is made.
Returns:
ts (TimeSeries): nnsa TimeSeries object.
"""
if len(self) == 1:
ts = self.time_series[self.channel_labels[0]]
if make_copy:
ts = copy.deepcopy(ts)
return ts
else:
raise ValueError('Cannot convert {} to TimeSeries, because there are {} TimeSeries. '
'to_time_series only works is there is only a single TimeSeries.'.format(
self.label, len(self)))
def _check_label(self, label):
"""
Return the label or raise an error if the given label is not in the current dataset.
Args:
label (str): label.
Returns:
label (str): label if the given label was a valid label.
Raises:
ValueError if the label is not in the dataset.
"""
# Raise error if the specified channel is not in the dataset.
if label not in self.time_series:
raise ValueError('Signal "{}" not in {}. Signals in dataset: {}.'
.format(label, self.label, list(self.time_series.keys())))
return label
def _postprocess_result(self, result, channels=None):
"""
Postprocess the result class common to all feature extractions.
Add info string about data to result.
Update the segment time vectors of the result if time_offset of the TimeSeries was not 0.
Args:
result (nnsa.ResultBase-derived): ResultBase-derived object.
channels (list, optional): list with labels of the signals of which to return the info.
If None, all signals in the Dataset are used.
Defaults to None.
"""
# By default use all channels.
if channels is None:
channels = self.time_series.keys()
# Loop over channels.
time_offset_all = []
all_info = []
for i, label in enumerate(channels):
ts = self[label]
time_offset_all.append(ts.time_offset)
str_info = str(ts.info)
if i == 0:
# Return a string version of this info object.
all_info.append(str_info)
else:
if str_info not in all_info:
# Only add if new info.
all_info.append(str_info)
# Merge all data info strings into one.
result.data_info = ';\n '.join(all_info)
# Time offset.
if len(set(time_offset_all)) == 1:
time_offset = time_offset_all[0]
else:
raise ValueError('TimeSeries in {} not synchronized.'.format(self.label))
result.time_offset = time_offset
[docs]class EegDataset(BaseDataset):
"""
High-level interface for processing (multichannel) EEG for neonatal signal analysis.
"""
def __getitem__(self, label):
"""
Evaluate self[label].
Makes it possible to extract the time series using self['Cz'], even when the proper key is 'EEG Cz'.
"""
if label not in self.time_series:
label = self._check_label(label)
return self.time_series[label]
@property
def signal(self):
"""
Return the EEG data as a numpy array (channels, time).
Returns:
(np.ndarray): EEG data matrix.
"""
return self.asarray()
[docs] def amplitude_eeg(self, channel_1='EEG C3', channel_2='EEG C4', verbose=1, **kwargs):
"""
Compute amplitude-integrated EEG (aEEG) from single bipolar channel continuous EEG.
Use on raw EEG data.
Args:
channel_1 (str, optional): the label of the first channel.
Defaults to 'EEG C3'.
channel_2 (str, optional): the label of the second channel (will be subtracted from the first channel).
Defaults to 'EEG C4'.
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): optional keyword arguments for the AmplitudeEeg class.
Returns:
result (nnsa.AmplitudeEegResult): object containing the aEEG.
"""
# Create bipolar channel.
ts = self.create_bipolar_channel(channel_1, channel_2)
# Emulate aEEG.
return ts.amplitude_eeg(verbose=verbose, **kwargs)
[docs] def as_brain_rt(self, bp_filter=True):
"""
Preprocess the data to make it similar to the view as in BrainRT.
Creates bipolar channels and filter.
Args:
bp_filter (bool): if True, filter the data as in Brain RT (with a first order bandpass Butterworth).
Returns:
eeg_ds (nnsa.EegDataset): preprocessed dataset (new object).
"""
# Bipolar channels.
channels_1 = ['Fp2', 'C4', 'Fp2', 'T4', 'Fp1', 'C3', 'Fp1', 'T3', 'T4', 'C4', 'Cz', 'C3'][::-1]
channels_2 = ['C4', 'O2', 'T4', 'O2', 'C3', 'O1', 'T3', 'O1', 'C4', 'Cz', 'C3', 'T3'][::-1]
# In BrainRT, apparently, the derivations are the other way around (e.g., C3-C4 is actually C4-C3).
eeg_ds = self.create_bipolar_channels(
channels_1=channels_2, channels_2=channels_1)
if bp_filter:
# Filter. In BrainRT, the filter is a first order butterworth.
butter = Butterworth(fn=[0.27, 30], order=1)
eeg_ds.filter(butter, inplace=True)
return eeg_ds
[docs] def brain_age_sinc(self, verbose=1, **kwargs):
"""
Compute the brain age from the EEG using the deep Sinc network.
Use on raw EEG data.
Args:
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): optional keyword arguments for the BrainAgeSinc class.
Returns:
result (BrainAgeResult): object containing the brain age.
"""
if verbose > 1:
print('Computing brain age of {} with {} channels: {}'
.format(self.label,
len(self.time_series),
[ts.label for ts in self.time_series.values()]))
eeg_ds = self
# Initialize feature extraction object (updates default parameters with user specified keyword arguments).
brain_age_sinc = BrainAgeSinc(verbose=verbose > 1, **kwargs)
data_requirements = brain_age_sinc.data_requirements
# Reference if reference channel in EEG data (otherwise assume it is already referenced correctly).
reference_channel = data_requirements['reference_channel']
if reference_channel in eeg_ds:
eeg_ds = eeg_ds.reference(reference_channel, verbose=verbose)
# The BrainAgeSInc expect the data in a specific order, which we can access here.
channel_order = brain_age_sinc.data_requirements['channel_order']
# Extract EEG data as array with required channel order.
eeg, ch_labs = eeg_ds.asarray(
channels=channel_order, return_channel_labels=True, channels_last=True)
channel_labels = [c.replace('EEG', '').strip() for c in ch_labs] # Remove EEG from label.
fs = eeg_ds.fs
# Process.
result = brain_age_sinc.process(eeg, fs, batch_size=1000, axis=0, verbose=verbose)
# Add info string about data to result.
eeg_ds._postprocess_result(result)
return result
[docs] def brain_age_rf(self, verbose=1, **kwargs):
"""
Compute the brain age from the EEG using the random forest approach.
Use on raw EEG data.
Args:
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): optional keyword arguments for the BrainAge class.
Returns:
result (BrainAgeResult): object containing the brain age.
"""
raise DeprecationWarning
if verbose > 0:
print('Computing brain age of {} with {} channels: {}'
.format(self.label,
len(self.time_series),
[ts.label for ts in self.time_series.values()]))
# Initialize feature extraction object (updates default parameters with user specified keyword arguments).
brain_age = BrainAge(**kwargs)
# Prepare input matrix for feature extraction function.
data_matrix, channel_labels = self.asarray(return_channel_labels=True)
# Sample frequency of all EEG signals is the same if self.asarray() did not raise an error.
fs = next((ts.fs for ts in self.time_series.values()))
# Process.
result = brain_age.process(data_matrix, fs=fs, channel_labels=channel_labels,
verbose=verbose)
# Add info string about data to result.
self._postprocess_result(result)
return result
[docs] def burst_detection(self, create_bipolar_channels=True, verbose=1, **kwargs):
"""
Perform burst detection on all channels in the dataset.
This is a wrapper that prepares the input for BurstDetection.burst_detection() and returns
the result.
Args:
create_bipolar_channels (bool, optional): if True, automatically create bipolar channels according to the
method used. If False, run the algorithm on the channels as currently in the object.
Defaults to True.
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): optional keyword arguments for the BurstDetection class.
Returns:
result (nnsa.BurstDetectionResult): the result of the burst detection.
"""
# Initialize BurstDetection object (pass user specified keyword arguments).
bd = BurstDetection(**kwargs)
# Create bipolar data if requested.
if create_bipolar_channels:
if bd.parameters['method'].lower() == 'otoole':
# Create bipolar channels as required by O'Toole.
ds = self.create_bipolar_channels(['C4', 'C3', 'T4', 'C4', 'Cz', 'C3'],
['O2', 'O1', 'C4', 'Cz', 'C3', 'T3'])
elif bd.parameters['method'].lower() == 'nleo':
# Create bipolar channels that are similar to P3-P4 (as used by Palmu et a. in optimization of
# NLEO algorithm).
ds = self.create_bipolar_channels(['O1', 'C3', 'T3'],
['O2', 'C4', 'T4'])
else:
ds = self
else:
ds = self
# Checks.
for ts in ds.time_series.values():
if any('Filtered' in log for log in ts.info['processing']):
msg = "\nWARNING: Signal(s) seems to be filtered: ts.info['processing'] = {}.\n" \
"Verify that the burst detection method works with filtered data.".format(ts.info['processing'])
warnings.warn(msg)
if verbose > 0:
print('Performing burst detection on {} with {} channels: {}'
.format(ds.__class__.__name__,
len(ds.time_series),
[ts.label for ts in ds.time_series.values()]))
# Prepare input matrix for burst_detection.
data_matrix, channel_labels = ds.asarray(return_channel_labels=True)
# Sample frequency of all EEG signals is the same if self.asarray() did not raise an error.
fs = ds.fs
# Run burst_detection.
result = bd.burst_detection(data_matrix, fs=fs,
channel_labels=channel_labels, verbose=verbose)
# Add info string about data to result.
ds._postprocess_result(result)
return result
[docs] def clamp(self, threshold=250, channels=None, inplace=False, verbose=1):
"""
Clamp EEG values by applying a clamping function to reduce dynamic range.
Args:
threshold (float): threshold for the clamping function.
inplace (bool, optional): whether to apply inplace (True) or not.
verbose (int, optional): verbosity level.
Returns:
ds_out (EegDataset): new Dataset object (only returned if inplace is False).
"""
if verbose > 0:
print('Clamping values in {} at {}...'.format(self.label, threshold))
# By default use all channels.
if channels is None:
channels = self.time_series.keys()
if inplace:
# We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries.
ds_out = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered
# TimeSeries.
# Create a copy of the Dataset object.
ds_out = copy.deepcopy(self)
# Filter the signals.
for label in channels:
ts_out = self.time_series[label].clamp(threshold=threshold)
ds_out.time_series[label] = ts_out
# Only return if not in place.
if not inplace:
return ds_out
[docs] def create_bipolar_channel(self, channel_1, channel_2, label=None):
"""
Create a bipolar channel by subtracting channel_2 from channel_1, i.e. bipolar channel = channel_1 - channel_2.
Args:
channel_1 (str): the label of the first channel.
channel_2 (str): the label of the second channel (will be subtracted from the first channel).
label (str, optional): label for the new bipolar channel.
If None, a label will automatically be created from channel_1 and channel_2.
Defaults to None.
Returns:
ts_bipolar (nnsa.TimeSeries): TimeSeries object holding the bipolar channel data.
"""
if label is None:
label = '{}-{}'.format(channel_1, channel_2)
# Subtract the signals.
ts_1 = self[channel_1]
ts_2 = self[channel_2]
ts_bipolar = ts_1 - ts_2
# Set the label.
ts_bipolar.label = label
return ts_bipolar
[docs] def create_bipolar_channels(self, channels_1, channels_2, labels=None, missing_mode='error'):
"""
Create a set of bipolar channels by subtracting channels_2 from channels_1.
channels_new[i] = channels_1[i] - channels_2[i]
Args:
channels_1 (list): the labels of the first channels.
channels_2 (list): the labels of the second channels (will be subtracted from the first channel).
Must have the same length as channels_1.
labels (str, optional): label for the new bipolar channels. Must have the same length as channels_1.
If None, labels will automatically be created from channels_1 and channels_2.
Defaults to None.
missing_mode (str, optional): what to do if a specified channel is missing from the dataset. Choose from:
- 'error': raise a ValueError if a channel is specified that is not in the EEG dataset.
- 'ignore': ignore the requested bipolar channels that could not be created.
- 'warn': ignore, but display a warning.
Returns:
ds_bipolar (nnsa.EegDataset): EegDataset object holding the bipolar channels data.
"""
if not isinstance(channels_1, list):
channels_1 = [channels_1]
if not isinstance(channels_2, list):
channels_2 = [channels_2]
if len(channels_1) != len(channels_2):
raise ValueError('channels_1 and channels_2 must have same length. Got lengths {} and {}.'
.format(len(channels_1), len(channels_2)))
if labels is not None:
if len(channels_1) != len(labels):
raise ValueError('channels_1 and labels must have same length. Got lengths {} and {}.'
.format(len(channels_1), len(labels)))
else:
labels = [None] * len(channels_1)
# Initialize an empty EegDataset.
ds_bipolar = EegDataset()
# Compute bipolar channels from each pair in channels_1 and channels_2.
for i in range(len(channels_1)):
channel_1_i = channels_1[i]
channel_2_i = channels_2[i]
label_i = labels[i]
# Create the bipolar channel (TimeSeries object).
try:
ts_i = self.create_bipolar_channel(channel_1_i, channel_2_i, label=label_i)
except (KeyError, ValueError) as e:
if missing_mode == 'error':
raise e
else:
if missing_mode == 'warn':
msg = 'Could not create bipolar channel {}-{}.'.format(channel_1_i, channel_2_i)
warnings.warn(msg)
continue
# Add the TimeSeries object to the EegDataset.
ds_bipolar.append(ts_i)
return ds_bipolar
[docs] def coherence_graph(self, preprocess=False, verbose=1, **kwargs):
"""
Compute a connectivity graph, where the weights of the edges are based on the coherence between channels.
This is a wrapper that prepares the input for CoherenceGraph.coherence_graph() and returns the result.
Args:
preprocess (bool, optional): apply a default preprocessing routine on the data before feature computation
(True) or not (False).
Defaults to False.
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): optional keyword arguments to overrule default parameters of the CoherenceGraph class.
Returns:
result (nnsa.CoherenceGraphResult): an object containing the resulting adjacency matrices of the graphs.
"""
# Preprocess.
if preprocess:
# Default processing routine.
fir_filt = get_eeg_fir_filter_a()
notch_filt = NotchIIR()
ds = self.filtfilt(fir_filt).filtfilt(notch_filt).resample(
fs_new=128, method='polyphase_filtering')
else:
# Use data as it is.
ds = self
if verbose > 0:
print('Computing coherence graph of {} with {} channels: {}'
.format(ds.__class__.__name__,
len(ds.time_series),
[ts.label for ts in ds.time_series.values()]))
# Initialize PowerAnalysis object (updates default parameters with user specified keyword arguments).
coherence_graph = CoherenceGraph(**kwargs)
# Prepare input matrix.
data_matrix, channel_labels = ds.asarray(return_channel_labels=True)
# Sample frequency of all EEG signals is the same if self.asarray() did not raise an error.
fs = next((ts.fs for ts in ds.time_series.values()))
# Run coherence_graph.
result = coherence_graph.coherence_graph(data_matrix, fs=fs,
channel_labels=channel_labels,
verbose=verbose)
# Add info string about data to result.
ds._postprocess_result(result)
return result
[docs] def compute_features(self, verbose=1, chunk_size=None, **kwargs):
"""
Compute EEG feature set.
Args:
kwargs: for EegFeatures().
verbose (int): verbosity level.
chunk_size (int): number of time samples to process at once.
If None, a suitable chunk_size is chosen automatically.
Returns:
result (FeatureSetResult): feature set result.
"""
if verbose > 0:
print('Computing features of {} with {} channels: {}'
.format(self.__class__.__name__,
len(self.time_series),
[ts.label for ts in self.time_series.values()]))
# Initialize feature extraction object (updates default parameters with user specified keyword arguments).
cf = EegFeatures(**kwargs)
# Prepare input matrix for feature extraction function.
eeg, channel_labels = self.asarray(return_channel_labels=True, channels_last=False)
# Sample frequency.
fs = self.fs
# Process.
result = cf.process(eeg=eeg, fs=fs, channel_labels=channel_labels,
axis=-1, chunk_size=chunk_size, verbose=verbose)
# Add info string about data to result.
self._postprocess_result(result)
return result
[docs] def detect_artefacts(self, multi_channel_cnn=True, verbose=1):
"""
Helper function to quickly detect artefacts using a combination of methods (see code).
Apply to raw EEG.
Args:
multi_channel_cnn (bool): whether to use the multi-channel CNN (True) or not (False).
verbose (int): verbosity level.
Returns:
af_ds (EegDataset): EegDataset with artefact masks consisting of boolean values:
True at locations of artefacts, False otherwise.
"""
# Select channels and reference.
channels = ['Fp1', 'Fp2', 'C3', 'C4', 'T3', 'T4', 'O1', 'O2']
if 'Cz' in self:
eeg_ds = self.reference('Cz', inplace=False, verbose=verbose)
else:
eeg_ds = self
eeg_ds = eeg_ds.extract_channels(channels=channels, make_copy=False)
# Preprocess for artefact detection.
eeg_ds_pp = eeg_ds. \
filtfilt(NotchIIR(f0=50), verbose=verbose). \
filter(Butterworth(fn=[0.27, 30], order=1), verbose=verbose). \
resample(fs_new=128, method='interpolation', verbose=verbose)
# Detect artefacts using CNN. Note that the preprocessing above is specific to
# this method.
af_ds = eeg_ds_pp.detect_artefacts_cnn(
multi_channel=multi_channel_cnn, detect_flats=True, detect_peaks=True,
preprocess=False, verbose=verbose*2)
# Detect anomalies.
anomaly_ds = eeg_ds_pp.detect_anomalous_channels(
window=3, std_factor=8, p_trim=0.25, verbose=verbose)
# Combine artefact detection methods.
af_ds = af_ds.logical_or(anomaly_ds)
# Interpolate/upsample artefact mask to original size.
af_mask = af_ds.astype(float).interp(
t=eeg_ds.time, channels=eeg_ds.channel_labels, left=1, right=1) > 0.5
# Create a copy of the EegDataset object and fill it with the boolean masks.
af_ds = copy.deepcopy(eeg_ds)
for ts, mask in zip(af_ds, af_mask):
assert len(ts) == len(mask)
ts.parameters.update(dtype=bool)
ts.signal = mask
return af_ds
[docs] def detect_artefacts_amplitude_kaupilla(self, notch_filter=True, bp_filter=True):
"""
Detect locations of artefacts by settings thresholds on the amplitude as proposed by Kaupilla et al. 2018.
To apply to raw EEG data: set `notch_filter` and `bp_filter` to True.
Args:
notch_filter (bool): specify whether the EEG needs to be filtered.
Set to True if `eeg` is raw data.
bp_filter (bool): specify whether the EEG needs to be filtered.
Set to True if `eeg` is raw data.
Returns:
af_ds (EegDataset): EegDataset with artefact masks consisting of boolean values:
True at locations of artefacts, False otherwise.
"""
# Create a copy of the EegDataset object.
eeg_ds = copy.deepcopy(self)
if notch_filter:
# Filter powerline from EEG.
eeg_ds.filtfilt(NotchIIR(f0=50), inplace=True)
if bp_filter:
# Filter bandpass.
eeg_ds.filtfilt(get_filter('1-40'), inplace=True)
amp_kwargs = dict(
threshold_low=0.5,
threshold_high=200,
window=1
)
return eeg_ds.detect_artefacts_amplitude(**amp_kwargs)
[docs] def detect_artefacts_amplitude(
self, threshold_high, threshold_low=0, window=1, how='max'):
"""
Detect locations of artefacts by settings thresholds on the running mean/maximum
of the absolute amplitude.
Returns a boolean mask where True values indicate locations where
x < `threshold_low` or x > `threshold_high`,
where x is the moving average/maximum of the absolute signal values in a window of `window` seconds.
Args:
threshold_high (float): threshold in uV.
threshold_low (float): threshold in uV.
window (float): window for the moving average in seconds.
how (str): whether to take the moving 'mean' or 'max'.
Returns:
af_ds (EegDataset): EegDataset with artefact masks consisting of boolean values:
True at locations of artefacts, False otherwise.
"""
# Create a copy of the EegDataset object.
af_ds = copy.deepcopy(self)
# Replace the EEG data with their artefact masks.
for ts in af_ds:
# Take absolute.
x_abs = np.abs(ts.signal)
# Window in samples.
n = round(window * ts.fs)
# Detect artefacts.
if how == 'mean':
amp = moving_mean(x_abs, n=n)
elif how == 'max':
amp = moving_max(x_abs, n=n)
else:
raise ValueError(f'Invalid option how="{how}". Choose from {["mean", "max"]}.')
mask = np.logical_or(amp < threshold_low, amp > threshold_high)
assert len(ts) == len(mask)
ts.parameters.update(dtype=bool)
ts.signal = mask
return af_ds
[docs] def detect_anomalous_channels(self, window=3, std_factor=8, p_trim=0.25, verbose=0):
"""
Simple method to detect anomalous high-frequency/amplitude channels by setting a threshold on the line length.
The log line length of the (1-`p_strip`)*n_channels with lowest line lengts is used to
determine mean and std and if any channel exceeds mean + `std_factor`*std, then it
is flagged as artefact. This is done per window.
Args:
window (float): the window length (seconds) in which to compute line length.
std_factor (float): the factor for std determining the threshold (mean+`std_factor*std).
p_trim (float): the relative amount of channels to strip before computing the mean and std.
The channels with the highest line lengths are stripped/excluded.
verbose (int): verbosity level.
"""
if verbose > 0:
print('Detecting anomalous channels in {} with {} channels: {}'
.format(self.__class__.__name__,
len(self.time_series),
[ts.label for ts in self.time_series.values()]))
# To array to prepare for function.
eeg_ar = self.asarray(channels_last=False)
fs = self.fs
# Detect anomalous channels.
mask_anomaly = detect_anomalous_channels(
x=eeg_ar, fs=fs, window=window, std_factor=std_factor, p_trim=p_trim)
# Create a copy of the EegDataset object in which we will save the artefact mask.
af_ds = copy.deepcopy(self)
for ts, mask in zip(af_ds, mask_anomaly):
assert len(ts) == len(mask)
ts.parameters.update(dtype=bool)
ts.signal = mask
return af_ds
[docs] def detect_artefacts_kota(self, notch_filter=True, bp_filter=True):
"""
Simple method to detect small and large amplitude artefacts.
Apply to raw data referenced to Cz with notch_filter and bp_filter set to True.
https://doi.org/10.1016/j.pediatrneurol.2021.06.001
"""
# Create a copy of the EegDataset object.
af_ds = copy.deepcopy(self)
if notch_filter:
# Filter powerline from EEG.
af_ds.filtfilt(NotchIIR(f0=50), inplace=True)
if bp_filter:
# Filter bandpass.
af_ds.filtfilt(Butterworth(order=4, fn=[0.3, 20]), inplace=True)
# Replace the EEG data with their artefact masks.
for ts in af_ds:
# Detect artefacts.
sd = np.sqrt(moving_mean(ts.signal**2, n=1 * ts.fs)) # Mean can be assumed zero after filtering so this is a good and fast way to compute moving SD.
max = moving_max(np.abs(ts.signal), n=0.5 * ts.fs) # Total of 1 second removed when finding high value.
mask = (sd < 0.01) | (sd > 50) | (max > 300)
assert len(ts) == len(mask)
ts.parameters.update(dtype=bool)
ts.signal = mask
return af_ds
[docs] def detect_artefacts_method(self, how, **kwargs):
"""
Shortcut to any of the artefact removal functions.
"""
how = how.lower()
if how == 'raw':
return None
elif how == 'amp':
return self.detect_artefacts_amplitude_kaupilla(**kwargs)
elif how == 'kota':
return self.detect_artefacts_kota(**kwargs)
elif how == 'rfc':
return self.detect_artefacts_rfc(**kwargs)
elif how == 'cnn':
return self.detect_artefacts_cnn(**kwargs)
else:
raise ValueError('Invalid how="{}".'.format(how))
[docs] def detect_artefacts_cnn(self, preprocess=None,
detect_flats=True, detect_peaks=True,
verbose=1, **kwargs):
"""
Detect locations of artefacts based on clean EEG detection with CNN.
EEG data must be referenced to Cz. If Cz is in self, referencing will be done automatically.
If Cz is not present, referencing will be ignored (assuming the EEG data is already referenced).
Apply to raw EEG data and set `preprocess_eeg` to True.
Args:
preprocess (bool): specify whether the EEG needs to be preprocessed (filtered, resampled).
Set to True if `eeg` is raw data.
If not specified, preprocessing will be done if `fs` is not as required by the model, otherwise not.
detect_flats (bool): if True, computes moving std in short windows
and if its below a threshold, the sample is marked as artefact (since the CNN might not catch this).
detect_peaks (bool): if True, computes moving max abs amplitude in short windows
and if its above a threshold, the sample is marked as artefact (since the CNN might not catch this).
verbose (int): verbosity level.
**kwargs (dict): for CleanDetectorCnn().
Returns:
af_ds (EegDataset): EegDataset with artefact masks consisting of boolean values:
True at locations of artefacts, False otherwise.
"""
# Init.
eeg_ds = self
cd = CleanDetectorCnn(**kwargs)
data_requirements = cd.data_requirements
# Reference if reference channel in EEG data (otherwise assume it is already referenced correctly).
reference_channel = data_requirements['reference_channel']
if reference_channel in eeg_ds:
eeg_ds = eeg_ds.reference(reference_channel)
# Extract channels in specific order if needed.
if data_requirements['channel_order'] is not None:
eeg_ds = eeg_ds.extract_channels(data_requirements['channel_order'], make_copy=False)
# Extract array data from eeg_ds object.
eeg, channel_labels = eeg_ds.asarray(return_channel_labels=True, channels_last=True)
fs = eeg_ds.fs
if verbose > 1:
print('Detecting artefacts using CNN in {} with {} channels: {}'
.format(eeg_ds.__class__.__name__,
len(eeg_ds),
[ts.label for ts in eeg_ds]))
# Predict/detect clean locations.
clean_mask = cd.predict(eeg, fs=fs, preprocess=preprocess,
detect_flats=detect_flats, detect_peaks=detect_peaks,
verbose=verbose)[0]
# Clean mask to artefact mask.
af_mask = clean_mask == 0
# Mark NaNs as artefacts too.
af_mask = af_mask | np.isnan(eeg)
# Create a copy of the EegDataset object.
af_ds = copy.deepcopy(eeg_ds)
# Replace the EEG data with their artefact masks.
for ts, mask in zip(af_ds, af_mask.T):
assert len(ts) == len(mask)
ts.parameters.update(dtype=bool)
ts.signal = mask
return af_ds
[docs] def detect_artefacts_rfc(self, pma=None, preprocess=None, verbose=1, **kwargs):
"""
Detect locations of artefacts based on a sample supervised random forest classifier
from the artefact_detection package.
EEG data must be referenced to Cz. If Cz is in self, referencing will be done automatically.
If Cz is not present, referencing will be ignored (assuming the EEG data is already referenced).
To apply to raw EEG data: set `preprocess_eeg` to True.
See Also:
RfcArtefactDetector
Args:
pma (float): PMA of the neonate at time of recording. Optional dependeing on model,
see RfcArtefactDetector.predict().
preprocess (bool): specify whether the EEG needs to be preprocessed (filtered, resampled).
Set to True if `eeg` is raw data (but it should still be referenced to Cz).
If not specified, preprocessing will be done if `fs` is not 128, otherwise not.
verbose (int): verbosity level.
**kwargs (dict): for RfcArtefactDetector().
Returns:
af_ds (EegDataset): EegDataset with artefact masks consisting of boolean values:
True at locations of artefacts, False otherwise.
"""
raise DeprecationWarning
from artefact_detection.models.rfc import ArtefactDetectorRfc
ad = ArtefactDetectorRfc(**kwargs)
# Reference if Cz in dataset.
eeg_ds = self
if 'Cz' in eeg_ds:
eeg_ds = eeg_ds.reference('Cz')
# Extract array data from eeg_ds object.
eeg, channel_labels = eeg_ds.asarray(return_channel_labels=True, channels_last=True)
fs = eeg_ds.fs
if verbose > 0:
print('Detecting artefacts using random forest classifier in {} with {} channels: {}'
.format(eeg_ds.__class__.__name__,
len(eeg_ds),
[ts.label for ts in eeg_ds]))
# Predict/detect artefacts.
af_mask = ad.predict(eeg, fs=fs, pma=pma, preprocess=preprocess,
verbose=verbose)
# Create a copy of the EegDataset object.
af_ds = copy.deepcopy(eeg_ds)
# Replace the EEG data with their artefact masks.
for ts, mask in zip(af_ds, af_mask.T):
assert len(ts) == len(mask)
ts.parameters.update(dtype=bool)
ts.signal = mask
return af_ds
[docs] def filter(self, *args, remove_mean=True, **kwargs):
"""
Calls the parent function, but with remove_mean set to True by default.
"""
return super().filter(*args, remove_mean=remove_mean, **kwargs)
[docs] def filter_saved_filter(self, *args, remove_mean=True, **kwargs):
"""
Calls the parent function, but with remove_mean set to True by default.
"""
return super().filter_saved_filter(*args, remove_mean=remove_mean, **kwargs)
[docs] def filtfilt(self, *args, remove_mean=True, **kwargs):
"""
Calls the parent function, but with remove_mean set to True by default.
"""
return super().filtfilt(*args, remove_mean=remove_mean, **kwargs)
[docs] def notch_filt(self, *args, remove_mean=True, **kwargs):
"""
Calls the parent function, but with remove_mean set to True by default.
"""
return super().notch_filt(*args, remove_mean=remove_mean, **kwargs)
[docs] @staticmethod
def from_array(eeg, fs, channel_labels=None, **kwargs):
"""
Create an EegDataset from a 2D array (channels x time).
Args:
eeg (np.ndarray): 2d array with dimensions (channels, time).
fs (float): sample frequency in Hz.
channel_labels (list): list with length len(eeg) containing labels for the channels.
If None, will create default numbered names, like Ch1, Ch2, Ch3, ...
**kwargs: optional keyword arguments for nnsa.TimeSeries.
Returns:
ds (nnsa.EegDataset): EegDataset with the data from the array.
Examples:
>>> eeg = np.random.rand(8, 10000)
>>> eeg_ds = EegDataset.from_array(eeg, fs=250, channel_labels=list(range(8)))
"""
# Check input dimensions.
eeg = np.asarray(eeg)
if eeg.ndim == 1:
# Assume one channel, change to shape (n_time, n_channels).
eeg = eeg.reshape(-1, 1)
if eeg.ndim != 2:
raise ValueError('`eeg` must have 2 dimensions. Got {} dimensions.'
.format(eeg.ndim))
if channel_labels is None:
# Create default labels.
channel_labels = ['Ch{}'.format(i+1) for i in range(min(eeg.shape))]
if len(eeg) != len(channel_labels):
# Try transpose.
eeg = np.transpose(eeg)
if len(eeg) != len(channel_labels):
raise ValueError('`len(eeg) ({}) does not equal len(channel_labels) ({})'
.format(len(eeg), len(channel_labels)))
# Create dataset (collection of TimeSeries).
ds = EegDataset()
for signal, label in zip(eeg, channel_labels):
ts = TimeSeries(signal=signal, fs=fs, label=label, **kwargs)
ds.append(ts)
return ds
[docs] def get_channel(self, channel, make_copy=True):
"""
Return a channel as TimeSeries. Can also be a bipolar channel.
Args:
channel (str): channel label, or a bipolar channel.
make_copy (bool): if True, the returned TimeSeries is always a copy.
If False, the returned TimeSeries may be the same object as in self,
if `channel` is in self.
Returns:
ts (TimeSeries): time series containing the channel signal.
"""
# Get the requested channel.
if channel in self:
ts = self[channel]
if make_copy:
ts = copy.deepcopy(ts)
elif '-' in channel:
# Create bipolar channel.
channels_1, channels_2 = channel.replace(' ', '').split('-')
ts = self.create_bipolar_channel(channels_1, channels_2)
else:
raise ValueError('Invalid channel "{}".'.format(channel))
return ts
[docs] def get_segments(self, segment_length, overlap=0, channels_last=True):
"""
Segment the EEG data and return the segmented data as a 3D array.
Args:
segment_length (float): segment length in seconds.
overlap (float): segment overlap in seconds.
channels_last (bool): if True, the returned array has shape (n_segments, n_time, n_channels).
If False, the shape is (n_segments, n_channels, n_time).
Returns:
eeg_seg (np.ndarray): 3D array with length equal to the number of segments.
The other two dimensions are time and channels, their order depends on `channels_last`.
"""
eeg_ar = self.asarray(channels_last=channels_last)
eeg_seg = get_all_segments(
eeg_ar, segment_length=segment_length,
overlap=overlap, axis=0 if channels_last else 1, fs=self.fs)
return eeg_seg
[docs] def kirubin_features(self, reference_channel=None, preprocess=False, inplace=False, verbose=1, **kwargs):
"""
Compute Kirubin's feature set.
This is a wrapper that prepares the input for KirubinFeatures.wct() and returns the result.
Args:
reference_channel (str, optional): specify how to reference the data.
If no referencing is desired, specify None.
preprocess (bool, optional): apply a default preprocessing routine on the data before feature computation
(True) or not (False).
Defaults to False.
inplace (bool, optional): if `preprocess_eeg` is True, specify if the preprocessing should be inplace or not.
Defaults to False.
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): optional keyword arguments to overrule default parameters of the KirubinFeatures
class.
Returns:
result (nnsa.KirubinFeaturesResult): the result with the features.
"""
# Preprocess.
if preprocess:
if inplace:
ds = self
else:
ds = copy.deepcopy(self)
# Default processing routine.
fir_filt = get_eeg_fir_filter_a()
notch_filt = NotchIIR(f0=50, Q=35)
if reference_channel is not None:
ds.reference(reference_channel=reference_channel, inplace=True)
ds.remove_flatlines(n_flatline=1, inplace=True)
ds.resample(fs_new=250, method='polyphase_filtering', inplace=True)
ds.filtfilt(fir_filt, inplace=True)
ds.filtfilt(notch_filt, inplace=True)
else:
# Use data as it is.
ds = self
if verbose > 0:
print('Computing Kirubin features of {} with {} channels: {}'
.format(ds.__class__.__name__,
len(ds.time_series),
[ts.label for ts in ds.time_series.values()]))
# Initialize feature extraction object (updates default parameters with user specified keyword arguments).
from nnsa.feature_extraction.feature_sets_old import KirubinFeatures
kf = KirubinFeatures(**kwargs)
# Prepare input matrix for feature extraction function.
data_matrix, channel_labels = ds.asarray(return_channel_labels=True)
# Sample frequency of all EEG signals is the same if self.asarray() did not raise an error.
fs = next((ts.fs for ts in ds.time_series.values()))
# Process.
result = kf.process(data_matrix, fs=fs, channel_labels=channel_labels,
verbose=verbose)
# Add info string about data to result.
ds._postprocess_result(result)
return result
[docs] def mean(self, channels=None, label=None, **kwargs):
"""
Return the nanmean of the specified channels as a TimeSeries.
Args:
channels (list, optional): list with labels of the signals to be average.
If None, all signals in the Dataset are averaged.
Defaults to None.
label (str, optional): label for returned TimeSeries.
**kwargs (optional): keyword arguments for TimeSeries().
Returns:
ts (nnsa.TimeSeries): TimeSeries containing the median of the channels.
"""
eeg, channel_labels = self.asarray(channels=channels, return_channel_labels=True)
fs = self.fs
# Take median across channels.
if len(eeg) == 1:
signal = eeg.copy()
else:
signal = np.nanmean(eeg, axis=0)
if label is None:
label = 'Mean of {}'.format(', '.join(
[lab.replace('EEG ', '') for lab in channel_labels]))
ts = TimeSeries(signal=signal, fs=fs, label=label,
unit=self.unit, info=self.info, time_offset=self.time_offset,
**kwargs)
return ts
[docs] def plot(self, begin=None, end=None, channels=None, color=None, scale=None, relative_time=False,
time_scale=None, ax=None, legend=True, **kwargs):
"""
Plot the EEG data of specified channels for a specifiec time frame.
Assumes that all channels have same sampling frequency and unit. Raises a NotImplementedError if this is not the
case.
Args:
begin (float, optional): begin time in seconds, relative to the beginning of the loaded signal
(ignores any offset).
Defaults to 0.
end (float, optional): end time in seconds or None to specify the end of the entire signal.
Relative to the beginning of the loaded signal ignores any offset.
Defaults to None.
channels (list, optional): list of labels specifying the channels to plot, in order bottom to up.
If None, all EEG channels in the EegDataset are plotted.
Defaults to None.
color (str or array or None): if str or array, use this as the color for all channels.
If None, the default colors will be used to plot each channel in a different color.
scale (int, optional): value (in the same unit as the data) by which to scale the data (data is divided by
scale). This factor must aim to squeeze most data between -0.5 and 0.5. If None, this scale is
determined automatically.
Default is None.
relative_time (bool, optional): if True, the time axis is relative to the start of the segment to plot.
If False, the time axis will correspond to the time in the recording.
time_scale (str, optional): the time scale to use. Choose from 'seconds', 'minutes', 'hours', None.
If None or 'timedelta', plots with a datetime.timedelta format (hh:mm:ss).
ax (plt.Axes, optional): axes to plot in. If None, plots in a new figure.
Defaults to None.
legend (bool, optional): add a legend (True) indicating the scale or not (False).
Defaults to True.
**kwargs (optional): optional keyword arguments for the plt.plot() function.
Returns:
ax (plt.axes): axes instance.
"""
# By default, plot all channels.
if channels is None:
channels = list(self.time_series.keys())[::-1]
if isinstance(channels, str):
channels = [channels]
if not self.is_synchronized():
raise ValueError('EEG signals not synchronized. Cannot plot.')
# # Get fs and time offset.
# fs, time_offset = next((ts.fs, ts.time_offset) for ts in self.time_series.values())
# Extract unit.
unit_list = [self[ch].unit for ch in channels]
if len(np.unique(unit_list)) != 1:
NotImplementedError('Plot function not compatible with EEG signals of different units.')
unit = np.unique(unit_list)[0]
# Create time.
time = self.time
time_mask = get_range_mask(time, x_min=begin, x_max=end)
time = time[time_mask]
# Collect (segments of) signals to plot.
signal_list = [self[ch].signal[time_mask] for ch in channels]
# Compute scaling: we want all signals to lie (for the most part within in a range of 1).
if scale is None:
deviation_range = np.nanmax(
[np.nanpercentile(sig, 95) - np.nanpercentile(sig, 5)
for sig in signal_list])
if np.isnan(deviation_range):
deviation_range = 0
scale = int(round(deviation_range))
if scale == 0:
scale = 1
# Compute offsets.
offset_list = np.arange(len(channels)) + 1
# Subtract t0 is relative time is requested.
if relative_time:
time -= time[0]
# Convert time scale if requested.
if time_scale is not None:
if time_scale == 'timedelta':
time_scale = None
else:
time = convert_time_scale(time, time_scale)
# Default plot keyword arguments.
plot_kwargs = {
'linewidth': compute_linewidth(signal_list[0]),
}
# Update plot keyword arguments with user specified keyword arguments.
plot_kwargs.update(kwargs)
# Set current axis.
if ax is not None:
plt.sca(ax)
else:
ax = plt.gca()
if color is None:
colors = ['C{}'.format(i) for i in range(len(channels))]
else:
colors = [color]*len(channels)
# Loop over channels and plot.
for label, signal, offset, color in zip(channels, signal_list, offset_list, colors):
# Create a scaled version of the channel signal.
y = signal / scale + offset
# Plot.
plt.plot(time, y, color=color, **plot_kwargs)
# Remove label from plot_kwargs (only put a label on the first channel).
plot_kwargs['label'] = None
# Figure makeup.
plt.yticks(offset_list, [ch.replace('EEG', '') for ch in channels])
for color, text in zip(colors, plt.yticks()[1]):
text.set_color(color)
plt.ylim(0, len(channels) + 1)
if legend:
add_scalebar(plt.gca(), labely='{}\n{}'.format(scale, unit), matchx=False, matchy=False, hidex=False,
hidey=False, sizey=1, barwidth=3, sep=4, loc=4, barcolor='black')
plt.grid(True)
plt.title(self.label)
# x-axis.
if time_scale is None:
def timeTicks(x, pos):
d = datetime.timedelta(seconds=x)
return str(d)
formatter = FuncFormatter(timeTicks)
ax.xaxis.set_major_formatter(formatter)
ax.set_xlabel('Time (h:mm:ss)')
else:
ax.set_xlabel('Time ({})'.format(time_scale))
return ax
[docs] def plot_mask(self, mask, time_scale='seconds', color=None,
skip_value=0, ax=None, **kwargs):
"""
Plot a shaded mask over an EEG plot.
Args:
mask (np.ndarray): mask with shape (n_time, n_channels).
time_scale (str): time scale in which to plot.
color (str, array, dict): color or dictionary of colors for plotting the labels in mask.
skip_value (int, bool, float or list): value(s) in mask to not plot.
ax (plt.Axes): axes to plot in.
**kwargs: for plot().
"""
if ax is None:
ax = plt.gca()
mask = np.asarray(mask).astype(int)
if mask.ndim == 1:
# 1D array is given, assume it holds for all channels: stack copies of the mask.
mask = np.tile(mask.reshape(-1, 1), [1, len(self)])
if len(self) != mask.shape[1]:
# Try transposing mask.
mask = mask.T
if len(self) != mask.shape[1]:
raise ValueError('`mask` should have same shape number of channels as EEG ({}). Got {}.'
.format(len(self), mask.shape[1]))
# By default, flip mask, so that the first channels appear at the top (like in self.plot()).
mask = mask[:, ::-1]
if not isinstance(skip_value, (list, tuple, set)):
skip_value = [skip_value]
skip_value = list(skip_value)
# Unique labels.
if not isinstance(color, dict):
# Default color for each unique label.
keys = list(set(np.unique(mask)) - set(skip_value))
if color is None:
values = ['C{}'.format(i) for i in range(len(keys))]
else:
values = [color]*len(keys)
color_mapping = dict(zip(keys, values))
else:
color_mapping = color
plot_kwargs = dict({
'zorder': -1,
'alpha': 0.4,
'edgecolor': None,
}, **kwargs)
# Get time array.
t = convert_time_scale(self.time, time_scale=time_scale)
# Loop over labels in mask.
for value, color in color_mapping.items():
if value in skip_value:
continue
for i, m in enumerate(mask.T):
mask_i = (m == value)
y_i = i + 1
onsets_i, offsets_i = get_onsets_offsets(mask_i, fs=self.fs)
durations_i = convert_time_scale((offsets_i - onsets_i), time_scale=time_scale)
for t_i, dur_i in zip(onsets_i, durations_i):
ax.add_patch(Rectangle(
xy=(t_i + self.time_offset, y_i - 0.5), width=dur_i, height=1,
facecolor=color, **plot_kwargs))
[docs] def power_analysis(self, verbose=1, **kwargs):
"""
Calls the paretn method, but sets some common frequency bands for EEG.
"""
feature_kwargs = dict({'frequency_bands': {
'delta_1': [0, 2],
'delta_2': [2, 4],
'theta': [4, 8],
'alpha': [8, 16],
'beta': [16, 32]
}}, **kwargs)
return super().power_analysis(verbose=verbose, **feature_kwargs)
[docs] def reference(self, reference_channel='Cz', remove_reference=True, inplace=False, verbose=1):
"""
Reference all EEG signals to reference_channel, i.e. subtract the signal of the reference_channel from all
other EEG channels.
Args:
reference_channel (str, optional): label of the channel to use as reference.
Can also be 'average' to subtract the average channel.
Default is 'Cz'.
remove_reference (bool, optional): if True, remove the reference channel from the referenced dataset.
If False, keep the reference channel after referencing (this channel will be zero everywhere).
Default is True.
inplace (bool, optional): if True, the referenced signals will replace the original signals in this
EegDataset. If False (default), the function returns a new EegDataset object containing the
referenced signals, leaving the data in the current EegDataset unchanged. Note that if only certain
channels are referenced, the newly created EegDataset object will only contain those channels.
verbose (int, optional): verbose level.
Defaults to 1.
"""
if verbose > 0:
print('Referencing {} with {}...'.format(self.label,
reference_channel))
if inplace:
# We will replace the original TimeSeries objects in the current EegDataset by the referenced TimeSeries.
ds_referenced = self
else:
# We will replace the original TimeSeries objects in A COPY OF the current EegDataset by the referenced
# TimeSeries.
# Create a copy of the EegDataset object.
ds_referenced = copy.deepcopy(self)
if reference_channel.lower() in ['mean', 'average', 'avg']:
# Compute average.
reference_channel = 'average'
ref_signal = np.mean(self.asarray(), axis=0)
else:
# Extract reference signal.
reference_channel = self._check_label(reference_channel)
ref_signal = self[reference_channel].signal
if remove_reference:
# Remove reference from dataset.
ds_referenced.remove(reference_channel, inplace=True, verbose=verbose)
# Reference all EEG signals.
for label in ds_referenced.time_series.keys():
ts_referenced = self.time_series[label].reference(
ref_signal, reference_channel, verbose=1 if verbose > 1 else 0)
ds_referenced.time_series[label] = ts_referenced
# Only return if not in place referencing.
if not inplace:
return ds_referenced
[docs] def remove_artefacts_method(self, how, **kwargs):
"""
Shortcut to any of the artefact removal functions.
"""
if how == 'raw':
return self
elif how == 'amp':
return self.remove_artefacts_amplitude(**kwargs)
elif how == 'kota':
return self.remove_artefacts_kota(**kwargs)
else:
raise ValueError('Invalid how="{}".'.format(how))
[docs] def remove_artefact_channels(self, inplace=False, **kwargs):
"""
Remove channels that do not meet the quality criteria.
Args:
inplace (bool, optional): If True, the channel is removed directly from the EegDataset object itself.
If False, the function returns a copy of the original EegDataset object in which the artefacted channels
are removed, leaving the data in the current EegDataset unchanged.
Default is False.
**kwargs (optional): optional keyword arguments for overruling default signal quality criteria (see
nnsa.artefacts.artefact_detection.default_eeg_signal_quality_criteria()).
Returns:
(EegDataset): new EegDataset object containing the same signals, except for the removed channels (only if
inplace is False).
"""
# Update default criteria with use-specified arguments.
criteria = default_eeg_signal_quality_criteria()
criteria.update(kwargs)
return super().remove_artefact_channels(inplace=inplace, **criteria)
[docs] def remove_artefacts_amplitude(self, notch_filter=True, bp_filter=True, inplace=False):
"""
Remove artefacts as detected by self.detect_artefacts_amplitude().
"""
if inplace:
# We will replace the original TimeSeries objects in the current EegDataset.
ds_out = self
else:
# Create a copy of the EegDataset object.
ds_out = copy.deepcopy(self)
# Detect the artefacts.
af_ds = self.detect_artefacts_amplitude_kaupilla(notch_filter=notch_filter, bp_filter=bp_filter)
# Replace artefacts with nan.
for ts, mask in zip(ds_out, af_ds):
assert ts.label == mask.label
ts._signal[mask] = np.nan
if not inplace:
return ds_out
[docs] def remove_artefacts_derivative(self, threshold=23000, inplace=False):
"""
Remove artefacts based on a threshold on the derivative of the filtered signal.
Apply this to raw EEG data.
Args: TODO
inplace:
Returns:
EegDataset where the artefacts are replaced by nans.
"""
if inplace:
# We will replace the original TimeSeries objects in the current EegDataset.
ds_out = self
else:
# Create a copy of the EegDataset object.
ds_out = copy.deepcopy(self)
# Filter before checking the derivative.
fir_filt = get_eeg_fir_filter_a()
notch_filt = NotchIIR(f0=50, Q=35)
ds_filt = ds_out.filtfilt(notch_filt, inplace=False)
ds_filt.filtfilt(fir_filt, inplace=True)
# Remove artefacts based on derivative.
threshold = threshold/ds_filt.fs # Threshold scales with fs (as fs is lower, threshold is higher).
ds_filt.remove_artefacts(inplace=True, max_diff=threshold)
# Get artefact masks to apply the detected artefacts to the original (unfiltered) signals.
ds_af_mask = ds_filt.transform(fun=np.isnan).astype(bool)
# Replace artefacts with nan.
for ts, mask in zip(ds_out, ds_af_mask):
ts._signal[mask] = np.nan
if not inplace:
return ds_out
[docs] def remove_artefacts_kota(self, notch_filter=True, bp_filter=True, inplace=False, verbose=1):
"""
Remove artefacts as detected by self.detect_artefacts_kota().
"""
if verbose:
print('Removing artefacts with Kota.')
if inplace:
# We will replace the original TimeSeries objects in the current EegDataset.
ds_out = self
else:
# Create a copy of the EegDataset object.
ds_out = copy.deepcopy(self)
# Detect the artefacts.
af_ds = self.detect_artefacts_kota(notch_filter=notch_filter, bp_filter=bp_filter)
# Replace artefacts with nan.
for ts, mask in zip(ds_out, af_ds):
assert ts.label == mask.label
ts._signal[mask] = np.nan
if not inplace:
return ds_out
[docs] def remove_artefacts_mask(self, mask_ds, inplace=False):
"""
Insert nans where boolean mask is True.
Args:
mask_ds (BaseDataset): dataset with boolean masks for each of the signals in self.
"""
if inplace:
# We will replace the original TimeSeries objects in the current EegDataset.
ds_out = self
else:
# Create a copy of the EegDataset object.
ds_out = copy.deepcopy(self)
if not isinstance(mask_ds, BaseDataset):
if len(mask_ds) != len(ds_out):
raise ValueError('`mask_ds` does not have the same length as the dataset.')
for ii, ts in enumerate(ds_out):
if isinstance(mask_ds, BaseDataset):
mask = np.asarray(mask_ds[ts.label].signal)
else:
mask = np.asarray(mask_ds[ii])
if mask.dtype is not np.dtype('bool'):
raise TypeError('Signals in `mask_ds` should have np.dtype `bool`. Got {}.'.format(mask.dtype))
ts._signal[mask] = np.nan
if not inplace:
return ds_out
[docs] def remove_artefacts_rfc(self, pma, preprocess=None,
inplace=False, **kwargs):
"""
Remove artefacts based on a sample supervised random forest classifier from the artefact_detection package.
EEG data must be referenced to Cz. If Cz is in self, referencing will be done automatically.
If Cz is not present, referencing will be ignored (assuming the EEG data is already referenced).
To apply to raw EEG data: set `preprocess_eeg` to True.
See Also:
RfcArtefactDetector, self.detect_artefacts_rfc().
Args:
pma (np.ndarray): PMA of the neonate at time of recording.
preprocess (bool): specify whether the EEG needs to be preprocessed (filtered, resampled).
Set to True if `eeg` is raw data (but it should still be referenced to Cz).
If not specified, preprocessing will be done if `fs` is not 128, otherwise not.
inplace (bool): whether to remove in place or not.
**kwargs (dict): for RfcArtefactDetector().
Returns:
ds_out (EegDataset): EegDataset where the artefacts are replaced by nans.
"""
if inplace:
# We will replace the original TimeSeries objects in the current EegDataset.
ds_out = self
else:
# Create a copy of the EegDataset object.
ds_out = copy.deepcopy(self)
# Detect the artefacts.
af_ds = self.detect_artefacts_rfc(pma=pma, preprocess=preprocess, **kwargs)
# Replace artefacts with nan.
for ts, mask in zip(ds_out, af_ds):
assert ts.label == mask.label
ts._signal[mask] = np.nan
if not inplace:
return ds_out
[docs] @staticmethod
def read_begin_end_time(filepath):
"""
Read the start and end time (wrt time offset) of the EEG data in a HDF5 file.
Args:
filepath (str): filepath to read.
Returns:
begin, end (float): begin and end time in seconds.
"""
with h5py.File(filepath, 'r') as f:
sig = f['EEG']
fs = sig.attrs['fs']
time_offset = sig.attrs['time_offset']
data_len = sig.shape[-1]
begin = time_offset
end = begin + data_len/fs
return begin, end
[docs] @staticmethod
def read_hdf5(filepath, begin=None, end=None, **kwargs):
"""
Read an EegDataset object from an .hdf5 file.
Args:
filepath (str): filepath to read.
begin (float, optional): start second (wrt time offset).
end (float, optional): end second (wrt time offset).
**kwargs (optional): keyword arguments for TimeSeries().
Returns:
ds (nnsa.EegDataset): EegDataset object with data read from the file.
Examples:
>>> signal = np.random.rand(8, 1000)
>>> eeg_ds = EegDataset()
>>> ts_all = [TimeSeries(signal=signal[i], fs=10, label=i) for i in range(len(signal))]
>>> eeg_ds.append(ts_all)
>>> eeg_ds.save_hdf5('testfile.hdf5')
>>> ds_read = EegDataset.read_hdf5('testfile.hdf5')
>>> ds_read_epoch = EegDataset.read_hdf5('testfile.hdf5', begin=10, end=20)
>>> os.remove('testfile.hdf5')
>>> assert_equal(eeg_ds.signal, ds_read.signal)
>>> assert_equal(eeg_ds.extract_epoch(begin=10, end=20).signal, ds_read_epoch.signal)
"""
ds = EegDataset()
with h5py.File(filepath, 'r') as f:
# Read data.
sig = f['EEG']
fs = sig.attrs['fs']
time_offset = sig.attrs['time_offset']
channel_labels = sig.attrs['channel_labels']
unit = sig.attrs['unit']
# Determine start and end index.
end_idx = None if end is None else max([0, min([int((end - time_offset)*fs), sig.shape[-1]])])
begin_idx = 0 if begin is None else max([int((begin - time_offset)*fs), 0])
# Extract requested part.
eeg = sig[:, begin_idx:end_idx]
# Update time_offset (depending on `begin`).
time_offset = time_offset + begin_idx/fs
# Do not read info, just put the current filepath.
info = {'source': f.filename}
# Add each channel to the dataset.
for x, label in zip(eeg, channel_labels):
ds.append(TimeSeries(x, fs=fs, label=label.decode(), unit=unit.decode(),
time_offset=time_offset, info=info, **kwargs))
return ds
[docs] @staticmethod
def read_edf(filepath, begin=None, end=None, **kwargs):
"""
Read an EegDataset object from an .edf file.
Args:
filepath (str): filepath to read.
begin (float, optional): start second.
end (float, optional): end second.
**kwargs (optional): keyword arguments for EdfReader.read_eeg_dataset().
Returns:
ds (nnsa.EegDataset): EegDataset object with data read from the file.
"""
from nnsa.io.readers import EdfReader
with EdfReader(filepath) as r:
ds = r.read_eeg_dataset(begin=begin, end=end, **kwargs)
return ds
[docs] @staticmethod
def read_mat(filepath, begin=None, end=None, loadmat_kwargs=None, **kwargs):
"""
Read an EegDataset object from a .mat file.
Args:
filepath (str): filepath to read.
begin (float, optional): start second (wrt time offset).
end (float, optional): end second (wrt time offset).
loadmat_kwargs (dict, optional): dict with keyword arguments for scipy.io.loadmat()/
**kwargs (optional): keyword arguments for TimeSeries().
Returns:
ds (nnsa.EegDataset): EegDataset object with data read from the file.
"""
# Read data.
loadmat_kwargs = dict({
'struct_as_record': True,
'squeeze_me': True,
'mat_dtype': False,
}, **loadmat_kwargs if loadmat_kwargs is not None else dict())
data = scipy.io.loadmat(filepath, **loadmat_kwargs)
sig = data['EEG']
fs = data['fs']
time_offset = data.get('time_offset', 0)
channel_labels = data.get('channel_labels', None)
unit = data.get('unit', 'a.u.')
# Check shape and transpose if needed (require n_channels, n_time).
if sig.shape[0] > sig.shape[-1]:
sig = sig.T
# Determine start and end index.
end_idx = None if end is None else max([0, min([int((end - time_offset) * fs), sig.shape[-1]])])
begin_idx = 0 if begin is None else max([int((begin - time_offset) * fs), 0])
# Extract requested part.
eeg = sig[:, begin_idx:end_idx]
# Update time_offset (depending on `begin`).
time_offset = time_offset + begin_idx / fs
# Just put the current filepath.
info = {'source': filepath}
# Add each channel to the dataset.
ds = EegDataset()
for x, label in zip(eeg, channel_labels):
ds.append(TimeSeries(x, fs=fs, label=label, unit=unit,
time_offset=time_offset, info=info, **kwargs))
return ds
[docs] @staticmethod
def read_pickle(filepath, begin=None, end=None, **kwargs):
"""
Read an EegDataset object from a .pkl file.
Args:
filepath (str): filepath to read.
begin (float, optional): start second (wrt time offset).
end (float, optional): end second (wrt time offset).
**kwargs (optional): keyword arguments for TimeSeries().
Returns:
ds (nnsa.EegDataset): EegDataset object with data read from the file.
"""
# Read data.
data = pickle_load(filepath)
sig = data['EEG']
fs = data['fs']
time_offset = data.get('time_offset', 0)
channel_labels = data.get('channel_labels', None)
unit = data.get('unit', 'a.u.')
# Check shape and transpose if needed (require n_channels, n_time).
if sig.shape[0] > sig.shape[-1]:
sig = sig.T
# Determine start and end index.
end_idx = None if end is None else max([0, min([int((end - time_offset) * fs), sig.shape[-1]])])
begin_idx = 0 if begin is None else max([int((begin - time_offset) * fs), 0])
# Extract requested part.
eeg = sig[:, begin_idx:end_idx]
# Update time_offset (depending on `begin`).
time_offset = time_offset + begin_idx / fs
# Just put the current filepath.
info = {'source': filepath}
# Add each channel to the dataset.
ds = EegDataset()
for x, label in zip(eeg, channel_labels):
ds.append(TimeSeries(x, fs=fs, label=label, unit=unit,
time_offset=time_offset, info=info, **kwargs))
return ds
[docs] @staticmethod
def read_set(filepath, begin=None, end=None, loadmat_kwargs=None, **kwargs):
"""
Read an EegDataset object from a .set file (from Matlab EEGLAB).
Args:
filepath (str): filepath to read.
begin (float, optional): start second (wrt time offset).
end (float, optional): end second (wrt time offset).
loadmat_kwargs (dict, optional): dict with keyword arguments for scipy.io.loadmat()/
**kwargs (optional): keyword arguments for TimeSeries().
Returns:
ds (nnsa.EegDataset): EegDataset object with data read from the file.
"""
# Read data.
loadmat_kwargs = dict({
'struct_as_record': True,
'squeeze_me': True,
'mat_dtype': False,
}, **loadmat_kwargs if loadmat_kwargs is not None else dict())
data = scipy.io.loadmat(filepath, **loadmat_kwargs)
sig = data['data']
fs = data['srate']
time_offset = data['times'][0]
channel_labels = None
if 'chanlocs' in data:
chanlocs = data['chanlocs']
channel_labels = [cl[0] for cl in chanlocs]
unit = data.get('unit', 'a.u.')
# Check shape and transpose if needed (require n_channels, n_time).
if sig.shape[0] > sig.shape[-1]:
sig = sig.T
# Determine start and end index.
end_idx = None if end is None else max([0, min([int((end - time_offset)*fs), sig.shape[-1]])])
begin_idx = 0 if begin is None else max([int((begin - time_offset)*fs), 0])
# Extract requested part.
eeg = sig[:, begin_idx:end_idx]
# Update time_offset (depending on `begin`).
time_offset = time_offset + begin_idx/fs
# Just put the current filepath.
info = {'source': filepath}
# Add each channel to the dataset.
ds = EegDataset()
for x, label in zip(eeg, channel_labels):
ds.append(TimeSeries(x, fs=fs, label=label, unit=unit,
time_offset=time_offset, info=info, **kwargs))
return ds
[docs] def save_hdf5(self, filepath, mode='w', overwrite=False):
"""
Save the EegDataset data to a .hdf5 file.
Args:
filepath (str): filepath to save to.
mode (str, optional): 'w' for write mode or 'a' for append mode.
Defaults to 'w'.
overwrite (bool, optional): if True, overwrites existing files. If False, raises
an error when `filepath` already exists.
Defaults to False.
"""
_, file_extension = os.path.splitext(filepath)
if not file_extension:
# Add file extension.
filepath = '{}.hdf5'.format(filepath)
elif file_extension.lower() not in ['.hdf5', '.h5']:
raise ValueError('Invalid file extension "{}". Use one of {}.'
.format(file_extension, ['.hdf5', '.h5']))
if not overwrite:
# Check if filepath already exists.
if os.path.exists(filepath):
raise ValueError('File "{}" already exists. Overwriting can be enabled by setting overwrite=True.'
.format(filepath))
# Check the directory and create if it does not exist.
check_directory_exists(filepath=filepath)
# Get EEG as array.
eeg, channel_labels = self.asarray(return_channel_labels=True)
unit = list(set([self.time_series[label].unit for label in channel_labels]))[0]
info = self.time_series[channel_labels[0]].info # Just take first channel.
# Write hdf5 file.
with h5py.File(filepath, mode=mode) as f:
# Write array data.
sig = f.create_dataset('EEG', data=eeg)
# Write non-array data as attributes to the signal array.
sig.attrs['fs'] = float(self.fs)
sig.attrs['time_offset'] = float(self.time_offset)
# Convert strings to np.string_ type as recommended for compatibility.
sig.attrs['channel_labels'] = [np.string_(label) for label in channel_labels]
sig.attrs['unit'] = np.string_(unit)
# Write dict as a separate dataset.
write_dict_to_hdf5(f, info, 'info')
[docs] def save_mat(self, filepath, overwrite=False):
"""
Save the EegDataset data to a .mat file.
Args:
filepath (str): filepath to save to.
overwrite (bool, optional): if True, overwrites existing files. If False, raises
an error when `filepath` already exists.
Defaults to False.
"""
_, file_extension = os.path.splitext(filepath)
if not file_extension:
# Add file extension.
filepath = '{}.mat'.format(filepath)
elif file_extension.lower() not in ['.mat']:
raise ValueError('Invalid file extension "{}". Use one of {}.'
.format(file_extension, ['.mat']))
if not overwrite:
# Check if filepath already exists.
if os.path.exists(filepath):
raise ValueError('File "{}" already exists. Overwriting can be enabled by setting overwrite=True.'
.format(filepath))
# Check the directory and create if it does not exist.
check_directory_exists(filepath=filepath)
# Get EEG as array.
eeg, channel_labels = self.asarray(return_channel_labels=True)
# Collect data in dict.
data = {
'EEG': eeg,
'channel_labels': channel_labels,
'fs': self.fs,
'time_offset': self.time_offset,
'unit': self.unit,
'info': str(self.info),
}
# Save dict.
scipy.io.savemat(filepath, data)
[docs] def save_csv(self, filepath, overwrite=False, **kwargs):
"""
Save the EegDataset data to a .csv (comma-separated value) file.
This is slow, as it used the pandas functionality.
Args:
filepath (str): filepath to save to.
overwrite (bool, optional): if True, overwrites existing files. If False, raises
an error when `filepath` already exists.
Defaults to False.
**kwargs (dict, optional): for pandas.DataFrame.to_csv().
"""
_, file_extension = os.path.splitext(filepath)
if not file_extension:
# Add file extension.
filepath = '{}.csv'.format(filepath)
elif file_extension.lower() not in ['.csv']:
raise ValueError('Invalid file extension "{}". Use {}.'
.format(file_extension, '.csv'))
if not overwrite:
# Check if filepath already exists.
if os.path.exists(filepath):
raise ValueError('File "{}" already exists. Overwriting can be enabled by setting overwrite=True.'
.format(filepath))
# Check the directory and create if it does not exist.
check_directory_exists(filepath=filepath)
# Get EEG as array.
eeg, channel_labels = self.asarray(return_channel_labels=True)
# To dataframe.
eeg_df = pd.DataFrame(
eeg.T,
columns=[lab.replace('EEG', '').strip() for lab in channel_labels])
eeg_df['Time'] = self.time
# Save.
eeg_df.to_csv(filepath, index=False, **kwargs)
[docs] def save_pickle(self, filepath, overwrite=False, **kwargs):
"""
Save the EegDataset data to a .pkl (pickle) file.
Args:
filepath (str): filepath to save to.
overwrite (bool, optional): if True, overwrites existing files. If False, raises
an error when `filepath` already exists.
Defaults to False.
**kwargs (dict, optional): for nnsa.pickle_save().
"""
_, file_extension = os.path.splitext(filepath)
if not file_extension:
# Add file extension.
filepath = '{}.pkl'.format(filepath)
elif file_extension.lower() not in ['.pkl', '.pickle']:
raise ValueError('Invalid file extension "{}". Use one of {}.'
.format(file_extension, ['.pkl', '.pickle']))
if not overwrite:
# Check if filepath already exists.
if os.path.exists(filepath):
raise ValueError('File "{}" already exists. Overwriting can be enabled by setting overwrite=True.'
.format(filepath))
# Check the directory and create if it does not exist.
check_directory_exists(filepath=filepath)
# Get EEG as array.
eeg, channel_labels = self.asarray(return_channel_labels=True)
# Collect data in dict.
data = {
'EEG': eeg,
'channel_labels': channel_labels,
'fs': self.fs,
'time_offset': self.time_offset,
'unit': self.unit,
'info': str(self.info),
}
# Save dict.
pickle_save(filepath, data, **kwargs)
[docs] def sleep_stages_cnn(self, preprocess=None,
remove_flats=True, remove_baseline_wander=False,
stepsize=None, verbose=1,
preprocess_kwargs=None, **kwargs):
"""
Automatic sleep stage classification using a Convolutional Neural Network.
This is a wrapper that prepares the input for SleepStagesCnn.sleep_stages_cnn() and returns the result.
Args:
preprocess (bool, optional): if True, the EEG data in ds will be processed by filtering and resampling
correspondingly. If False, the data in ds will not be preprocessed (use this when the ds already
contains correspondingly prepreocessed data).
If None, preprocessing will be done if the sampling frequency of the data does not match
the required fs by the model.
remove_flats (bool, optional): if True, removes flatlines.
remove_baseline_wander (bool): if True, removes excessive baseline wander by applying a HP filter.
This was not part of original preprocessing during training, and most EEGs don't need this
(in the original preprocessing there is already a HP filter, but milder).
If not needed, don't do this additional baseline removal.
stepsize (float): stepsize (in seconds) for sleep stage prediction. If None, stepsize is equal to
segment length (no overlap).
verbose (int, optional): verbose level (0 or 1).
Defaults to 1.
preprocess_kwargs (dict, optional): dict with optional keyword arguments for
SleepStagesCnn.preprocess_recording().
Defaults to None.
**kwargs (optional): optional keyword arguments to overrule default parameters of the SleepStagesCnn class.
Returns:
result (nnsa.SleepStagesCnnResult): nnsa object containing the results of the CNN sleep stage
classification.
"""
if verbose > 1:
print('Classifying sleep stages with CNN of {} with {} channels: {}'
.format(self.label,
len(self.time_series),
[ts.label for ts in self.time_series.values()]))
# Initialize SleepStagesCnn object.
sleep_stages_cnn = SleepStagesCnn(**kwargs)
# Prepare the dataset for the CNN and initialize a SleepStagesCnn object.
x = self._prepare_sleep_stages_cnn(
sleep_stages_cnn=sleep_stages_cnn,
preprocess=preprocess,
remove_flats=remove_flats,
remove_baseline_wander=remove_baseline_wander,
stepsize=stepsize,
verbose=verbose > 1,
**(preprocess_kwargs if preprocess_kwargs is not None else dict()))
segment_length = sleep_stages_cnn.data_requirements['segment_length']
stepsize = stepsize if stepsize is not None else segment_length
segment_start_times = get_segment_times(
num_segments=len(x), segment_length=segment_length,
overlap=segment_length - stepsize, offset=0) # Time offset will be added by self._postprocess_result.
# Run sleep stages cnn algorithm.
result = sleep_stages_cnn.sleep_stages_cnn(
x,
segment_start_times=segment_start_times,
segment_end_times=segment_start_times + segment_length,
verbose=verbose)
# If using a stepsize, keep only 1 prediction per segment length.
if stepsize is not None:
stride = int(round(segment_length // stepsize))
result.probabilities = result.probabilities[:, ::stride]
result.segment_start_times = result.segment_start_times[::stride]
result.segment_end_times = result.segment_end_times[::stride]
if result.latent_features is not None:
result.latent_features = result.latent_features[:, ::stride]
# Add info string about data to result.
channel_order = sleep_stages_cnn.data_requirements['channel_order']
eeg_labels = ['EEG {}'.format(ch) for ch in channel_order]
self._postprocess_result(result, channels=eeg_labels)
return result
[docs] def sleep_stages_robust(self, verbose=1, **kwargs):
"""
Automatic robust sleep stage classification.
This is a wrapper that prepares the input for SleepStagesRobust.process() and returns the result.
Args:
verbose (int, optional): verbose level (0 or 1).
Defaults to 1.
**kwargs (optional): optional keyword arguments to the SleepStagesRobust class.
Returns:
result (nnsa.SleepStagesCnnResult): nnsa object containing the results.
"""
if verbose > 1:
print('Classifying robust sleep stages of {} with {} channels: {}'
.format(self.label,
len(self.time_series),
[ts.label for ts in self.time_series.values()]))
# Select channels and reference.
if 'Cz' in self:
eeg_ds = self.reference('Cz', inplace=False, verbose=0)
else:
# Assume already referenced.
eeg_ds = self
channels = ['Fp1', 'Fp2', 'C3', 'C4', 'T3', 'T4', 'O1', 'O2']
eeg_ds = eeg_ds.extract_channels(channels=channels, make_copy=False)
# Prepare to process.
eeg_ = eeg_ds.asarray(channels_last=True)
fs_ = eeg_ds.fs
# Process.
result = SleepStagesRobust(**kwargs).process(eeg=eeg_, fs=fs_)
eeg_ds._postprocess_result(result)
return result
[docs] def substitute_bad_channels(self, af_mask, inplace=False, **kwargs):
"""
Find segments with bad channels in the EEG and substitute by the mean of the good channels.
Args:
af_mask (np.ndarray or EegDataset): boolean mask for the EEG array data with True at locations
of artefacts. Shape should be (n_channels, n_time). An incompatible shape will be
transposed if that makes it compatible, or raises an error otherwise.
inplace (bool): whether to apply inplace (True) or not (False).
kwargs: for substitute_bad_channels.
Returns:
ds_out (EegDataset): cleaned EegDataset with bad channels substituted.
"""
if inplace:
ds_out = self
else:
ds_out = copy.deepcopy(self)
# Get data.
eeg, channel_labels = ds_out.asarray(channels_last=False, return_channel_labels=True)
if isinstance(af_mask, EegDataset):
af_mask = af_mask.asarray(channels=channel_labels, channels_last=False)
else:
af_mask = np.asarray(af_mask)
# Check shapes.
if eeg.shape != af_mask.shape:
if af_mask.T.shape == eeg.shape:
# Transpose mask.
af_mask = af_mask.T
else:
raise ValueError('`af_mask` ({}) should have the same shape as the EEG data ({}).'
.format(af_mask.shape, eeg.shape))
# Clean.
eeg_cleaned, _ = substitute_bad_channels(x=eeg, af_mask=af_mask, fs=ds_out.fs, axis=-1, **kwargs)
# Replace the old by the new signals.
for eeg_chan, lab in zip(eeg_cleaned, channel_labels):
ds_out[lab].signal = eeg_chan
# Only return if not in place.
if not inplace:
return ds_out
def _check_label(self, label):
"""
Return the standardized label for the given (unstandardized) EEG label or raise an error if the given label
is not a valid label for EEG or is not in the current dataset.
Args:
label (str): (unstandardized) EEG label.
Returns:
std_label (str): standardized EEG label if the given label was a valid EEG label.
Raises:
ValueError if the label is not in the dataset.
"""
if label in self.time_series:
return label
# Standardize the EEG label.
std_label = standardize_and_check_eeg_label(label)[0]
if std_label in self.time_series:
return std_label
# Raise error if the specified channel is not in the dataset.
raise KeyError('Channel "{}" not in dataset. Channels in dataset: {}.'
.format(label, list(self.time_series.keys())))
def _prepare_sleep_stages_cnn(self, sleep_stages_cnn, preprocess=None,
remove_flats=False, remove_baseline_wander=False,
stepsize=None, verbose=1, **kwargs):
"""
Prepare the input array for SleepStagesCnn.sleep_stages_cnn().
Args:
sleep_stages_cnn (nnsa.SleepStagesCnn): initialized SleepStagesCnn object.
preprocess (bool, optional): see self.sleep_stages_cnn(). If None, sets it to True is sampling
frequency is not as required for sleep staging input.
remove_flats (bool, optional): if True, removes flatlines in EEG (replace by nans).
remove_baseline_wander (bool): if True, removes excessive baseline wander by applying a HP filter.
This was not part of original preprocessing during training, and most EEGs don't need this
(in the original preprocessing there is already a highpass filter, but milder).
If not needed, don't do this additional baseline removal.
stepsize (float): stepsize (in seconds) for sleep stage prediction. If None, is set to segment length.
verbose (int, optional): verbosity level.
Defaults to 1.
**kwargs (optional): optinal arguments for the sleep_stages_cnn.preprocess_recording function.
Returns:
x (np.ndarray): input for SleepStagesCnn.sleep_stages_cnn().
"""
# Extract data requirements.
channel_order = sleep_stages_cnn.data_requirements['channel_order']
segment_length = sleep_stages_cnn.data_requirements['segment_length']
fs_req = sleep_stages_cnn.data_requirements['fs']
stepsize = stepsize if stepsize is not None else segment_length
# Make a copy of self so that we do not mutate the original object.
eeg_ds = copy.deepcopy(self)
fs = eeg_ds.fs
if preprocess is None:
# Preprocess if sampling frequency is not as required.
preprocess = fs != fs_req
# EEG labels (in required order that they must appear in the input array).
eeg_labels = ['EEG {}'.format(ch) for ch in channel_order]
# Reference to Cz. If Cz is not in the collection, assume that it is already referenced to Cz.
if 'Cz' in eeg_ds:
eeg_ds.reference('Cz', verbose=verbose, inplace=True)
# Remove flatlines.
if remove_flats:
eeg_ds.remove_flatlines(n_flatline=1, inplace=True, verbose=verbose)
# Remove baseline wander.
if remove_baseline_wander:
hp_filter = Butterworth(fn=[0.01], filter_type='highpass', order=2, fs=fs)
eeg_ds.filter(hp_filter, inplace=True)
# EEG data as array (time, channels).
eeg_array = eeg_ds.asarray(channels=eeg_labels, channels_last=True)
if preprocess:
if verbose > 0:
print('Preprocessing EEG data for sleep stage classification...')
eeg_array, fs = sleep_stages_cnn.preprocess_recording(
raw_eeg=eeg_array, raw_fs=fs,
verbose=verbose,
**kwargs)
# Segment the EEG signals (all channels simultaneously).
x = get_all_segments(eeg_array, fs=fs, segment_length=segment_length,
overlap=segment_length - stepsize, axis=0)
return x
[docs]class OxygenDataset(BaseDataset):
"""
High-level interface for processing oxygen data for neonatal signal analysis.
"""
[docs] def nirs_features(self, verbose=1, **kwargs):
"""
Compute feature set for NIRS.
This is a wrapper that prepares the input for NirsFeatures.wct() and returns the result.
Args:
verbose (int, optional): verbose level.
Defaults to 1.
**kwargs (optional): optional keyword arguments to overrule default parameters of the NirsFeatures
class.
Returns:
result (nnsa.FeatureSetResult): the result with the features.
"""
if verbose > 0:
print('Computing NIRS features of {} with {} channels: {}'
.format(self.label,
len(self.time_series),
[ts.label for ts in self.time_series.values()]))
# Initialize feature extraction object (updates default parameters with user specified keyword arguments).
from nnsa.feature_extraction.feature_sets_old import NirsFeatures
nf = NirsFeatures(**kwargs)
# Prepare input matrix for feature extraction function.
data_matrix, channel_labels = self.asarray(return_channel_labels=True)
# Sample frequency of all EEG signals is the same if self.asarray() did not raise an error.
fs = next((ts.fs for ts in self.time_series.values()))
# Process.
result = nf.process(data_matrix, fs=fs, channel_labels=channel_labels,
verbose=verbose)
# Add info string about data to result.
self._postprocess_result(result)
return result
[docs] def remove_artefacts(self, inplace=False, **kwargs):
"""
Replace samples that are artefacts by np.nan
Args:
inplace (bool, optional): If True, the samples are replaced directly from the Dataset object itself.
If False, the function returns a copy of the original Dataset object in which the artefacted samples
are replaced, leaving the data in the current Dataset unchanged.
Default is False.
**kwargs (optional): optional keyword arguments for overruling default sample quality criteria (see
nnsa.artefacts.artefact_detection.default_oxygen_sample_quality_criteria()).
Returns:
(OxygenDataset): new Dataset object containing the same signals,but with artefacted samples changed to
np.nan (only returned if inplace is False).
"""
# Update default criteria with user-specified arguments.
criteria = default_oxygen_sample_quality_criteria()
criteria.update(kwargs)
return super().remove_artefacts(inplace=inplace, **criteria)