Envelope

Demonstration code for envelope computation using Envelope().

Link to script: feature_extraction/envelope.py

import os

import numpy as np
import matplotlib.pyplot as plt

from nnsa import EdfReader, MovingAverage, Envelope, SUPPORTED_RESULT_FILE_TYPES, \
    read_result_from_file, assert_equal, print_object_summary
from nnsa.utils.dictionaries import itemize_items

plt.close('all')

Parameters.

# Print the default parameters of Envelope():
print(Envelope().default_parameters())

# Descriptions of the parameters are documented in the default_parameters() code.

# Create an instance of the BurstDetection class with custom parameters, overruling some defaults:
envelope = Envelope(method='hilbert', method_kwargs={'n': None})

# See if the custom parameters were accepted:
print('\nCustom parameters:')
print(envelope.parameters)

Main method: envelope.

# Now that we have initialized a Envelope object with certain parameters, we can run the envelope
# method, which performs the envelope computation:
# Create 8 signals (8 channels) from sine waves with each a different amplitude (channel i has amplitude i).
fs = 256
t = 1/fs*np.arange(100000)
signal = np.asarray([(i+1) * np.sin(t) for i in range(8)])
result = envelope.envelope(signal, fs=fs)
print('Mean envelope/amplitude of channels:\n{}'.format(result.envelope.mean(axis=-1)))

assert np.all(np.round(result.envelope.mean(axis=-1)) == np.arange(1, 9))

Save the result.

# We can save the result to any supported file. To list the supported file types/extensions:
print('Supported result file types: {}'.format(SUPPORTED_RESULT_FILE_TYPES))

# E.g. save as hdf5 (the file type to save is automatically detected from the extension of the filename).
filename = 'temp_result.hdf5'
result.save_to_file(filename, overwrite=True)

# We can reload the result back to a LineLengthResult object:
result_loaded = read_result_from_file(filename)

# Remove temporary file.
os.remove(filename)

# Verify that the loaded object is equal to the original object:
assert_equal(actual=result_loaded, desired=result)  # Will raise an AssertionError if not equal.
print('Loaded result object is equal to original object.')

The returned object is a EnvelopeResult object, which is a high-level interface for manipulating/visualizing/saving the result:

print('\nSummary of result object:')
print_object_summary(result)

EnvelopeResult attributes.

# The main attribute is envelope, which is the array with the envelope values.
print('result.envelope: {}'.format(result.envelope))

EnvelopeResult methods.

# Baseline correction can be performed on the envelope:
env_baseline_cor = result.baseline_correction_min(window_length=fs)
# Since the envelope is constant over time, this constant will be detected as baseline and subtracted
# from the envelope in baseline correction. Therefore, the envelope values will be (close to) zero.
print('Mean envelope/amplitude of channels after baseline correction:\n{}'
      .format(env_baseline_cor.envelope.mean(axis=-1)))

# The global features can be extracted:
features = result.extract_global_features(concatenate_channels=False)  # Compute features per channel.
print('Global features:')
print(itemize_items(features.items()))

# One channel with envelope values can be converted to a TimeSeries object:
plt.figure()
ts_env = result.to_time_series(channel='Channel 4')
ts_env.plot()

# The envelope arrays can be converted to a EegDataset object:
plt.figure()
ds_env = result.to_eeg_dataset()
ds_env.plot(scale=10)

Envelope on real EegDataset.

filepath = 'C:/data_temp/test.edf'
with EdfReader(filepath) as r:
    ds = r.read_eeg_dataset()

# Filter the EEG data and compute the envelope of the filtered data.
ds_filtered = ds.filter_saved_filter(filter_name='bandpassfir_a')
envelope_result = ds_filtered.envelope(method='hilbert')
ds_envelope = envelope_result.to_eeg_dataset()

# Smoothen the envelope.
moving_average = MovingAverage(fs=None, numtaps=20)
ds_envelope_smooth = ds_envelope.filtfilt(moving_average)

# Plot (using the plot method from EegDataset).
begin = 1000
end = 1100
plt.figure()
ds_filtered.plot(begin=begin, end=end, scale=200, label='original (filtered)')
ds_envelope.plot(begin=begin, end=end, scale=200, linestyle='-', color='r', label='envelope')
ds_envelope_smooth.plot(begin=begin, end=end, scale=200, linestyle='--', color='b', label='envelope (smooth)')
plt.legend(loc='upper right')