"""
Algorithms related to analysis of discontinuity of the EEG.
"""
import copy
import csv
import sys
import warnings
from collections import defaultdict
import h5py
import numpy as np
import pyprind
import scipy.signal
import pandas as pd
from matplotlib import pyplot as plt
from scipy.stats import skew, kurtosis
from nnsa.artefacts.artefact_detection import default_eeg_signal_quality_criteria
from nnsa.preprocessing.filter import RemezFIR, WinFIR
from nnsa.annotations.annotation_set import AnnotationSet
from nnsa.annotations.annotation import Annotation
from nnsa.annotations.config import NO_LABEL, SLEEP_LABELS
from nnsa.containers.time_series import TimeSeries
from nnsa.feature_extraction.envelope import power_envelope
from nnsa.feature_extraction.result import ResultBase
from nnsa.feature_extraction.common import check_multichannel_data_matrix, aggregate_channel_events, \
baseline_correction_min, \
prepare_postfix
from nnsa.feature_extraction.time_domain import compute_flatness
from nnsa.utils.event_detections import time_threshold, join_events, get_onsets_offsets
from nnsa.feature_extraction.fractality import LineLength
from nnsa.parameters.parameters import ClassWithParameters, Parameters
from nnsa.preprocessing.resample import resample_by_filtering
from nnsa.utils.arrays import interp_nan, moving_envelope, moving_average
from nnsa.utils.config import HORIZONTAL_RULE
from nnsa.utils.other import enumerate_label, convert_string_auto
from nnsa.utils.plotting import shade_axis
__all__ = [
'BurstDetection',
'BurstDetectionResult',
'SuppressionCurve',
'IbiFeaturesResult',
]
from nnsa.utils.segmentation import segment_generator
[docs]class BurstDetection(ClassWithParameters):
"""
Class for burst detection.
Main method: burst_detection().
Args:
see nnsa.ClassWithParameters.
Examples:
>>> np.random.seed(0)
>>> x = np.random.rand(10, 100000)
>>> bd = BurstDetection()
>>> print(type(bd.parameters).__name__)
Parameters
>>> result = bd.burst_detection(x, fs=256, verbose=0)
>>> print(type(result).__name__)
BurstDetectionResult
>>> result.bursts[0, 20000]
0.0
"""
[docs] @staticmethod
def default_parameters():
"""
Return the default parameters.
Returns:
(nnsa.Parameters): a default set of parameters for the object.
"""
pars = {
# The method/algorithm to use for burst detection. Choose from:
# 'envelope', 'line_length', 'NLEO', 'OToole'.
# See the dedicated methods in this class for each of these methods for more information
# (e.g. self.line_length_burst_detection()):
'method': 'line_length',
# Optional additional keyword arguments/parameters for the method/function that detects the bursts.
# These keyword arguments depend on the method specified above. E.g. if method is set to 'NLEO', see the
# function self.nleo_burst_detection() for the optional keyword argument that you can specify here:
'method_kwargs': {},
}
return Parameters(**pars)
[docs] def burst_detection(self, data_matrix, fs, channel_labels=None, verbose=1):
"""
Perform burst detection on multichannel data.
Args:
data_matrix (np.ndarray): EEG data. See check_multichannel_data_matrix().
The data might be filtered before, depending on the method for burst detection.
Most methods, do their own specific filtering. See the documentation and code of the
specific methods.
fs (float): sample frequency of the EEG signals.
channel_labels (list of str, optional): see check_multichannel_data_matrix().
verbose (int, optional): verbose level.
Defaults to 1.
Returns:
(nnsa.BurstDetectionResult): object containing the burst detection result per channel.
"""
# Check input.
data_matrix, channel_labels = check_multichannel_data_matrix(data_matrix, channel_labels)
if verbose > 0:
print(HORIZONTAL_RULE)
print('Running burst_detection with parameters:')
print(self.parameters)
# Extract some parameters.
method = self.parameters['method'].lower()
method_kwargs = self.parameters['method_kwargs']
# Call the corresponding function to compute the masks for the bursts and inter-burst-intervals (IBIs).
if method == 'dibi':
bursts, ibis = self.dibi_burst_detection(data_matrix, fs=fs, **method_kwargs)
# The dIBI method combines information of all channels to give one global detection array.
channel_labels = ['GLOBAL']
elif method == 'envelope':
bursts, ibis = self.envelope_burst_detection(data_matrix, fs=fs, **method_kwargs)
# The envelope method combines information of all channels to give one global detection array.
channel_labels = ['GLOBAL']
elif method == 'line_length':
bursts, ibis, fs = self.line_length_burst_detection(data_matrix, fs=fs, **method_kwargs)
# The line length method combines information of all channels to give one global detection array.
channel_labels = ['GLOBAL']
elif method == 'nleo':
bursts, ibis, fs = self.nleo_burst_detection(data_matrix, fs=fs, verbose=verbose, **method_kwargs)
elif method == 'otoole':
bursts, ibis = self.otoole_burst_detection(data_matrix, fs=fs, verbose=verbose, **method_kwargs)
else:
raise ValueError('Invalid method "{}". Choose from "{}".'.format(method, ['dibi', 'envelope', 'line_length',
'NLEO', 'OToole']))
result = BurstDetectionResult(bursts=bursts, ibis=ibis, fs=fs,
algorithm_parameters=self.parameters,
channel_labels=channel_labels)
return result
[docs] @staticmethod
def dibi_burst_detection(data_matrix, fs, verbose=1):
"""
Burst detection method as proposed by Vladimir Matic.
See detect_dibi() for details.
Args:
data_matrix (np.ndarray): unfiltered EEG data, see check_multichannel_data_matrix().
fs (float): sample frequency of the EEG signals.
verbose (int, optional): verbose level.
Defaults to 1.
Returns:
bursts (np.ndarray): array with dimensions (1, time) containing 1s at locations of bursts and 0s
at locations of non-bursts.
ibis (np.ndarray): array with dimensions (1, time) containing 1s at locations of
inter-burst-intervals (IBIs) and 0s at locations of non-IBIs.
"""
ibis = detect_dibi(eeg=data_matrix, fs=fs, verbose=verbose, show_plots=False)
# ibis must have shape (channels, time).
ibis = np.expand_dims(ibis, axis=0)
# Bursts array is the binary complement of IBIs array.
bursts = (ibis - 1) ** 2
return bursts, ibis
[docs] @staticmethod
def envelope_burst_detection(data_matrix, fs, lc=0.5, hc=32, notch=50,
amplitude_thr_high=30, channel_thr_high=2,
min_ibi_dur=1, max_burst_dur=20,
amplitude_thr_low=30, channel_thr_low=None):
"""
Detect bursts using the signal envelope (signal energy).
Implementation of the algorithm described by Jennekens et al. 2011.
Notes:
The default parameters are the optimal values according to the papaer by Jennekens et al. 2011.
References:
Jennekens W , Ruijs LS , Lommen CML , Niemarkt HJ , Pasman JW , van Kranen–Mastenbroek VM , et al.
Automatic burst detection for the EEG of the preterm infant.
Physiol Meas 2011;32(10):1623–37 .
Args:
data_matrix (np.ndarray): unfiltered EEG data (in uV), see check_multichannel_data_matrix().
fs (float): sample frequency of the EEG signals.
lc (float, optional): low cut-off filter frequency (LC) in Hz.
Defaults to 0.5.
hc (float, optional): high cut-off filter frequency (HC) in Hz.
Defaults to 32.
notch (float, optional): notch filter frequency in Hz.
Defaults to 50.
amplitude_thr_high (float, optional): amplitude-threshold-high in uV (ATH). Sample points with an
envelope value equal to or above ATH are considered high-voltage activity.
Defaults to 30.
channel_thr_high (int, optional): channel-threshold-high (CTH). If the number of channels with
high voltage is >= CTH, the global EEG-activity is considered as high-voltage activity.
Defaults to 2.
min_ibi_dur (float, optional): the minimum separation between 2 bursts in seconds. If two detected
periods of high-voltage activity appear within min_ibi_dur after each other it is assumed that
they belong together and form one period.
Defaults to 1.
max_burst_dur (float, optional): the maximum duration of bursts in seconds. High-voltage
periods >=max_burst_dur are classified as continuous patterns while periods < max_burst_dur
are classified as bursts.
Defaults to 20.
amplitude_thr_low (float, optional): amplitude-threshold low in uV (ATL). If the envelope value is lower
than the ATL, the corresponding sample is a candidate for IBI.
Defaults to 30.
channel_thr_low (int, optional): channel-threshold-low (CTL). The minimum number of channels with low
voltage activity for the sample to be considered an IBI.
If None, the total number of channels in the input data_matrix is used.
Defaults to None.
Returns:
bursts (np.ndarray): array with dimensions (1, time) containing 1s at locations of bursts and 0s
at locations of non-bursts.
ibis (np.ndarray): array with dimensions (1, time) containing 1s at locations of
inter-burst-intervals (IBIs) and 0s at locations of non-IBIs.
"""
if channel_thr_low is None:
channel_thr_low = data_matrix.shape[0] # All channels/electrodes.
# Check if sample frequency is adequately high for filtering.
min_fs = hc*2 + 5 # Use a margin of 5 Hz.
if fs < min_fs:
raise ValueError('Use a sample frequency >= {} Hz for this method (the method involves bandpass '
'filtering between {}-{} Hz. The given fs is {}.'.format(min_fs, lc, hc, fs))
# Filter.
# High-pass Butterworth.
b_hp, a_hp = scipy.signal.butter(N=4, Wn=lc, btype='highpass', fs=fs)
# Low pass Butterworth.
b_lp, a_lp = scipy.signal.butter(N=4, Wn=hc, btype='lowpass', fs=fs)
# Notch.
b_notch, a_notch = scipy.signal.iirnotch(w0=notch, Q=30, fs=fs)
# Filter the signals per channel.
x = scipy.signal.filtfilt(b_hp, a_hp, data_matrix, axis=-1)
x = scipy.signal.filtfilt(b_notch, a_notch, x, axis=-1)
x = scipy.signal.filtfilt(b_lp, a_lp, x, axis=-1)
# Detect bursts.
# Compute envelope.
n_window = int(1*fs) # 1 second window.
envelope = power_envelope(x, n_window=n_window)
# Thresholding on amplitude.
high_voltage_per_channel = (envelope > amplitude_thr_high).astype(int)
# Combine channels.
high_voltage = aggregate_channel_events(high_voltage_per_channel,
min_channels=channel_thr_high,
min_channels_elong=1)
# Join periods with separation < min_ibi_dur.
high_voltage = join_events(high_voltage, min_separation=min_ibi_dur * fs)
# Time-threshold for detection of bursts.
bursts = time_threshold(high_voltage, max_duration=max_burst_dur * fs)
# # For verification/visualization of the algorithm, the following plots may be insightful.
# plt.figure()
# idx_chan = 4
# t = np.arange(len(x[idx_chan]))/fs
# plt.plot(t, x[idx_chan], label='EEG')
# plt.plot(t, envelope[idx_chan], label='envelope')
# plt.axhline(amplitude_thr_high, color='k', label='amplitude_thr')
# plt.plot(t, high_voltage_per_channel[idx_chan], label='high_voltage_channel')
# plt.legend()
# plt.figure()
# plt.plot(t, np.nansum(high_voltage_per_channel, axis=0), label='num_channels_hv')
# plt.axhline(channel_thr_high, color='k', label='channel_thr')
# plt.plot(t, bursts, label='bursts')
# plt.legend()
# Detect IBIs.
# Threshold on amplitude.
low_voltage_per_channel = (envelope < amplitude_thr_low).astype(int)
# Threshold on number of channels with low amplitude.
low_voltage_num_channels = np.nansum(low_voltage_per_channel, axis=0)
low_voltage = np.logical_and(~high_voltage.astype(bool), low_voltage_num_channels >= channel_thr_low)
low_voltage = low_voltage.astype(int)
# Time-threshold for detection of IBIs.
ibis = time_threshold(low_voltage, min_duration=min_ibi_dur * fs)
# Remove boundary effects from moving average window.
idx_boundary = int(np.ceil((n_window - 1) / 2))
bursts = bursts.astype(float)
bursts[:idx_boundary] = np.nan
bursts[-idx_boundary:] = np.nan
ibis = ibis.astype(float)
ibis[:idx_boundary] = np.nan
ibis[-idx_boundary:] = np.nan
# Reshape to dimensions corresponding to (channel, time).
bursts = bursts.reshape(1, -1)
ibis = ibis.reshape(1, -1)
return bursts, ibis
[docs] @staticmethod
def line_length_burst_detection(data_matrix, fs, F_1=0.85, F_2=0.40, threshold_window=None,
min_ibi_dur=2, min_burst_amp=30, verbose=1, **line_length_kwargs):
"""
Detect bursts using the line length as introduced by Koolen et al.
Notes:
The default parameters are the optimal values according to the paper by Koolen et al. 2014.
In that paper, the algorithm was optimized for unipolar multi-channel EEG sampled at 250 Hz.
If the given signals have a different sample frequency, the signals are resampled to 250 Hz.
The postprocessing that checks whether the high energy parts are valid bursts is a bit simpler
than mentioned in the paper and also differs from the original Matlab code. In this code, high energy
periods (candidates for bursts) are considered burst only if they contain a (normalized) line length value
higher than some threshold (controlled by `F_2` parameter) AND a high voltage ampltiude in any channel of
the original signal (controlled by `min_burst_amp` parameter).
More specifically for this first criterion: the median normalized line length (me_LL) in a high energy
period must reach at least F_1*mean(me_LL) + F_2*std(me_LL) to be a burst (in Koolen et al., this is
referred to as a 'pronounced peak').
References:
Koolen, N. et al. Line length as a robust method to detect high-activity events: Automated
burst detection in premature EEG recordings. Clinical Neurophysiology 125, 1985{1994
(2014).
Args:
data_matrix (np.ndarray): (filtered) EEG data, see check_multichannel_data_matrix().
Koolen et al. originally used EEG signals bandpass filtered between 1 - 20 Hz.
fs (float): sample frequency of the EEG signals. Must be 250 Hz or higher.
F_1 (float, optional): scale factor for the burst detection (`F` in Eq. 4 of the paper).
Defaults to 0.85.
F_2 (float, optional): scale factor for the standard deviation (used to detect pronounced peaks in me_LL).
Defaults to 0.40.
threshold_window (float, optional): window for the adaptive threshold in seconds. The threshold is adapted
in periods of `threshold_window` seconds, using a moving average of the given window size.
If set to None, the threshold is fixed for the entire signal.
Defaults to None.
min_ibi_dur (float, optional): the minimum separation between 2 bursts in seconds. If two detected
periods of high-voltage activity appear within min_ibi_dur after each other it is assumed that
they belong together and form one burst period.
Defaults to 2.
min_burst_amp (float, optional): the minimum maximum EEG amplitude of a high energy period to be considered
a burst.
Defaults to 30.
verbose (int, optional): verbose level.
Defaults to 1.
**line_length_kwargs (optional): keyword arguments with parameters for
nnsa.LineLength(**line_length_kwargs).
See nnsa.LineLength().
Returns:
bursts (np.ndarray): array containing 1s at locations of bursts and 0s
at locations of non-bursts.
ibis (np.ndarray): array containing 1s at locations of
inter-burst-intervals (IBIs) and 0s at locations of non-IBIs.
fs_LL (float): sample frequency of corresponding to the output arrays.
"""
if fs > 250:
print('Resampling to 250 Hz for burst detection using line length.')
data_matrix = resample_by_filtering(data_matrix, fs=fs, fs_new=250, axis=-1)
fs = 250
elif fs < 250:
raise ValueError('EEG sampling frequency must be >= 250 Hz. Got {} Hz.'.format(fs))
# Line length parameters.
default_line_length_kwargs = dict(
# Set normalization window to 10 minutes by default, since Koolen et al. trained on 10 minute segments.
segmentation={'segment_length': 1,
'overlap': 0.12},
artefact_criteria={'max_nan_frac': 1e-12},
line_length={'normalization_kind': 'mean_segments',
'normalization_window': 10*60,
'normalize_in_moving_window': False}
)
default_line_length_kwargs.update(line_length_kwargs)
# Compute line length.
line_length_obj = LineLength(**default_line_length_kwargs)
line_length_result = line_length_obj.line_length(data_matrix, fs=fs, verbose=verbose)
line_length = line_length_result.line_length
# Sample frequency of the line length array.
seg_pars = line_length_obj.parameters['segmentation']
fs_LL = 1 / (seg_pars['segment_length'] - seg_pars['overlap'])
# Take the median over the channels (Eq. 3 in the paper).
me_LL = np.nanmedian(line_length, axis=0)
# Compute the mean me_LL that scales the threshold for burst detection.
if threshold_window is None:
# Use the same threshold for all segments.
mean_me_LL = np.full(me_LL.shape, np.nanmean(me_LL))
std_me_LL = np.full(me_LL.shape, np.nanstd(me_LL))
else:
# Adapt the threshold in specified window.
num_segments_in_win = int(np.round(threshold_window * fs_LL))
mean_me_LL = moving_average(me_LL, n=num_segments_in_win)[0]
# Moving std.
z = (me_LL - mean_me_LL) ** 2
std_me_LL = np.sqrt(moving_average(z, n=num_segments_in_win)[0])
# Store locations of nan values so we can remove the output at these locations later.
nan_mask = np.isnan(me_LL)
# Interpolate nan values to make life easier.
me_LL = interp_nan(me_LL)
# Compute the thresholds for burst detection.
thr_det = F_1 * mean_me_LL
thr_diff = F_2 * std_me_LL
# Determine if samples have the high energy.
he = me_LL > thr_det
# Find onsets and offsets of high energy events.
d = np.diff(he.astype(int))
onsets = np.where(d == 1)[0] + 1
endings = np.where(d == -1)[0] + 1
# Loop over high energy events and check the secondary requirements for burst detection.
bursts = np.zeros_like(me_LL)
if len(endings) > 0 and len(onsets) > 0:
if endings[0] < onsets[0]:
# If first ending is earlier than first onset, then add onset at sample 0 to beginning of array.
onsets = np.insert(onsets, 0, 0)
if endings[-1] < onsets[-1]:
# If last ending is earlier than last onset, then add ending at last sample to end of array.
endings = np.append(endings, len(d)+1)
for on_idx, end_idx in zip(onsets, endings):
# Check if the high energy event is a pronounced peak.
# Define some minimum peak value and classify as pronounced peak if the burst exceeds this value.
minimum_peak_value = thr_det[on_idx] + thr_diff[on_idx]
pronounced_peak = np.any(me_LL[on_idx: end_idx] > minimum_peak_value)
# Classify as high amplitude if the EEG (any channel) exceeds the amplitude threshold.
high_amplitude = np.any(
data_matrix[:, int(round(on_idx * fs / fs_LL)): int(round(end_idx * fs / fs_LL))] > min_burst_amp)
# Finally, classify as burst if the high energy event has a pronounced peak and high amplitude.
if pronounced_peak and high_amplitude:
# Burst.
bursts[on_idx: end_idx] = 1
# IBIs array is the binary complement of bursts array.
ibis = (bursts - 1) ** 2
# Remove IBIs shorter than min_ibi_duration seconds.
ibis = time_threshold(ibis, min_duration=min_ibi_dur * fs_LL)
# Remove values at locations that were nan originally.
ibis = ibis.astype(float) # Needed to insert nans.
ibis[nan_mask] = np.nan
# Reshape to dimensions corresponding to (channel, time).
ibis = ibis.reshape(1, -1)
# Get the burst mask from the IBI mask (now without short IBIs).
bursts = (ibis - 1) ** 2
return bursts, ibis, fs_LL
[docs] @staticmethod
def nleo_burst_detection(data_matrix, fs, lc=0.5, hc=10, max_ripple=1, min_attenuation=40,
window_avg=1.5, window_baseline=60, sat_thr=1.5, min_duration=1, verbose=1):
"""
Detect bursts using the non-linear energy operator (NLEO) as introduced by Palmu et al.
Notes:
The default parameters are the optimal values according to the papaer by Palmu et al. 2010.
Here, the NLEO algorithm was optimized for the P3-P4 bipolar channel, using data sampled at 256 Hz.
However, since the first step of the algorithm is bandpass filtering between 0.5-10 Hz,
sampling frequencies >= 25 Hz are accepted. Then after filtering, the signal is upsampled to
256 Hz.
References:
Palmu K , Stevenson N , Wikström S , Hellström-Westas L , Vanhatalo S , Palva JM .
Optimization of an NLEO-based algorithm for automated detec- tion of spontaneous activity transients
in early preterm EEG.
Physiol Meas 2010;31(11):N85–93 .
Args:
data_matrix (np.ndarray): unfiltered EEG data (in uV), see check_multichannel_data_matrix().
fs (float): sample frequency of the EEG signals.
lc (float, optional): low cut-off filter frequency (LC) in Hz.
Defaults to 0.5.
hc (float, optional): high cut-off filter frequency (HC) in Hz.
Defaults to 10.
max_ripple (float, optional): max ripple in dB for the low pass elliptic filter, see scipy.signal.ellip().
Note that this value was not mentioned in the paper.
Defaults to 1.
min_attenuation (float, optional): min attenuation in dB for the low pass elliptic filter,
see scipy.signal.ellip().
Note that this value was not mentioned in the paper.
Defaults to 40.
window_avg (float, optional): window in seconds for averaging the absolute value of the NLEO output (g).
Defaults to 1.5.
window_baseline (float, optional): window in seconds for finding the baseline value (the minimum value in
this window is the baseline value, which is subtracted).
Defaults to 60.
sat_thr (float, optional): threshold on x_nleo (in uV^2) for detection of bursts.
Defaults to 1.5.
min_duration (float, optional): minimum duration of bursts in seconds.
Defaults to 1.
verbose (int, optional): verbose level.
Defaults to 1.
Returns:
bursts (np.ndarray): array containing 1s at locations of bursts and 0s
at locations of non-bursts.
ibis (np.ndarray): array containing 1s at locations of
inter-burst-intervals (IBIs) and 0s at locations of non-IBIs.
fs (float): sample frequency of corresponding to the output arrays.
"""
# Check if sample frequency is adequately high for filtering.
min_fs = hc*2 + 5 # Use a margin of 5 Hz.
if fs < min_fs:
raise ValueError('Use a sample frequency >= {} Hz for this method (the method involves bandpass '
'filtering between {}-{} Hz. The given fs is {}.'.format(min_fs, lc, hc, fs))
# Filtering.
# High-pass Butterworth.
b_hp, a_hp = scipy.signal.butter(1, lc, btype='highpass', fs=fs)
# Low pass Elliptic.
b_lp, a_lp = scipy.signal.ellip(6, rp=max_ripple, rs=min_attenuation, Wn=hc, btype='lowpass', fs=fs)
# Filter the signals per channel.
x = scipy.signal.filtfilt(b_hp, a_hp, data_matrix, axis=-1)
x = scipy.signal.filtfilt(b_lp, a_lp, x, axis=-1)
# Resampling to 256 Hz.
if fs != 256:
x = resample_by_filtering(x, fs=fs, fs_new=256, axis=-1)
fs = 256
# Compute NLEO.
g = np.zeros(x.shape)
g[:, 3:] = np.abs(x[:, 3:] * x[:, :-3] - x[:, 2:-1] * x[:, 1:-2])
# Average in window.
numtaps = int(np.ceil(window_avg * fs))
kernel_shape = np.ones(g.ndim, dtype=int)
kernel_shape[-1] = numtaps
g_avg = scipy.signal.convolve(g, np.ones(kernel_shape) / numtaps, mode='same')
# Remove boundary effects.
idx_boundary = int(np.ceil((numtaps - 1) / 2))
g_avg[:, :idx_boundary + 3] = np.nan # Note that we skip the 3 first samples as they are not defined.
g_avg[:, -idx_boundary:] = np.nan
# Baseline correction.
window_len_baseline = int(window_baseline * fs)
x_nleo = np.zeros_like(g_avg)
# Initialize progress bar.
bar = pyprind.ProgBar(g_avg.shape[0], stream=sys.stdout)
# Loop over channels.
for j, g_avg_channel in enumerate(g_avg):
# Compute baseline corrected value for g_avg by subtracting the minimum value in a window before
# each sample.
x_nleo[j] = baseline_correction_min(g_avg_channel, window_length=window_len_baseline)
# Update progress bar.
if verbose > 0:
bar.update()
# Set values to nan for which there is no preceding window of the desired size available for baseline
# correction.
x_nleo[:, :idx_boundary + window_len_baseline - 1 + 3] = np.nan
# Classification.
min_dur = int(min_duration*fs)
bursts_pre = x_nleo >= sat_thr
bursts = np.zeros(bursts_pre.shape, dtype=float)
for j, burst_pre_channel in enumerate(bursts_pre):
# Implementation of the pseudo-code in Table 1 (Palmu et al. 2010).
detection_length = 0
for i, bp in enumerate(burst_pre_channel):
if bp:
detection_length += 1
if detection_length == min_dur:
# Set all previous samples up to the current one to 1.
bursts[j, np.max([0, i - min_dur + 1]): i + 1] = 1
elif detection_length > min_dur:
# Set the current sample to 1.
bursts[j, i] = 1
else:
detection_length = 0
# Remove boundary effects.
bursts[:, :idx_boundary + window_len_baseline - 1 + 3] = np.nan
bursts[:, -idx_boundary:] = np.nan
# IBIs array is the binary complement of bursts array.
ibis = (bursts - 1)**2
# # For verification/visualization of the algorithm, the following plots may be insightful.
# plt.figure()
# t = np.arange(len(x[0])) / fs
# plt.plot(t, x[0], label='x')
# plt.plot(t, g[0], label='g')
# plt.plot(t, g_avg[0], label='g_avg')
# plt.plot(t, min_g_avg[0], label='min_g_avg')
# plt.plot(t, x_nleo[0], label='x_nleo')
# plt.axhline(sat_thr, color='k', label='sat_thr')
# plt.plot(t, bursts[0], label='bursts')
# plt.legend(loc='upper right')
# plt.xlabel('Time (seconds)')
return bursts, ibis, fs
[docs]class BurstDetectionResult(ResultBase):
"""
High-level interface for processing the results of burst detection analysis as created by
nnsa.BurstDetection().
Args:
bursts (np.ndarray): array with dimensions (channels, time) containing 1s or True at locations of bursts and 0s
or False at locations of non-bursts. May also contain np.nan for indicating missing values.
ibis (np.ndarray): array with dimensions (channels, time) containing 1s or True at locations of
inter-burst-intervals (IBIs) and 0s or False at locations of non-IBIs. May also contain np.nan for
indicating missing values.
algorithm_parameters (nnsa.Parameters): see ResultBase.
fs (float): sample frequency corresponding to the bursts and ibis arrays.
nan_mask (np.array, optional): boolean array with True at locations of missing values and False at locations
without missing values. If not None, this nan_mask is applied to `bursts` and `ibis` to set missing values
to np.nan. This way, the bursts and ibis arrays can be saved as boolean arrays, even if there are missing
values (memory efficient). Must have the same shape as `bursts` and `ibis`.
Defaults to None.
channel_labels (list of str, optional): labels of the channels corresponding to the channel dimensions of
the arrays.
If None, default labels will be created.
Defaults to None.
data_info (str, optional): see ResultBase.
segment_start_times (np.ndarray, optional): see ResultBase.
segment_end_times (np.ndarray, optional): see ResultBase.
"""
def __init__(self, bursts, ibis,
algorithm_parameters, fs, nan_mask=None, channel_labels=None, data_info=None,
segment_start_times=None, segment_end_times=None):
super().__init__(algorithm_parameters=algorithm_parameters, data_info=data_info,
segment_start_times=segment_start_times, segment_end_times=segment_end_times, fs=fs)
# Input check.
data_shape = bursts.shape
if len(data_shape) < 2:
raise ValueError('Invalid input shape: bursts.shape = {}. Bursts should have dimensions (channels, time).'
.format(bursts.shape))
if channel_labels is None:
channel_labels = enumerate_label(data_shape[0], label='Channel')
elif len(channel_labels) != data_shape[0]:
raise ValueError('Length of channel_labels ({}) does not correspond to the shape of the data {}.'
.format(len(channel_labels), data_shape))
# Store variables that are not already stored by the parent class (ResultBase).
# Cast to float, so we are sure that we have 0s and 1s and we can use np.nan to indicate missing values.
self.bursts = bursts.astype(float)
self.ibis = ibis.astype(float)
if nan_mask is not None:
self.nan_mask = nan_mask.astype(bool)
self.bursts[self.nan_mask] = np.nan
self.ibis[self.nan_mask] = np.nan
else:
self.nan_mask = np.logical_or(np.isnan(self.bursts), np.isnan(self.ibis))
self.channel_labels = channel_labels
if self.is_discontinuous():
raise ValueError('Data is discontinuous. This is invalid for the {} class.'.format(self.__class__.__name__))
@property
def burst_id(self):
"""
Return the class number for bursts.
Returns:
(int): class number for bursts.
"""
return 2
@property
def burst_label(self):
"""
Return the class label for bursts.
Returns:
(str): class label for bursts.
"""
return 'SAT'
@property
def class_labels(self):
"""
Return the class labels corresponding to self.class_numbers().
Returns:
class_labels (dict): dictionary that maps a class number to a class label.
"""
class_labels = {
np.nan: NO_LABEL,
-1: NO_LABEL,
self.ibi_id: self.ibi_label,
self.burst_id: self.burst_label,
}
return class_labels
@property
def duration(self):
"""
Return the total duration of the recorded samples.
Ignores nans.
Returns:
(float): duration in seconds.
"""
return np.sum(~np.isnan(self.bursts[0])) / self.fs
@property
def ibi_id(self):
"""
Return the class number for ibis.
Returns:
(int): class number for ibis.
"""
return 1
@property
def ibi_label(self):
"""
Return the class label for ibis.
Returns:
(str): class label for ibis.
"""
return 'IBI'
@property
def num_segments(self):
"""
Return the number of segments (number of samples).
Returns:
(int): number of segments/samples.
"""
return self.bursts.shape[-1]
[docs] def aggregate_bursts(self, min_channels_frac=2/8, min_channels_elong_frac=1/8):
"""
Combine the burst detection of all channels to get an aggregate, global brust detection array.
Args:
min_channels_frac (int, optional): minimum fraction of channels that must detect a burst in order
to consider it a global burst (after aggregation). See aggregate_channel_events().
Defaults to 2/8.
min_channels_elong_frac (int, optional): minimum fraction of channels that must detect a burst when
elongating the globally detected bursts. See aggregate_channel_events().
Defaults to 1/8.
Returns:
bursts (np.ndarray); array with dimensions corresponding to (1, time) containing 1s and 0s to indicate
time instants where bursts occur.
"""
if self.bursts.shape[0] > 1:
min_channels = int(round(min_channels_frac*self.bursts.shape[0]))
min_channels_elong = max([1, int(round(min_channels_elong_frac*self.bursts.shape[0]))])
bursts = aggregate_channel_events(self.bursts, min_channels=min_channels, min_channels_elong=min_channels_elong)
return bursts.reshape(1, -1)
else:
return self.bursts
[docs] def aggregate_ibis(self, min_channels_frac=1, min_channels_elong_frac=1):
"""
Combine the IBI detection of all channels to get an aggregate, global IBI detection array.
Args:
min_channels_frac (int, optional): minimum fraction of channels that must detect an IBI in order
to consider it a global IBI (after aggregation). See aggregate_channel_events().
Defaults to 1.
min_channels_elong_frac (int, optional): minimum fraction of channels that must detect an IBI when
elongating the globally detected IBIs. See aggregate_channel_events().
Defaults to 1.
Returns:
ibis (np.ndarray); array with dimensions corresponding to (1, time) containing 1s and 0s to indicate
time instants where IBIs occur.
"""
if self.ibis.shape[0] > 1:
min_channels = int(round(min_channels_frac*self.ibis.shape[0]))
min_channels_elong = max([1, int(round(min_channels_elong_frac*self.ibis.shape[0]))])
ibis = aggregate_channel_events(self.ibis, min_channels=min_channels, min_channels_elong=min_channels_elong)
return ibis.reshape(1, -1)
else:
return self.ibis
[docs] def class_numbers(self):
"""
Return a 1D array with 2s at locations of bursts, 1s at IBIs and np.nan elsewhere.
Returns:
class_numbers (np.ndarray): 1D array with same shape as self.bursts and self.ibis. 2s in the array
correspond to bursts, 1s to IBIs.
"""
class_numbers = np.full(self.bursts.shape, fill_value=np.nan)
class_numbers[np.where(self.bursts > 0)] = self.burst_id
class_numbers[np.where(self.ibis > 0)] = self.ibi_id
return class_numbers
[docs] @staticmethod
def compute_event_durations(detected, fs):
"""
Compute the durations of events in detected.
Ignores incomplete events (events surrounded by nans).
Args:
detected (np.ndarray): 2D array with 1s (detected) and 0s (not-detected). Dimensions correspond to
(channels, time).
fs (float): sample frequency of `detected` in Hz.
Returns:
out (tuple): tuple of arrays with the durations of the detected events in seconds (per channel).
Examples:
>>> BurstDetectionResult.compute_event_durations(np.array([[1, 1, 1, 0, 0, 1, 1, 0]]), fs=1)
(array([2.]),)
>>> BurstDetectionResult.compute_event_durations(np.array([[0, 1, 1, 1, 0, 1, 1, 0]]), fs=1)
(array([3., 2.]),)
>>> BurstDetectionResult.compute_event_durations(np.array([[1, 1, 1, 0, 0, np.nan, np.nan]]), fs=1)
(array([0]),)
>>> BurstDetectionResult.compute_event_durations(np.array([[np.nan, 1, 1, 0, 0, np.nan, np.nan]]), fs=1)
(array([0]),)
>>> BurstDetectionResult.compute_event_durations(np.array([[np.nan, 0, 1, 1, 1, 0]]), fs=1)
(array([3.]),)
>>> BurstDetectionResult.compute_event_durations(np.array([[np.nan, 0, 1, np.nan, 1, 0]]), fs=1)
(array([0]),)
"""
if detected.ndim != 2:
raise ValueError('`detected` should be a 2D array (channels, time). Got an array with shape {}.'
.format(detected.shape))
# Save masks where detected is nan.
nan_mask = np.isnan(detected)
# Replace NaNs with zeros.
detected = np.nan_to_num(detected)
# Find state transitions.
d_all = np.diff(detected, axis=-1)
out = ()
# Loop over channels.
for i_chan, d in enumerate(d_all):
durations = []
onsets = np.where(d == 1)[0] + 1
endings = np.where(d == -1)[0] + 1
if len(endings) > 0 and len(onsets) > 0:
# If first ending is earlier than first onset, remove that ending (we do not have its onset).
if endings[0] < onsets[0]:
endings = endings[1:]
if len(endings) > 0:
# Loop over onsets and endings and compute the duration.
for on_idx, end_idx in zip(onsets, endings):
if not nan_mask[i_chan, on_idx - 1] and not nan_mask[i_chan, end_idx]:
duration_samples = end_idx - on_idx
duration_sec = duration_samples/fs
durations.append(duration_sec)
if len(durations) == 0:
# No segments found, so duration is zero.
durations = [0]
# Add durations of the events in the current channel to the output.
out += (np.array(durations), )
return out
[docs] @staticmethod
def compute_event_occurrences(detected):
"""
Compute the number of occurences in detected.
Counts the number of onsets.
Handles np.nan values by substituting them by zeros using np.nan_to_num().
Args:
detected (np.ndarray): array with 1s (detected) and 0s (not-detected). The last dimension
corresponds to the time dimension. E.g. bursts array.
Returns:
num_events (int or np.ndarray): the number of detected events (per channel).
"""
# Replace NaNs with zeros.
detected = np.nan_to_num(detected)
# Count onsets.
d = np.diff(detected, axis=-1)
num_events = np.sum(d == 1, axis=-1)
# If first sample is a detected event, add one (this does not yield an onset as counted above,
# but we are sure that the event occurred).
num_events += detected.take(indices=0, axis=-1).astype(int)
return num_events
[docs] @staticmethod
def compute_event_percentage(detected):
"""
Compute the time-percentage of detected events.
Ignores nan values.
Args:
detected (np.ndarray): array with 1s (detected) and 0s (not-detected). The last dimension
corresponds to the time dimension. E.g. bursts array.
Returns:
(float or np.ndarray): percentage of detected events (per channel).
"""
return np.nanmean(detected, axis=-1)*100
[docs] def compute_features(self, bursts, ibis, channel_labels, postfix=None):
"""
Compute burst/IBI features for each channel.
Args:
bursts (np.ndarray): array with dimensions (channels, time) containing 1s at locations of bursts and 0s
at locations of non-bursts. May also contain np.nan for indicating missing values.
ibis (np.ndarray): array with dimensions (channels, time) containing 1sat locations of
inter-burst-intervals (IBIs) and 0s at locations of non-IBIs. May also contain np.nan for
indicating missing values.
channel_labels (list): list with the same length as `bursts` and `ibis` containing the labels of the
channels (the first dimension of the arrays). May also be a string in case of only one channel.
postfix (string, optional): optional postfix to add to the feature name in the output dictionary.
Returns:
features (dict): dictionary of features values, hashed by feature name. Each channel is treated as a
separate feature.
"""
# Check array lengths.
if isinstance(channel_labels, str):
channel_labels = [channel_labels]
if len(channel_labels) != len(bursts) or len(channel_labels) != len(ibis):
raise ValueError('Length of `channel_labels` ({}) does not correspond to the length of `bursts` ({})'
'and/or `ibis` ({}).'.format(len(channel_labels), len(bursts), len(ibis)))
# Samples per hour.
samples_per_hour = self.fs * 3600
# Compute features per channel.
percentage_bursts_all = self.compute_event_percentage(bursts)
percentage_ibis_all = self.compute_event_percentage(ibis)
occurrences_bursts_all = self.compute_event_occurrences(bursts)
occurrences_ibis_all = self.compute_event_occurrences(ibis)
duration_ibis_all = self.compute_event_durations(ibis, fs=self.fs)
# Prepare postfix.
postfix = prepare_postfix(postfix)
# For each channel, save features.
features = dict()
for i, chan_label in enumerate(channel_labels):
# Percentage of bursts and IBIs.
features['DISC_burst_perc_{}{}'.format(chan_label, postfix)] = percentage_bursts_all[i]
features['DISC_ibi_perc_{}{}'.format(chan_label, postfix)] = percentage_ibis_all[i]
# Number of burst and IBI occurrences per hour.
features['DISC_burst_occ_{}{}'.format(chan_label, postfix)] = \
occurrences_bursts_all[i] / np.sum(~np.isnan(bursts)) * samples_per_hour
features['DISC_ibi_occ_{}{}'.format(chan_label, postfix)] = \
occurrences_ibis_all[i] / np.sum(~np.isnan(ibis)) * samples_per_hour
# Maximum and median duration of IBIs in seconds.
features['DISC_ibi_max_{}{}'.format(chan_label, postfix)] = np.max(duration_ibis_all[i])
features['DISC_ibi_median_{}{}'.format(chan_label, postfix)] = np.median(duration_ibis_all[i])
features['DISC_ibi_mean_{}{}'.format(chan_label, postfix)] = np.mean(duration_ibis_all[i])
return features
[docs] def compute_global_features(self, **kwargs):
"""
Compute global features.
Args:
**kwargs (optional): keyword arguments for pd.DataFrame.
Returns:
df (pd.DataFrame): dataframe with one row, and feature values in columns.
"""
bursts = self.bursts
ibis = self.ibis
channel_labels = self.channel_labels
features_dict = self.compute_features(bursts, ibis, channel_labels, postfix=None)
# Collect features in a DataFrame.
df = pd.DataFrame(features_dict, **kwargs)
return df
[docs] def compute_global_features_per_sleep_stage(self, sleep_stages_result, sleep_labels=None,
line_length_result=None, **kwargs):
"""
Compute global features in specific sleep stages.
Args:
sleep_stages_result:
sleep_labels:
line_length_result (nnse.LineLengthResult, optional): if given, the features in the 10 most suppressed
period are computed additionally.
**kwargs (optional): kwargs for self.compute_global_features().
Returns:
"""
if sleep_labels is None:
# By default use ALL + any available sleep label (except 'no_label', 'artefact').
sleep_labels = list(sleep_stages_result.class_mapping.keys())
exclude_labels = [SLEEP_LABELS['no_label'], SLEEP_LABELS['artefact']]
sleep_labels = [label for label in sleep_labels if label not in exclude_labels]
sleep_labels = ['ALL'] + sleep_labels
# Initialize output.
df = pd.DataFrame()
for label in sleep_labels:
# Create a copy of self.
result = copy.deepcopy(self)
# Create mask to fill with nan to ignore in analysis.
if label == 'ALL':
# Ignore artefact and no label annotations.
nan_mask = sleep_stages_result.create_mask(
class_label=[SLEEP_LABELS['artefact'],
SLEEP_LABELS['no_label']],
query_times=result.segment_start_times,
check_class_label=False)
else:
# Ignore all but the current sleep label.
nan_mask = ~sleep_stages_result.create_mask(
class_label=label,
query_times=result.segment_start_times)
# Put np.nan at locations not of interest in the bursts and ibis arrays.
result.bursts[:, nan_mask] = np.nan
result.ibis[:, nan_mask] = np.nan
# Compute global features.
df_i = result.compute_global_features(**kwargs)
# Rename the columns to indicate sleep label.
df_i = df_i.add_suffix('_{}'.format(label))
# Add df_i to output.
df = pd.concat([df, df_i], axis=1, sort=False)
if line_length_result is not None:
# Add bursts features in 10 most suppressed minutes.
suppression_curve = line_length_result.to_suppression_curve()
start_time, stop_time = suppression_curve.get_most_suppressed_period(period_length=600)
# Extract epoch.
result = self.extract_epoch(begin=start_time, end=stop_time)
# Compute global features.
df_i = result.compute_global_features(**kwargs)
# Rename the columns to indicate sleep label.
df_i = df_i.add_suffix('_{}'.format('SUP'))
# Add df_i to output.
df = pd.concat([df, df_i], axis=1, sort=False)
return df
[docs] def plot(self):
"""
Plot the occurence of bursts/ibis as a function of time.
TODO Option to plot one channel only or to plot the aggregate.
"""
# Create array with categories indicating burst (2), IBI (1).
a = self.class_numbers()
# Loop over channels.
for a_i, label in zip(a, self.channel_labels):
plt.plot(self.segment_start_times, a_i, label=label)
# Figure make up.
plt.xlabel('Time (seconds)')
y_limits = plt.ylim()
plt.ylim([y_limits[0] - 0.1, y_limits[1] + 0.1])
plt.yticks([self.ibi_id, self.burst_id], [self.ibi_label, self.burst_label])
plt.title('Burst detection')
plt.legend(loc='upper right')
[docs] def shade_axis(self, *args, channel=None, **kwargs):
"""
Shade the current axis based on bursts and IBIs.
Wrapper that converts this class to an AnnotationSet and calls shade_axis() from AnnatationSet.
Args:
*args (optional): see nnsa.AnnotationSet.shade_axis().
channel (str, optional): see self.to_annotation_set().
**kwargs (optional): see nnsa.AnnotationSet.shade_axis().
"""
# Convert to AnnotationSet.
annotation_set = self.to_annotation_set(channel=channel)
# Call shade axis.
annotation_set.shade_axis(*args, **kwargs)
[docs] def to_aggregate_result(self):
"""
Compute the aggregate result by combining the burst detection of all channels.
Returns:
(nnsa.BurstDetectionResult): a new burst detection result object, with only one channel (AGG), which is an
aggregate of all channels.
"""
# Compute aggregate burst and IBI masks.
bursts = self.aggregate_bursts()
ibis = self.aggregate_ibis()
channel_labels = ['AGG']
return BurstDetectionResult(bursts=bursts, ibis=ibis, fs=self.fs,
algorithm_parameters=self.algorithm_parameters,
channel_labels=channel_labels)
[docs] def to_annotation_set(self, channel=None):
"""
Convert to AnnotationSet object.
Args:
channel (str): the channel label for which to return an AnnotationSet.
If None, an aggregate burst detection, combining all channels will be converted to an AnnotationSet.
Defaults to None.
Returns:
(nnsa.AnnotationSet): AnnotationSet containing `self.burst_label` and `self.ibi_label` annotations.
"""
# Extract the requested channel, or compute the aggregate and create array with class numbers.
if channel is None:
result = self.to_aggregate_result()
class_numbers = result.class_numbers()[0]
else:
channel_idx = self._get_channel_idx(channel)
class_numbers = self.class_numbers()[channel_idx]
# Replace nan values with -1 (-1 is easier to handle than nan).
class_numbers[np.isnan(class_numbers)] = -1
# Create an empty AnnotationSet.
annotation_set = AnnotationSet(label='burst_detection_{}'.format(self.algorithm_parameters['method']))
# Add epochs of one burst or IBI as annotations to the annotation set.
sample_period = 1/self.fs
transition_idx = np.append(np.nonzero(np.diff(class_numbers))[0], len(class_numbers) - 1)
time = 0
for idx in transition_idx:
# Compute duration.
duration = (idx + 1)*sample_period - time
# Extract text label.
class_num = class_numbers[idx]
label = self.class_labels[class_num]
# Add annotation.
annotation = Annotation(onset=time, duration=duration, text=label)
annotation_set.append(annotation, inplace=True)
# Update the time.
time += duration
return annotation_set
def _get_channel_idx(self, channel):
"""
Check if the given channel is in the channel_labels and return the index of the channel.
Args:
channel (str): the label of the channel.
Returns:
(int): the index of the channel in channel_labels, which corresponds also to the index of the channel in
self.bursts and self.ibis.
Raises:
ValueError: if the specified channel is not in channel_labels.
"""
if channel not in self.channel_labels:
raise ValueError('Channel "{}" not in self.channel_labels: {}.'
.format(channel, self.channel_labels))
return self.channel_labels.index(channel)
def _merge(self, other, index):
"""
See ResultBase.
"""
# Check if the channel labels of self and other are the same.
if self.channel_labels != other.channel_labels:
raise ValueError('Cannot merge objects with different channel labels.')
n_channels, n_samples = self.bursts.shape
if index < n_samples:
# Cut piece off.
msg = 'Overwriting data while merging.'
warnings.warn(msg)
self.bursts = self.bursts[:, :index]
self.ibis = self.ibis[:, :index]
else:
# Add nans.
self.bursts = np.concatenate([self.bursts, np.full((n_channels, index-n_samples),
fill_value=np.nan)], axis=-1)
self.ibis = np.concatenate([self.ibis, np.full((n_channels, index - n_samples),
fill_value=np.nan)], axis=-1)
# Merge.
self.bursts = np.concatenate([self.bursts, other.bursts], axis=-1)
self.ibis = np.concatenate([self.ibis, other.ibis], axis=-1)
self.nan_mask = np.logical_or(np.isnan(self.bursts), np.isnan(self.ibis))
@staticmethod
def _read_from_csv(filepath):
"""
Read result from csv file into a BurstDetectionResult class.
Args:
filepath (str): see ResultBase._read_from_csv().
Returns:
result (nnsa.BurstDetectionResult): instance of BurstDetectionResult containing the
burst detection result.
"""
# Lines 1-4: Standard csv header (use the ResultBase method).
algorithm_parameters, data_info, fs = ResultBase._read_csv_header(filepath)[1:]
# Re-open the file and read the rest of the file, line by line.
with open(filepath, 'r') as f:
reader = csv.reader(f)
# Lines 1-4: Standard csv header (already read, skip).
[next(reader) for i in range(4)]
# Line 5: Non-array data header (skip).
assert(next(reader) == ['channel_labels', 'data.shape'])
# Line 6: Non-array data.
channel_labels, data_shape = [
convert_string_auto(i) for i in next(reader)]
# Line 7: Array header (skip).
assert(next(reader)[0] == 'bursts')
# Line 8: Array data (flattened).
bursts_as_list = [float(i) for i in next(reader)]
# Line 9: Array header (skip).
assert(next(reader)[0] == 'ibis')
# Line 10: Array data (flattened).
ibis_as_list = [float(i) for i in next(reader)]
# Convert to numpy array and reshape.
bursts = np.reshape(bursts_as_list, data_shape)
ibis = np.reshape(ibis_as_list, data_shape)
# Create a result object.
result = BurstDetectionResult(bursts=bursts,
ibis=ibis,
fs=fs,
nan_mask=None,
channel_labels=channel_labels,
algorithm_parameters=algorithm_parameters,
data_info=data_info)
return result
@staticmethod
def _read_from_hdf5(filepath):
"""
Read result from hdf5 file into a BurstDetectionResult class.
Args:
filepath (str): see ResultBase._read_from_csv().
Returns:
result (nnsa.BurstDetectionResult): instance of BurstDetectionResult containing the
burst detection result.
"""
# Read standard hdf5 header (use the ResultBase method).
algorithm_parameters, data_info, segment_start_times, segment_end_times, fs, time_offset =\
ResultBase._read_hdf5_header(filepath)[1:]
# Re-open the file and read the rest of the file.
with h5py.File(filepath, 'r') as f:
# Read array data.
bursts = f['bursts'][:]
ibis = f['ibis'][:]
if 'nan_mask' in f:
nan_mask = f['nan_mask'][:]
else:
nan_mask = None
# Read non-array data.
channel_labels = [label.decode() for label in f['bursts'].attrs['channel_labels']]
# Create a result object.
result = BurstDetectionResult(bursts=bursts,
ibis=ibis,
nan_mask=nan_mask,
channel_labels=channel_labels,
algorithm_parameters=algorithm_parameters,
data_info=data_info,
segment_start_times=segment_start_times,
segment_end_times=segment_end_times,
fs=fs)
return result
def _write_to_csv(self, filepath):
"""
Write the contents of the object to a csv file.
Args:
filepath (str): see ResultBase._write_to_csv().
"""
# Lines 1-4: Standard csv header (use the ResultBase method).
self._write_csv_header(filepath)
# Append attributes to the csv file, line by line.
with open(filepath, 'a', newline='') as csvfile:
writer = csv.writer(csvfile)
# Line 5: Non-array data header.
writer.writerow(['channel_labels', 'data.shape'])
# Line 6: Non-array data.
writer.writerow([self.channel_labels, self.bursts.shape])
# Line 7: Array header.
writer.writerow(['bursts'])
# Line 8: Array data (flattened).
writer.writerow(self.bursts.reshape(-1).tolist())
# Line 9: Array header.
writer.writerow(['ibis'])
# Line 10: Array data (flattened).
writer.writerow(self.ibis.reshape(-1).tolist())
def _write_to_hdf5(self, filepath):
"""
Write the contents of the object to an hdf5 file.
Args:
filepath (str): see ResultBase._write_to_hdf5().
"""
# Write standard hdf5 header (use the ResultBase method).
self._write_hdf5_header(filepath)
# Append attributes to the hdf5 file.
with h5py.File(filepath, 'a') as f:
# Write array data.
f.create_dataset('bursts', data=self.bursts.astype(bool), compression='gzip')
f.create_dataset('ibis', data=self.ibis.astype(bool), compression='gzip')
f.create_dataset('nan_mask', data=self.nan_mask, compression='gzip')
# Write non-array data as attributes.
# Convert strings to np.string_ type as recommended for compatibility.
f['bursts'].attrs['channel_labels'] = [np.string_(label)
for label in self.channel_labels]
[docs]class SuppressionCurve(TimeSeries):
"""
High-level object containing the suppression curve (Dereymaeker et al. 2015).
References:
A. Dereymaeker, N. Koolen, K. Jansen, J. Vervisch, E. Ortibus, M. De Vos, S. Van Huffel, G. Naulaers,
The suppression curve as a quantitative approach for measuring brain maturation in preterm infants,
Clinical Neurophysiology, Volume 127, Issue 8, 2016, Pages 2760-2765,
Args:
suppression (np.ndarray): 1D array containing the suppression values as function of time.
window_length (float): the window length corresponding to one sample in `suppression`. I.e. the length of the
window in which the suppression value was computed in seconds.
"""
def __init__(self, suppression, window_length, label='Suppresion curve', time_offset=0, **kwargs):
suppression = np.asarray(suppression).squeeze()
if len(suppression.shape) > 1:
raise ValueError('`suppression` must be 1D. Got array with shape {}.'.format(suppression.shape))
super().__init__(signal=suppression, fs=1/window_length, label=label, time_offset=time_offset, **kwargs)
self.window_length = window_length
[docs] def get_most_suppressed_period(self, period_length=600):
"""
Get the start and end time of the most suppressed period.
Args:
period_length (flaot): length of the suppressed period to extract (in seconds).
Choose a multiple of self.window_length.
Returns:
start_time (float): time (in seconds) of the beginning of the suppressed period.
stop_time(float): time (in seconds) of the end of the suppressed period.
"""
# Running average.
kernel_size = int(round(period_length/self.window_length))
kernel = np.ones(kernel_size)/kernel_size
running_sum = np.convolve(self.signal, kernel, mode='valid')
start_times = np.convolve(self.time, np.arange(kernel_size) == (kernel_size - 1), mode='valid') # Average time.
end_times = np.convolve(self.time, np.arange(kernel_size) == 0, mode='valid') # Average time.
# Select start and stop time of most suppressed period.
running_sum[np.isnan(running_sum)] = -np.inf
idx = np.argmax(running_sum)
start_time = start_times[idx]
end_time = end_times[idx]
return start_time, end_time
def dibi_burst_detection_matlab(eeg, fs, verbose=1):
"""
Burst detection method as implemented by Vladimir Matic in MATLAB.
References:
V. Matic et al., “Improving Reliability of Monitoring Background EEG Dynamics in Asphyxiated Infants,”
IEEE Transactions on Biomedical Engineering, vol. 63, no. 5, pp. 973–983, May 2016
Args:
eeg (np.ndarray): 2D array with dimensions (channels, time) containing filtered! EEG data in uV.
fs (float): sample frequency of the EEG signals.
verbose (int, optional): verbose level.
Defaults to 1.
Returns:
dibi (np.ndarray): array with dimensions (time,) containing 1s at locations of
inter-burst-intervals (IBIs) and 0s at locations of non-IBIs.
"""
from nnsa.matlab.utils import matlab_engine, ml_array
eeg = check_multichannel_data_matrix(eeg)[0]
# Process in batches to reduce the amount of data that needs to be transferred between Python and matlab.
batchsize = int(3 * 3600 * 256)
tot_len = eeg.shape[-1]
# Progress bar.
bar = pyprind.ProgBar(int(np.ceil(tot_len / batchsize)), stream=sys.stdout)
# Create ibis list to collect results.
ibis = []
# Initiate MATLAB engine.
with matlab_engine() as eng:
print("Running the dIBI algorithm...")
# Process per segment.
for idx in np.arange(0, tot_len, batchsize):
# Convert numpy to MATLAB array.
seg = eeg[:, idx: min([tot_len, idx+batchsize])]
seg_ml = ml_array(array=seg, eng=eng, dtype='double')
# Call MATLAB function.
starts_i, ends_i = eng.compute_dIBI(seg_ml, float(fs), nargout=2)
# Turn into arrays (starts_i could be empty, a float, or a matlabarray).
starts = np.asarray([starts_i]).reshape(-1)
ends = np.asarray([ends_i]).reshape(-1)
# Convert the start and end times to indices.
starts_idx = np.round(starts*fs).astype(int)
ends_idx = np.round(ends*fs).astype(int)
# Create IBI mask with ones and zeros.
ibis_i = np.zeros((seg.shape[-1]))
for s, e in zip(starts_idx, ends_idx):
ibis_i[s:e] = 1
ibis.append(ibis_i)
if verbose:
bar.update()
# To array.
dibi = np.concatenate(ibis)
return dibi
def dibi_burst_detection_python(eeg, fs, per_channel=False, verbose=1, show_plots=False):
"""
Detect dynamic inter-burst-intervals (dIBI) on (filtered) multichannel EEG data.
Implementation of the 'part A' algorithm described in sections 4.2.3-4.2.5 of the PhD thesis
of Vladimir Matic:
V. Matic, “Neonatal EEG Signal Processing,” Katholieke Universiteit Leuven, 2015.
Args:
eeg (np.ndarray): 2D array with dimensions (channels, time) containing filtered! EEG data in uV.
fs (float): sample frequency of the EEG.
per_channel (bool, optional): if True, the output will be 2D with same shape as `eeg`, containing a dIBI mask
for each channel. If False, returns a 1D array (not channel specific).
verbose (int, optional): verbosity level.
show_plots (bool, optional): toggle to show plots of (intermediate) results.
Returns:
dibi (np.ndarray): array with length equal to x.shape[-1],
containing 1's at locations of dIBIs and zeros elsewhere.
"""
x = check_multichannel_data_matrix(eeg)[0]
# Replace nans by zeros.
x_original = x.copy()
nan_mask = np.isnan(x)
x[nan_mask] = 0
# Precompute absolute values and differences.
x_abs = np.abs(x)
x_abs_dif = np.concatenate((np.zeros((len(x), 1), dtype=x.dtype),
np.abs(np.diff(x, axis=1))), axis=1)
# Compute the moving average using convolution (use zero-padding) to create the output of the first window.
numtaps = int(fs * 1)
kernel = np.ones(numtaps, dtype=x.dtype) / numtaps
Aw1 = scipy.signal.lfilter(kernel, 1, x_abs, axis=-1).astype(x.dtype)
Fw1 = scipy.signal.lfilter(kernel, 1, x_abs_dif, axis=-1).astype(x.dtype)
# Roll with one second to create the output of the second window.
Aw2 = np.roll(Aw1, -numtaps, axis=-1)
Fw2 = np.roll(Fw1, -numtaps, axis=-1)
# Compute the total difference.
TD = np.abs(Aw1 - Aw2) + np.abs(Fw1 - Fw2) - 1
# Clip TD at 0.
TD[TD < 0] = 0
# Moving average of the TD.
numtaps = int(0.25 * fs)
kernel = np.ones(numtaps, dtype=int) / numtaps
for i in range(len(TD)):
# Do this per channel, may be a bit slower, but requires less memory.
TD[i] = np.convolve(TD[i], kernel, mode='same')
# Detect peaks in total diff.
locs = list(map(lambda a: scipy.signal.find_peaks(a, distance=int(0.75 * fs))[0], TD))
# Initialize LP.
LP = np.zeros(x.shape)
# Loop over channels.
bar = pyprind.ProgBar(len(locs), stream=sys.stdout)
for i, locs_i in enumerate(locs):
locs_i = np.concatenate(([0], locs_i, [x.shape[-1]]))
# Loop over segments.
for j_start, j_end in zip(locs_i, locs_i[1:]):
x_seg = x_original[i, j_start: j_end]
# Compute area amp 10.
x_seg_abs = np.abs(x_seg) - 10
x_seg_abs[x_seg_abs < 0] = 0
area_amp_10 = np.nanmean(x_seg_abs) * fs # Per second (from Vladimir's Matlab code).
if area_amp_10 < 50:
LP[i, j_start: j_end] = 1
# Update progress bar.
if verbose > 0:
bar.update()
# Add nans back.
LP[nan_mask] = np.nan
# Merge channel information.
LTP = np.nanmean(LP, axis=0)
LTP[np.mean(np.isnan(LP), axis=0) > 0.5] = np.nan
# Set nans to 0 (not detected).
LP[np.isnan(LP)] = 0
LTP[np.isnan(LTP)] = 0
# dIBI if more than half of the channels is LP for more than 3 seconds.
dibi = time_threshold(LTP > 0.5, min_duration=int(3 * fs))
if show_plots:
time = np.arange(x.shape[1]) / fs
# Figure 4.3.
plt.figure()
ax1 = plt.subplot(2, 1, 1)
plt.grid()
plt.plot(time, x[0], color='k')
plt.subplot(2, 1, 2, sharex=ax1)
plt.plot(time, TD[0], color='b')
plt.grid()
plt.xlabel('Time (s)')
plt.suptitle('Figure 4.3')
# Figure 4.4.
plt.figure()
ax1 = plt.subplot(2, 1, 1)
plt.grid()
plt.plot(time, TD[0], color='b')
plt.scatter(time[locs[0]], TD[0][locs[0]], color='r')
plt.subplot(2, 1, 2, sharex=ax1)
plt.plot(time, x[0], color='k')
[plt.axvline(time[idx], color='r') for idx in locs[0]]
plt.grid()
plt.xlabel('Time (s)')
plt.suptitle('Figure 4.4')
# Figure 4.6 (own artistic interpretation).
locs_final = np.where(np.diff(dibi) != 0)[0] + 1
nchans = len(x)
fig, axes = plt.subplots(nchans, 1, sharex='all', sharey='all', squeeze=True)
for i in range(nchans):
plt.sca(axes[i])
plt.plot(time, x[i], color='k')
plt.ylabel('Ch{} (uV)'.format(i), rotation=np.pi / 2)
[plt.axvline(time[idx], color='k', alpha=1) for idx in locs_final]
# Channel specific dIBI selection.
onsets, offsets = get_onsets_offsets(LP[i])
shade_axis(time[onsets], (offsets - onsets) / fs, color='b', alpha=0.25)
# Global.
onsets, offsets = get_onsets_offsets(dibi)
shade_axis(time[onsets], (offsets - onsets) / fs, color='k', alpha=0.25)
plt.ylim([-50, 50])
plt.xlabel('Time (s)')
plt.suptitle('Figure 4.6')
if per_channel:
# Return LP, but set low profile to zero for non IBIs.
dibi = LP * dibi
dibi[nan_mask] = np.nan
else:
# If more than half of the channels is nan, set to nan.
dibi = 1.0*dibi # To float.
dibi[np.mean(nan_mask, axis=0) > 0.5] = np.nan
return dibi
def detect_dibi(eeg, fs, use_matlab=False, verbose=1, **kwargs):
"""
Preprocess and detect dynamic inter-burst-intervals (dIBI) on multichannel EEG data using either
the matlab or python implementation.
The Python and MATLAB implementations are not identical, but yield similar results.
The Python implementation is shorter, simpler and faster, but the
MATLAB implementation serves as the reference. Try both to see which one works best.
Args:
eeg (np.ndarray): 2D array with dimensions (channels, time) containing raw EEG data in uV.
fs (float): sample frequency of the EEG.
use_matlab (bool): whether to use the python implementation (False) or matlab implementation (True).
verbose (int, optional): verbosity level.
**kwargs (optional): keyword arguments specific to the function called.
Returns:
dibi (np.ndarray): array with length equal to the length of the eeg rows,
containing 1's at locations of dIBIs and zeros elsewhere.
x (np.ndarray): filtered EEG (same shape as `eeg`). Returned for further processing
(e.g. computation of amplitude of IBIs).
"""
eeg = check_multichannel_data_matrix(eeg)[0]
# Bandpass filter in [0.7, 20] Hz band (see section 4.2.1 in the thesis). I'm not aware of the exact
# filter design that was used originally. Now I choose a filter with a flat bandpass response.
# I noticed that it is important to use a high filter order to make sure there is no
# baseline drift (baseline drift will affect the mean amplitude).
if verbose > 0:
print('Filtering...')
firwin = WinFIR(numtaps=int(3*fs + 1), cutoff=[0.7, 20],
fs=fs, pass_zero='bandpass')
x = firwin.filtfilt(eeg)
if verbose > 0:
print('Detecting dIBIs...')
if use_matlab:
dibi = dibi_burst_detection_matlab(x, fs, verbose=verbose, **kwargs)
else:
dibi = dibi_burst_detection_python(x, fs, verbose=verbose, **kwargs)
return dibi, x
def compute_bsr(x, fs, threshold=5.0, min_length=0.5):
# Very simple method to compute burst suppression ratio (Lacan et al. 2021).
low_amp_mask = np.abs(x) < threshold
ibi_mask = time_threshold(low_amp_mask, min_duration=int(min_length*fs)).astype(float)
ibi_mask[np.isnan(x)] = np.nan
bsr = np.nanmean(ibi_mask)*100
return bsr
def _compute_continuity_features(env, fs, segment_length, segment_overlap):
dt = 1/fs
env = np.asarray(env).squeeze()
if env.ndim != 1:
raise ValueError('Should be 1-dimensional.')
stepsize = segment_length - segment_overlap
if len(env) > segment_length*fs:
# Segment.
seg_generator = segment_generator(env, segment_length, segment_overlap, fs=fs)
else:
# Take entire signal.
seg_generator = [env]
data = defaultdict(list)
# Loop over segments.
for i_seg, x in enumerate(seg_generator):
q5, q10, q25, q50, q75, q90, q95 = np.nanpercentile(x, q=[5, 10, 25, 50, 75, 90, 95])
dx = np.abs(np.diff(x))/dt
# dx = np.abs(np.convolve(x, np.array([1, 0, -1])/2,
# mode='same')/dt)
# dx = np.abs(np.convolve(x, np.array([-1, 8, 0, -8, 1])/12,
# mode='same')/dt)
data['SD'].append(np.nanstd(x))
data['Skewness'].append(skew(x, nan_policy='omit'))
data['Kurtosis'].append(kurtosis(x, nan_policy='omit'))
data['meanLL'].append(np.nanmean(dx))
data['medianLL'].append(np.nanmedian(dx))
data['meanlogLL'].append(np.nanmean(np.log10(dx)))
data['medianlogLL'].append(np.nanmedian(np.log10(dx)))
data['q5'].append(q5)
data['q10'].append(q10)
data['q50'].append(q50)
data['q90'].append(q90)
data['q95'].append(q95)
data['mean'].append(np.nanmean(x))
data['mean/median'].append(10**np.nanmean(x)/10**q50)
data['q10/q90'].append(10**q10/10**q90)
data['q5/q95'].append(10**q5/10**q95)
data['q25/q75'].append(10**q25/10**q75)
data['q50-q25'].append(q50 - q25)
data['q50-q5'].append(q50 - q5)
data['q75-q25'].append(q75 - q25)
data['q95-q5'].append(q95 - q5)
data['q50-q5/q95-q5'].append((q50 - q5)/(q95 - q5))
data['q75-q10/q90-q10'].append((q75 - q10)/(q90 - q10))
data['q50-q10/q90-q10'].append((q50 - q10)/(q90 - q10))
data['onset'].append(i_seg*stepsize)
df = pd.DataFrame(data)
df['offset'] = df['onset'] + segment_length
# Skew and kurtosis may have '--' values instead of nans.
df = df.replace('--', np.nan).astype(float)
return df
def compute_continuity_features(eeg, fs, overlap_frac=0.75, time_offset=0,
normalize=True):
eeg = np.asarray(eeg).squeeze()
if eeg.ndim != 1:
raise NotImplementedError('Must be 1 dimensional.')
# Settings.
normalize_window = 3600
segment_length = 3600
segment_overlap = overlap_frac*segment_length
fs_res = 1
# Normalize if requested.
if normalize:
eeg = eeg / np.sqrt(
moving_average(eeg ** 2, n=normalize_window * fs)[0])
# Compute envelope.
env = moving_envelope(eeg, n=2, fs=fs)
# Remove amplitudes < 0.01 uV that are most likely artefacts.
env[env < 0.01] = np.nan
# To log.
env = np.log10(env)
# Downsample.
env = resample_by_filtering(x=env, fs=fs, fs_new=fs_res)
# Compute features.
df = _compute_continuity_features(
env, fs=fs_res, segment_length=segment_length,
segment_overlap=segment_overlap)
df['onset'] += time_offset
df['offset'] += time_offset
return df
def compute_dibi_features(eeg, fs, per_channel=False, time_offset=0, use_matlab=False, verbose=1):
"""
Detect and compute features of dynamic inter-burst-intervals in EEG data.
Features are taken from the thesis of Vladimir Matic, KU Leuven (2015).
Args:
eeg (np.ndarray): raw EEG data with shape (channels, time).
fs (float): sample frequency of the EEG (in Hz).
per_channel (bool, optional): whether to do it per channel (take median only from channels with
low amplitude or (True) or just use all channel (False)).
use_matlab (bool): if True, use the matlab code compute the IBIs (see detect_dibi()).
time_offset (float, optional): optional time offset (in seconds).
verbose (int): verbosity level.
Returns:
df (pd.DataFrame()): pandas DataFrame with the results. Each row is a dIBI.
"""
if per_channel and use_matlab:
raise ValueError('`per_channel` option not available when `use_matlab` is True. '
'Set `per_channel` or `use_matlab` to False.')
# Preprocess and detect inter-burst-intervals.
# Returns preprocessed (filtered) EEG.
dIBI, eeg = detect_dibi(eeg, fs, use_matlab=use_matlab, verbose=verbose, per_channel=per_channel)
if per_channel:
dIBI_per_channel = dIBI.copy() # Per channel.
dIBI = np.any(dIBI_per_channel, axis=0) # Global.
# Find onsets and offsets of dIBIs (indices).
onsets, offsets = get_onsets_offsets(dIBI)
if verbose:
print('Extracting features...')
data = defaultdict(list)
bar = pyprind.ProgBar(len(onsets), stream=sys.stdout)
for i, (on, off) in enumerate(zip(onsets, offsets)):
# Select dIBI segment.
eeg_seg = eeg[:, on: off]
if per_channel:
# Ignore channels with not dIBI using np.nan.
mask = dIBI_per_channel[:, on: off]
eeg_seg[mask] = np.nan
# Extract and save some features.
data['onset'].append(on / fs + time_offset) # In seconds.
data['duration'].append((off - on) / fs) # In seconds.
data['amplitude'].append(np.nanmedian(np.abs(eeg_seg))) # = 'suppression' parameter.
data['flatness'].append(np.nanmedian([compute_flatness(x) for x in eeg_seg])) # correlates strongly with amplitude.
if (off - on) < 0:
raise AssertionError('This should never happen.')
if verbose:
bar.update()
# Export it as a DataFrame.
df = pd.DataFrame(data)
return df
[docs]class IbiFeaturesResult(ResultBase):
"""
High-level interface for processing IbiFeatures.
Args:
df (pd.DataFrame): dataframe with IBI features (onset, duration, amplitude, ...).
"""
def __init__(self, df, *args, **kwargs):
super().__init__(*args, **kwargs)
self.df = df
[docs] def compute_eeg_grade(self, window=3600, overlap=0):
"""
Compute EEG grade according to Dereymaeker et al. 2019.
Returns the time as median of the onset times of the IBIs in the window.
Args:
window (float): time window (in seconds).
overlap (float): overlap between windows (in seconds).
Returns:
grade (pd.Series): Series with the EEG grades and with `onset` as index.
"""
df_median = self.compute_median_in_window(window=window, overlap=overlap,
features=['duration', 'amplitude'])
def eeg_grade(row):
amp = row['amplitude']
dur = row['duration']
if amp < 5 and dur > 60:
grade = 5
elif amp < 15 and dur > 10:
grade = 4
elif amp < 15 and dur <= 10:
grade = 3
elif 15 < amp < 25:
if dur > 10:
grade = 2
else:
grade = 1
else:
grade = np.nan
return grade
grade = df_median.apply(eeg_grade, axis=1)
return grade
[docs] def compute_ibi_percentage(self, window=3600, overlap=0):
"""
Compute IBI time percentage in specific time windows.
Returns the time as median of the onset times of the IBIs in the window.
Args:
window (float): time window (in seconds).
overlap (float): overlap between windows (in seconds).
Returns:
percentage (pd.Series): Series with the IBI percentage and with `onset` as index.
"""
df = self.df
step = window - overlap
onsets = df['onset'].values
data = []
for t_start in np.arange(0, onsets[-1], step):
t_end = t_start + window
window_mask = np.logical_and(onsets >= t_start, onsets < t_end)
with warnings.catch_warnings(): # Catch numpy mean of empty slice warnings.
warnings.simplefilter("ignore", category=RuntimeWarning)
onset_i = np.nanmedian(df['onset'][window_mask].values, axis=0)
if np.isnan(onset_i):
onset_i = t_start + window/2
duration_i = np.nansum(df['duration'][window_mask].values, axis=0)
percentage_i = duration_i/window*100
data.append([onset_i, percentage_i])
data = np.array(data)
percentage = pd.DataFrame(data=data, columns=['onset', 'percentage']).dropna(how='any').set_index('onset')
return percentage