import os
import numpy as np
import tensorflow as tf
from tqdm import tqdm
[docs]class EnsembleModels():
'''
To load all ensembled Sinc models.
Note that it is not a keras model, but it has a predict funtion with similar inputs/outputs.
Parameters:
===========
CH: the number of eeg channels {1, 2, 4, or 8}
trained_model_directory: the directory in which the ensemble models are located. The name of the models must be 'model[x].h5' (e.g. model1.h5, model2.h5, ...). It maximally supports 1000 models.
verbose: if True, it shows the loading progress; otherwise, it is silent.
'''
def __init__(self, CH, trained_model_directory, verbose=True):
assert(CH in (1,2,4,8))
self.CH = CH
self.verbose = verbose
self.__find_models(trained_model_directory)
if(self.verbose):
print(' ')
def __len__(self):
'''Return the length of the loaded ensmbled models.'''
return len(self.models_paths)
def __find_models(self, trained_model_directory):
self.models_paths = {}
if(trained_model_directory is None):
return
rn = range(1000)
if self.verbose:
rn = tqdm(rn)
for sn in rn:
fname = f'{trained_model_directory}/model{sn}.h5'
if(not os.path.exists(fname)):
continue
self.models_paths[sn] = fname
def __get_models(self, index):
fname = self.models_paths[index]
if(not os.path.exists(fname)):
raise ValueError('This index is not correct!')
model = tf.keras.models.load_model(fname)
return model
[docs] def predict(self, eeg):
'''
To predict the outputs of all loaded ensembled models.
parameters:
===========
eeg: a numpy tensor as [batch, 1920, CH]
OR
a datagenerator returning [batch, 1920, CH]
Return:
=======
PMA numpy matrix as [batch, number of models]
'''
res = []
for i in range(len(self)):
print(f'model {i} out of {len(self)}...')
model = self.__get_models(i)
r = model.predict(eeg, verbose = self.verbose)
res.append(r)
res = np.stack(np.squeeze(res), -1)
print('prediction values are ready.')
return res
[docs] def aggregate(self, pmas, recordings_indices):
'''
This function takes pmas tensor of all recordings and all models as well as the recording indices/names and returns the aggregated PMA per recording in a dictionary.
This can be called after the 'predict' function.
The 'predict_recording' function internally calls this function and returns its output.
'''
res = {}
recs = np.unique(recordings_indices)
for rec in recs:
ii = (recordings_indices == rec)
x = pmas[ii,:]
m = np.median(x,0) # median across epochs
mm = np.median(m, 0) #median across models
res[rec] = mm
return res