Source code for nnsa.containers.datasets

"""
Module for creating an object that holds the data of multiple time series, annotations and features.
"""
import copy
import datetime
import os
import warnings

import h5py
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.io
from matplotlib.patches import Rectangle
from matplotlib.ticker import FuncFormatter

from nnsa.artefacts.clean_detector_cnn import CleanDetectorCnn
from nnsa.feature_extraction.brain_age_cnn import BrainAgeSinc
from nnsa.feature_extraction.feature_sets.eeg import EegFeatures
from nnsa.preprocessing.data_cleaning import substitute_bad_channels
from nnsa.utils.arrays import get_range_mask, moving_max, moving_mean
from nnsa.utils.event_detections import get_onsets_offsets
from nnsa.utils.pkl import pickle_save, pickle_load
from nnsa.utils.testing import assert_equal
from nnsa.artefacts.artefact_detection import detect_artefact_signals, default_eeg_signal_quality_criteria, \
    default_oxygen_sample_quality_criteria, detect_anomalous_channels
from nnsa.containers.time_series import TimeSeries
from nnsa.edfreadpy.io.utils import standardize_and_check_eeg_label
from nnsa.feature_extraction.brain_age import BrainAge
from nnsa.feature_extraction.discontinuity import BurstDetection
from nnsa.feature_extraction.connectivity import CoherenceGraph
from nnsa.feature_extraction.entropy import MultiScaleEntropy
from nnsa.feature_extraction.envelope import Envelope
from nnsa.feature_extraction.frequency_analysis import PowerAnalysis
from nnsa.feature_extraction.fractality import MultifractalAnalysis, LineLength
from nnsa.feature_extraction.sleep_stages import SleepStagesCnn, SleepStagesRobust
from nnsa.feature_extraction.statistics import SignalStats
from nnsa.io.hdf5 import write_dict_to_hdf5
from nnsa.preprocessing.combine_channels import combine_channels
from nnsa.preprocessing.filter import NotchIIR, Butterworth
from nnsa.preprocessing.saved_filters import get_eeg_fir_filter_a, get_eeg_fir_filter_b, get_filter
from nnsa.utils import check_directory_exists
from nnsa.utils.segmentation import segment_generator, get_segment_times, get_all_segments
from nnsa.utils.plotting import compute_linewidth, subplot_rows_columns
from nnsa.utils.conversions import convert_time_scale
from nnsa.utils.scalebars import add_scalebar

__all__ = [
    'DEFAULT_EEG_CHANNELS',
    'BaseDataset',
    'EegDataset',
    'OxygenDataset',
]


DEFAULT_EEG_CHANNELS = ['Fp1', 'Fp2', 'C3', 'C4', 'Cz', 'T3', 'T4', 'O1', 'O2']


