Source code for nnsa.preprocessing.saved_filters

"""
Module for filtering with saved filters, i.e. filters with known coefficients.
"""
import warnings

import scipy.signal as ss
import numpy as np
import os

# When adding new saved filters, add them to __all__.
__all__ = [
    'filter_saved_filter',

    'filter_bandpassfir_a',
    'filter_cnn_ansari_2019',

    'get_bandpassfir_a_coefs',
    'get_eeg_fir_filter_a',
    'get_filter',
    'get_matlab_bandpassfir_coefs',
]

# Create a list of all saved filters (function names) that may be used.
SAVED_FILTERS = []
for filt_name in __all__:
    if filt_name[:6] == 'filter':
        SAVED_FILTERS.append(filt_name)
SAVED_FILTERS.remove('filter_saved_filter')

# Path to the directory with files containing coefficients of saved filters.
SAVED_FILTERS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'filter_coefs')


[docs]def filter_saved_filter(x, filter_name, fs=None, axis=-1, **kwargs): """ Filter the signal x with a saved filter. Args: x (np.ndarray): signal to be filtered. filter_name (str): string specifying which saved filter to use. See the code for options. fs (float, optional): sample frequency of signal to be filtered. Needed for some saved filters that design the filter, see the code which saved_filter requires fs. If the fs is not needed, you need not specify fs. However, if uncertain, you may always pass the sample frequency. Defaults to None. axis (int, optional): the axis of x to which the filter is applied. Defaults to -1. **kwargs (optional): optional keyword arguments that are passed to the function that does the filtering. Returns: signal_filtered (np.ndarray): the filtered signal. """ warnings.warn('Deprecated. Does not handle NaN well. Rewrite.') # Check which saved filter is requested and filter the signal by calling the corresponding function. if filter_name in ['cnn_ansari_2019', 'filter_cnn_ansari_2019']: signal_filtered = filter_cnn_ansari_2019(x, fs=fs, axis=axis, **kwargs) elif filter_name in ['bandpassfir_a', 'filter_bandpassfir_a']: signal_filtered = filter_bandpassfir_a(x, fs=fs, axis=axis, **kwargs) else: raise ValueError('Invalid saved_filter "{}". Choose from: {}.' .format(filter_name, SAVED_FILTERS)) return signal_filtered
[docs]def filter_bandpassfir_a(x, fs, axis=-1, **kwargs): """ Filter the signal x with a FIR bandpass filter with a fixed set of filter specifications defined in get_bandpassfir_a_coefs(). Args: x (np.ndarray): signal array to be filtered. fs (float): sample frequency of x. This is needed in the filter design to compute (or load) the filter coefficients for the given frequency (but maintaining the standard/fixed set of filter specifications). axis (int, optional): the axis of x to which the filter is applied. Default is -1. **kwargs (optional): optional keyword arguments for the scipy.signal.filtfilt() function. Returns: (np.ndarray): filtered signal array (has same size as x). """ # Get the filter coefficients (for FIR filters, the coefficients correspond to the b parameter). b = get_bandpassfir_a_coefs(fs=fs) # For FIR filters, the a parameter equals 1. a = 1 # Filter. y = ss.filtfilt(b, a, x, axis=axis, **kwargs) return y
[docs]def filter_cnn_ansari_2019(x, fs, axis=-1, **kwargs): """ Filter the signal x with the same filter used by Ansari 2019 for 2-class sleep stage classification using a CNN. The sos coefficients are hard-coded and copied from MATLABs output of: bpfilt = designfilt('bandpassiir',... 'PassbandFrequency1',1,'PassbandFrequency2',20, ... 'StopbandFrequency1',0.01,'StopbandFrequency2', 27,... 'PassbandRipple',1,... 'StopbandAttenuation1',40,'StopbandAttenuation2',40,... 'SampleRate',fs); Note that differences between SciPy's sosfiltfilt and Matlab's filtfilt result in differences at the borders of the signal. Fortunately, SciPy's output looks better than Matlab's output near the borders. Args: x (np.ndarray): signal array to be filtered. fs (int): sampling frequency of `x`. axis (int, optional): the axis of x to which the filter is applied. Default is -1. **kwargs (optional): optional keyword arguments for the scipy.signal.sosfiltfilt function. Returns: (np.ndarray): filtered signal array (has same size as x). """ # sos coefficients for this filter (output of Matlab's designfilt with the specific inputs that were used by Amir). if fs == 250: sos = np.array([ [0.243216182622267, 0, -0.243216182622267, 1, -1.65436615754352, 0.913668453704997], [0.243216182622267, 0, -0.243216182622267, 1, -1.99514097253079, 0.99571270962927], [0.23311774723518, 0, -0.23311774723518, 1, -1.52768498803376, 0.764223818147241], [0.23311774723518, 0, -0.23311774723518, 1, -1.98664322526997, 0.987219017806187], [0.224671708470001, 0, -0.224671708470001, 1, -1.97830692151704, 0.978893031937156], [0.224671708470001, 0, -0.224671708470001, 1, -1.42736316626388, 0.643204757205783], [0.217806571345434, 0, -0.217806571345434, 1, -1.97025837836953, 0.970860471032313], [0.217806571345434, 0, -0.217806571345434, 1, -1.35004578917496, 0.547513650483623], [0.212438317965548, 0, -0.212438317965548, 1, -1.96273353716245, 0.963355977476836], [0.212438317965548, 0, -0.212438317965548, 1, -1.2926635492432, 0.474376464191818], [0.208488135060687, 0, -0.208488135060687, 1, -1.95613859225961, 0.956783202658562], [0.208488135060687, 0, -0.208488135060687, 1, -1.25253161477321, 0.421530244051553], [0.205891934625482, 0, -0.205891934625482, 1, -1.95108324531832, 0.951747616534758], [0.205891934625482, 0, -0.205891934625482, 1, -1.22735401835247, 0.38726771418154], [0.204605186294117, 0, -0.204605186294117, 1, -1.94828948035172, 0.948965808747501], [0.204605186294117, 0, -0.204605186294117, 1, -1.21530391537097, 0.370433080224339]]) elif fs == 256: sos = np.array([ [0.237574547250801, 0, -0.237574547250801, 1, -1.66804122175605, 0.915499818956205], [0.237574547250801, 0, -0.237574547250801, 1, -1.99526432690052, 0.995810503853529], [0.227906862481412, 0, -0.227906862481412, 1, -1.54275287790504, 0.768849862063498], [0.227906862481412, 0, -0.227906862481412, 1, -1.98695893873719, 0.987509057087125], [0.219802206025772, 0, -0.219802206025772, 1, -1.97880904115827, 0.979369099060719], [0.219802206025772, 0, -0.219802206025772, 1, -1.44318391215637, 0.649753917415089], [0.213202295119539, 0, -0.213202295119539, 1, -1.97093768909577, 0.971513115778716], [0.213202295119539, 0, -0.213202295119539, 1, -1.36621700589838, 0.555377243524725], [0.208033993568186, 0, -0.208033993568186, 1, -1.96357537416794, 0.964170358255492], [0.208033993568186, 0, -0.208033993568186, 1, -1.30894921259301, 0.483129642348542], [0.204226865571592, 0, -0.204226865571592, 1, -1.95711973480638, 0.957736028976683], [0.204226865571592, 0, -0.204226865571592, 1, -1.26881076107805, 0.430869994630954], [0.201722835363977, 0, -0.201722835363977, 1, -1.95216871647257, 0.952804010355437], [0.201722835363977, 0, -0.201722835363977, 1, -1.24358498877612, 0.396965983674931], [0.200481231385749, 0, -0.200481231385749, 1, -1.94943152036553, 0.950078314462072], [0.200481231385749, 0, -0.200481231385749, 1, -1.23149731763112, 0.380302306378394]]) else: raise NotImplementedError(f'Not implemented for fs={fs}. Either resample the signal to 250 or 256 Hz ' f'or add MATLABs sos filter coefficients to this function.') # Filter the signal and return. x_filt = ss.sosfiltfilt(sos, x, axis=axis, **kwargs) return x_filt
[docs]def get_bandpassfir_a_coefs(fs): """ Return the filter coefficients for a FIR bandpass filter with a standard/fixed set of specifications, which are hardcoded in this function. Args: fs (float): sample frequency of x. This is needed in the filter design to compute (or load) the filter coefficients for the given frequency (but maintaining the standard/fixed set of filter specifications). Returns: coefs (np.ndarray): filter coefficients for the FIR bandpass filter. """ # Specify fixed filter parameters for bandpassfir_a. filter_parameters = { 'pass_low': 0.6, 'pass_high': 40, 'stop_low': 0.1, 'stop_high': 47, 'db_low': 40, 'db_high': 40, 'pass_ripple': 1, 'fs': fs } # Get filter coefficients from MATLAB. coefs = get_matlab_bandpassfir_coefs(filter_parameters) return coefs
[docs]def get_filter(which): """ Helper function to get the filter object of a pre-defined bandpass filter. Returns: fir (nnsa.FilterBase): nnsa FilterBase-derived object. """ which = which.lower() from nnsa.preprocessing.filter import RemezFIR, Butterworth if which == '1-40': filt = RemezFIR(passband=[0.6, 40], stopband=[0.1, 47], passband_ripple=1, stopband_attenuation=40) elif which == '1-32': filt = RemezFIR(passband=[0.5, 32], stopband=[0.1, 35], passband_ripple=1, stopband_attenuation=40) elif which == '1-20': filt = RemezFIR(passband=[1, 20], stopband=[0.1, 27], passband_ripple=1, stopband_attenuation=40) elif which == '2-16': filt = RemezFIR(passband=[2, 16], stopband=[1, 20], passband_ripple=1, stopband_attenuation=40) elif which == 'BrainRT'.lower(): filt = Butterworth(fn=[0.27, 30], order=1) else: raise ValueError('Invaid choice which="{}".'.format(which)) return filt
[docs]def get_eeg_fir_filter_a(**kwargs): """ Return a default FIR filter for filtering EEG. Returns: fir (nnsa.RemezFIR): nnsa filter object. """ from nnsa.preprocessing.filter import RemezFIR fir = RemezFIR(passband=[0.6, 40], stopband=[0.1, 47], passband_ripple=1, stopband_attenuation=40, **kwargs) return fir
def get_eeg_fir_filter_b(**kwargs): """ Return a default FIR filter for filtering EEG. Returns: fir (nnsa.RemezFIR): nnsa filter object. """ from nnsa.preprocessing.filter import RemezFIR fir = RemezFIR(passband=[1, 20], stopband=[0.01, 27], passband_ripple=1, stopband_attenuation=40, **kwargs) return fir def get_eeg_fir_filter_c(): """ Return a default FIR filter for filtering EEG. Returns: fir (nnsa.RemezFIR): nnsa filter object. """ from nnsa.preprocessing.filter import RemezFIR fir = RemezFIR(passband=[0.1, 32], stopband=[0.01, 35], passband_ripple=1, stopband_attenuation=40) return fir
[docs]def get_matlab_bandpassfir_coefs(filter_parameters): """ Load the coefficients of the FIR bandpass filter as computed by MATLAB's designfilt with specifications defined in filter_parameters. If it cannot find the file to load the coefficients from, it will automatically create this file by calling designfilt in a MATLAB engine. The name of this file contains the filter_parameters. Args: filter_parameters (dict): dictionary containing parameters for MATLAB's designfilt (see the which items this dictionary must contain). Returns: (np.ndarray): the coefficients of the filter (the Coefficients field of MATLAB's designfilt object). Those correspond to the b coefficients (since the a coefficients equal 1 in a FIR filter). """ # Create filename from parameters. filename = 'bandpassfir_{pass_low:.2f}_{pass_high:.2f}_{stop_low:.2f}_{stop_high:.2f}_' \ '{db_low:.2f}_{db_high:.2f}_{pass_ripple:.2f}_{fs:.2f}.csv'\ .format(**filter_parameters) filter_coefs_path = os.path.join(SAVED_FILTERS_DIR, filename) # Add the created filename to the filter_parameters (in order to pass it to the MATLAB code, and use ). filter_parameters['filter_coefs_path'] = filter_coefs_path # If filename does not exist, compute the filter coefficients with MATLAB and save them to filename. if not os.path.exists(filter_coefs_path): # Generate the MATLAB code based on filter parameters. matlab_code = """ bpfilt = designfilt('bandpassfir', ... 'PassbandFrequency1', {pass_low}, 'PassbandFrequency2', {pass_high}, ... 'StopbandFrequency1', {stop_low}, 'StopbandFrequency2', {stop_high}, ... 'StopbandAttenuation1', {db_low}, 'StopbandAttenuation2', {db_high}, ... 'PassbandRipple', {pass_ripple}, ... 'SampleRate', {fs}) info(bpfilt) % Write coefficients to csv. writematrix(bpfilt.Coefficients, '{filter_coefs_path}') """.format(**filter_parameters) from nnsa.matlab.utils import matlab_engine with matlab_engine() as eng: eng.eval(matlab_code, nargout=0) # Load the file with coefficients and return. return np.genfromtxt(filter_coefs_path, delimiter=',')