Burst detection

Demonstration code for burst detection using BurstDetection().

Link to script: feature_extraction/burst_detection.py

import os
import numpy as np
import matplotlib.pyplot as plt

from nnsa import print_object_summary, SUPPORTED_RESULT_FILE_TYPES, read_result_from_file, assert_equal, \
    BurstDetection, EdfReader, EegDataset, TimeSeries, RemezFIR

plt.close('all')

Parameters.

# Print the default parameters of BurstDetection():
print(BurstDetection().default_parameters())

# Descriptions of the parameters are documented in the default_parameters() code.

# Create an instance of the BurstDetection class with custom parameters, overruling some defaults:
burst_detection = BurstDetection(method='NLEO', method_kwargs={'window_baseline': 50})

# See if the custom parameters were accepted:
print('\nCustom parameters:')
print(burst_detection.parameters)

Main method: burst_detection.

# Now that we have initialized a BurstDetection object with certain parameters, we can run the burst_detection
# method, which performs the burst detection:
# Create a random signal (8 channels).
fs = 256
t = 1/fs*np.arange(8*100000).reshape(8, -1)
signal = 200*(np.sin(t) + np.random.normal(scale=0.1, size=t.shape))
signal[:, 25000:75000] /= 100
result = burst_detection.burst_detection(signal, fs=fs)

Save the result.

# We can save the result to any supported file. To list the supported file types/extensions:
print('Supported result file types: {}'.format(SUPPORTED_RESULT_FILE_TYPES))

# E.g. save as hdf5 (the file type to save is automatically detected from the extension of the filename).
filename = 'temp_result.hdf5'
result.save_to_file(filename)

# We can reload the result back to a LineLengthResult object:
result_loaded = read_result_from_file(filename)

# Remove temporary file.
os.remove(filename)

# Verify that the loaded object is equal to the original object:
assert_equal(actual=result_loaded, desired=result)  # Will raise an AssertionError if not equal.
print('Loaded result object is equal to original object.')

The returned object is a BurstDetectionResult object, which is a high-level interface for manipulating/visualizing/saving the result:

print('\nSummary of result object:')
print_object_summary(result)

BurstDetectionResult attributes.

# The main attributes are bursts and ibis, which contain binary masks for bursts and inter-burst-intervals,
# respectively. Some values may be np.nan, due to boundary effects.
print('result.bursts: {}'.format(result.bursts))
print('result.ibis: {}'.format(result.ibis))

BurstDetectionResult methods.

# The bursts/ibis of multiple channels can be aggregated into one array, combining information of all channels.
# This is done by counting only bursts if they occur in more than a certain % of the channels.
print('result.aggregate_bursts: {}'.format(result.aggregate_bursts(min_channels_frac=2/8, min_channels_elong_frac=1/8)))
print('result.aggregate_ibis: {}'.format(result.aggregate_ibis(min_channels_frac=1, min_channels_elong_frac=1)))

# We can convert the aggregated bursts/ibis to a new BurstDetectionResult object:
result_agg = result.to_aggregate_result()
print('result.to_aggregate_result: {}'.format(result_agg))

# We can plot the bursts/ibis:
plt.figure()
ax1 = plt.subplot(2, 1, 1)
result_agg.plot()

# We can plot the original signals.
plt.subplot(2, 1, 2, sharex=ax1)
time_series = [TimeSeries(signal=sig_i, fs=fs, label='Random Ch {}'.format(i+1), unit='uV', info={'source': 'random'})
               for i, sig_i in enumerate(signal)]
ds = EegDataset(time_series)
ds.plot()

# We can convert the bursts/ibis to AnnotationSet object:
burst_annot = result_agg.to_annotation_set()

# We can shade the current axis based on the burst/ibis:
result_agg.shade_axis()

# We can extract global features (characterizing the entire recording):
features = result_agg.extract_global_features()

BurstDetection of EegDataset.

# Specify a file and open a reader to read the data (e.g. EdfReader to read an EDF(+) file).
filepath = 'C:/data_temp/test.edf'

# Read EEG channels in an EegDataset object.
with EdfReader(filepath) as r:
    ds = r.read_eeg_dataset(dtype=np.float32)  # We can specify the dtype, lower precision, means less memory usage.

ds_preprocessed = ds.reference('Cz', remove_reference=False).resample(fs_new=128)

# We can compute the burst detection for multichannel EEG via the wrapper method 'burst_detection' of the EegDataset
# class.
# Specify the parameters to the BurstDetection class as kwargs.
result_ds = ds_preprocessed.burst_detection(method='envelope', method_kwargs={'max_burst_dur': 25},
                                            create_bipolar_channels=True)

# Plot.
plt.figure()
begin = 0
end = None
ds_preprocessed.plot(begin=begin, end=end, scale=200)
result_ds.to_aggregate_result().shade_axis(begin=begin, end=end)

A second example.

# Read EEG channels in an EegDataset object.
with EdfReader(filepath) as r:
    ds = r.read_eeg_dataset()

# Preprocess.
fir_filter = RemezFIR(passband=[1, 20], stopband=[0.5, 21])
ds_preprocessed = ds.reference('Cz').resample(fs_new=250).filtfilt(fir_filter)

# Run algorithm.
bursts = ds_preprocessed.burst_detection(method='line_length')

# Plot.
plt.figure()
ds_preprocessed.plot(begin=940, end=1125, scale=150)
bursts.shade_axis(begin=940, end=1125)