Robust fba ========== This script demonstrates how to apply the (robust) FBA algorithm. Author: Tim Hermans (tim-hermans@hotmail.com). Link to script: `feature_extraction/robust_fba.py `_ .. code-block:: python import numpy as np from nnsa import CleanDetectorCnn, BrainAgeSinc from nnsa.utils.arrays import stepwise_reduce Options. .. code-block:: python # Select the version of the FBA model. Choose 'v1' or 'v2': # 'v1' is the original Sinc model trained by Amir and Kirubin. # 'v2' is the model retrained with more older neonates. which_model = 'v2' Load EEG data. .. code-block:: python # Typically the sampling frequency is around 250 Hz. fs = 250 # For the brain age model, the following channels are needed: channel_labels = ['Fp1', 'Fp2', 'C3', 'C4', 'T3', 'T4', 'O1', 'O2'] # Create random numbers to simulate 30-minutes of EEG data (not realistic at all). # With shape (n_time, n_channels). np.random.seed(43) eeg = (np.random.rand(fs*30*60, len(channel_labels)) - 0.5)*300 Initialize. .. code-block:: python # Initiate a BrainAgeSinc object. brain_age_sinc = BrainAgeSinc(which=which_model) # We can check the data requirements to check the channel_order and reference_channel. # 1) make sure that your EEG data consists of the same channels and in the same order. # 2) make sure that your EEG data is (re-)referenced correspondingly. print('Data requirements:', brain_age_sinc.data_requirements) Process. .. code-block:: python # If all is ok, we can pass the (raw) EEG data to the process function. # The `axis` parameter specifies the time axis of `eeg`. # In case of a long recording and memory may be an issue, you can set `batch_size` # to an integer (e.g. 7200) to set the number of segments processed at a time to reduce memory usage. result = brain_age_sinc.process(eeg, fs, batch_size=None, axis=0, verbose=2) # The result is a BrainAgeResult object. print(result) # The predictions can be directly obtained by result.y_pred, # which has shape (n_segments, n_models). # Note that we are using an ensemble of 10 models, so n_models=10. y_pred = result.y_pred # We can aggregate the predictions to get a naive estimate of the FBA: fba_naive = np.nanmedian(y_pred) print('-'*100) print('Naive FBA: {:.2f} weeks'.format(fba_naive)) print('-'*100) Robust FBA. .. code-block:: python # For version v2 of the model, the results also contains result.is_novelty which indicate # which segments were out-of-distribution with respect to the train data (i.e., input segments # that are different from the train data). We can discard these segments when taking the # median to get a more robust estimate of FBA, as implemented in the robust_fba() method. fba_robust = result.robust_fba() # This result contains a median FBA of only the reliable (non-novelty) segments. print('-'*100) print('Robust FBA (novelties discarded):') print(fba_robust) print('-'*100) Detect artefacts for even more robust FBA. .. code-block:: python # It's also worth discarding segments that contain artefacts. # We first detect artefacts using an automated method. In this case a CNN is used, # see also: https://gitlab.com/timhermans/artefact_detection_public). cd = CleanDetectorCnn(multi_channel=True) # If you are using different montages/channels, you may get better results setting multi-channel to False. # This predicts an array with the same shape as `eeg` where the values are zeros at locations # where it detected an artefact and ones where it did not detect an artefact. clean_mask, probs = cd.predict( eeg, fs=fs, preprocess=True) # Convert the clean mask to an artefact mask. Shape (n_time, n_channels). af_mask = clean_mask == 0 # Compute the ACI (artefact contamination index) per 30s segment per channel. # This is the percentage of samples in a segment classified as artefact. # Has shape (n_segments, n_channels). aci_per_chan = stepwise_reduce( af_mask, window=30, fs=fs, reduce_fun=np.nanmean, axis=0)[0] * 100 # We further average over channels to get ACI per segment (n_segments,). aci_per_seg = np.nanmean(aci_per_chan, axis=1) # Again we call robust_fba() but now we pass the ACI vales. fba_robust = result.robust_fba(aci_per_seg=aci_per_seg) # This result contains a median FBA of only the reliable (non-novelty and non-artefact) segments. print('-'*100) print('Robust FBA (novelties and artefacts discarded):') print(fba_robust) print('-'*100)