Robust fba

This script demonstrates how to apply the (robust) FBA algorithm.

Author: Tim Hermans (tim-hermans@hotmail.com).

Link to script: feature_extraction/robust_fba.py

import numpy as np

from nnsa import CleanDetectorCnn, BrainAgeSinc
from nnsa.utils.arrays import stepwise_reduce

Options.

# Select the version of the FBA model. Choose 'v1' or 'v2':
# 'v1' is the original Sinc model trained by Amir and Kirubin.
# 'v2' is the model retrained with more older neonates.

which_model = 'v2'

Load EEG data.

# Typically the sampling frequency is around 250 Hz.
fs = 250

# For the brain age model, the following channels are needed:
channel_labels = ['Fp1', 'Fp2', 'C3', 'C4', 'T3', 'T4', 'O1', 'O2']

# Create random numbers to simulate 30-minutes of EEG data (not realistic at all).
# With shape (n_time, n_channels).
np.random.seed(43)
eeg = (np.random.rand(fs*30*60, len(channel_labels)) - 0.5)*300

Initialize.

# Initiate a BrainAgeSinc object.
brain_age_sinc = BrainAgeSinc(which=which_model)

# We can check the data requirements to check the channel_order and reference_channel.
# 1) make sure that your EEG data consists of the same channels and in the same order.
# 2) make sure that your EEG data is (re-)referenced correspondingly.
print('Data requirements:', brain_age_sinc.data_requirements)

Process.

# If all is ok, we can pass the (raw) EEG data to the process function.
# The `axis` parameter specifies the time axis of `eeg`.
# In case of a long recording and memory may be an issue, you can set `batch_size`
# to an integer (e.g. 7200) to set the number of segments processed at a time to reduce memory usage.
result = brain_age_sinc.process(eeg, fs, batch_size=None, axis=0, verbose=2)

# The result is a BrainAgeResult object.
print(result)

# The predictions can be directly obtained by result.y_pred,
# which has shape (n_segments, n_models).
# Note that we are using an ensemble of 10 models, so n_models=10.
y_pred = result.y_pred

# We can aggregate the predictions to get a naive estimate of the FBA:
fba_naive = np.nanmedian(y_pred)
print('-'*100)
print('Naive FBA: {:.2f} weeks'.format(fba_naive))
print('-'*100)

Robust FBA.

# For version v2 of the model, the results also contains result.is_novelty which indicate
# which segments were out-of-distribution with respect to the train data (i.e., input segments
# that are different from the train data). We can discard these segments when taking the
# median to get a more robust estimate of FBA, as implemented in the robust_fba() method.
fba_robust = result.robust_fba()

# This result contains a median FBA of only the reliable (non-novelty) segments.
print('-'*100)
print('Robust FBA (novelties discarded):')
print(fba_robust)
print('-'*100)

Detect artefacts for even more robust FBA.

# It's also worth discarding segments that contain artefacts.
# We first detect artefacts using an automated method. In this case a CNN is used,
# see also: https://gitlab.com/timhermans/artefact_detection_public).
cd = CleanDetectorCnn(multi_channel=True)  # If you are using different montages/channels, you may get better results setting multi-channel to False.

# This predicts an array with the same shape as `eeg` where the values are zeros at locations
# where it detected an artefact and ones where it did not detect an artefact.
clean_mask, probs = cd.predict(
    eeg, fs=fs, preprocess=True)

# Convert the clean mask to an artefact mask. Shape (n_time, n_channels).
af_mask = clean_mask == 0

# Compute the ACI (artefact contamination index) per 30s segment per channel.
# This is the percentage of samples in a segment classified as artefact.
# Has shape (n_segments, n_channels).
aci_per_chan = stepwise_reduce(
    af_mask, window=30, fs=fs, reduce_fun=np.nanmean, axis=0)[0] * 100

# We further average over channels to get ACI per segment (n_segments,).
aci_per_seg = np.nanmean(aci_per_chan, axis=1)

# Again we call robust_fba() but now we pass the ACI vales.
fba_robust = result.robust_fba(aci_per_seg=aci_per_seg)

# This result contains a median FBA of only the reliable (non-novelty and non-artefact) segments.
print('-'*100)
print('Robust FBA (novelties and artefacts discarded):')
print(fba_robust)
print('-'*100)