import sys
import numpy as np
import pyprind
from nnsa.feature_extraction.brain_age_cnn import BrainAgeResult
from nnsa.parameters.parameters import ClassWithParameters, Parameters
from nnsa.feature_extraction.common import check_multichannel_data_matrix
from nnsa.utils.segmentation import segment_generator, compute_n_segments
__all__ = [
'BrainAge',
'BrainAgeResult',
]
[docs]class BrainAge(ClassWithParameters):
"""
Class for prediction of the (functional) brain age based on EEG-data.
References:
K. Pillay, A. Dereymaeker, K. Jansen, G. Naulaers, and M. D. Vos,
“Applying a data-driven approach to quantify EEG maturational deviations in preterms with normal and
abnormal neurodevelopmental outcomes,”
Scientific Reports, vol. 10, no. 1, Apr. 2020,
doi: 10.1038/s41598-020-64211-0.
Args:
**kwargs (optional): see nnsa.ClassWithParameters.
Examples:
TODO
"""
[docs] @staticmethod
def default_parameters():
"""
Return the default parameters.
Returns:
(nnsa.Parameters): a default set of parameters for the object.
"""
# Parameters for segmentation of EEG recording into large epochs.
segmentation = Parameters(**{
# Segment length in seconds.
# If None, uses the entire recording to predict one brain age.
'segment_length': None,
# Overlap in segments in seconds:
'overlap': 0*3600,
})
# The method/algorithm to use for computation of the brain age.
# Choose from 'Kirubin'*, .
# *these options run on MATLAB in the background.
# See the functions in this module for each of these methods for more information
# (e.g. compute_brain_age_kirubin()).
method = 'Kirubin'
# Optional additional keyword arguments/parameters for the method/function that computes the brain age.
# These keyword arguments depend on the method specified above. E.g. if method is set to 'Kirubin', see the
# function compute_brain_age_kirubin() for the optional keyword argument that you can specify here:
method_kwargs = {}
pars = {
'segmentation': segmentation,
'method': method,
'method_kwargs': method_kwargs,
}
return Parameters(**pars)
[docs] def brain_age(self, data_matrix, fs, channel_labels, verbose=1):
"""
Predict the brain age.
Args:
data_matrix (np.ndarray): EEG data (channels, time), see check_multichannel_data_matrix().
fs (float): sample frequency of EEG data.
channel_labels (list): list with channels labels of data_matrix, see check_multichannel_data_matrix().
verbose (int, optional): verbose level.
Defaults to 1.
Returns:
result (BrainAgeResult): the predicted brain age result.
"""
# Check input.
data_matrix, channel_labels = check_multichannel_data_matrix(data_matrix, channel_labels)
# Extract some parameters.
seg_pars = self.parameters['segmentation']
method = self.parameters['method'].lower()
method_kwargs = self.parameters['method_kwargs']
if seg_pars['segment_length'] is None:
# Use entire signal.
seg_generator = [data_matrix]
n_segments = 1
else:
# Segment the data in the time axis (create a generator).
seg_generator = segment_generator(data_matrix, segment_length=seg_pars['segment_length'],
overlap=seg_pars['overlap'], fs=fs, axis=-1)
n_segments = compute_n_segments(data_matrix, segment_length=seg_pars['segment_length'],
overlap=seg_pars['overlap'], fs=fs, axis=-1)
# Initialize progress bar.
bar = pyprind.ProgBar(n_segments, stream=sys.stdout)
# Loop over segments.
brain_ages = np.zeros(n_segments)
for i_seg, seg in enumerate(seg_generator):
if verbose > 0:
print('\nPart {} of {}.'.format(i_seg+1, n_segments))
# Compute brain age of segment with requested method.
if method == 'kirubin':
from nnsa.feature_extraction.old.brain_age import compute_brain_age_kirubin
ba = compute_brain_age_kirubin(seg, fs=fs, channel_labels=channel_labels,
verbose=verbose, **method_kwargs)
else:
raise ValueError('Invalid method "{}". Choose from {}.'
.format(method, ['Kirubin']))
# Store in array.
brain_ages[i_seg] = ba
# Update progress bar.
if verbose > 0:
bar.update()
# Create result object.
result = BrainAgeResult(brain_age=brain_ages,
algorithm_parameters=self.parameters)
return result
[docs] def process(self, *args, **kwargs):
return self.brain_age(*args, **kwargs)