Sleep stages 4class
This script demonstrates how to use the nnsa package to compute robust 4-class sleep stages as described in the thesis of Tim Hermans. 4-class sleep staging is only suitable for neonates >= 36 weeks PMA.
Author: Tim Hermans (tim-hermans@hotmail.com).
Link to script: feature_extraction/sleep_stages_4class.py
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import nnsa
Load EEG data.
# Typically the sampling frequency is around 250 Hz.
fs = 250
# For the sleep model, the following channels are needed:
channel_labels = ['Fp1', 'Fp2', 'C3', 'C4', 'T3', 'T4', 'O1', 'O2']
# Create random numbers to simulate 30-minutes of EEG data (not realistic at all).
# With shape (n_time, n_channels).
np.random.seed(43)
eeg = (np.random.rand(fs*30*60, len(channel_labels)) - 0.5)*300
# Make voltage increase over time.
eeg *= (np.arange(len(eeg))/len(eeg)).reshape(-1, 1)
Initialize.
# Initiate a SleepStagesRobust object.
sleep_stager = nnsa.SleepStagesRobust()
# We can check the data requirements to check the channel_order and reference_channel.
# 1) make sure that your EEG data consists of the same channels and in the same order.
# 2) make sure that your EEG data is (re-)referenced correspondingly.
print('Data requirements:', sleep_stager.data_requirements)
Process.
# If all is ok, we can pass the (raw) EEG data to the process function.
# In case of a long recording and memory may be an issue, you can set `batch_size`
# to an integer (e.g. 7200) to set the number of segments processed at a time to reduce memory usage.
result = sleep_stager.process(eeg, fs, verbose=2)
# The result is a SleepStagesRobustResult object.
# The sleep stages are contained in the attribute `df` as a pandas dataframe:
df = result.df
print(df)
# Each row in the dataframe is a segment.
# The dataframe has the following columns:
# sleep_label_cnn: the sleep label predicted by the CNN (i.e., the predictions used in Ansari's paper).
# sleep_label_hmm: the sleep label predicted by the HMM (a postprocessing of the CNN output).
# quality_label: a label indicating the quality of the segment.
# sleep_label_robust: the final label for each segment (combines sleep_label_hmm and quality_label).
# is_sleep: bool indicating which segments are sleep (True) or non-sleep (False), i.e. artefact/wake/uncertain.
# is_usable: bool indicating is the segment belongs to a longer epoch with predominantly good-quality sleep segments.
Simple plot of the results.
fig, ax = plt.subplots(tight_layout=True)
sns.lineplot(x='start_time', y='sleep_label_cnn', data=df, ax=ax, label='CNN')
sns.lineplot(x='start_time', y='sleep_label_hmm', data=df, ax=ax, label='HMM')
sns.lineplot(x='start_time', y='sleep_label_robust', data=df, ax=ax, label='Robust (final)')
ax.set_xlabel('Time onset (s)')
ax.set_ylabel('Sleep label')
nnsa.format_time_axis()
The sleep results can easily be saved using pandas’ save function (e.g. as csv, xlsx, …).
fp_out = 'test.xlsx'
df.to_excel(fp_out, index=False)
df_loaded = pd.read_excel(fp_out)
# The loaded and original dataframes should have the same values.
assert (df == df_loaded).all().all()
Clean up. Remove the saved file.
os.remove(fp_out)