Sleep stages cnn ================ Demonstration code for classification of sleep stages using SleepStagesCnn(). Link to script: `feature_extraction/sleep_stages_cnn.py `_ .. code-block:: python import os import matplotlib.pyplot as plt import numpy as np from nnsa import EdfReader, print_object_summary, SUPPORTED_RESULT_FILE_TYPES, read_result_from_file, \ assert_equal from nnsa.feature_extraction.sleep_stages import SleepStagesCnn from nnsa.preprocessing.saved_filters import filter_saved_filter from nnsa.utils.segmentation import segment_generator plt.close('all') Parameters. .. code-block:: python # Print the default parameters of SleepStagesCnn(): print(SleepStagesCnn().default_parameters()) # Descriptions of the parameters are documented in the default_parameters() code. # Create an instance of the SleepStagesCnn class with custom parameters, overruling some defaults: cnn = SleepStagesCnn(predict_kwargs={'batch_size': 100}) # See if the custom parameters were accepted: print('\nCustom parameters:') print(cnn.parameters) Main method: sleep_stages_cnn. .. code-block:: python # Now that we have initialized a SleepStagesCnn object with certain parameters, we can run the sleep_stages_cnn method, # which performs the actual sleep stage classification. First we create some random data. # The required sample frequency (in Hz) and segment length (in seconds) is given by cnn.data_requirements: fs = cnn.data_requirements['fs'] segment_length = cnn.data_requirements['segment_length'] # Create random signals to simulate a recording of 8 channel EEG data with maximum amplitude of 250. raw_fs = 256 x = np.random.rand(int(3600*raw_fs), 8)*500 - 250 # We can preprocess_eeg the input as required by the algorithm: x, fs = cnn.preprocess_recording(raw_eeg=x, raw_fs=raw_fs) # We need to segment the data in 30 second segments. seg_generator = segment_generator(x, segment_length=segment_length, overlap=0, fs=fs, axis=0) x = np.asarray([seg for seg in seg_generator]) # Dimensions (segments, time, channels). # Run the CNN. result = cnn.sleep_stages_cnn(x=x, verbose=1) # Not that the verbose arg is optional. # The returned object is a SleepStagesCnnResult object, which is a high-level interface for # manipulating/visualizing/saving the result: print('\nSummary of result object:') print_object_summary(result) SleepStagesCnnResult attributes. .. code-block:: python # The main attribute is probabilities, which contains the probabilities per class per segment: print('result.probabilities: {}'.format(result.probabilities)) # The shape of probabilities corresponds to (classes, segments): print('result.probabilities.shape: {}. Corresponds to (classes, segments).'.format(result.probabilities.shape)) # The labels of the classes are stored in the class_labels attribute: print('result.class_labels: {}'.format(result.class_labels)) # The time (in seconds) corresponding to the segments in the dat in probabilities: print('result.segment_times: {}'.format(result.segment_times)) SleepStagesCnnResult methods. .. code-block:: python # Get the class number for each segment, where class number i corresponds to result.class_labels[i]: class_numbers = result.get_classes() print('class_numbers: {}'.format(class_numbers)) # Get probability per segment that it belongs to a certain class: qs_probabilities = result.get_probabilities('QS') print('qs_probabilities: {}'.format(qs_probabilities)) # Get the segment indices that were classified as a certain class: qs_indices = result.get_segment_indices('QS') print('qs_indices: {}'.format(qs_indices)) # Plot probabilities of certain class: plt.figure() result.plot_probabilities('QS') Save the result. .. code-block:: python # We can save the result to any supported file. To list the supported file types/extensions: print('Supported result file types: {}'.format(SUPPORTED_RESULT_FILE_TYPES)) # E.g. save as hdf5 (the file type to save is automatically detected from the extension of the filename). filename = 'temp_result.hdf5' result.save_to_file(filename) # We can reload the result back to a SleepStagesCnnResult object: result_loaded = read_result_from_file(filename) # Remove temporary file. os.remove(filename) # Verify that the loaded object is equal to the original object: assert_equal(actual=result_loaded, desired=result) # Will raise an AssertionError if not equal. print('Loaded result object is equal to original object.') SleepStagesCnn on EegDataset. .. code-block:: python # Specify a file and open a reader to read the data (e.g. EdfReader to read an EDF(+) file). filepath = 'C:/data_temp/test.edf' # Read EEG channels in an EegDataset object. with EdfReader(filepath) as r: ds = r.read_eeg_dataset() # We can do the sleep stage classification via the wrapper method 'speel_stages_cnn' of the EegDataset class. # Preprocessing, i.e. filtering and resampling, can be done in wrapper function. Specify the parameters of the # SleepStagesCnn class as additional kwargs: result_ds = ds.sleep_stages_cnn(predict_kwargs={'batch_size': 1000}) # Again, the returned object is the a SleepStagesCnnResult object: print(result_ds) # A data info string is appended to power_result object, which summarizes the source and preprocessing of the data. print('result_ds.data_info: {}'.format(result_ds.data_info)) # Plot probabilities of certain class: plt.figure() result_ds.plot_probabilities('QS', label='probability QS') # We can convert the result to a SleepStagesResult object, which provides more methods for analyzing the sleep features. # (see the example script nnsa/examples/feature_extraction/sleep_stages.py). sleep_stages = result_ds.to_sleep_stages_result() # E.g. plot the hypnogram: sleep_stages.plot_hypnogram(label='predicted sleep stage') plt.legend()