[docs]class BaseDataset(object): """ Base class with common functionalities for Datasets. Args: *args (optional): optional positional arguments may be TimeSeries objects containing signals. Those will be added to the Dataset upon initialization. When a list is given, this list is expected to be a list of TimeSeries objects. label (str): optional label for dataset. """ def __init__(self, *args, label=None): self._time_series = dict() # holds TimeSeries objects of signals, label is the key. # Append the data objects passed as optional positional arguments. for data in args: self.append(data) # Set label. if label is None: label = self.__class__.__name__ self.label = label def __add__(self, other): """ Combine two datasets into a new dataset containing the signals in both sets. Args: other (BaseDataset): other dataset to add. Returns: ds_out (BaseDataset): new dataset with all signals. Raises: ValueError if there are overlapping keys/signal labels. """ # Other should be a BaseDataset. if not isinstance(other, BaseDataset): raise TypeError('Cannot add {} to {}.'.format(type(other), type(self))) # Create a copy. ds_out = copy.deepcopy(self) time_series_labels_self = list(self._time_series.keys()) time_series_labels_other = list(other._time_series.keys()) if any(lab in time_series_labels_self for lab in time_series_labels_other): raise ValueError('Cannot add {} to {} with overlapping signal labels.'.format(type(other), type(self))) ds_out._time_series.update(other._time_series) return ds_out def __array__(self): """ We can convert to a numpy array using numpy.array or numpy.asarray, which will call this __array__ method. Returns: a copy of the 2D signal array. """ return self.asarray() def __contains__(self, item): """ Returns True if item in self. False otherwise. Returns: found (bool): True if item in self. False otherwise. """ # Try getting the item. If succeeds, return True, else return False. try: _ = self[item] found = True except KeyError: found = False return found def __repr__(self): """ Return a comprehensive info string about this object. Returns: (str): a comprehensive info string about this object. """ info_string_parts = ['{} with label {} containing:'.format(self.__class__.__name__, self.label), 'time series:\n\t{}'.format('\n\t'.join(('{}: {}'.format(k, v) for k, v in self._time_series.items())))] return '\n'.join(info_string_parts) def __iter__(self): """ Iterate over time series objects. Returns: (iterator): iterator over time series objects. """ return iter(self.time_series.values()) def __len__(self): """ Return the number of time series objects. Returns: (int): length. """ return len(self.time_series) def __getitem__(self, item): """ Evaluation of self[item]. """ return self.time_series[item] def __mul__(self, other): ds_out = copy.deepcopy(self) for key in ds_out.time_series: if isinstance(other, BaseDataset): x = other[key] else: x = other ds_out.time_series[key] = ds_out[key] * x return ds_out def __truediv__(self, other): ds_out = copy.deepcopy(self) for key in ds_out.time_series: if isinstance(other, BaseDataset): x = other[key] else: x = other ds_out.time_series[key] = ds_out[key] / x return ds_out @property def channel_labels(self): """ Return a list of channel labels (labels of the time series objects). Returns: (list): list with channel labels. """ return list(self.time_series.keys()) @property def dtype(self): """ Return the dtype if all TimeSeries have the same dtype. Raises an error otherwise. Returns: dtype: dtype common to all signals. """ dtype_all = [ts.dtype for ts in self] if len(set(dtype_all)) == 1: return dtype_all[0] else: raise AttributeError('{} with time series with different dtypes does not have a `dtype` attribute.' .format(self.__class__.__name__)) @property def fs(self): """ Return the sample frequency if all TimeSeries have the same sample frequency. Raises an error otherwise. Returns: fs (flaot): sample frequency common to all signals. """ fs_all = [ts.fs for ts in self] if len(set(fs_all)) == 1: return fs_all[0] else: raise AttributeError('{} with time series of different sample frequencies does not have an `fs` attribute.' .format(self.__class__.__name__)) @property def shape(self): """ Return the shape (n_channels, n_samples). Raises an error if not all signals have the same lengths. """ len_all = [len(ts) for ts in self] if len(set(len_all)) == 1: n_samples = len_all[0] else: raise AttributeError('{} with time series of different lengths does not have a `shape` attribute.' .format(self.__class__.__name__)) n_channels = len(self) return n_channels, n_samples @property def unit(self): """ Return the unit if all TimeSeries have the same unit. Raises an error otherwise. Returns: unit (str): unit common to all signals. """ unit_all = [ts.unit for ts in self] if len(set(unit_all)) == 1: return unit_all[0] else: raise AttributeError('{} with time series with different units does not have a `unit` attribute.' .format(self.__class__.__name__)) @property def info(self): """ Return the info dict of first signal in dataset. Returns: info (dict): info dict of first signal in dataset. """ for ts in self: return ts.info @property def time(self): """ Return the time array if all TimeSeries have the same length, sample frequency and time_offset. Raises an error otherwise. Returns: time (np.ndaaray): time array common to all signals. """ if len(self) == 0: return [] fs_all = [ts.fs for ts in self] time_offset_all = [ts.time_offset for ts in self] len_all = [len(ts) for ts in self] if len(set(fs_all)) == 1 and len(set(time_offset_all)) == 1 and len(set(len_all)) == 1: return np.arange(len_all[0])/fs_all[0] + time_offset_all[0] else: raise AttributeError('{} with time series of different sample frequencies, length or time offset' ' does not have a `time` attribute.' .format(self.__class__.__name__)) @property def time_offset(self): """ Return the time offset if all TimeSeries have the same time offset. Raises an error otherwise. Returns: time_offset (flaot): time offset in seconds common to all signals. """ to_all = [ts.time_offset for ts in self] if len(set(to_all)) == 1: return to_all[0] else: raise AttributeError('{} with time series of different time offsets does not have a `time_offset` attribute.' .format(self.__class__.__name__)) @time_offset.setter def time_offset(self, time_offset): """ Set time offset of all TimeSeries in the DataSet. Note that the TimeSeries are affected inplace. If you don not want this, use self.apply_time_offset(time_offset, inplace=False) """ self.set_time_offset(time_offset, inplace=True) @property def time_series(self): """ Return the TimeSeries objects of the Dataset object. Returns: (dict): the TimeSeries objects in the Dataset object, where the keys are the TimeSeries labels and the values the corresponding TimeSeries objects. """ return self._time_series
[docs] def append(self, data_object): """ Append a data object to the Dataset (in place). Valid data objects: TimeSeries Args: data_object (various): an nnsa data object. If data_object is a list, this function will recursively be called on each item in the list such that data_object may also be a list of nnsa data objects. """ # Check type and append appropriately. object_type = type(data_object) if object_type is TimeSeries: # Extract label. label = data_object.label # Check if TimeSeries with this label is already in the Dataset. if label in self._time_series: # Give a warning that an old signal is being replaced. warnings.warn('\n{} with label {} replaced in {}.'.format(data_object.__class__.__name__, label, self.label)) # Add the TimeSeries object to the Dataset. self._time_series[label] = data_object elif object_type is list: # Input of append may be a list of data objects. for i in data_object: self.append(i) else: raise NotImplementedError('{} cannot append object of type {}.'.format(self.__class__.__name__, object_type))
[docs] def set_time_offset(self, time_offset, inplace=False): """ Set the same time offset to all time series. Args: time_offset (float): time offset to set. inplace (bool, optional): whether to set in place (True) or create a new object (False). Defaults to False. Returns: ds_out (BaseDataset): new dataset with time offset set to all time series (if inplace is True). """ if len(self) == 0: # No timeseries, we cannot set the time_offset. msg = 'No TimeSeries in {}.\n\tDid not set time_offset.'.format(self) warnings.warn(msg) return self if inplace: # We will apply the changes to the current Dataset object. ds = self else: # We will apply the changes to a copy of the current Dataset object and return this changed copy. ds = copy.deepcopy(self) # Loop over time series. for ts in ds: ts.time_offset = time_offset if not inplace: return ds
[docs] def asarray(self, channels=None, return_channel_labels=False, channels_last=False): """ Return an array containing (a copy of) multi-channel data with shape (num_channels, num_samples) if channels_last is False. Args: channels (list, optional): list with labels of the signals that are put into the output array. The order of the labels determines the order of the signals in the output array. I.e. the signal with label labels[i] is put in array(i, :). If None, all signals in the EegDataset are put in the output array. Defaults to None. return_channel_labels (bool, optional): if True, additionally outputs the labels that correspond with the rows of the created array. If False, only outputs the array. Default is False. channels_last (bool): if False, output shape is (num_channels, num_samples), if True, output shape is (num_samples, num_channels). Returns: array (np.ndarray): array with shape (num_channels, num_samples) if channels_last == False, containing (copies of) the multi-channel data. channels (list, optional): the labels corresponding to the rows in the output array (only if output_channel_labels is True). """ # Default channels are all channels in the order that .keys() returns. if channels is None: channels = self.channel_labels else: # Check if all channels are available in the Dataset. channels = [self._check_label(c) for c in channels] # Check if the signals are synchronized. if not self.is_synchronized(channels=channels): raise ValueError('Cannot convert {} to array because signals are not synchronized.' .format(self.label)) # Initialize output array of shape (num_channels, num_samples). example_signal = self.time_series[channels[0]].signal num_channels = len(channels) num_samples = example_signal.size array = np.zeros((num_channels, num_samples), dtype=example_signal.dtype) # Collect signals in array. for i, label in enumerate(channels): # Extract TimeSeries in array. array[i, :] = self.time_series[label].signal if channels_last: # Transpose to (n_samples, n_channels). array = array.T if return_channel_labels: # Return array and row labels corresponding to the channels. return array, channels else: # Only return array. return array
[docs] def astype(self, *args, inplace=False, **kwargs): if inplace: # We will apply the changes to the current Dataset object. ds = self else: # We will apply the changes to a copy of the current Dataset object and return this changed copy. ds = copy.deepcopy(self) for ts in ds: ts.astype(*args, **kwargs, inplace=True) if not inplace: return ds
[docs] def average_signals(self, channels=None, **kwargs): """ Average the signals and return as a new TimeSeries. Raises an error if the signals are not synchronized. Args: channels (list, optional): list with labels of the signals to be average. If None, all signals in the Dataset are averaged. Defaults to None. **kwargs (optional): keyword arguments for the creation of a the new TimeSeries object. Returns: (nnsa.TimeSeries): new TimeSeries object containing the average signal. """ # By default use all channels. if channels is None: channels = list(self.time_series.keys()) # Test whether signals to average are synchronized. if not self.is_synchronized(channels=channels): raise ValueError('Signals {} are not synchronized. Cannot compute average signal.' .format(channels)) n_channels = len(channels) if n_channels == 1: # No need to average. return self.time_series[channels[0]] all_units = [] all_info = [] # First signal forms basis. ts0 = self.time_series[channels[0]] sum_signals = ts0.signal.copy() fs = ts0.fs all_units.append(ts0.unit) all_info.append(str(ts0)) time_offset = ts0.time_offset pars = ts0.parameters # Loop over other signals. for label in channels[1:]: ts = self.time_series[label] sum_signals += ts.signal all_units.append(ts.unit) all_info.append(str(ts)) # Compute average signal. signal_avg = sum_signals/n_channels # Create info string for new TimeSeries. info = {'source': 'Average of {} TimeSeries: {}'.format(n_channels, '; '.join(all_info))} # Default arguments for new TimeSeries. ts_kwargs = { 'label': 'average {}'.format(' + '.join(channels)), 'unit': all_units[0] if len(np.unique(all_units)) == 1 else 'a.u.', 'check_label': False, 'check_unit': True, } # Add parameters to the keyword arguments. ts_kwargs.update(pars) # Override defaults with user-specified arguments. ts_kwargs.update(kwargs) # Create new TimeSeries. ts_avg = TimeSeries(signal_avg, fs, info=info, time_offset=time_offset, **ts_kwargs) return ts_avg
[docs] def clear(self): """ Empty the dataset (inplace). """ self.time_series.clear()
[docs] def combine_channels(self, channels=None, label=None, **kwargs): """ Combine the channels into one signal by cutting into segments and pasting the best channels together. See nnsa.combine_channels(). Args: channels (list, optional): list with labels of the signals to be combined. If None, all signals in the Dataset are combined. Defaults to None. label (str, optional): a label for the combined time series. **kwargs (optional): keyword arguments for nnsa.combine_channels(). Returns: ts (nnsa.TimeSeries): TimeSeries object containing the combined signal. """ # By default use all channels. if channels is None: channels = list(self.time_series.keys()) n_channels = len(channels) # Create a copy of one of the time series. ts = copy.deepcopy(self.time_series[channels[0]]) if n_channels == 1: # If only one signal, do not combine, just return the one signal. return ts # Get the data matrix and combine the channels. data_matrix = self.asarray(channels=channels) signal = combine_channels(data_matrix, **kwargs) # Create info string for new TimeSeries. all_info = [] for label in channels: ts = self.time_series[label] all_info.append(str(ts)) info = {'source': 'Combination of {} TimeSeries: {}'.format(n_channels, '; '.join(all_info))} # Update some variables. if label is None: label = 'Combined {}'.format(' + '.join(channels)) ts._signal = signal ts._label = label ts.info = info return ts
[docs] def compute_power_cwt(self, freq_low=None, freq_high=None, channels=None, inplace=False, **kwargs): """ Compute power in specified band for each signal in the dataset. Args: freq_low (float, optional): lower frequency cutoff for bandpower. If None, takes lowest possible. freq_high (float, optional): upper frequency cutoff for bandpower. If None, takes highest possible. channels (list, optional): list with labels of the signals to apply to. If None, all signals in the Dataset are processed. Defaults to None. inplace (bool, optional): if True, replaces the signals in place by their power. If False, returns a new object. **kwargs (optional): optional keyword arguments for compute_power_cwt(). Returns: ds (BaseDataset): dataset with power of the signals (if inplace is False). """ # By default use all channels. if channels is None: channels = self.time_series.keys() if inplace: # We will apply the changes to the current Dataset object. ds = self else: # We will apply the changes to a copy of the current Dataset object and return this changed copy. ds = copy.deepcopy(self) for label in channels: ds.time_series[label].compute_power_cwt(freq_low=freq_low, freq_high=freq_high, inplace=True, **kwargs) if not inplace: return ds
[docs] def demean(self, channels=None, verbose=1, inplace=False, **kwargs): """ Demean each time series in the dataset. """ if verbose > 0: print('Demeaning {}...'.format(self.label)) # By default use all channels. if channels is None: channels = self.time_series.keys() if inplace: # We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries. ds_filtered = self else: # We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered # TimeSeries. # Create a copy of the Dataset object. ds_filtered = copy.deepcopy(self) # Filter the signals. for label in channels: ts_filtered = self.time_series[label].demean(**kwargs) ds_filtered.time_series[label] = ts_filtered # Only return if not in place filtering. if not inplace: return ds_filtered
[docs] def envelope(self, verbose=1, **kwargs): """ Compute the envelope of the signals. This is a wrapper that prepares the input for Envelope.envelope() and returns the result. Args: verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): optional keyword arguments for the Envelope class. Returns: result (nnsa.EnvelopeResult): the result of the envelope computation. Examples: To get the envelope data as a EegDataset: >>> signal = 4 * np.sin(np.arange(10000)) >>> ts = TimeSeries(signal=signal, fs=100, label='EEG Cz', unit='uV', info={'source': 'random'}) >>> ds = BaseDataset([ts]) >>> envelope_result = ds.envelope(verbose=0) >>> ds_envelope = envelope_result.to_eeg_dataset() >>> print(type(ds_envelope).__name__) EegDataset >>> print(ds_envelope.time_series['EEG Cz'].signal.mean()) 3.9998155220627787 """ if verbose > 0: print('Computing envelope of {} with {} channels: {}' .format(self.label, len(self.time_series), [ts.label for ts in self.time_series.values()])) # Initialize Envelope object (pass user specified keyword arguments). envelope = Envelope(**kwargs) # Prepare input matrix for envelope. data_matrix, channel_labels = self.asarray(return_channel_labels=True) # Sample frequency of all signals is the same if self.asarray() did not raise an error. fs = next((ts.fs for ts in self.time_series.values())) # Run envelope. result = envelope.envelope(data_matrix, fs=fs, channel_labels=channel_labels, verbose=verbose) # Add info string about data to result. self._postprocess_result(result) return result
[docs] def extract_channels(self, channels=None, make_copy=True): """ Extract the specified channels and return them in a new dataset Args: channels (list): labels of channels to return. make_copy (bool): whether to make a copy of the TimeSeries objects or not. If False and a TimeSeries object in the output dataset is mutated, that TimeSeries in the input dataset is alse mutated. Returns: ds_out (Dataset): new dataset with the requested channels. Raise: ValueError if channels do not exist. """ # Check inputs. if isinstance(channels, str): channels = [channels] # Create an empty dataset. ds_out = self.__class__() # Loop over channels and add them to the dataset. for chan in channels: ts = self[chan] if make_copy: ts = copy.deepcopy(ts) ds_out.append(ts) return ds_out
[docs] def extract_default_channels(self, **kwargs): """ Extract the 9 default channels ([]). Args: **kwargs: for self.extract_channels(). Returns: ds_out (Dataset): new dataset with the requested channels. """ ds_out = self.extract_channels(channels=DEFAULT_EEG_CHANNELS, **kwargs) return ds_out
[docs] def extract_epoch(self, begin=None, end=None, channels=None, **kwargs): """ Extract a piece of data in the specified interval. Args: begin (float, optional): start time in seconds (including time_offset) of the epoch to extract. If negative, begin is duration + begin (python-like indexing). If None, the beginning of the signals is taken. Defaults to None. end (float, optional): end time in seconds (including time_offset) of the epoch to extract. If negative, the end is duration + end (python-like indexing). If None, the end of the signal is taken. Must point to a time greater than `begin`. Defaults to None. channels (list, optional): list signal labels to extract an epoch from. If None, all channels are taken. Defaults to None. **kwargs (dict, optional): for self.__init__(). Returns: ds_epoch (nnsa.BaseDataset): a new Dataset with only the data of the specified epoch. """ if channels is None: channels = self.time_series.keys() # Create new Dataset object and append the epoch data of the requested channels. opts = dict({'label': self.label}, **kwargs) ds_epoch = self.__class__(**opts) for label in channels: ts = self.time_series[label] ts_epoch = ts.extract_epoch(begin=begin, end=end) ds_epoch.append(ts_epoch) return ds_epoch
[docs] def fill_nan(self, value=0, inplace=False, **kwargs): """ Fill NaN values in the signals. Args: value (float): value to fill nans with. inplace (bool): whether to fill nans in place or create a new object. **kwargs: for TimeSeries.fill_nan(). Returns: ds (BaseDataset): new Dataset object containing the new signals (only if inplace is False). """ if inplace: # We will apply the changes to the current Dataset object. ds = self else: # We will apply the changes to a copy of the current Dataset object and return this changed copy. ds = copy.deepcopy(self) # Loop over TimeSeries in Dataset. for ts in ds: ts.fill_nan(value=value, inplace=True, **kwargs) if not inplace: return ds
[docs] def filter(self, filter_obj, channels=None, inplace=False, verbose=1, **kwargs): """ Filter each of the signals in the Dataset by calling the filter() method on each TimeSeries objects. Args: filter_obj (FilterBase-derived): a child class from the nnsa.FilterBase class. The fs property does not need to be set, will be done automatically. channels (list, optional): list with labels of the signals that should be filtered. If None, all signals in the Dataset are filtered. Defaults to None. inplace (bool, optional): if True, the filtered signals will replace the original signals in this Dataset. If False (default), the function returns a new Dataset object containing the filtered signals, leaving the data in the current Dataset unchanged. Note that if only certain channels are filtered, the newly created Dataset object will only contain those filtered channels. verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): keyword arguments for TimeSeries.filter(). Returns: (BaseDataset): new Dataset object containing the filtered signals (only if inplace is False). """ if verbose > 0: print('One-way filtering {} with {}...'.format(self.label, filter_obj)) # By default use all channels. if channels is None: channels = self.time_series.keys() if inplace: # We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries. ds_filtered = self else: # We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered # TimeSeries. # Create a copy of the Dataset object. ds_filtered = copy.deepcopy(self) # Filter the signals. for label in channels: ts_filtered = self.time_series[label].filter(filter_obj, verbose=1 if verbose > 1 else 0, **kwargs) ds_filtered.time_series[label] = ts_filtered # Only return if not in place filtering. if not inplace: return ds_filtered
[docs] def filtfilt(self, filter_obj, channels=None, inplace=False, verbose=1, **kwargs): """ Filter each of the signals in the Dataset by calling the filtfilt() method on each TimeSeries objects. Args: filter_obj (FilterBase-derived): a child class from the nnsa.FilterBase class. The fs property does not need to be set, will be done automatically. channels (list, optional): list with labels of the signals that should be filtered. If None, all signals in the Dataset are filtered. Defaults to None. inplace (bool, optional): if True, the filtered signals will replace the original signals in this Dataset. If False (default), the function returns a new Dataset object containing the filtered signals, leaving the data in the current Dataset unchanged. Note that if only certain channels are filtered, the newly created Dataset object will only contain those filtered channels. verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): keyword arguments for TimeSeries.filtfilt(). Returns: (BaseDataset): new Dataset object containing the filtered signals (only if inplace is False). """ if verbose > 0: print('Zero-phase filtering {} with {}...'.format(self.label, filter_obj)) # By default use all channels. if channels is None: channels = self.time_series.keys() if inplace: # We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries. ds_filtered = self else: # We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered # TimeSeries. # Create a copy of the Dataset object. ds_filtered = copy.deepcopy(self) # Filter the signals. for label in channels: ts_filtered = self.time_series[label].filtfilt(filter_obj, verbose=1 if verbose > 1 else 0, **kwargs) ds_filtered.time_series[label] = ts_filtered # Only return if not in place filtering. if not inplace: return ds_filtered
[docs] def filter_saved_filter(self, filter_name, channels=None, inplace=False, verbose=1, **kwargs): """ Filter each of the signals in the dataset by calling the filter_saved_filter() method on each TimeSeries object. Args: filter_name (str): see nnsa.preprocessing.filter_saved_filter(). channels (list, optional): list with labels of the signals that should be filtered. If None, all signals in the Dataset are filtered. Defaults to None. inplace (bool, optional): if True, the filtered signals will replace the original signals in this Dataset. If False (default), the function returns a new Dataset object containing the filtered signals, leaving the data in the current Dataset unchanged. Note that if only certain channels are filtered, the newly created Dataset object will only contain those filtered channels. verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): keyword arguments for TimeSeries.filter_saved_filter(). Returns: (BaseDataset): new Dataset object containing the filtered signals (only if inplace is False). """ if verbose > 0: print('Filtering {} with saved filter "{}"...'.format(self.label, filter_name)) # By default use all channels. if channels is None: channels = self.time_series.keys() if inplace: # We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries. ds_filtered = self else: # We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered # TimeSeries. # Create a copy of the EegDataset object. ds_filtered = copy.deepcopy(self) # Filter all EEG signals. for label in channels: ts_filtered = self.time_series[label].filter_saved_filter(filter_name=filter_name, verbose=1 if verbose > 1 else 0, **kwargs) ds_filtered.time_series[label] = ts_filtered # Only return if not in place filtering. if not inplace: return ds_filtered
[docs] def from_data_dict(self, data_dict, **kwargs): """ Create a dataset based on data in a dict. Args: data_dict = dict( data=data, # 2D array with shape (n_channels, n_samples) fs=fs, # sampling frequency (Hz) label=label, # Optional list of labels for each channel in data filepath=filepath, # Optional filepath where the data comes from t0=t0, # Optional time offset (time at t0) in seconds. **kwargs: for self.__init__(). ) """ # Check input dimensions. data = np.asarray(data_dict['data']) if data.ndim == 1: # Assume one channel, change to shape (n_time, n_channels). data = data.reshape(-1, 1) if data.ndim != 2: raise ValueError('`data` must have 2 dimensions. Got {} dimensions.' .format(data.ndim)) channel_labels = data_dict.get('label', None) if channel_labels is None: # Create default labels. channel_labels = ['Ch{}'.format(i+1) for i in range(min(data.shape))] if len(data) != len(channel_labels): # Try transpose. data = np.transpose(data) if len(data) != len(channel_labels): raise ValueError('`len(data) ({}) does not equal len(channel_labels) ({})' .format(len(data), len(channel_labels))) # Extract more information. fs = data_dict['fs'] time_offset = data_dict.get('t0', 0) source = data_dict.get('filepath', 'unknown') info = {'source': source} # Create dataset (collection of TimeSeries). opts = dict({'label': self.label}, **kwargs) ds = self.__class__(**opts) for signal, label in zip(data, channel_labels): ts = TimeSeries(signal=signal, fs=fs, label=label, time_offset=time_offset, info=info) ds.append(ts) return ds
[docs] def get_non_artefact_masks(self, segment_length, overlap=0, artefact_criteria=None): """ For each signal get a mask indicating non artefact segments. Args: segment_length (float): segment length. overlap (float, optional): overlap between segments. Defaults to 0. artefact_criteria (dict): dict with artefact criteria for determining whether each segment is an artefact. Returns: ds_out (BaseDataset): new dataset with non artefact masks for each signal. """ if artefact_criteria is None: artefact_criteria = dict() # Segment the datasets. ds_out = BaseDataset() for ts in self: seg = np.asarray(list(segment_generator(ts.signal, segment_length=segment_length, overlap=0, fs=ts.fs, axis=0))) mask = ~detect_artefact_signals(seg, axis=-1, demean=True, keepdims=False, **artefact_criteria) ds_out.append(TimeSeries(signal=mask, fs=1 / (segment_length - overlap), label=ts.label, check_label=False, time_offset=ts.time_offset, info=ts.info)) return ds_out
[docs] def interp(self, t, channels=None, **kwargs): """ Linearly interpolate values. Args: t (np.ndarray): time instances at which to evaluate the signals. channels (list): determines which and the order of the channels that are interpolated. If None, uses all available channels. **kwargs: for TimeSeries.interp(). Returns: data (np.ndarray): 2D array with shape (n_channels, len(t)) containing the interpolated data. """ if channels is None: # All channels. channels = list(self.time_series.keys()) # Loop over TimeSeries in Dataset. data = [] for ch in channels: yi = self[ch].interp(t=t, **kwargs) data.append(yi) data = np.array(data) return data
[docs] def interp_nan(self, *args, **kwargs): return self.interpolate_nan(*args, **kwargs)
[docs] def interpolate_nan(self, inplace=False, max_nan_length=None): """ Linearly interpolate nan values. Args: inplace (bool, optional): interpolate inplace (True) or return a copy (False). Defaults to False. max_nan_length (int or str, optional): number of maximum allowable consecutive nan samples that we interpolate (in seconds). If 'auto', the average duration of a constant value in the signal is taken. If None, no limit is applied, i.e. all nans are interpolated. Defaults to None. Returns: ds (BaseDataset): dataset with nans interpolated (if inplace is False). """ if inplace: # We will apply the changes to the current Dataset object. ds = self else: # We will apply the changes to a copy of the current Dataset object and return this changed copy. ds = copy.deepcopy(self) # Loop over TimeSeries in Dataset. for ts in ds: ts.interpolate_nan(inplace=True, max_nan_length=max_nan_length) if not inplace: return ds
[docs] def isempty(self): """ Check if dataset is empty. Returns: (bool): True if the Dataset is empty, False if not. """ return len(self.time_series) == 0
[docs] def is_synchronized(self, channels=None): """ Check if all the TimeSeries in the Dataset are synchronized. Args: channels (list, optional): list with labels of signals to check. If None, all signals are checked. Defaults to None. Returns: (bool): True if all TimeSeries are in sync, False if not. """ if channels is None: # All channels. channels = list(self.time_series.keys()) # Synchronized if all TimeSeries have same signal length, fs and time_offset. lengths_all = [] fs_all = [] time_offset_all = [] for label in channels: lengths_all.append(len(self.time_series[label].signal)) fs_all.append(self.time_series[label].fs) time_offset_all.append(self.time_series[label].time_offset) return len(set(lengths_all)) == 1 and len(set(fs_all)) == 1 and len(set(time_offset_all)) == 1
[docs] def line_length(self, preprocess=False, verbose=1, **kwargs): """ Compute line length feature. This is a wrapper that prepares the input for LineLength.line_length() and returns the result. Args: preprocess (bool, optional): apply a default preprocessing routine on the data before feature computation (True) or not (False). Defaults to False. verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): optional keyword arguments for the LineLength class. Returns: result (nnsa.LineLengthResult): the result of the line length computation. """ # Preprocess. if preprocess: # Default processing routine. fir_filt = get_eeg_fir_filter_b() notch_filt = NotchIIR() ds = self.filtfilt(fir_filt).filtfilt(notch_filt)#.resample( #fs_new=250, method='polyphase_filtering') else: # Use data as it is. ds = self if verbose > 0: print('Computing Line Length of {} with {} channels: {}' .format(ds.__class__.__name__, len(ds.time_series), [ts.label for ts in ds.time_series.values()])) # Initialize LineLength object (pass user specified keyword arguments). line_length = LineLength(**kwargs) # Prepare input matrix for line_length. data_matrix, channel_labels = ds.asarray(return_channel_labels=True) # Sample frequency of all signals is the same if self.asarray() did not raise an error. fs = next((ts.fs for ts in ds.time_series.values())) # Run line_length. result = line_length.line_length(data_matrix, fs=fs, channel_labels=channel_labels, verbose=verbose) # Add info string about data to result. ds._postprocess_result(result) return result
[docs] def logical_and(self, other): ds_out = copy.deepcopy(self) for key in ds_out.time_series: if isinstance(other, BaseDataset): x = other[key] else: x = other ds_out.time_series[key].signal = np.logical_and(ds_out[key], x) return ds_out
[docs] def logical_or(self, other): ds_out = copy.deepcopy(self) for key in ds_out.time_series: if isinstance(other, BaseDataset): x = other[key] else: x = other ds_out.time_series[key].signal = np.logical_or(ds_out[key], x) return ds_out
[docs] def merge(self, other, inplace=False): """ Merge other BaseDataset object(s) into one object. Args: other (BaseDataset or list): list with (compatible) BaseDataset objects. Must contain the same TimeSeries (with same labels and fs) as self. inplace (bool, optional): if True, merges the data inplace by adding it to the current object. If False, a new object is returned, leaving the original ones unchanged. Defaults to False. Returns: out (BaseDataset): new BaseDataset object containing the merged data (if inplace is False). """ if not isinstance(other, (list, tuple)): other = [other] if inplace: out = self else: out = copy.deepcopy(self) # Loop over Datasets in other. for item in other: # Check class type. if not isinstance(item, out.__class__): raise ValueError('Object of type "{}" cannot be merged with object of type "{}".' .format(type(item), type(out))) # Merge the timeseries in the datasets. for key, ts in out.time_series.items(): if key not in item.time_series: raise ValueError('Time series {} not in other dataset with time series {}.'.format(key, item.time_series.keys())) out.time_series[key] = out[key].merge(item[key], inplace=False) if not inplace: return out
[docs] def multi_scale_entropy(self, preprocess=False, verbose=1, **kwargs): """ Perform a multi-scale entropy (mse) analysis. This is a wrapper that prepares the input for MultiScaleEntropy.multi_scale_entropy() and returns the result. Args: preprocess (bool, optional): apply a default preprocessing routine on the data before feature computation (True) or not (False). Defaults to False. verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): optional keyword arguments to overrule default parameters of the MultiScaleEntropy class. Returns: result (nnsa.MultiScaleEntropyResult): the result of the mse analysis. """ # Preprocess. if preprocess: # Default processing routine. fir_filt = get_eeg_fir_filter_b() notch_filt = NotchIIR() ds = self.filtfilt(fir_filt).filtfilt(notch_filt).resample( fs_new=125, method='polyphase_filtering') else: # Use data as it is. ds = self if verbose > 0: print('Multi-scale entropy analysis of {} with {} channels: {}' .format(ds.__class__.__name__, len(ds.time_series), [ts.label for ts in ds.time_series.values()])) # Initialize MultiScaleEntropy object (updates default parameters with user specified keyword arguments). mse = MultiScaleEntropy(**kwargs) # Prepare input matrix for multi_scale_entropy. data_matrix, channel_labels = ds.asarray(return_channel_labels=True) # Sample frequency of all signals is the same if self.asarray() did not raise an error. fs = next((ts.fs for ts in ds.time_series.values())) # Run multi-scale entropy analysis. result = mse.multi_scale_entropy(data_matrix, fs=fs, channel_labels=channel_labels, verbose=verbose) # Add info string about data to result. ds._postprocess_result(result) return result
[docs] def multifractal_analysis(self, verbose=1, **kwargs): """ Perform a multifractal analysis (mfa). This is a wrapper that prepares the input for MultifractalAnalysis.multifractal_analysis() and returns the result. Args: verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): optional keyword arguments for the MultifractalAnalysis class. Returns: result (nnsa.MultifractalAnalysisResult): the result of the mfa analysis. """ if verbose > 0: print('Multifractal analysis of {} with {} channels: {}' .format(self.label, len(self.time_series), [ts.label for ts in self.time_series.values()])) # Initialize MultifractalAnalysis object (pass user specified keyword arguments). mfa = MultifractalAnalysis(**kwargs) # Prepare input matrix for multifractal_analysis. data_matrix, channel_labels = self.asarray(return_channel_labels=True) # Sample frequency of all signals is the same if self.asarray() did not raise an error. fs = next((ts.fs for ts in self.time_series.values())) # Run multifractal_analysis. result = mfa.multifractal_analysis(data_matrix, fs=fs, channel_labels=channel_labels, verbose=verbose) # Add info string about data to result. self._postprocess_result(result) return result
[docs] def normalize(self, how='zscore', channels=None, inplace=False, verbose=0, **kwargs): """ Normalize the data. Args: how (str): how to normalize. Choose from: - "zscore" channels (list, optional): which channels to normalize. If None, all channels are normalized. inplace (bool, optional): whether to normalize inplace (True) or not. verbose (int, optional): verbosity level. **kwargs (dict): for TimeSeries.normalize(). Returns: (BaseDataset): new Dataset object (only returned if inplace is False). """ if verbose > 0: print('Normalizing {} ({})...'.format(self.label, how)) # By default use all channels. if channels is None: channels = self.time_series.keys() if inplace: # We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries. ds_out = self else: # We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered # TimeSeries. # Create a copy of the Dataset object. ds_out = copy.deepcopy(self) # Normalize the signals. for label in channels: ts_out = self.time_series[label].normalize(how=how, **kwargs) ds_out.time_series[label] = ts_out # Only return if not inplace. if not inplace: return ds_out
[docs] def notch_filt(self, f0=50, channels=None, inplace=False, verbose=1, **kwargs): if verbose > 0: print('Notch filtering {} at {} Hz...'.format(self.label, f0)) # By default use all channels. if channels is None: channels = self.time_series.keys() if inplace: # We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries. ds_filtered = self else: # We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered # TimeSeries. # Create a copy of the Dataset object. ds_filtered = copy.deepcopy(self) # Filter the signals. for label in channels: ts_filtered = self.time_series[label].notch_filt(f0=f0, verbose=1 if verbose > 1 else 0, **kwargs) ds_filtered.time_series[label] = ts_filtered # Only return if not in place filtering. if not inplace: return ds_filtered
[docs] def plot(self, *args, begin=None, end=None, channels=None, scale=None, relative_time=False, time_scale='seconds', add_offsets=False, ticklabels=True, subplots=False, color=None, legend=True, ax=None, **kwargs): """ Plot the data of specified channels for a specifiec time frame. Plots in the current axis. Args: *args (optional): optional arguments for the plt.plot() function. begin (float, optional): begin time in seconds (incl. offset). If None, plots from the first sample. Defaults to None. end (float, optional): end time in seconds or None to specify the end of the entire signal. Relative to the beginning of the loaded signal ignores any offset. Defaults to None. channels (list, optional): list of labels specifying the signals to plot. If None, all signals in the Dataset are plotted. Defaults to None. relative_time (bool, optional): if True, the time axis is relative to the start of the segment to plot. If False, the time axis will correspond to the time in the recording. time_scale (str, optional): the time scale to use. Choose from 'seconds', 'minutes', 'hours'. Defaults to 'seconds'. add_offsets (bool, optional): if True, adds y offsets, based on signal SD, to each of the channels for easier visualization. If False, does not apply any offsets. If a float, adds this offset to each channel. ticklabels (bool, optional): if add_offsets is True, set ticklabels to True to put the signal labels as yticks. Set to False to keep numeric y-labels. subplots (bool, optional): if True, plots each signal in a separate subplot. If False, plots all signals in `ax``. Defaults to False. color (list or dict, optional): list or dictionary woth colors for each TimeSeries in the Dataset. legend (bool, optional): include legend (True) or not (False) if subplots is False. Defaults to True. ax (plt.Axes, optional): axes object to plot in. If None, plots in the current axes. If subplots is True, always plots in a new figure. Defaults to None. **kwargs (optional): optional keyword arguments for the plt.plot() function. Returns: ax (plt.axes): axes instance. """ # By default, plot all channels. if channels is None: channels = list(self.time_series.keys()) # Loop over channels and plot. if subplots: nrows, ncols = subplot_rows_columns(len(channels)) fig, axes = plt.subplots(nrows, ncols, tight_layout=True, sharex='all') axes = np.reshape([axes], -1) elif ax is None: # Current axes. ax = plt.gca() axes = [ax] else: axes = [ax] y_offset = 0 yticks = [] yticklabels = [] for i, label in enumerate(channels): if subplots: ax = axes[i] else: ax = axes[0] if isinstance(color, list): col = color[i] elif isinstance(color, dict): col = color[label] else: col = color ts_i = self.time_series[label] if add_offsets is not False: ts_i = ts_i + y_offset # Collect ticks. yticks.append(y_offset) yticklabels.append(ts_i.label) # Add offset for next channel. if isinstance(add_offsets, (int, float)): y_offset_add = add_offsets else: # Determine offset from SD. y_offset_add = np.std(ts_i)*3 y_offset -= y_offset_add ts_i.plot(*args, begin=begin, end=end, time_scale=time_scale, ax=ax, color=col, **kwargs) # Figure makeup. if not subplots: plt.xlabel('Time ({})'.format(time_scale)) plt.title('{}'.format(self.label)) if add_offsets and ticklabels: ax.set_yticks(yticks) ax.set_yticklabels(yticklabels) else: plt.ylabel('Signal (a.u.)') if legend: plt.legend(loc='upper right') return ax
[docs] def power_analysis(self, verbose=1, **kwargs): """ Perform a Fourier based power analysis. This is a wrapper that prepares the input for PowerAnalysis.power_analysis() and returns the result. Args: verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): optional keyword arguments to overrule default parameters of the PowerAnalysis class. Returns: result (nnsa.PowerAnalysisResult): the result of the power analysis. """ if verbose > 0: print('Power analysis of {} with {} channels: {}' .format(self.label, len(self.time_series), [ts.label for ts in self.time_series.values()])) # Initialize PowerAnalysis object (updates default parameters with user specified keyword arguments). power_analysis = PowerAnalysis(**kwargs) # Prepare input matrix. data_matrix, channel_labels = self.asarray(return_channel_labels=True) # Sample frequency of all signals is the same if self.asarray() did not raise an error. fs = next((ts.fs for ts in self.time_series.values())) # Assume the units are the same for all signals. unit = next((ts.unit for ts in self.time_series.values())) # Run power_analysis. result = power_analysis.power_analysis(data_matrix, fs=fs, channel_labels=channel_labels, unit=unit, verbose=verbose) # Add info string about data to result. self._postprocess_result(result) return result
[docs] def nan_to_num(self, channels=None, inplace=False, **kwargs): """ Apply numpy's nan_to_num function to each timeseries (e.g. to replace nans by zeros). Args: channels (list): see self.transform. inplace (bool): see self.transform. **kwargs (optional): kwargs for self.transform. """ return self.transform(fun=lambda x: np.nan_to_num(x, **kwargs), channels=channels, inplace=inplace)
[docs] def remove(self, channel, inplace=False, verbose=1): """ Remove a signal from the dataset. Args: channel (str): label of the channel to remove. inplace (bool, optional): If True, the channel is removed directly from the EegDataset object itself. If False, the function returns a copy of the original EegDataset object in which the specified channel is removed, leaving the data in the current EegDataset unchanged. Default is False. verbose (int, optional): verbose level. Defaults to 1. Returns: (BaseDataset): new Dataset object containing the same signals, except for the removed channel (only if inplace is False). """ # Check channel. label = self._check_label(channel) if inplace: # We will remove the channel from the current Dataset object. ds = self else: # We will remove the channel from a copy of the current Dataset object and return this copy. # Create a copy of the Dataset object. ds = copy.deepcopy(self) del ds.time_series[label] if verbose > 0: print('Removed "{}" from {}.'.format(label, self.label)) # Only return if not in place removing. if not inplace: return ds
[docs] def remove_artefact_channels(self, inplace=False, **kwargs): """ Remove channels that do not meet the quality criteria. Args: inplace (bool, optional): If True, the channel is removed directly from the Dataset object itself. If False, the function returns a copy of the original Dataset object in which the artefacted channels are removed, leaving the data in the current Dataset unchanged. Default is False. **kwargs (optional): optional keyword arguments for overruling default signal quality criteria (see nnsa.artefacts.artefact_detection.detect_artefact_signals()). Returns: (EegDataset): new Dataset object containing the same signals, except for the removed channels (only if inplace is False). """ if inplace: # We will remove the to be excluded channels from the current Dataset object. ds = self else: # We will remove the to be excluded channels from a copy of the current Dataset object and return this # copy. ds = copy.deepcopy(self) # Loop over channels. channels = list(ds.time_series.keys()) for label in channels: # Extract channel. ts = ds.time_series[label] # Check if the signal quality meets the signal criteria. exclude = detect_artefact_signals(ts.signal, **kwargs) if exclude: # Remove the channel. ds.remove(label, inplace=True) # Only return if not in place removing. if not inplace: return ds
[docs] def remove_artefacts(self, inplace=False, **kwargs): """ Replace samples that are artefacts by np.nan Args: inplace (bool, optional): If True, the samples are replaced directly from the Dataset object itself. If False, the function returns a copy of the original Dataset object in which the artefacted samples are replaced, leaving the data in the current Dataset unchanged. Default is False. **kwargs (optional): optional keyword arguments for overruling default sample quality criteria (see nnsa.artefacts.artefact_detection.detect_artefact_samples()). Returns: (BaseDataset): new Dataset object containing the same signals, but with artefacted samples changed to np.nan (only returned if inplace is False). """ if inplace: # We will apply the changes to the current Dataset object. ds = self else: # We will apply the changes to a copy of the current Dataset object and return this changed copy. ds = copy.deepcopy(self) # Loop over channels. channels = list(ds.time_series.keys()) for label in channels: # Extract channel. ts = ds.time_series[label] # Remove artefacts from the ts inplace (copy has been made already if requested). ts.remove_artefacts(inplace=True, **kwargs) # Only return if not in place removing. if not inplace: return ds
[docs] def remove_flatlines(self, inplace=False, verbose=1, **kwargs): """ Replace flatline samples by np.nan Args: inplace (bool, optional): If True, the samples are replaced directly from the Dataset object itself. If False, the function returns a copy of the original Dataset object in which the flatline samples are replaced, leaving the data in the current Dataset unchanged. Default is False. verbose (int): verbosity level. **kwargs (optional): optional keyword arguments for TimeSeries.remove_flatlines(). Returns: (BaseDataset): new Dataset object containing the same signals, but with flatline samples changed to np.nan (only returned if inplace is False). """ if verbose > 0: print('Removing flatlines in {}...'.format(self.label)) if inplace: # We will replace the original TimeSeries objects in the current EegDataset by the new TimeSeries. ds_out = self else: # We will replace the original TimeSeries objects in A COPY OF the current EegDataset by the new # TimeSeries. # Create a copy of the EegDataset object. ds_out = copy.deepcopy(self) for ts in ds_out: ts.remove_flatlines(inplace=True, **kwargs) # Only return if not in place referencing. if not inplace: return ds_out
[docs] def remove_neighborhood_artefacts(self, *args, inplace=False, verbose=1, **kwargs): """ Replace values by np.nan if samples in their neighborhood are nan. Args: *args: arguments for TimeSeries.remove_neighborhood_artefacts(). inplace (bool, optional): If True, the samples are replaced directly from the Dataset object itself. If False, the function returns a copy of the original Dataset object in which the samples are replaced, leaving the data in the current Dataset unchanged. Default is False. verbose (int): verbosity level. **kwargs (optional): optional keyword arguments for TimeSeries.remove_neighborhood_artefacts(). Returns: (BaseDataset): new Dataset object containing the same signals, but with specific samples changed to np.nan (only returned if inplace is False). """ if verbose > 0: print('Removing neighborhood artefacts in {}...'.format(self.label)) if inplace: # We will replace the original TimeSeries objects in the current EegDataset by the new TimeSeries. ds_out = self else: # We will replace the original TimeSeries objects in A COPY OF the current EegDataset by the new # TimeSeries. # Create a copy of the EegDataset object. ds_out = copy.deepcopy(self) for ts in ds_out: ts.remove_neighborhood_artefacts(*args, inplace=True, **kwargs) # Only return if not in place referencing. if not inplace: return ds_out
[docs] def remove_outliers(self, *args, inplace=False, verbose=1, **kwargs): """ Replace outlier samples by np.nan Args: *args: arguments for TimeSeries.remove_outliers(). inplace (bool, optional): If True, the samples are replaced directly from the Dataset object itself. If False, the function returns a copy of the original Dataset object in which the samples are replaced, leaving the data in the current Dataset unchanged. Default is False. verbose (int): verbosity level. **kwargs (optional): optional keyword arguments for TimeSeries.remove_outliers(). Returns: (BaseDataset): new Dataset object containing the same signals, but with anomaly samples changed to np.nan (only returned if inplace is False). """ if verbose > 0: print('Removing outliers in {}...'.format(self.label)) if inplace: # We will replace the original TimeSeries objects in the current EegDataset by the new TimeSeries. ds_out = self else: # We will replace the original TimeSeries objects in A COPY OF the current EegDataset by the new # TimeSeries. # Create a copy of the EegDataset object. ds_out = copy.deepcopy(self) for ts in ds_out: ts.remove_outliers(*args, inplace=True, **kwargs) # Only return if not in place referencing. if not inplace: return ds_out
[docs] def remove_values(self, *args, inplace=False, verbose=1, **kwargs): """ Replace values by np.nan Args: *args: arguments for TimeSeries.remove_values(). inplace (bool, optional): If True, the samples are replaced directly from the Dataset object itself. If False, the function returns a copy of the original Dataset object in which the samples are replaced, leaving the data in the current Dataset unchanged. Default is False. verbose (int): verbosity level. **kwargs (optional): optional keyword arguments for TimeSeries.remove_values(). Returns: (BaseDataset): new Dataset object containing the same signals, but with specific samples changed to np.nan (only returned if inplace is False). """ if verbose > 0: print('Removing values in {}...'.format(self.label)) if inplace: # We will replace the original TimeSeries objects in the current EegDataset by the new TimeSeries. ds_out = self else: # We will replace the original TimeSeries objects in A COPY OF the current EegDataset by the new # TimeSeries. # Create a copy of the EegDataset object. ds_out = copy.deepcopy(self) for ts in ds_out: ts.remove_values(*args, inplace=True, **kwargs) # Only return if not in place referencing. if not inplace: return ds_out
[docs] def resample(self, fs_new, method='polyphase_filtering', channels=None, inplace=False, verbose=1, **kwargs): """ Resample the signals in this Dataset to new sampling frequency fs_new. Args: fs_new (float): new sampling frequency to resample to. method (str, optional): see nnsa.TimeSeries.resample() channels (list, optional): list with labels of the signals to be resampled. If None, all signals in the Dataset are resampled. Defaults to None. inplace (bool, optional): if True, the resampled signals will replace the original signals in this Dataset. If False (default), the function returns a new Dataset object containing the resampled signals, leaving the data in the current Dataset unchanged. Note that if only certain channels are resampled, the newly created Dataset object will only contain those resampled channels. verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): see nnsa.TimeSeries.resample() Returns: (BaseDataset): new Dataset object containing the resampled signals (only if inplace is False). """ if verbose > 0: print('Resampling {} using {}...'.format(self.label, method)) # By default use all channels. if channels is None: channels = self.time_series.keys() if inplace: # We will replace the original TimeSeries objects in the current Dataset by the resampled TimeSeries. ds_resampled = self else: # We will replace the original TimeSeries objects in A COPY OF the current Dataset by the resampled # TimeSeries. # Create a copy of the Dataset object. ds_resampled = copy.deepcopy(self) # Resample the signals. for label in channels: ts_resampled = self.time_series[label].resample( fs_new, method=method, verbose=1 if verbose > 1 else 0, **kwargs) ds_resampled.time_series[label] = ts_resampled # Only return if not in place resampling. if not inplace: return ds_resampled
[docs] @staticmethod def read_hdf5(filepath, ds=None, begin=None, end=None, **kwargs): """ Read the Dataset object to a .hdf5 file. Args: filepath (str): filepath to read. begin (float, optional): start second (wrt time offset). end (float, optional): end second (wrt time offset). ds (optional): Dataset object to append the data to. If None, creates a BaseDataset object. **kwargs (optional): keyword arguments for TimeSeries(). Returns: ds (nnsa.BaseDataset): Dataset object with data read from the file. Examples: >>> signal = np.random.rand(8, 1000) >>> ds = BaseDataset() >>> ts_all = [TimeSeries(signal=signal[i], fs=10, label=i) for i in range(len(signal))] >>> ds.append(ts_all) >>> ds.save_hdf5('testfile.hdf5') >>> ds_read = BaseDataset.read_hdf5('testfile.hdf5') >>> ds_read_epoch = BaseDataset.read_hdf5('testfile.hdf5', begin=10, end=20) >>> os.remove('testfile.hdf5') >>> _ = [assert_equal(ts.signal, ts_read.signal) for ts, ts_read in zip (ds, ds_read)] >>> _ = [assert_equal(ts.signal, ts_read.signal) for ts, ts_read in zip (ds.extract_epoch(begin=10, end=20), ds_read_epoch)] """ if ds is None: ds = BaseDataset() with h5py.File(filepath, 'r') as f: # Loop over datasets in HDF5 and if it is a TimeSeries, append to the Dataset. for lab in f: is_time_series = False sig = f[lab] if 'type' in sig.attrs: if sig.attrs['type'].decode() == 'TimeSeries': # Found time series. is_time_series = True elif 'fs' in sig.attrs: # Found time series. is_time_series = True if is_time_series: # Read file. ts = TimeSeries.read_hdf5(f, label=lab, begin=begin, end=end) ds.append(ts) return ds
[docs] def save_hdf5(self, filepath, overwrite=False): """ Save the Dataset object to a .hdf5 file. Args: filepath (str): filepath to save to. overwrite (bool, optional): if True, overwrites existing files. If False, raises an error when `filepath` already exists and mode is 'w'. Defaults to False. """ if not overwrite: # Check if filepath already exists. if os.path.exists(filepath): raise ValueError('File "{}" already exists. Overwriting can be enabled by setting overwrite=True.') check_directory_exists(filepath=filepath) # Create an empty HDF5 file. with h5py.File(filepath, mode='w') as f: pass # Save each timeseries to the hdf5. for ts in self: ts.save_hdf5(filepath, mode='a', overwrite=True)
[docs] def segment(self, segment_length, overlap=0): """ Return a segment generator that segments the signals in the dataset (into smaller time segments). Only works is all signals have the same sample frequency. Args: segment_length (float): length of a segment in seconds. overlap (float, optional): overlap between succesive segments in seconds. Defaults to 0. Yields: (BaseDataset): Dataset object holding the signals of the next segment. """ # Get frequencies. fs_all = [ts.fs for ts in self.time_series.values()] # Verify all sample frequencies are the same. if len(set(fs_all)) != 1: raise ValueError('Cannot segment {} if it contains signals of different frequencies: {}.' .format(self.__class__.__name__, fs_all)) # Create segment generators for all channels in the Dataset. seg_generators = [ts.segment(segment_length, overlap=overlap) for ts in self.time_series.values()] # Start infinite loop. while True: # Initialize Dataset. ds = self.__class__() # Loop over channels. for gen in seg_generators: # Put the next segment in the Dataset for the segment. ds.append(next(gen)) # Yield Dataset. yield ds
[docs] def signal_stats(self, verbose=1, **kwargs): """ Compute signal statistics. This is a wrapper that prepares the input for SignalStats.signal_stats() and returns the result. Args: verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): optional keyword arguments to overrule default parameters of the SignalStats class. Returns: result (nnsa.SignalStatsResult): object containing the signal stats. """ if verbose > 0: print('Computing signal statistics of {} with {} channels: {}' .format(self.label, len(self.time_series), [ts.label for ts in self.time_series.values()])) # Initialize feature extraction object (updates default parameters with user specified keyword arguments). ss = SignalStats(**kwargs) # Prepare input matrix for signal stats. data_matrix, channel_labels = self.asarray(return_channel_labels=True) # Sample frequency of all signals is the same if self.asarray() did not raise an error. fs = next((ts.fs for ts in self.time_series.values())) # Run signal stats. result = ss.signal_stats(data_matrix, fs=fs, channel_labels=channel_labels, verbose=verbose) # Add info string about data to result and update the time vectors if needed. self._postprocess_result(result) return result
[docs] def squeeze(self): """ Returns the TimeSeries if only one TimeSeries in the Dataset. Otherwise returns the Dataset itself. """ if len(self) == 1: return list(self.time_series.values())[0] else: return self
[docs] def stepwise_moving_average(self, *args, avg_fun=np.nanmean, **kwargs): return self.stepwise_reduce(*args, reduce_fun=avg_fun, **kwargs)
[docs] def stepwise_reduce(self, *args, channels=None, inplace=False, **kwargs): """ Compute stepwise values for the signals in this Dataset. Args: *args: inputs for nnsa.TimeSeries.stepwise_reduce() channels (list, optional): list with labels of the signals to apply to. If None, all signals in the Dataset are processed. Defaults to None. inplace (bool, optional): if True, the new signals will replace the original signals in this Dataset. If False (default), the function returns a new Dataset object containing the new signals, leaving the data in the current Dataset unchanged. Note that if only certain channels are processed, the newly created Dataset object will only contain those processed channels. **kwargs (optional): keyword arguments for nnsa.TimeSeries.stepwise_reduce() Returns: (BaseDataset): new Dataset object containing the new signals (only if inplace is False). """ # By default use all channels. if channels is None: channels = self.time_series.keys() if inplace: # We will replace the original TimeSeries objects in the current Dataset by the new TimeSeries. ds_new = self else: # We will replace the original TimeSeries objects in A COPY OF the current Dataset by the new # TimeSeries. # Create a copy of the Dataset object. ds_new = copy.deepcopy(self) # Compute the signals. for label in channels: ds_new.time_series[label].stepwise_reduce( *args, inplace=True, **kwargs) # Only return if not in place. if not inplace: return ds_new
[docs] def to_time_series(self, make_copy=False): """ If the Dataset contains only one TimeSeries, return this. Otherwise raise a ValueError. Args: make_copy (bool, optional): if True, the returned TimeSeries is a copy of the TimeSeries in self. If False, no copy is made. Returns: ts (TimeSeries): nnsa TimeSeries object. """ if len(self) == 1: ts = self.time_series[self.channel_labels[0]] if make_copy: ts = copy.deepcopy(ts) return ts else: raise ValueError('Cannot convert {} to TimeSeries, because there are {} TimeSeries. ' 'to_time_series only works is there is only a single TimeSeries.'.format( self.label, len(self)))
[docs] def transform(self, fun, channels=None, inplace=False): """ Apply a custom function to each signal in the dataset. Args: fun (function): function that transform the data. Takes in a 1D array and returns a 1D array. channels (list, optional): list with labels of the signals to transform. If None, all signals in the Dataset are transformed. Defaults to None. inplace (bool): transform the data inplace (True) or return a new dataset instance (False). Returns: ds (BaseDataset): dataset with the transformed data (if inplace is False). """ if inplace: # We will apply the changes to the current Dataset object. ds = self else: # We will apply the changes to a copy of the current Dataset object and return this changed copy. ds = copy.deepcopy(self) if channels is None: channels = list(ds.time_series.keys()) # Loop over channels. for label in channels: # Transform channel. ds.time_series[label].transform(fun, inplace=True) # Only return if not in place removing. if not inplace: return ds
def _check_label(self, label): """ Return the label or raise an error if the given label is not in the current dataset. Args: label (str): label. Returns: label (str): label if the given label was a valid label. Raises: ValueError if the label is not in the dataset. """ # Raise error if the specified channel is not in the dataset. if label not in self.time_series: raise ValueError('Signal "{}" not in {}. Signals in dataset: {}.' .format(label, self.label, list(self.time_series.keys()))) return label def _postprocess_result(self, result, channels=None): """ Postprocess the result class common to all feature extractions. Add info string about data to result. Update the segment time vectors of the result if time_offset of the TimeSeries was not 0. Args: result (nnsa.ResultBase-derived): ResultBase-derived object. channels (list, optional): list with labels of the signals of which to return the info. If None, all signals in the Dataset are used. Defaults to None. """ # By default use all channels. if channels is None: channels = self.time_series.keys() # Loop over channels. time_offset_all = [] all_info = [] for i, label in enumerate(channels): ts = self[label] time_offset_all.append(ts.time_offset) str_info = str(ts.info) if i == 0: # Return a string version of this info object. all_info.append(str_info) else: if str_info not in all_info: # Only add if new info. all_info.append(str_info) # Merge all data info strings into one. result.data_info = ';\n '.join(all_info) # Time offset. if len(set(time_offset_all)) == 1: time_offset = time_offset_all[0] else: raise ValueError('TimeSeries in {} not synchronized.'.format(self.label)) result.time_offset = time_offset
[docs]class EegDataset(BaseDataset): """ High-level interface for processing (multichannel) EEG for neonatal signal analysis. """ def __getitem__(self, label): """ Evaluate self[label]. Makes it possible to extract the time series using self['Cz'], even when the proper key is 'EEG Cz'. """ if label not in self.time_series: label = self._check_label(label) return self.time_series[label] @property def signal(self): """ Return the EEG data as a numpy array (channels, time). Returns: (np.ndarray): EEG data matrix. """ return self.asarray()
[docs] def amplitude_eeg(self, channel_1='EEG C3', channel_2='EEG C4', verbose=1, **kwargs): """ Compute amplitude-integrated EEG (aEEG) from single bipolar channel continuous EEG. Use on raw EEG data. Args: channel_1 (str, optional): the label of the first channel. Defaults to 'EEG C3'. channel_2 (str, optional): the label of the second channel (will be subtracted from the first channel). Defaults to 'EEG C4'. verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): optional keyword arguments for the AmplitudeEeg class. Returns: result (nnsa.AmplitudeEegResult): object containing the aEEG. """ # Create bipolar channel. ts = self.create_bipolar_channel(channel_1, channel_2) # Emulate aEEG. return ts.amplitude_eeg(verbose=verbose, **kwargs)
[docs] def as_brain_rt(self, bp_filter=True): """ Preprocess the data to make it similar to the view as in BrainRT. Creates bipolar channels and filter. Args: bp_filter (bool): if True, filter the data as in Brain RT (with a first order bandpass Butterworth). Returns: eeg_ds (nnsa.EegDataset): preprocessed dataset (new object). """ # Bipolar channels. channels_1 = ['Fp2', 'C4', 'Fp2', 'T4', 'Fp1', 'C3', 'Fp1', 'T3', 'T4', 'C4', 'Cz', 'C3'][::-1] channels_2 = ['C4', 'O2', 'T4', 'O2', 'C3', 'O1', 'T3', 'O1', 'C4', 'Cz', 'C3', 'T3'][::-1] # In BrainRT, apparently, the derivations are the other way around (e.g., C3-C4 is actually C4-C3). eeg_ds = self.create_bipolar_channels( channels_1=channels_2, channels_2=channels_1) if bp_filter: # Filter. In BrainRT, the filter is a first order butterworth. butter = Butterworth(fn=[0.27, 30], order=1) eeg_ds.filter(butter, inplace=True) return eeg_ds
[docs] def brain_age_sinc(self, verbose=1, **kwargs): """ Compute the brain age from the EEG using the deep Sinc network. Use on raw EEG data. Args: verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): optional keyword arguments for the BrainAgeSinc class. Returns: result (BrainAgeResult): object containing the brain age. """ if verbose > 1: print('Computing brain age of {} with {} channels: {}' .format(self.label, len(self.time_series), [ts.label for ts in self.time_series.values()])) eeg_ds = self # Initialize feature extraction object (updates default parameters with user specified keyword arguments). brain_age_sinc = BrainAgeSinc(verbose=verbose > 1, **kwargs) data_requirements = brain_age_sinc.data_requirements # Reference if reference channel in EEG data (otherwise assume it is already referenced correctly). reference_channel = data_requirements['reference_channel'] if reference_channel in eeg_ds: eeg_ds = eeg_ds.reference(reference_channel, verbose=verbose) # The BrainAgeSInc expect the data in a specific order, which we can access here. channel_order = brain_age_sinc.data_requirements['channel_order'] # Extract EEG data as array with required channel order. eeg, ch_labs = eeg_ds.asarray( channels=channel_order, return_channel_labels=True, channels_last=True) channel_labels = [c.replace('EEG', '').strip() for c in ch_labs] # Remove EEG from label. fs = eeg_ds.fs # Process. result = brain_age_sinc.process(eeg, fs, batch_size=1000, axis=0, verbose=verbose) # Add info string about data to result. eeg_ds._postprocess_result(result) return result
[docs] def brain_age_rf(self, verbose=1, **kwargs): """ Compute the brain age from the EEG using the random forest approach. Use on raw EEG data. Args: verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): optional keyword arguments for the BrainAge class. Returns: result (BrainAgeResult): object containing the brain age. """ raise DeprecationWarning if verbose > 0: print('Computing brain age of {} with {} channels: {}' .format(self.label, len(self.time_series), [ts.label for ts in self.time_series.values()])) # Initialize feature extraction object (updates default parameters with user specified keyword arguments). brain_age = BrainAge(**kwargs) # Prepare input matrix for feature extraction function. data_matrix, channel_labels = self.asarray(return_channel_labels=True) # Sample frequency of all EEG signals is the same if self.asarray() did not raise an error. fs = next((ts.fs for ts in self.time_series.values())) # Process. result = brain_age.process(data_matrix, fs=fs, channel_labels=channel_labels, verbose=verbose) # Add info string about data to result. self._postprocess_result(result) return result
[docs] def burst_detection(self, create_bipolar_channels=True, verbose=1, **kwargs): """ Perform burst detection on all channels in the dataset. This is a wrapper that prepares the input for BurstDetection.burst_detection() and returns the result. Args: create_bipolar_channels (bool, optional): if True, automatically create bipolar channels according to the method used. If False, run the algorithm on the channels as currently in the object. Defaults to True. verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): optional keyword arguments for the BurstDetection class. Returns: result (nnsa.BurstDetectionResult): the result of the burst detection. """ # Initialize BurstDetection object (pass user specified keyword arguments). bd = BurstDetection(**kwargs) # Create bipolar data if requested. if create_bipolar_channels: if bd.parameters['method'].lower() == 'otoole': # Create bipolar channels as required by O'Toole. ds = self.create_bipolar_channels(['C4', 'C3', 'T4', 'C4', 'Cz', 'C3'], ['O2', 'O1', 'C4', 'Cz', 'C3', 'T3']) elif bd.parameters['method'].lower() == 'nleo': # Create bipolar channels that are similar to P3-P4 (as used by Palmu et a. in optimization of # NLEO algorithm). ds = self.create_bipolar_channels(['O1', 'C3', 'T3'], ['O2', 'C4', 'T4']) else: ds = self else: ds = self # Checks. for ts in ds.time_series.values(): if any('Filtered' in log for log in ts.info['processing']): msg = "\nWARNING: Signal(s) seems to be filtered: ts.info['processing'] = {}.\n" \ "Verify that the burst detection method works with filtered data.".format(ts.info['processing']) warnings.warn(msg) if verbose > 0: print('Performing burst detection on {} with {} channels: {}' .format(ds.__class__.__name__, len(ds.time_series), [ts.label for ts in ds.time_series.values()])) # Prepare input matrix for burst_detection. data_matrix, channel_labels = ds.asarray(return_channel_labels=True) # Sample frequency of all EEG signals is the same if self.asarray() did not raise an error. fs = ds.fs # Run burst_detection. result = bd.burst_detection(data_matrix, fs=fs, channel_labels=channel_labels, verbose=verbose) # Add info string about data to result. ds._postprocess_result(result) return result
[docs] def clamp(self, threshold=250, channels=None, inplace=False, verbose=1): """ Clamp EEG values by applying a clamping function to reduce dynamic range. Args: threshold (float): threshold for the clamping function. inplace (bool, optional): whether to apply inplace (True) or not. verbose (int, optional): verbosity level. Returns: ds_out (EegDataset): new Dataset object (only returned if inplace is False). """ if verbose > 0: print('Clamping values in {} at {}...'.format(self.label, threshold)) # By default use all channels. if channels is None: channels = self.time_series.keys() if inplace: # We will replace the original TimeSeries objects in the current Dataset by the filtered TimeSeries. ds_out = self else: # We will replace the original TimeSeries objects in A COPY OF the current Dataset by the filtered # TimeSeries. # Create a copy of the Dataset object. ds_out = copy.deepcopy(self) # Filter the signals. for label in channels: ts_out = self.time_series[label].clamp(threshold=threshold) ds_out.time_series[label] = ts_out # Only return if not in place. if not inplace: return ds_out
[docs] def create_bipolar_channel(self, channel_1, channel_2, label=None): """ Create a bipolar channel by subtracting channel_2 from channel_1, i.e. bipolar channel = channel_1 - channel_2. Args: channel_1 (str): the label of the first channel. channel_2 (str): the label of the second channel (will be subtracted from the first channel). label (str, optional): label for the new bipolar channel. If None, a label will automatically be created from channel_1 and channel_2. Defaults to None. Returns: ts_bipolar (nnsa.TimeSeries): TimeSeries object holding the bipolar channel data. """ if label is None: label = '{}-{}'.format(channel_1, channel_2) # Subtract the signals. ts_1 = self[channel_1] ts_2 = self[channel_2] ts_bipolar = ts_1 - ts_2 # Set the label. ts_bipolar.label = label return ts_bipolar
[docs] def create_bipolar_channels(self, channels_1, channels_2, labels=None, missing_mode='error'): """ Create a set of bipolar channels by subtracting channels_2 from channels_1. channels_new[i] = channels_1[i] - channels_2[i] Args: channels_1 (list): the labels of the first channels. channels_2 (list): the labels of the second channels (will be subtracted from the first channel). Must have the same length as channels_1. labels (str, optional): label for the new bipolar channels. Must have the same length as channels_1. If None, labels will automatically be created from channels_1 and channels_2. Defaults to None. missing_mode (str, optional): what to do if a specified channel is missing from the dataset. Choose from: - 'error': raise a ValueError if a channel is specified that is not in the EEG dataset. - 'ignore': ignore the requested bipolar channels that could not be created. - 'warn': ignore, but display a warning. Returns: ds_bipolar (nnsa.EegDataset): EegDataset object holding the bipolar channels data. """ if not isinstance(channels_1, list): channels_1 = [channels_1] if not isinstance(channels_2, list): channels_2 = [channels_2] if len(channels_1) != len(channels_2): raise ValueError('channels_1 and channels_2 must have same length. Got lengths {} and {}.' .format(len(channels_1), len(channels_2))) if labels is not None: if len(channels_1) != len(labels): raise ValueError('channels_1 and labels must have same length. Got lengths {} and {}.' .format(len(channels_1), len(labels))) else: labels = [None] * len(channels_1) # Initialize an empty EegDataset. ds_bipolar = EegDataset() # Compute bipolar channels from each pair in channels_1 and channels_2. for i in range(len(channels_1)): channel_1_i = channels_1[i] channel_2_i = channels_2[i] label_i = labels[i] # Create the bipolar channel (TimeSeries object). try: ts_i = self.create_bipolar_channel(channel_1_i, channel_2_i, label=label_i) except (KeyError, ValueError) as e: if missing_mode == 'error': raise e else: if missing_mode == 'warn': msg = 'Could not create bipolar channel {}-{}.'.format(channel_1_i, channel_2_i) warnings.warn(msg) continue # Add the TimeSeries object to the EegDataset. ds_bipolar.append(ts_i) return ds_bipolar
[docs] def coherence_graph(self, preprocess=False, verbose=1, **kwargs): """ Compute a connectivity graph, where the weights of the edges are based on the coherence between channels. This is a wrapper that prepares the input for CoherenceGraph.coherence_graph() and returns the result. Args: preprocess (bool, optional): apply a default preprocessing routine on the data before feature computation (True) or not (False). Defaults to False. verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): optional keyword arguments to overrule default parameters of the CoherenceGraph class. Returns: result (nnsa.CoherenceGraphResult): an object containing the resulting adjacency matrices of the graphs. """ # Preprocess. if preprocess: # Default processing routine. fir_filt = get_eeg_fir_filter_a() notch_filt = NotchIIR() ds = self.filtfilt(fir_filt).filtfilt(notch_filt).resample( fs_new=128, method='polyphase_filtering') else: # Use data as it is. ds = self if verbose > 0: print('Computing coherence graph of {} with {} channels: {}' .format(ds.__class__.__name__, len(ds.time_series), [ts.label for ts in ds.time_series.values()])) # Initialize PowerAnalysis object (updates default parameters with user specified keyword arguments). coherence_graph = CoherenceGraph(**kwargs) # Prepare input matrix. data_matrix, channel_labels = ds.asarray(return_channel_labels=True) # Sample frequency of all EEG signals is the same if self.asarray() did not raise an error. fs = next((ts.fs for ts in ds.time_series.values())) # Run coherence_graph. result = coherence_graph.coherence_graph(data_matrix, fs=fs, channel_labels=channel_labels, verbose=verbose) # Add info string about data to result. ds._postprocess_result(result) return result
[docs] def compute_features(self, verbose=1, chunk_size=None, **kwargs): """ Compute EEG feature set. Args: kwargs: for EegFeatures(). verbose (int): verbosity level. chunk_size (int): number of time samples to process at once. If None, a suitable chunk_size is chosen automatically. Returns: result (FeatureSetResult): feature set result. """ if verbose > 0: print('Computing features of {} with {} channels: {}' .format(self.__class__.__name__, len(self.time_series), [ts.label for ts in self.time_series.values()])) # Initialize feature extraction object (updates default parameters with user specified keyword arguments). cf = EegFeatures(**kwargs) # Prepare input matrix for feature extraction function. eeg, channel_labels = self.asarray(return_channel_labels=True, channels_last=False) # Sample frequency. fs = self.fs # Process. result = cf.process(eeg=eeg, fs=fs, channel_labels=channel_labels, axis=-1, chunk_size=chunk_size, verbose=verbose) # Add info string about data to result. self._postprocess_result(result) return result
[docs] def detect_artefacts(self, multi_channel_cnn=True, verbose=1): """ Helper function to quickly detect artefacts using a combination of methods (see code). Apply to raw EEG. Args: multi_channel_cnn (bool): whether to use the multi-channel CNN (True) or not (False). verbose (int): verbosity level. Returns: af_ds (EegDataset): EegDataset with artefact masks consisting of boolean values: True at locations of artefacts, False otherwise. """ # Select channels and reference. channels = ['Fp1', 'Fp2', 'C3', 'C4', 'T3', 'T4', 'O1', 'O2'] if 'Cz' in self: eeg_ds = self.reference('Cz', inplace=False, verbose=verbose) else: eeg_ds = self eeg_ds = eeg_ds.extract_channels(channels=channels, make_copy=False) # Preprocess for artefact detection. eeg_ds_pp = eeg_ds. \ filtfilt(NotchIIR(f0=50), verbose=verbose). \ filter(Butterworth(fn=[0.27, 30], order=1), verbose=verbose). \ resample(fs_new=128, method='interpolation', verbose=verbose) # Detect artefacts using CNN. Note that the preprocessing above is specific to # this method. af_ds = eeg_ds_pp.detect_artefacts_cnn( multi_channel=multi_channel_cnn, detect_flats=True, detect_peaks=True, preprocess=False, verbose=verbose*2) # Detect anomalies. anomaly_ds = eeg_ds_pp.detect_anomalous_channels( window=3, std_factor=8, p_trim=0.25, verbose=verbose) # Combine artefact detection methods. af_ds = af_ds.logical_or(anomaly_ds) # Interpolate/upsample artefact mask to original size. af_mask = af_ds.astype(float).interp( t=eeg_ds.time, channels=eeg_ds.channel_labels, left=1, right=1) > 0.5 # Create a copy of the EegDataset object and fill it with the boolean masks. af_ds = copy.deepcopy(eeg_ds) for ts, mask in zip(af_ds, af_mask): assert len(ts) == len(mask) ts.parameters.update(dtype=bool) ts.signal = mask return af_ds
[docs] def detect_artefacts_amplitude_kaupilla(self, notch_filter=True, bp_filter=True): """ Detect locations of artefacts by settings thresholds on the amplitude as proposed by Kaupilla et al. 2018. To apply to raw EEG data: set `notch_filter` and `bp_filter` to True. Args: notch_filter (bool): specify whether the EEG needs to be filtered. Set to True if `eeg` is raw data. bp_filter (bool): specify whether the EEG needs to be filtered. Set to True if `eeg` is raw data. Returns: af_ds (EegDataset): EegDataset with artefact masks consisting of boolean values: True at locations of artefacts, False otherwise. """ # Create a copy of the EegDataset object. eeg_ds = copy.deepcopy(self) if notch_filter: # Filter powerline from EEG. eeg_ds.filtfilt(NotchIIR(f0=50), inplace=True) if bp_filter: # Filter bandpass. eeg_ds.filtfilt(get_filter('1-40'), inplace=True) amp_kwargs = dict( threshold_low=0.5, threshold_high=200, window=1 ) return eeg_ds.detect_artefacts_amplitude(**amp_kwargs)
[docs] def detect_artefacts_amplitude( self, threshold_high, threshold_low=0, window=1, how='max'): """ Detect locations of artefacts by settings thresholds on the running mean/maximum of the absolute amplitude. Returns a boolean mask where True values indicate locations where x < `threshold_low` or x > `threshold_high`, where x is the moving average/maximum of the absolute signal values in a window of `window` seconds. Args: threshold_high (float): threshold in uV. threshold_low (float): threshold in uV. window (float): window for the moving average in seconds. how (str): whether to take the moving 'mean' or 'max'. Returns: af_ds (EegDataset): EegDataset with artefact masks consisting of boolean values: True at locations of artefacts, False otherwise. """ # Create a copy of the EegDataset object. af_ds = copy.deepcopy(self) # Replace the EEG data with their artefact masks. for ts in af_ds: # Take absolute. x_abs = np.abs(ts.signal) # Window in samples. n = round(window * ts.fs) # Detect artefacts. if how == 'mean': amp = moving_mean(x_abs, n=n) elif how == 'max': amp = moving_max(x_abs, n=n) else: raise ValueError(f'Invalid option how="{how}". Choose from {["mean", "max"]}.') mask = np.logical_or(amp < threshold_low, amp > threshold_high) assert len(ts) == len(mask) ts.parameters.update(dtype=bool) ts.signal = mask return af_ds
[docs] def detect_anomalous_channels(self, window=3, std_factor=8, p_trim=0.25, verbose=0): """ Simple method to detect anomalous high-frequency/amplitude channels by setting a threshold on the line length. The log line length of the (1-`p_strip`)*n_channels with lowest line lengts is used to determine mean and std and if any channel exceeds mean + `std_factor`*std, then it is flagged as artefact. This is done per window. Args: window (float): the window length (seconds) in which to compute line length. std_factor (float): the factor for std determining the threshold (mean+`std_factor*std). p_trim (float): the relative amount of channels to strip before computing the mean and std. The channels with the highest line lengths are stripped/excluded. verbose (int): verbosity level. """ if verbose > 0: print('Detecting anomalous channels in {} with {} channels: {}' .format(self.__class__.__name__, len(self.time_series), [ts.label for ts in self.time_series.values()])) # To array to prepare for function. eeg_ar = self.asarray(channels_last=False) fs = self.fs # Detect anomalous channels. mask_anomaly = detect_anomalous_channels( x=eeg_ar, fs=fs, window=window, std_factor=std_factor, p_trim=p_trim) # Create a copy of the EegDataset object in which we will save the artefact mask. af_ds = copy.deepcopy(self) for ts, mask in zip(af_ds, mask_anomaly): assert len(ts) == len(mask) ts.parameters.update(dtype=bool) ts.signal = mask return af_ds
[docs] def detect_artefacts_kota(self, notch_filter=True, bp_filter=True): """ Simple method to detect small and large amplitude artefacts. Apply to raw data referenced to Cz with notch_filter and bp_filter set to True. https://doi.org/10.1016/j.pediatrneurol.2021.06.001 """ # Create a copy of the EegDataset object. af_ds = copy.deepcopy(self) if notch_filter: # Filter powerline from EEG. af_ds.filtfilt(NotchIIR(f0=50), inplace=True) if bp_filter: # Filter bandpass. af_ds.filtfilt(Butterworth(order=4, fn=[0.3, 20]), inplace=True) # Replace the EEG data with their artefact masks. for ts in af_ds: # Detect artefacts. sd = np.sqrt(moving_mean(ts.signal**2, n=1 * ts.fs)) # Mean can be assumed zero after filtering so this is a good and fast way to compute moving SD. max = moving_max(np.abs(ts.signal), n=0.5 * ts.fs) # Total of 1 second removed when finding high value. mask = (sd < 0.01) | (sd > 50) | (max > 300) assert len(ts) == len(mask) ts.parameters.update(dtype=bool) ts.signal = mask return af_ds
[docs] def detect_artefacts_method(self, how, **kwargs): """ Shortcut to any of the artefact removal functions. """ how = how.lower() if how == 'raw': return None elif how == 'amp': return self.detect_artefacts_amplitude_kaupilla(**kwargs) elif how == 'kota': return self.detect_artefacts_kota(**kwargs) elif how == 'rfc': return self.detect_artefacts_rfc(**kwargs) elif how == 'cnn': return self.detect_artefacts_cnn(**kwargs) else: raise ValueError('Invalid how="{}".'.format(how))
[docs] def detect_artefacts_cnn(self, preprocess=None, detect_flats=True, detect_peaks=True, verbose=1, **kwargs): """ Detect locations of artefacts based on clean EEG detection with CNN. EEG data must be referenced to Cz. If Cz is in self, referencing will be done automatically. If Cz is not present, referencing will be ignored (assuming the EEG data is already referenced). Apply to raw EEG data and set `preprocess_eeg` to True. Args: preprocess (bool): specify whether the EEG needs to be preprocessed (filtered, resampled). Set to True if `eeg` is raw data. If not specified, preprocessing will be done if `fs` is not as required by the model, otherwise not. detect_flats (bool): if True, computes moving std in short windows and if its below a threshold, the sample is marked as artefact (since the CNN might not catch this). detect_peaks (bool): if True, computes moving max abs amplitude in short windows and if its above a threshold, the sample is marked as artefact (since the CNN might not catch this). verbose (int): verbosity level. **kwargs (dict): for CleanDetectorCnn(). Returns: af_ds (EegDataset): EegDataset with artefact masks consisting of boolean values: True at locations of artefacts, False otherwise. """ # Init. eeg_ds = self cd = CleanDetectorCnn(**kwargs) data_requirements = cd.data_requirements # Reference if reference channel in EEG data (otherwise assume it is already referenced correctly). reference_channel = data_requirements['reference_channel'] if reference_channel in eeg_ds: eeg_ds = eeg_ds.reference(reference_channel) # Extract channels in specific order if needed. if data_requirements['channel_order'] is not None: eeg_ds = eeg_ds.extract_channels(data_requirements['channel_order'], make_copy=False) # Extract array data from eeg_ds object. eeg, channel_labels = eeg_ds.asarray(return_channel_labels=True, channels_last=True) fs = eeg_ds.fs if verbose > 1: print('Detecting artefacts using CNN in {} with {} channels: {}' .format(eeg_ds.__class__.__name__, len(eeg_ds), [ts.label for ts in eeg_ds])) # Predict/detect clean locations. clean_mask = cd.predict(eeg, fs=fs, preprocess=preprocess, detect_flats=detect_flats, detect_peaks=detect_peaks, verbose=verbose)[0] # Clean mask to artefact mask. af_mask = clean_mask == 0 # Mark NaNs as artefacts too. af_mask = af_mask | np.isnan(eeg) # Create a copy of the EegDataset object. af_ds = copy.deepcopy(eeg_ds) # Replace the EEG data with their artefact masks. for ts, mask in zip(af_ds, af_mask.T): assert len(ts) == len(mask) ts.parameters.update(dtype=bool) ts.signal = mask return af_ds
[docs] def detect_artefacts_rfc(self, pma=None, preprocess=None, verbose=1, **kwargs): """ Detect locations of artefacts based on a sample supervised random forest classifier from the artefact_detection package. EEG data must be referenced to Cz. If Cz is in self, referencing will be done automatically. If Cz is not present, referencing will be ignored (assuming the EEG data is already referenced). To apply to raw EEG data: set `preprocess_eeg` to True. See Also: RfcArtefactDetector Args: pma (float): PMA of the neonate at time of recording. Optional dependeing on model, see RfcArtefactDetector.predict(). preprocess (bool): specify whether the EEG needs to be preprocessed (filtered, resampled). Set to True if `eeg` is raw data (but it should still be referenced to Cz). If not specified, preprocessing will be done if `fs` is not 128, otherwise not. verbose (int): verbosity level. **kwargs (dict): for RfcArtefactDetector(). Returns: af_ds (EegDataset): EegDataset with artefact masks consisting of boolean values: True at locations of artefacts, False otherwise. """ raise DeprecationWarning from artefact_detection.models.rfc import ArtefactDetectorRfc ad = ArtefactDetectorRfc(**kwargs) # Reference if Cz in dataset. eeg_ds = self if 'Cz' in eeg_ds: eeg_ds = eeg_ds.reference('Cz') # Extract array data from eeg_ds object. eeg, channel_labels = eeg_ds.asarray(return_channel_labels=True, channels_last=True) fs = eeg_ds.fs if verbose > 0: print('Detecting artefacts using random forest classifier in {} with {} channels: {}' .format(eeg_ds.__class__.__name__, len(eeg_ds), [ts.label for ts in eeg_ds])) # Predict/detect artefacts. af_mask = ad.predict(eeg, fs=fs, pma=pma, preprocess=preprocess, verbose=verbose) # Create a copy of the EegDataset object. af_ds = copy.deepcopy(eeg_ds) # Replace the EEG data with their artefact masks. for ts, mask in zip(af_ds, af_mask.T): assert len(ts) == len(mask) ts.parameters.update(dtype=bool) ts.signal = mask return af_ds
[docs] def filter(self, *args, remove_mean=True, **kwargs): """ Calls the parent function, but with remove_mean set to True by default. """ return super().filter(*args, remove_mean=remove_mean, **kwargs)
[docs] def filter_saved_filter(self, *args, remove_mean=True, **kwargs): """ Calls the parent function, but with remove_mean set to True by default. """ return super().filter_saved_filter(*args, remove_mean=remove_mean, **kwargs)
[docs] def filtfilt(self, *args, remove_mean=True, **kwargs): """ Calls the parent function, but with remove_mean set to True by default. """ return super().filtfilt(*args, remove_mean=remove_mean, **kwargs)
[docs] def notch_filt(self, *args, remove_mean=True, **kwargs): """ Calls the parent function, but with remove_mean set to True by default. """ return super().notch_filt(*args, remove_mean=remove_mean, **kwargs)
[docs] @staticmethod def from_array(eeg, fs, channel_labels=None, **kwargs): """ Create an EegDataset from a 2D array (channels x time). Args: eeg (np.ndarray): 2d array with dimensions (channels, time). fs (float): sample frequency in Hz. channel_labels (list): list with length len(eeg) containing labels for the channels. If None, will create default numbered names, like Ch1, Ch2, Ch3, ... **kwargs: optional keyword arguments for nnsa.TimeSeries. Returns: ds (nnsa.EegDataset): EegDataset with the data from the array. Examples: >>> eeg = np.random.rand(8, 10000) >>> eeg_ds = EegDataset.from_array(eeg, fs=250, channel_labels=list(range(8))) """ # Check input dimensions. eeg = np.asarray(eeg) if eeg.ndim == 1: # Assume one channel, change to shape (n_time, n_channels). eeg = eeg.reshape(-1, 1) if eeg.ndim != 2: raise ValueError('`eeg` must have 2 dimensions. Got {} dimensions.' .format(eeg.ndim)) if channel_labels is None: # Create default labels. channel_labels = ['Ch{}'.format(i+1) for i in range(min(eeg.shape))] if len(eeg) != len(channel_labels): # Try transpose. eeg = np.transpose(eeg) if len(eeg) != len(channel_labels): raise ValueError('`len(eeg) ({}) does not equal len(channel_labels) ({})' .format(len(eeg), len(channel_labels))) # Create dataset (collection of TimeSeries). ds = EegDataset() for signal, label in zip(eeg, channel_labels): ts = TimeSeries(signal=signal, fs=fs, label=label, **kwargs) ds.append(ts) return ds
[docs] def get_channel(self, channel, make_copy=True): """ Return a channel as TimeSeries. Can also be a bipolar channel. Args: channel (str): channel label, or a bipolar channel. make_copy (bool): if True, the returned TimeSeries is always a copy. If False, the returned TimeSeries may be the same object as in self, if `channel` is in self. Returns: ts (TimeSeries): time series containing the channel signal. """ # Get the requested channel. if channel in self: ts = self[channel] if make_copy: ts = copy.deepcopy(ts) elif '-' in channel: # Create bipolar channel. channels_1, channels_2 = channel.replace(' ', '').split('-') ts = self.create_bipolar_channel(channels_1, channels_2) else: raise ValueError('Invalid channel "{}".'.format(channel)) return ts
[docs] def get_segments(self, segment_length, overlap=0, channels_last=True): """ Segment the EEG data and return the segmented data as a 3D array. Args: segment_length (float): segment length in seconds. overlap (float): segment overlap in seconds. channels_last (bool): if True, the returned array has shape (n_segments, n_time, n_channels). If False, the shape is (n_segments, n_channels, n_time). Returns: eeg_seg (np.ndarray): 3D array with length equal to the number of segments. The other two dimensions are time and channels, their order depends on `channels_last`. """ eeg_ar = self.asarray(channels_last=channels_last) eeg_seg = get_all_segments( eeg_ar, segment_length=segment_length, overlap=overlap, axis=0 if channels_last else 1, fs=self.fs) return eeg_seg
[docs] def kirubin_features(self, reference_channel=None, preprocess=False, inplace=False, verbose=1, **kwargs): """ Compute Kirubin's feature set. This is a wrapper that prepares the input for KirubinFeatures.wct() and returns the result. Args: reference_channel (str, optional): specify how to reference the data. If no referencing is desired, specify None. preprocess (bool, optional): apply a default preprocessing routine on the data before feature computation (True) or not (False). Defaults to False. inplace (bool, optional): if `preprocess_eeg` is True, specify if the preprocessing should be inplace or not. Defaults to False. verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): optional keyword arguments to overrule default parameters of the KirubinFeatures class. Returns: result (nnsa.KirubinFeaturesResult): the result with the features. """ # Preprocess. if preprocess: if inplace: ds = self else: ds = copy.deepcopy(self) # Default processing routine. fir_filt = get_eeg_fir_filter_a() notch_filt = NotchIIR(f0=50, Q=35) if reference_channel is not None: ds.reference(reference_channel=reference_channel, inplace=True) ds.remove_flatlines(n_flatline=1, inplace=True) ds.resample(fs_new=250, method='polyphase_filtering', inplace=True) ds.filtfilt(fir_filt, inplace=True) ds.filtfilt(notch_filt, inplace=True) else: # Use data as it is. ds = self if verbose > 0: print('Computing Kirubin features of {} with {} channels: {}' .format(ds.__class__.__name__, len(ds.time_series), [ts.label for ts in ds.time_series.values()])) # Initialize feature extraction object (updates default parameters with user specified keyword arguments). from nnsa.feature_extraction.feature_sets_old import KirubinFeatures kf = KirubinFeatures(**kwargs) # Prepare input matrix for feature extraction function. data_matrix, channel_labels = ds.asarray(return_channel_labels=True) # Sample frequency of all EEG signals is the same if self.asarray() did not raise an error. fs = next((ts.fs for ts in ds.time_series.values())) # Process. result = kf.process(data_matrix, fs=fs, channel_labels=channel_labels, verbose=verbose) # Add info string about data to result. ds._postprocess_result(result) return result
[docs] def mean(self, channels=None, label=None, **kwargs): """ Return the nanmean of the specified channels as a TimeSeries. Args: channels (list, optional): list with labels of the signals to be average. If None, all signals in the Dataset are averaged. Defaults to None. label (str, optional): label for returned TimeSeries. **kwargs (optional): keyword arguments for TimeSeries(). Returns: ts (nnsa.TimeSeries): TimeSeries containing the median of the channels. """ eeg, channel_labels = self.asarray(channels=channels, return_channel_labels=True) fs = self.fs # Take median across channels. if len(eeg) == 1: signal = eeg.copy() else: signal = np.nanmean(eeg, axis=0) if label is None: label = 'Mean of {}'.format(', '.join( [lab.replace('EEG ', '') for lab in channel_labels])) ts = TimeSeries(signal=signal, fs=fs, label=label, unit=self.unit, info=self.info, time_offset=self.time_offset, **kwargs) return ts
[docs] def median(self, channels=None, label=None, **kwargs): """ Return the nanmedian of the specified channels as a TimeSeries. Args: channels (list, optional): list with labels of the signals to be average. If None, all signals in the Dataset are averaged. Defaults to None. label (str, optional): label for returned TimeSeries. **kwargs (optional): keyword arguments for TimeSeries(). Returns: ts (nnsa.TimeSeries): TimeSeries containing the median of the channels. """ eeg, channel_labels = self.asarray(channels=channels, return_channel_labels=True) fs = self.fs # Take median across channels. if len(eeg) == 1: signal = eeg.copy() else: signal = np.nanmedian(eeg, axis=0) if label is None: label = 'Median of {}'.format(', '.join( [lab.replace('EEG ', '') for lab in channel_labels])) ts = TimeSeries(signal=signal, fs=fs, label=label, unit=self.unit, info=self.info, time_offset=self.time_offset, **kwargs) return ts
[docs] def plot(self, begin=None, end=None, channels=None, color=None, scale=None, relative_time=False, time_scale=None, ax=None, legend=True, **kwargs): """ Plot the EEG data of specified channels for a specifiec time frame. Assumes that all channels have same sampling frequency and unit. Raises a NotImplementedError if this is not the case. Args: begin (float, optional): begin time in seconds, relative to the beginning of the loaded signal (ignores any offset). Defaults to 0. end (float, optional): end time in seconds or None to specify the end of the entire signal. Relative to the beginning of the loaded signal ignores any offset. Defaults to None. channels (list, optional): list of labels specifying the channels to plot, in order bottom to up. If None, all EEG channels in the EegDataset are plotted. Defaults to None. color (str or array or None): if str or array, use this as the color for all channels. If None, the default colors will be used to plot each channel in a different color. scale (int, optional): value (in the same unit as the data) by which to scale the data (data is divided by scale). This factor must aim to squeeze most data between -0.5 and 0.5. If None, this scale is determined automatically. Default is None. relative_time (bool, optional): if True, the time axis is relative to the start of the segment to plot. If False, the time axis will correspond to the time in the recording. time_scale (str, optional): the time scale to use. Choose from 'seconds', 'minutes', 'hours', None. If None or 'timedelta', plots with a datetime.timedelta format (hh:mm:ss). ax (plt.Axes, optional): axes to plot in. If None, plots in a new figure. Defaults to None. legend (bool, optional): add a legend (True) indicating the scale or not (False). Defaults to True. **kwargs (optional): optional keyword arguments for the plt.plot() function. Returns: ax (plt.axes): axes instance. """ # By default, plot all channels. if channels is None: channels = list(self.time_series.keys())[::-1] if isinstance(channels, str): channels = [channels] if not self.is_synchronized(): raise ValueError('EEG signals not synchronized. Cannot plot.') # # Get fs and time offset. # fs, time_offset = next((ts.fs, ts.time_offset) for ts in self.time_series.values()) # Extract unit. unit_list = [self[ch].unit for ch in channels] if len(np.unique(unit_list)) != 1: NotImplementedError('Plot function not compatible with EEG signals of different units.') unit = np.unique(unit_list)[0] # Create time. time = self.time time_mask = get_range_mask(time, x_min=begin, x_max=end) time = time[time_mask] # Collect (segments of) signals to plot. signal_list = [self[ch].signal[time_mask] for ch in channels] # Compute scaling: we want all signals to lie (for the most part within in a range of 1). if scale is None: deviation_range = np.nanmax( [np.nanpercentile(sig, 95) - np.nanpercentile(sig, 5) for sig in signal_list]) if np.isnan(deviation_range): deviation_range = 0 scale = int(round(deviation_range)) if scale == 0: scale = 1 # Compute offsets. offset_list = np.arange(len(channels)) + 1 # Subtract t0 is relative time is requested. if relative_time: time -= time[0] # Convert time scale if requested. if time_scale is not None: if time_scale == 'timedelta': time_scale = None else: time = convert_time_scale(time, time_scale) # Default plot keyword arguments. plot_kwargs = { 'linewidth': compute_linewidth(signal_list[0]), } # Update plot keyword arguments with user specified keyword arguments. plot_kwargs.update(kwargs) # Set current axis. if ax is not None: plt.sca(ax) else: ax = plt.gca() if color is None: colors = ['C{}'.format(i) for i in range(len(channels))] else: colors = [color]*len(channels) # Loop over channels and plot. for label, signal, offset, color in zip(channels, signal_list, offset_list, colors): # Create a scaled version of the channel signal. y = signal / scale + offset # Plot. plt.plot(time, y, color=color, **plot_kwargs) # Remove label from plot_kwargs (only put a label on the first channel). plot_kwargs['label'] = None # Figure makeup. plt.yticks(offset_list, [ch.replace('EEG', '') for ch in channels]) for color, text in zip(colors, plt.yticks()[1]): text.set_color(color) plt.ylim(0, len(channels) + 1) if legend: add_scalebar(plt.gca(), labely='{}\n{}'.format(scale, unit), matchx=False, matchy=False, hidex=False, hidey=False, sizey=1, barwidth=3, sep=4, loc=4, barcolor='black') plt.grid(True) plt.title(self.label) # x-axis. if time_scale is None: def timeTicks(x, pos): d = datetime.timedelta(seconds=x) return str(d) formatter = FuncFormatter(timeTicks) ax.xaxis.set_major_formatter(formatter) ax.set_xlabel('Time (h:mm:ss)') else: ax.set_xlabel('Time ({})'.format(time_scale)) return ax
[docs] def plot_mask(self, mask, time_scale='seconds', color=None, skip_value=0, ax=None, **kwargs): """ Plot a shaded mask over an EEG plot. Args: mask (np.ndarray): mask with shape (n_time, n_channels). time_scale (str): time scale in which to plot. color (str, array, dict): color or dictionary of colors for plotting the labels in mask. skip_value (int, bool, float or list): value(s) in mask to not plot. ax (plt.Axes): axes to plot in. **kwargs: for plot(). """ if ax is None: ax = plt.gca() mask = np.asarray(mask).astype(int) if mask.ndim == 1: # 1D array is given, assume it holds for all channels: stack copies of the mask. mask = np.tile(mask.reshape(-1, 1), [1, len(self)]) if len(self) != mask.shape[1]: # Try transposing mask. mask = mask.T if len(self) != mask.shape[1]: raise ValueError('`mask` should have same shape number of channels as EEG ({}). Got {}.' .format(len(self), mask.shape[1])) # By default, flip mask, so that the first channels appear at the top (like in self.plot()). mask = mask[:, ::-1] if not isinstance(skip_value, (list, tuple, set)): skip_value = [skip_value] skip_value = list(skip_value) # Unique labels. if not isinstance(color, dict): # Default color for each unique label. keys = list(set(np.unique(mask)) - set(skip_value)) if color is None: values = ['C{}'.format(i) for i in range(len(keys))] else: values = [color]*len(keys) color_mapping = dict(zip(keys, values)) else: color_mapping = color plot_kwargs = dict({ 'zorder': -1, 'alpha': 0.4, 'edgecolor': None, }, **kwargs) # Get time array. t = convert_time_scale(self.time, time_scale=time_scale) # Loop over labels in mask. for value, color in color_mapping.items(): if value in skip_value: continue for i, m in enumerate(mask.T): mask_i = (m == value) y_i = i + 1 onsets_i, offsets_i = get_onsets_offsets(mask_i, fs=self.fs) durations_i = convert_time_scale((offsets_i - onsets_i), time_scale=time_scale) for t_i, dur_i in zip(onsets_i, durations_i): ax.add_patch(Rectangle( xy=(t_i + self.time_offset, y_i - 0.5), width=dur_i, height=1, facecolor=color, **plot_kwargs))
[docs] def power_analysis(self, verbose=1, **kwargs): """ Calls the paretn method, but sets some common frequency bands for EEG. """ feature_kwargs = dict({'frequency_bands': { 'delta_1': [0, 2], 'delta_2': [2, 4], 'theta': [4, 8], 'alpha': [8, 16], 'beta': [16, 32] }}, **kwargs) return super().power_analysis(verbose=verbose, **feature_kwargs)
[docs] def reference(self, reference_channel='Cz', remove_reference=True, inplace=False, verbose=1): """ Reference all EEG signals to reference_channel, i.e. subtract the signal of the reference_channel from all other EEG channels. Args: reference_channel (str, optional): label of the channel to use as reference. Can also be 'average' to subtract the average channel. Default is 'Cz'. remove_reference (bool, optional): if True, remove the reference channel from the referenced dataset. If False, keep the reference channel after referencing (this channel will be zero everywhere). Default is True. inplace (bool, optional): if True, the referenced signals will replace the original signals in this EegDataset. If False (default), the function returns a new EegDataset object containing the referenced signals, leaving the data in the current EegDataset unchanged. Note that if only certain channels are referenced, the newly created EegDataset object will only contain those channels. verbose (int, optional): verbose level. Defaults to 1. """ if verbose > 0: print('Referencing {} with {}...'.format(self.label, reference_channel)) if inplace: # We will replace the original TimeSeries objects in the current EegDataset by the referenced TimeSeries. ds_referenced = self else: # We will replace the original TimeSeries objects in A COPY OF the current EegDataset by the referenced # TimeSeries. # Create a copy of the EegDataset object. ds_referenced = copy.deepcopy(self) if reference_channel.lower() in ['mean', 'average', 'avg']: # Compute average. reference_channel = 'average' ref_signal = np.mean(self.asarray(), axis=0) else: # Extract reference signal. reference_channel = self._check_label(reference_channel) ref_signal = self[reference_channel].signal if remove_reference: # Remove reference from dataset. ds_referenced.remove(reference_channel, inplace=True, verbose=verbose) # Reference all EEG signals. for label in ds_referenced.time_series.keys(): ts_referenced = self.time_series[label].reference( ref_signal, reference_channel, verbose=1 if verbose > 1 else 0) ds_referenced.time_series[label] = ts_referenced # Only return if not in place referencing. if not inplace: return ds_referenced
[docs] def remove_artefacts_method(self, how, **kwargs): """ Shortcut to any of the artefact removal functions. """ if how == 'raw': return self elif how == 'amp': return self.remove_artefacts_amplitude(**kwargs) elif how == 'kota': return self.remove_artefacts_kota(**kwargs) else: raise ValueError('Invalid how="{}".'.format(how))
[docs] def remove_artefact_channels(self, inplace=False, **kwargs): """ Remove channels that do not meet the quality criteria. Args: inplace (bool, optional): If True, the channel is removed directly from the EegDataset object itself. If False, the function returns a copy of the original EegDataset object in which the artefacted channels are removed, leaving the data in the current EegDataset unchanged. Default is False. **kwargs (optional): optional keyword arguments for overruling default signal quality criteria (see nnsa.artefacts.artefact_detection.default_eeg_signal_quality_criteria()). Returns: (EegDataset): new EegDataset object containing the same signals, except for the removed channels (only if inplace is False). """ # Update default criteria with use-specified arguments. criteria = default_eeg_signal_quality_criteria() criteria.update(kwargs) return super().remove_artefact_channels(inplace=inplace, **criteria)
[docs] def remove_artefacts_amplitude(self, notch_filter=True, bp_filter=True, inplace=False): """ Remove artefacts as detected by self.detect_artefacts_amplitude(). """ if inplace: # We will replace the original TimeSeries objects in the current EegDataset. ds_out = self else: # Create a copy of the EegDataset object. ds_out = copy.deepcopy(self) # Detect the artefacts. af_ds = self.detect_artefacts_amplitude_kaupilla(notch_filter=notch_filter, bp_filter=bp_filter) # Replace artefacts with nan. for ts, mask in zip(ds_out, af_ds): assert ts.label == mask.label ts._signal[mask] = np.nan if not inplace: return ds_out
[docs] def remove_artefacts_derivative(self, threshold=23000, inplace=False): """ Remove artefacts based on a threshold on the derivative of the filtered signal. Apply this to raw EEG data. Args: TODO inplace: Returns: EegDataset where the artefacts are replaced by nans. """ if inplace: # We will replace the original TimeSeries objects in the current EegDataset. ds_out = self else: # Create a copy of the EegDataset object. ds_out = copy.deepcopy(self) # Filter before checking the derivative. fir_filt = get_eeg_fir_filter_a() notch_filt = NotchIIR(f0=50, Q=35) ds_filt = ds_out.filtfilt(notch_filt, inplace=False) ds_filt.filtfilt(fir_filt, inplace=True) # Remove artefacts based on derivative. threshold = threshold/ds_filt.fs # Threshold scales with fs (as fs is lower, threshold is higher). ds_filt.remove_artefacts(inplace=True, max_diff=threshold) # Get artefact masks to apply the detected artefacts to the original (unfiltered) signals. ds_af_mask = ds_filt.transform(fun=np.isnan).astype(bool) # Replace artefacts with nan. for ts, mask in zip(ds_out, ds_af_mask): ts._signal[mask] = np.nan if not inplace: return ds_out
[docs] def remove_artefacts_kota(self, notch_filter=True, bp_filter=True, inplace=False, verbose=1): """ Remove artefacts as detected by self.detect_artefacts_kota(). """ if verbose: print('Removing artefacts with Kota.') if inplace: # We will replace the original TimeSeries objects in the current EegDataset. ds_out = self else: # Create a copy of the EegDataset object. ds_out = copy.deepcopy(self) # Detect the artefacts. af_ds = self.detect_artefacts_kota(notch_filter=notch_filter, bp_filter=bp_filter) # Replace artefacts with nan. for ts, mask in zip(ds_out, af_ds): assert ts.label == mask.label ts._signal[mask] = np.nan if not inplace: return ds_out
[docs] def remove_artefacts_mask(self, mask_ds, inplace=False): """ Insert nans where boolean mask is True. Args: mask_ds (BaseDataset): dataset with boolean masks for each of the signals in self. """ if inplace: # We will replace the original TimeSeries objects in the current EegDataset. ds_out = self else: # Create a copy of the EegDataset object. ds_out = copy.deepcopy(self) if not isinstance(mask_ds, BaseDataset): if len(mask_ds) != len(ds_out): raise ValueError('`mask_ds` does not have the same length as the dataset.') for ii, ts in enumerate(ds_out): if isinstance(mask_ds, BaseDataset): mask = np.asarray(mask_ds[ts.label].signal) else: mask = np.asarray(mask_ds[ii]) if mask.dtype is not np.dtype('bool'): raise TypeError('Signals in `mask_ds` should have np.dtype `bool`. Got {}.'.format(mask.dtype)) ts._signal[mask] = np.nan if not inplace: return ds_out
[docs] def remove_artefacts_rfc(self, pma, preprocess=None, inplace=False, **kwargs): """ Remove artefacts based on a sample supervised random forest classifier from the artefact_detection package. EEG data must be referenced to Cz. If Cz is in self, referencing will be done automatically. If Cz is not present, referencing will be ignored (assuming the EEG data is already referenced). To apply to raw EEG data: set `preprocess_eeg` to True. See Also: RfcArtefactDetector, self.detect_artefacts_rfc(). Args: pma (np.ndarray): PMA of the neonate at time of recording. preprocess (bool): specify whether the EEG needs to be preprocessed (filtered, resampled). Set to True if `eeg` is raw data (but it should still be referenced to Cz). If not specified, preprocessing will be done if `fs` is not 128, otherwise not. inplace (bool): whether to remove in place or not. **kwargs (dict): for RfcArtefactDetector(). Returns: ds_out (EegDataset): EegDataset where the artefacts are replaced by nans. """ if inplace: # We will replace the original TimeSeries objects in the current EegDataset. ds_out = self else: # Create a copy of the EegDataset object. ds_out = copy.deepcopy(self) # Detect the artefacts. af_ds = self.detect_artefacts_rfc(pma=pma, preprocess=preprocess, **kwargs) # Replace artefacts with nan. for ts, mask in zip(ds_out, af_ds): assert ts.label == mask.label ts._signal[mask] = np.nan if not inplace: return ds_out
[docs] @staticmethod def read_begin_end_time(filepath): """ Read the start and end time (wrt time offset) of the EEG data in a HDF5 file. Args: filepath (str): filepath to read. Returns: begin, end (float): begin and end time in seconds. """ with h5py.File(filepath, 'r') as f: sig = f['EEG'] fs = sig.attrs['fs'] time_offset = sig.attrs['time_offset'] data_len = sig.shape[-1] begin = time_offset end = begin + data_len/fs return begin, end
[docs] @staticmethod def read_hdf5(filepath, begin=None, end=None, **kwargs): """ Read an EegDataset object from an .hdf5 file. Args: filepath (str): filepath to read. begin (float, optional): start second (wrt time offset). end (float, optional): end second (wrt time offset). **kwargs (optional): keyword arguments for TimeSeries(). Returns: ds (nnsa.EegDataset): EegDataset object with data read from the file. Examples: >>> signal = np.random.rand(8, 1000) >>> eeg_ds = EegDataset() >>> ts_all = [TimeSeries(signal=signal[i], fs=10, label=i) for i in range(len(signal))] >>> eeg_ds.append(ts_all) >>> eeg_ds.save_hdf5('testfile.hdf5') >>> ds_read = EegDataset.read_hdf5('testfile.hdf5') >>> ds_read_epoch = EegDataset.read_hdf5('testfile.hdf5', begin=10, end=20) >>> os.remove('testfile.hdf5') >>> assert_equal(eeg_ds.signal, ds_read.signal) >>> assert_equal(eeg_ds.extract_epoch(begin=10, end=20).signal, ds_read_epoch.signal) """ ds = EegDataset() with h5py.File(filepath, 'r') as f: # Read data. sig = f['EEG'] fs = sig.attrs['fs'] time_offset = sig.attrs['time_offset'] channel_labels = sig.attrs['channel_labels'] unit = sig.attrs['unit'] # Determine start and end index. end_idx = None if end is None else max([0, min([int((end - time_offset)*fs), sig.shape[-1]])]) begin_idx = 0 if begin is None else max([int((begin - time_offset)*fs), 0]) # Extract requested part. eeg = sig[:, begin_idx:end_idx] # Update time_offset (depending on `begin`). time_offset = time_offset + begin_idx/fs # Do not read info, just put the current filepath. info = {'source': f.filename} # Add each channel to the dataset. for x, label in zip(eeg, channel_labels): ds.append(TimeSeries(x, fs=fs, label=label.decode(), unit=unit.decode(), time_offset=time_offset, info=info, **kwargs)) return ds
[docs] @staticmethod def read_edf(filepath, begin=None, end=None, **kwargs): """ Read an EegDataset object from an .edf file. Args: filepath (str): filepath to read. begin (float, optional): start second. end (float, optional): end second. **kwargs (optional): keyword arguments for EdfReader.read_eeg_dataset(). Returns: ds (nnsa.EegDataset): EegDataset object with data read from the file. """ from nnsa.io.readers import EdfReader with EdfReader(filepath) as r: ds = r.read_eeg_dataset(begin=begin, end=end, **kwargs) return ds
[docs] @staticmethod def read_mat(filepath, begin=None, end=None, loadmat_kwargs=None, **kwargs): """ Read an EegDataset object from a .mat file. Args: filepath (str): filepath to read. begin (float, optional): start second (wrt time offset). end (float, optional): end second (wrt time offset). loadmat_kwargs (dict, optional): dict with keyword arguments for scipy.io.loadmat()/ **kwargs (optional): keyword arguments for TimeSeries(). Returns: ds (nnsa.EegDataset): EegDataset object with data read from the file. """ # Read data. loadmat_kwargs = dict({ 'struct_as_record': True, 'squeeze_me': True, 'mat_dtype': False, }, **loadmat_kwargs if loadmat_kwargs is not None else dict()) data = scipy.io.loadmat(filepath, **loadmat_kwargs) sig = data['EEG'] fs = data['fs'] time_offset = data.get('time_offset', 0) channel_labels = data.get('channel_labels', None) unit = data.get('unit', 'a.u.') # Check shape and transpose if needed (require n_channels, n_time). if sig.shape[0] > sig.shape[-1]: sig = sig.T # Determine start and end index. end_idx = None if end is None else max([0, min([int((end - time_offset) * fs), sig.shape[-1]])]) begin_idx = 0 if begin is None else max([int((begin - time_offset) * fs), 0]) # Extract requested part. eeg = sig[:, begin_idx:end_idx] # Update time_offset (depending on `begin`). time_offset = time_offset + begin_idx / fs # Just put the current filepath. info = {'source': filepath} # Add each channel to the dataset. ds = EegDataset() for x, label in zip(eeg, channel_labels): ds.append(TimeSeries(x, fs=fs, label=label, unit=unit, time_offset=time_offset, info=info, **kwargs)) return ds
[docs] @staticmethod def read_pickle(filepath, begin=None, end=None, **kwargs): """ Read an EegDataset object from a .pkl file. Args: filepath (str): filepath to read. begin (float, optional): start second (wrt time offset). end (float, optional): end second (wrt time offset). **kwargs (optional): keyword arguments for TimeSeries(). Returns: ds (nnsa.EegDataset): EegDataset object with data read from the file. """ # Read data. data = pickle_load(filepath) sig = data['EEG'] fs = data['fs'] time_offset = data.get('time_offset', 0) channel_labels = data.get('channel_labels', None) unit = data.get('unit', 'a.u.') # Check shape and transpose if needed (require n_channels, n_time). if sig.shape[0] > sig.shape[-1]: sig = sig.T # Determine start and end index. end_idx = None if end is None else max([0, min([int((end - time_offset) * fs), sig.shape[-1]])]) begin_idx = 0 if begin is None else max([int((begin - time_offset) * fs), 0]) # Extract requested part. eeg = sig[:, begin_idx:end_idx] # Update time_offset (depending on `begin`). time_offset = time_offset + begin_idx / fs # Just put the current filepath. info = {'source': filepath} # Add each channel to the dataset. ds = EegDataset() for x, label in zip(eeg, channel_labels): ds.append(TimeSeries(x, fs=fs, label=label, unit=unit, time_offset=time_offset, info=info, **kwargs)) return ds
[docs] @staticmethod def read_set(filepath, begin=None, end=None, loadmat_kwargs=None, **kwargs): """ Read an EegDataset object from a .set file (from Matlab EEGLAB). Args: filepath (str): filepath to read. begin (float, optional): start second (wrt time offset). end (float, optional): end second (wrt time offset). loadmat_kwargs (dict, optional): dict with keyword arguments for scipy.io.loadmat()/ **kwargs (optional): keyword arguments for TimeSeries(). Returns: ds (nnsa.EegDataset): EegDataset object with data read from the file. """ # Read data. loadmat_kwargs = dict({ 'struct_as_record': True, 'squeeze_me': True, 'mat_dtype': False, }, **loadmat_kwargs if loadmat_kwargs is not None else dict()) data = scipy.io.loadmat(filepath, **loadmat_kwargs) sig = data['data'] fs = data['srate'] time_offset = data['times'][0] channel_labels = None if 'chanlocs' in data: chanlocs = data['chanlocs'] channel_labels = [cl[0] for cl in chanlocs] unit = data.get('unit', 'a.u.') # Check shape and transpose if needed (require n_channels, n_time). if sig.shape[0] > sig.shape[-1]: sig = sig.T # Determine start and end index. end_idx = None if end is None else max([0, min([int((end - time_offset)*fs), sig.shape[-1]])]) begin_idx = 0 if begin is None else max([int((begin - time_offset)*fs), 0]) # Extract requested part. eeg = sig[:, begin_idx:end_idx] # Update time_offset (depending on `begin`). time_offset = time_offset + begin_idx/fs # Just put the current filepath. info = {'source': filepath} # Add each channel to the dataset. ds = EegDataset() for x, label in zip(eeg, channel_labels): ds.append(TimeSeries(x, fs=fs, label=label, unit=unit, time_offset=time_offset, info=info, **kwargs)) return ds
[docs] def save_hdf5(self, filepath, mode='w', overwrite=False): """ Save the EegDataset data to a .hdf5 file. Args: filepath (str): filepath to save to. mode (str, optional): 'w' for write mode or 'a' for append mode. Defaults to 'w'. overwrite (bool, optional): if True, overwrites existing files. If False, raises an error when `filepath` already exists. Defaults to False. """ _, file_extension = os.path.splitext(filepath) if not file_extension: # Add file extension. filepath = '{}.hdf5'.format(filepath) elif file_extension.lower() not in ['.hdf5', '.h5']: raise ValueError('Invalid file extension "{}". Use one of {}.' .format(file_extension, ['.hdf5', '.h5'])) if not overwrite: # Check if filepath already exists. if os.path.exists(filepath): raise ValueError('File "{}" already exists. Overwriting can be enabled by setting overwrite=True.' .format(filepath)) # Check the directory and create if it does not exist. check_directory_exists(filepath=filepath) # Get EEG as array. eeg, channel_labels = self.asarray(return_channel_labels=True) unit = list(set([self.time_series[label].unit for label in channel_labels]))[0] info = self.time_series[channel_labels[0]].info # Just take first channel. # Write hdf5 file. with h5py.File(filepath, mode=mode) as f: # Write array data. sig = f.create_dataset('EEG', data=eeg) # Write non-array data as attributes to the signal array. sig.attrs['fs'] = float(self.fs) sig.attrs['time_offset'] = float(self.time_offset) # Convert strings to np.string_ type as recommended for compatibility. sig.attrs['channel_labels'] = [np.string_(label) for label in channel_labels] sig.attrs['unit'] = np.string_(unit) # Write dict as a separate dataset. write_dict_to_hdf5(f, info, 'info')
[docs] def save_mat(self, filepath, overwrite=False): """ Save the EegDataset data to a .mat file. Args: filepath (str): filepath to save to. overwrite (bool, optional): if True, overwrites existing files. If False, raises an error when `filepath` already exists. Defaults to False. """ _, file_extension = os.path.splitext(filepath) if not file_extension: # Add file extension. filepath = '{}.mat'.format(filepath) elif file_extension.lower() not in ['.mat']: raise ValueError('Invalid file extension "{}". Use one of {}.' .format(file_extension, ['.mat'])) if not overwrite: # Check if filepath already exists. if os.path.exists(filepath): raise ValueError('File "{}" already exists. Overwriting can be enabled by setting overwrite=True.' .format(filepath)) # Check the directory and create if it does not exist. check_directory_exists(filepath=filepath) # Get EEG as array. eeg, channel_labels = self.asarray(return_channel_labels=True) # Collect data in dict. data = { 'EEG': eeg, 'channel_labels': channel_labels, 'fs': self.fs, 'time_offset': self.time_offset, 'unit': self.unit, 'info': str(self.info), } # Save dict. scipy.io.savemat(filepath, data)
[docs] def save_csv(self, filepath, overwrite=False, **kwargs): """ Save the EegDataset data to a .csv (comma-separated value) file. This is slow, as it used the pandas functionality. Args: filepath (str): filepath to save to. overwrite (bool, optional): if True, overwrites existing files. If False, raises an error when `filepath` already exists. Defaults to False. **kwargs (dict, optional): for pandas.DataFrame.to_csv(). """ _, file_extension = os.path.splitext(filepath) if not file_extension: # Add file extension. filepath = '{}.csv'.format(filepath) elif file_extension.lower() not in ['.csv']: raise ValueError('Invalid file extension "{}". Use {}.' .format(file_extension, '.csv')) if not overwrite: # Check if filepath already exists. if os.path.exists(filepath): raise ValueError('File "{}" already exists. Overwriting can be enabled by setting overwrite=True.' .format(filepath)) # Check the directory and create if it does not exist. check_directory_exists(filepath=filepath) # Get EEG as array. eeg, channel_labels = self.asarray(return_channel_labels=True) # To dataframe. eeg_df = pd.DataFrame( eeg.T, columns=[lab.replace('EEG', '').strip() for lab in channel_labels]) eeg_df['Time'] = self.time # Save. eeg_df.to_csv(filepath, index=False, **kwargs)
[docs] def save_pickle(self, filepath, overwrite=False, **kwargs): """ Save the EegDataset data to a .pkl (pickle) file. Args: filepath (str): filepath to save to. overwrite (bool, optional): if True, overwrites existing files. If False, raises an error when `filepath` already exists. Defaults to False. **kwargs (dict, optional): for nnsa.pickle_save(). """ _, file_extension = os.path.splitext(filepath) if not file_extension: # Add file extension. filepath = '{}.pkl'.format(filepath) elif file_extension.lower() not in ['.pkl', '.pickle']: raise ValueError('Invalid file extension "{}". Use one of {}.' .format(file_extension, ['.pkl', '.pickle'])) if not overwrite: # Check if filepath already exists. if os.path.exists(filepath): raise ValueError('File "{}" already exists. Overwriting can be enabled by setting overwrite=True.' .format(filepath)) # Check the directory and create if it does not exist. check_directory_exists(filepath=filepath) # Get EEG as array. eeg, channel_labels = self.asarray(return_channel_labels=True) # Collect data in dict. data = { 'EEG': eeg, 'channel_labels': channel_labels, 'fs': self.fs, 'time_offset': self.time_offset, 'unit': self.unit, 'info': str(self.info), } # Save dict. pickle_save(filepath, data, **kwargs)
[docs] def sleep_stages_cnn(self, preprocess=None, remove_flats=True, remove_baseline_wander=False, stepsize=None, verbose=1, preprocess_kwargs=None, **kwargs): """ Automatic sleep stage classification using a Convolutional Neural Network. This is a wrapper that prepares the input for SleepStagesCnn.sleep_stages_cnn() and returns the result. Args: preprocess (bool, optional): if True, the EEG data in ds will be processed by filtering and resampling correspondingly. If False, the data in ds will not be preprocessed (use this when the ds already contains correspondingly prepreocessed data). If None, preprocessing will be done if the sampling frequency of the data does not match the required fs by the model. remove_flats (bool, optional): if True, removes flatlines. remove_baseline_wander (bool): if True, removes excessive baseline wander by applying a HP filter. This was not part of original preprocessing during training, and most EEGs don't need this (in the original preprocessing there is already a HP filter, but milder). If not needed, don't do this additional baseline removal. stepsize (float): stepsize (in seconds) for sleep stage prediction. If None, stepsize is equal to segment length (no overlap). verbose (int, optional): verbose level (0 or 1). Defaults to 1. preprocess_kwargs (dict, optional): dict with optional keyword arguments for SleepStagesCnn.preprocess_recording(). Defaults to None. **kwargs (optional): optional keyword arguments to overrule default parameters of the SleepStagesCnn class. Returns: result (nnsa.SleepStagesCnnResult): nnsa object containing the results of the CNN sleep stage classification. """ if verbose > 1: print('Classifying sleep stages with CNN of {} with {} channels: {}' .format(self.label, len(self.time_series), [ts.label for ts in self.time_series.values()])) # Initialize SleepStagesCnn object. sleep_stages_cnn = SleepStagesCnn(**kwargs) # Prepare the dataset for the CNN and initialize a SleepStagesCnn object. x = self._prepare_sleep_stages_cnn( sleep_stages_cnn=sleep_stages_cnn, preprocess=preprocess, remove_flats=remove_flats, remove_baseline_wander=remove_baseline_wander, stepsize=stepsize, verbose=verbose > 1, **(preprocess_kwargs if preprocess_kwargs is not None else dict())) segment_length = sleep_stages_cnn.data_requirements['segment_length'] stepsize = stepsize if stepsize is not None else segment_length segment_start_times = get_segment_times( num_segments=len(x), segment_length=segment_length, overlap=segment_length - stepsize, offset=0) # Time offset will be added by self._postprocess_result. # Run sleep stages cnn algorithm. result = sleep_stages_cnn.sleep_stages_cnn( x, segment_start_times=segment_start_times, segment_end_times=segment_start_times + segment_length, verbose=verbose) # If using a stepsize, keep only 1 prediction per segment length. if stepsize is not None: stride = int(round(segment_length // stepsize)) result.probabilities = result.probabilities[:, ::stride] result.segment_start_times = result.segment_start_times[::stride] result.segment_end_times = result.segment_end_times[::stride] if result.latent_features is not None: result.latent_features = result.latent_features[:, ::stride] # Add info string about data to result. channel_order = sleep_stages_cnn.data_requirements['channel_order'] eeg_labels = ['EEG {}'.format(ch) for ch in channel_order] self._postprocess_result(result, channels=eeg_labels) return result
[docs] def sleep_stages_robust(self, verbose=1, **kwargs): """ Automatic robust sleep stage classification. This is a wrapper that prepares the input for SleepStagesRobust.process() and returns the result. Args: verbose (int, optional): verbose level (0 or 1). Defaults to 1. **kwargs (optional): optional keyword arguments to the SleepStagesRobust class. Returns: result (nnsa.SleepStagesCnnResult): nnsa object containing the results. """ if verbose > 1: print('Classifying robust sleep stages of {} with {} channels: {}' .format(self.label, len(self.time_series), [ts.label for ts in self.time_series.values()])) # Select channels and reference. if 'Cz' in self: eeg_ds = self.reference('Cz', inplace=False, verbose=0) else: # Assume already referenced. eeg_ds = self channels = ['Fp1', 'Fp2', 'C3', 'C4', 'T3', 'T4', 'O1', 'O2'] eeg_ds = eeg_ds.extract_channels(channels=channels, make_copy=False) # Prepare to process. eeg_ = eeg_ds.asarray(channels_last=True) fs_ = eeg_ds.fs # Process. result = SleepStagesRobust(**kwargs).process(eeg=eeg_, fs=fs_) eeg_ds._postprocess_result(result) return result
[docs] def substitute_bad_channels(self, af_mask, inplace=False, **kwargs): """ Find segments with bad channels in the EEG and substitute by the mean of the good channels. Args: af_mask (np.ndarray or EegDataset): boolean mask for the EEG array data with True at locations of artefacts. Shape should be (n_channels, n_time). An incompatible shape will be transposed if that makes it compatible, or raises an error otherwise. inplace (bool): whether to apply inplace (True) or not (False). kwargs: for substitute_bad_channels. Returns: ds_out (EegDataset): cleaned EegDataset with bad channels substituted. """ if inplace: ds_out = self else: ds_out = copy.deepcopy(self) # Get data. eeg, channel_labels = ds_out.asarray(channels_last=False, return_channel_labels=True) if isinstance(af_mask, EegDataset): af_mask = af_mask.asarray(channels=channel_labels, channels_last=False) else: af_mask = np.asarray(af_mask) # Check shapes. if eeg.shape != af_mask.shape: if af_mask.T.shape == eeg.shape: # Transpose mask. af_mask = af_mask.T else: raise ValueError('`af_mask` ({}) should have the same shape as the EEG data ({}).' .format(af_mask.shape, eeg.shape)) # Clean. eeg_cleaned, _ = substitute_bad_channels(x=eeg, af_mask=af_mask, fs=ds_out.fs, axis=-1, **kwargs) # Replace the old by the new signals. for eeg_chan, lab in zip(eeg_cleaned, channel_labels): ds_out[lab].signal = eeg_chan # Only return if not in place. if not inplace: return ds_out
def _check_label(self, label): """ Return the standardized label for the given (unstandardized) EEG label or raise an error if the given label is not a valid label for EEG or is not in the current dataset. Args: label (str): (unstandardized) EEG label. Returns: std_label (str): standardized EEG label if the given label was a valid EEG label. Raises: ValueError if the label is not in the dataset. """ if label in self.time_series: return label # Standardize the EEG label. std_label = standardize_and_check_eeg_label(label)[0] if std_label in self.time_series: return std_label # Raise error if the specified channel is not in the dataset. raise KeyError('Channel "{}" not in dataset. Channels in dataset: {}.' .format(label, list(self.time_series.keys()))) def _prepare_sleep_stages_cnn(self, sleep_stages_cnn, preprocess=None, remove_flats=False, remove_baseline_wander=False, stepsize=None, verbose=1, **kwargs): """ Prepare the input array for SleepStagesCnn.sleep_stages_cnn(). Args: sleep_stages_cnn (nnsa.SleepStagesCnn): initialized SleepStagesCnn object. preprocess (bool, optional): see self.sleep_stages_cnn(). If None, sets it to True is sampling frequency is not as required for sleep staging input. remove_flats (bool, optional): if True, removes flatlines in EEG (replace by nans). remove_baseline_wander (bool): if True, removes excessive baseline wander by applying a HP filter. This was not part of original preprocessing during training, and most EEGs don't need this (in the original preprocessing there is already a highpass filter, but milder). If not needed, don't do this additional baseline removal. stepsize (float): stepsize (in seconds) for sleep stage prediction. If None, is set to segment length. verbose (int, optional): verbosity level. Defaults to 1. **kwargs (optional): optinal arguments for the sleep_stages_cnn.preprocess_recording function. Returns: x (np.ndarray): input for SleepStagesCnn.sleep_stages_cnn(). """ # Extract data requirements. channel_order = sleep_stages_cnn.data_requirements['channel_order'] segment_length = sleep_stages_cnn.data_requirements['segment_length'] fs_req = sleep_stages_cnn.data_requirements['fs'] stepsize = stepsize if stepsize is not None else segment_length # Make a copy of self so that we do not mutate the original object. eeg_ds = copy.deepcopy(self) fs = eeg_ds.fs if preprocess is None: # Preprocess if sampling frequency is not as required. preprocess = fs != fs_req # EEG labels (in required order that they must appear in the input array). eeg_labels = ['EEG {}'.format(ch) for ch in channel_order] # Reference to Cz. If Cz is not in the collection, assume that it is already referenced to Cz. if 'Cz' in eeg_ds: eeg_ds.reference('Cz', verbose=verbose, inplace=True) # Remove flatlines. if remove_flats: eeg_ds.remove_flatlines(n_flatline=1, inplace=True, verbose=verbose) # Remove baseline wander. if remove_baseline_wander: hp_filter = Butterworth(fn=[0.01], filter_type='highpass', order=2, fs=fs) eeg_ds.filter(hp_filter, inplace=True) # EEG data as array (time, channels). eeg_array = eeg_ds.asarray(channels=eeg_labels, channels_last=True) if preprocess: if verbose > 0: print('Preprocessing EEG data for sleep stage classification...') eeg_array, fs = sleep_stages_cnn.preprocess_recording( raw_eeg=eeg_array, raw_fs=fs, verbose=verbose, **kwargs) # Segment the EEG signals (all channels simultaneously). x = get_all_segments(eeg_array, fs=fs, segment_length=segment_length, overlap=segment_length - stepsize, axis=0) return x
[docs]class OxygenDataset(BaseDataset): """ High-level interface for processing oxygen data for neonatal signal analysis. """
[docs] def nirs_features(self, verbose=1, **kwargs): """ Compute feature set for NIRS. This is a wrapper that prepares the input for NirsFeatures.wct() and returns the result. Args: verbose (int, optional): verbose level. Defaults to 1. **kwargs (optional): optional keyword arguments to overrule default parameters of the NirsFeatures class. Returns: result (nnsa.FeatureSetResult): the result with the features. """ if verbose > 0: print('Computing NIRS features of {} with {} channels: {}' .format(self.label, len(self.time_series), [ts.label for ts in self.time_series.values()])) # Initialize feature extraction object (updates default parameters with user specified keyword arguments). from nnsa.feature_extraction.feature_sets_old import NirsFeatures nf = NirsFeatures(**kwargs) # Prepare input matrix for feature extraction function. data_matrix, channel_labels = self.asarray(return_channel_labels=True) # Sample frequency of all EEG signals is the same if self.asarray() did not raise an error. fs = next((ts.fs for ts in self.time_series.values())) # Process. result = nf.process(data_matrix, fs=fs, channel_labels=channel_labels, verbose=verbose) # Add info string about data to result. self._postprocess_result(result) return result
[docs] def remove_artefacts(self, inplace=False, **kwargs): """ Replace samples that are artefacts by np.nan Args: inplace (bool, optional): If True, the samples are replaced directly from the Dataset object itself. If False, the function returns a copy of the original Dataset object in which the artefacted samples are replaced, leaving the data in the current Dataset unchanged. Default is False. **kwargs (optional): optional keyword arguments for overruling default sample quality criteria (see nnsa.artefacts.artefact_detection.default_oxygen_sample_quality_criteria()). Returns: (OxygenDataset): new Dataset object containing the same signals,but with artefacted samples changed to np.nan (only returned if inplace is False). """ # Update default criteria with user-specified arguments. criteria = default_oxygen_sample_quality_criteria() criteria.update(kwargs) return super().remove_artefacts(inplace=inplace, **criteria)