Sleep stages cnn
Demonstration code for classification of sleep stages using SleepStagesCnn().
Link to script: feature_extraction/sleep_stages_cnn.py
import os
import matplotlib.pyplot as plt
import numpy as np
from nnsa import EdfReader, print_object_summary, SUPPORTED_RESULT_FILE_TYPES, read_result_from_file, \
assert_equal
from nnsa.feature_extraction.sleep_stages import SleepStagesCnn
from nnsa.preprocessing.saved_filters import filter_saved_filter
from nnsa.utils.segmentation import segment_generator
plt.close('all')
Parameters.
# Print the default parameters of SleepStagesCnn():
print(SleepStagesCnn().default_parameters())
# Descriptions of the parameters are documented in the default_parameters() code.
# Create an instance of the SleepStagesCnn class with custom parameters, overruling some defaults:
cnn = SleepStagesCnn(predict_kwargs={'batch_size': 100})
# See if the custom parameters were accepted:
print('\nCustom parameters:')
print(cnn.parameters)
Main method: sleep_stages_cnn.
# Now that we have initialized a SleepStagesCnn object with certain parameters, we can run the sleep_stages_cnn method,
# which performs the actual sleep stage classification. First we create some random data.
# The required sample frequency (in Hz) and segment length (in seconds) is given by cnn.data_requirements:
fs = cnn.data_requirements['fs']
segment_length = cnn.data_requirements['segment_length']
# Create random signals to simulate a recording of 8 channel EEG data with maximum amplitude of 250.
raw_fs = 256
x = np.random.rand(int(3600*raw_fs), 8)*500 - 250
# We can preprocess_eeg the input as required by the algorithm:
x, fs = cnn.preprocess_recording(raw_eeg=x, raw_fs=raw_fs)
# We need to segment the data in 30 second segments.
seg_generator = segment_generator(x, segment_length=segment_length, overlap=0,
fs=fs, axis=0)
x = np.asarray([seg for seg in seg_generator]) # Dimensions (segments, time, channels).
# Run the CNN.
result = cnn.sleep_stages_cnn(x=x, verbose=1) # Not that the verbose arg is optional.
# The returned object is a SleepStagesCnnResult object, which is a high-level interface for
# manipulating/visualizing/saving the result:
print('\nSummary of result object:')
print_object_summary(result)
SleepStagesCnnResult attributes.
# The main attribute is probabilities, which contains the probabilities per class per segment:
print('result.probabilities: {}'.format(result.probabilities))
# The shape of probabilities corresponds to (classes, segments):
print('result.probabilities.shape: {}. Corresponds to (classes, segments).'.format(result.probabilities.shape))
# The labels of the classes are stored in the class_labels attribute:
print('result.class_labels: {}'.format(result.class_labels))
# The time (in seconds) corresponding to the segments in the dat in probabilities:
print('result.segment_times: {}'.format(result.segment_times))
SleepStagesCnnResult methods.
# Get the class number for each segment, where class number i corresponds to result.class_labels[i]:
class_numbers = result.get_classes()
print('class_numbers: {}'.format(class_numbers))
# Get probability per segment that it belongs to a certain class:
qs_probabilities = result.get_probabilities('QS')
print('qs_probabilities: {}'.format(qs_probabilities))
# Get the segment indices that were classified as a certain class:
qs_indices = result.get_segment_indices('QS')
print('qs_indices: {}'.format(qs_indices))
# Plot probabilities of certain class:
plt.figure()
result.plot_probabilities('QS')
Save the result.
# We can save the result to any supported file. To list the supported file types/extensions:
print('Supported result file types: {}'.format(SUPPORTED_RESULT_FILE_TYPES))
# E.g. save as hdf5 (the file type to save is automatically detected from the extension of the filename).
filename = 'temp_result.hdf5'
result.save_to_file(filename)
# We can reload the result back to a SleepStagesCnnResult object:
result_loaded = read_result_from_file(filename)
# Remove temporary file.
os.remove(filename)
# Verify that the loaded object is equal to the original object:
assert_equal(actual=result_loaded, desired=result) # Will raise an AssertionError if not equal.
print('Loaded result object is equal to original object.')
SleepStagesCnn on EegDataset.
# Specify a file and open a reader to read the data (e.g. EdfReader to read an EDF(+) file).
filepath = 'C:/data_temp/test.edf'
# Read EEG channels in an EegDataset object.
with EdfReader(filepath) as r:
ds = r.read_eeg_dataset()
# We can do the sleep stage classification via the wrapper method 'speel_stages_cnn' of the EegDataset class.
# Preprocessing, i.e. filtering and resampling, can be done in wrapper function. Specify the parameters of the
# SleepStagesCnn class as additional kwargs:
result_ds = ds.sleep_stages_cnn(predict_kwargs={'batch_size': 1000})
# Again, the returned object is the a SleepStagesCnnResult object:
print(result_ds)
# A data info string is appended to power_result object, which summarizes the source and preprocessing of the data.
print('result_ds.data_info: {}'.format(result_ds.data_info))
# Plot probabilities of certain class:
plt.figure()
result_ds.plot_probabilities('QS', label='probability QS')
# We can convert the result to a SleepStagesResult object, which provides more methods for analyzing the sleep features.
# (see the example script nnsa/examples/feature_extraction/sleep_stages.py).
sleep_stages = result_ds.to_sleep_stages_result()
# E.g. plot the hypnogram:
sleep_stages.plot_hypnogram(label='predicted sleep stage')
plt.legend()