Source code for nnsa.feature_extraction.brainagemodel.models.sincnetwork

"""
The Sinc Keras model presented in the paper.
To get the model call 'net' function.
"""


import numpy as np
from tensorflow.keras import initializers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import *
from tensorflow.keras.regularizers import *
from nnsa.feature_extraction.brainagemodel.core.config import Config

def __sinc_block(inpLayer, N, M, l2_reg, act, inc_ind):
    ld=[]
    incname = 'inc{}_'.format(inc_ind)
    ldt = Conv1D(N, 1,name=incname+'conv11_p', padding='same', activation=act, kernel_regularizer=l2(l2_reg))(inpLayer)
    for i in range(int(M)):
        ldt = Conv1D(N, 3,name=incname+'conv_'+str(i), padding='same', activation=act, kernel_regularizer=l2(l2_reg))(ldt)
        ld.append(ldt)

    lm = MaxPooling1D(3, 1,name=incname+'max', padding='same')(inpLayer)
    lm = Conv1D(N, 1,name=incname+'maxconv11', padding='same', activation=act, kernel_regularizer=l2(l2_reg))(lm)

    l1 = Conv1D(N, 1,name=incname+'conv11', padding='same', activation=act, kernel_regularizer=l2(l2_reg))(inpLayer)
    
    return concatenate([l1, lm] + ld,name=incname+'concat', axis = 2)

[docs]def net(config): ''' This function returns the Sinc network presented in the paper. Parameters: =========== config: an object from any class (preferably from brainagemodel.core.config.Config) which includes the following attributes: - fs: sampling frequecy in Hz (according to the paper, it should be 64Hz) - frame_sec: defining the frame length of EEG in seconds. According to the paper it should be 30s. - CH: the number of EEG channels. According to the paper it should be 1, 2, 4, or 8. Example: model = net(Config(fs=64, CH=8, frame_sec=30)) The input signal to the network will be [frame_sec*fs), ch] (paper: [1920(30s*64Hz), 8/4/2/1]) Return: =========== a Keras API-based model (not compiled). ''' #input is [1920(30s*64Hz), ch] assert(hasattr(config,'fs') and hasattr(config,'frame_sec') and hasattr(config,'CH')) input_shape_orig = config.frame_sec * config.fs input_shape = (input_shape_orig, config.CH) can_have_bnorm = (not hasattr(config,'remove_all_batch_neomalization_layers') or not config.remove_all_batch_neomalization_layers) act = 'elu' config.l2_reg = 0.001 np.random.seed(0) initializers.he_normal(seed=1) # Input layers and normalisation input_orig = Input((input_shape_orig, config.CH), name='input_orig') x = input_orig # Add a Gaussian Noise Layer x = GaussianNoise(0.001)(x) x = Conv1D(128, 3, padding='same', activation='elu')(x) x = Conv1D(128, 3, padding='same', activation=None)(x) x = LayerNormalization()(x) x = Activation('elu')(x) x = MaxPooling1D(pool_size=5, padding='same')(x) x = Conv1D(256, 3, padding='same', activation='elu')(x) x = Conv1D(256, 3, padding='same', activation=None)(x) x = LayerNormalization()(x) x = Activation('elu')(x) x = MaxPooling1D(pool_size=4, padding='same')(x) #sinc(inpLayer, N, M, l2_reg, act, inc_ind) x = __sinc_block(x, 64, 5, config.l2_reg, 'elu', 0) x = LayerNormalization()(x) x = __sinc_block(x,64, 5, config.l2_reg, 'elu', 1) x = LayerNormalization()(x) x = AveragePooling1D(pool_size=4, padding='valid')(x) x = Flatten()(x) x = Dropout(0.20)(x) regress = Dense(1, activation='linear')(x) # Initialise model model_init = Model(inputs=[input_orig], outputs=regress) return model_init
#%% if (__name__ == '__main__'): # for example: model = net(Config(fs=64, CH=8, frame_sec=30)) model.summary() # %%