"""
Module for filtering with saved filters, i.e. filters with known coefficients.
"""
import warnings
import scipy.signal as ss
import numpy as np
import os
# When adding new saved filters, add them to __all__.
__all__ = [
'filter_saved_filter',
'filter_bandpassfir_a',
'filter_cnn_ansari_2019',
'get_bandpassfir_a_coefs',
'get_eeg_fir_filter_a',
'get_filter',
'get_matlab_bandpassfir_coefs',
]
# Create a list of all saved filters (function names) that may be used.
SAVED_FILTERS = []
for filt_name in __all__:
if filt_name[:6] == 'filter':
SAVED_FILTERS.append(filt_name)
SAVED_FILTERS.remove('filter_saved_filter')
# Path to the directory with files containing coefficients of saved filters.
SAVED_FILTERS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'filter_coefs')
[docs]def filter_saved_filter(x, filter_name, fs=None, axis=-1, **kwargs):
"""
Filter the signal x with a saved filter.
Args:
x (np.ndarray): signal to be filtered.
filter_name (str): string specifying which saved filter to use. See the code for options.
fs (float, optional): sample frequency of signal to be filtered. Needed for some saved filters that design
the filter, see the code which saved_filter requires fs. If the fs is not needed, you need not specify fs.
However, if uncertain, you may always pass the sample frequency.
Defaults to None.
axis (int, optional): the axis of x to which the filter is applied.
Defaults to -1.
**kwargs (optional): optional keyword arguments that are passed to the function that does the filtering.
Returns:
signal_filtered (np.ndarray): the filtered signal.
"""
warnings.warn('Deprecated. Does not handle NaN well. Rewrite.')
# Check which saved filter is requested and filter the signal by calling the corresponding function.
if filter_name in ['cnn_ansari_2019', 'filter_cnn_ansari_2019']:
signal_filtered = filter_cnn_ansari_2019(x, fs=fs, axis=axis, **kwargs)
elif filter_name in ['bandpassfir_a', 'filter_bandpassfir_a']:
signal_filtered = filter_bandpassfir_a(x, fs=fs, axis=axis, **kwargs)
else:
raise ValueError('Invalid saved_filter "{}". Choose from: {}.'
.format(filter_name, SAVED_FILTERS))
return signal_filtered
[docs]def filter_bandpassfir_a(x, fs, axis=-1, **kwargs):
"""
Filter the signal x with a FIR bandpass filter with a fixed set of filter specifications defined in
get_bandpassfir_a_coefs().
Args:
x (np.ndarray): signal array to be filtered.
fs (float): sample frequency of x. This is needed in the filter design to compute (or load) the filter
coefficients for the given frequency (but maintaining the standard/fixed set of filter specifications).
axis (int, optional): the axis of x to which the filter is applied. Default is -1.
**kwargs (optional): optional keyword arguments for the scipy.signal.filtfilt() function.
Returns:
(np.ndarray): filtered signal array (has same size as x).
"""
# Get the filter coefficients (for FIR filters, the coefficients correspond to the b parameter).
b = get_bandpassfir_a_coefs(fs=fs)
# For FIR filters, the a parameter equals 1.
a = 1
# Filter.
y = ss.filtfilt(b, a, x, axis=axis, **kwargs)
return y
[docs]def filter_cnn_ansari_2019(x, fs, axis=-1, **kwargs):
"""
Filter the signal x with the same filter used by Ansari 2019 for 2-class sleep stage classification using a CNN.
The sos coefficients are hard-coded and copied from MATLABs output of:
bpfilt = designfilt('bandpassiir',...
'PassbandFrequency1',1,'PassbandFrequency2',20, ...
'StopbandFrequency1',0.01,'StopbandFrequency2', 27,...
'PassbandRipple',1,...
'StopbandAttenuation1',40,'StopbandAttenuation2',40,...
'SampleRate',fs);
Note that differences between SciPy's sosfiltfilt and Matlab's filtfilt result in differences at the borders of the
signal. Fortunately, SciPy's output looks better than Matlab's output near the borders.
Args:
x (np.ndarray): signal array to be filtered.
fs (int): sampling frequency of `x`.
axis (int, optional): the axis of x to which the filter is applied. Default is -1.
**kwargs (optional): optional keyword arguments for the scipy.signal.sosfiltfilt function.
Returns:
(np.ndarray): filtered signal array (has same size as x).
"""
# sos coefficients for this filter (output of Matlab's designfilt with the specific inputs that were used by Amir).
if fs == 250:
sos = np.array([
[0.243216182622267, 0, -0.243216182622267, 1, -1.65436615754352, 0.913668453704997],
[0.243216182622267, 0, -0.243216182622267, 1, -1.99514097253079, 0.99571270962927],
[0.23311774723518, 0, -0.23311774723518, 1, -1.52768498803376, 0.764223818147241],
[0.23311774723518, 0, -0.23311774723518, 1, -1.98664322526997, 0.987219017806187],
[0.224671708470001, 0, -0.224671708470001, 1, -1.97830692151704, 0.978893031937156],
[0.224671708470001, 0, -0.224671708470001, 1, -1.42736316626388, 0.643204757205783],
[0.217806571345434, 0, -0.217806571345434, 1, -1.97025837836953, 0.970860471032313],
[0.217806571345434, 0, -0.217806571345434, 1, -1.35004578917496, 0.547513650483623],
[0.212438317965548, 0, -0.212438317965548, 1, -1.96273353716245, 0.963355977476836],
[0.212438317965548, 0, -0.212438317965548, 1, -1.2926635492432, 0.474376464191818],
[0.208488135060687, 0, -0.208488135060687, 1, -1.95613859225961, 0.956783202658562],
[0.208488135060687, 0, -0.208488135060687, 1, -1.25253161477321, 0.421530244051553],
[0.205891934625482, 0, -0.205891934625482, 1, -1.95108324531832, 0.951747616534758],
[0.205891934625482, 0, -0.205891934625482, 1, -1.22735401835247, 0.38726771418154],
[0.204605186294117, 0, -0.204605186294117, 1, -1.94828948035172, 0.948965808747501],
[0.204605186294117, 0, -0.204605186294117, 1, -1.21530391537097, 0.370433080224339]])
elif fs == 256:
sos = np.array([
[0.237574547250801, 0, -0.237574547250801, 1, -1.66804122175605, 0.915499818956205],
[0.237574547250801, 0, -0.237574547250801, 1, -1.99526432690052, 0.995810503853529],
[0.227906862481412, 0, -0.227906862481412, 1, -1.54275287790504, 0.768849862063498],
[0.227906862481412, 0, -0.227906862481412, 1, -1.98695893873719, 0.987509057087125],
[0.219802206025772, 0, -0.219802206025772, 1, -1.97880904115827, 0.979369099060719],
[0.219802206025772, 0, -0.219802206025772, 1, -1.44318391215637, 0.649753917415089],
[0.213202295119539, 0, -0.213202295119539, 1, -1.97093768909577, 0.971513115778716],
[0.213202295119539, 0, -0.213202295119539, 1, -1.36621700589838, 0.555377243524725],
[0.208033993568186, 0, -0.208033993568186, 1, -1.96357537416794, 0.964170358255492],
[0.208033993568186, 0, -0.208033993568186, 1, -1.30894921259301, 0.483129642348542],
[0.204226865571592, 0, -0.204226865571592, 1, -1.95711973480638, 0.957736028976683],
[0.204226865571592, 0, -0.204226865571592, 1, -1.26881076107805, 0.430869994630954],
[0.201722835363977, 0, -0.201722835363977, 1, -1.95216871647257, 0.952804010355437],
[0.201722835363977, 0, -0.201722835363977, 1, -1.24358498877612, 0.396965983674931],
[0.200481231385749, 0, -0.200481231385749, 1, -1.94943152036553, 0.950078314462072],
[0.200481231385749, 0, -0.200481231385749, 1, -1.23149731763112, 0.380302306378394]])
else:
raise NotImplementedError(f'Not implemented for fs={fs}. Either resample the signal to 250 or 256 Hz '
f'or add MATLABs sos filter coefficients to this function.')
# Filter the signal and return.
x_filt = ss.sosfiltfilt(sos, x, axis=axis, **kwargs)
return x_filt
[docs]def get_bandpassfir_a_coefs(fs):
"""
Return the filter coefficients for a FIR bandpass filter with a standard/fixed set of specifications, which are
hardcoded in this function.
Args:
fs (float): sample frequency of x. This is needed in the filter design to compute (or load) the filter
coefficients for the given frequency (but maintaining the standard/fixed set of filter specifications).
Returns:
coefs (np.ndarray): filter coefficients for the FIR bandpass filter.
"""
# Specify fixed filter parameters for bandpassfir_a.
filter_parameters = {
'pass_low': 0.6,
'pass_high': 40,
'stop_low': 0.1,
'stop_high': 47,
'db_low': 40,
'db_high': 40,
'pass_ripple': 1,
'fs': fs
}
# Get filter coefficients from MATLAB.
coefs = get_matlab_bandpassfir_coefs(filter_parameters)
return coefs
[docs]def get_filter(which):
"""
Helper function to get the filter object of a pre-defined bandpass filter.
Returns:
fir (nnsa.FilterBase): nnsa FilterBase-derived object.
"""
which = which.lower()
from nnsa.preprocessing.filter import RemezFIR, Butterworth
if which == '1-40':
filt = RemezFIR(passband=[0.6, 40],
stopband=[0.1, 47],
passband_ripple=1,
stopband_attenuation=40)
elif which == '1-32':
filt = RemezFIR(passband=[0.5, 32],
stopband=[0.1, 35],
passband_ripple=1,
stopband_attenuation=40)
elif which == '1-20':
filt = RemezFIR(passband=[1, 20],
stopband=[0.1, 27],
passband_ripple=1,
stopband_attenuation=40)
elif which == '2-16':
filt = RemezFIR(passband=[2, 16],
stopband=[1, 20],
passband_ripple=1,
stopband_attenuation=40)
elif which == 'BrainRT'.lower():
filt = Butterworth(fn=[0.27, 30], order=1)
else:
raise ValueError('Invaid choice which="{}".'.format(which))
return filt
[docs]def get_eeg_fir_filter_a(**kwargs):
"""
Return a default FIR filter for filtering EEG.
Returns:
fir (nnsa.RemezFIR): nnsa filter object.
"""
from nnsa.preprocessing.filter import RemezFIR
fir = RemezFIR(passband=[0.6, 40],
stopband=[0.1, 47],
passband_ripple=1,
stopband_attenuation=40,
**kwargs)
return fir
def get_eeg_fir_filter_b(**kwargs):
"""
Return a default FIR filter for filtering EEG.
Returns:
fir (nnsa.RemezFIR): nnsa filter object.
"""
from nnsa.preprocessing.filter import RemezFIR
fir = RemezFIR(passband=[1, 20],
stopband=[0.01, 27],
passband_ripple=1,
stopband_attenuation=40,
**kwargs)
return fir
def get_eeg_fir_filter_c():
"""
Return a default FIR filter for filtering EEG.
Returns:
fir (nnsa.RemezFIR): nnsa filter object.
"""
from nnsa.preprocessing.filter import RemezFIR
fir = RemezFIR(passband=[0.1, 32],
stopband=[0.01, 35],
passband_ripple=1,
stopband_attenuation=40)
return fir
[docs]def get_matlab_bandpassfir_coefs(filter_parameters):
"""
Load the coefficients of the FIR bandpass filter as computed by MATLAB's designfilt with specifications defined
in filter_parameters. If it cannot find the file to load the coefficients from, it will automatically create this
file by calling designfilt in a MATLAB engine. The name of this file contains the filter_parameters.
Args:
filter_parameters (dict): dictionary containing parameters for MATLAB's designfilt (see the which items this
dictionary must contain).
Returns:
(np.ndarray): the coefficients of the filter (the Coefficients field of MATLAB's designfilt object).
Those correspond to the b coefficients (since the a coefficients equal 1 in a FIR filter).
"""
# Create filename from parameters.
filename = 'bandpassfir_{pass_low:.2f}_{pass_high:.2f}_{stop_low:.2f}_{stop_high:.2f}_' \
'{db_low:.2f}_{db_high:.2f}_{pass_ripple:.2f}_{fs:.2f}.csv'\
.format(**filter_parameters)
filter_coefs_path = os.path.join(SAVED_FILTERS_DIR, filename)
# Add the created filename to the filter_parameters (in order to pass it to the MATLAB code, and use ).
filter_parameters['filter_coefs_path'] = filter_coefs_path
# If filename does not exist, compute the filter coefficients with MATLAB and save them to filename.
if not os.path.exists(filter_coefs_path):
# Generate the MATLAB code based on filter parameters.
matlab_code = """
bpfilt = designfilt('bandpassfir', ...
'PassbandFrequency1', {pass_low}, 'PassbandFrequency2', {pass_high}, ...
'StopbandFrequency1', {stop_low}, 'StopbandFrequency2', {stop_high}, ...
'StopbandAttenuation1', {db_low}, 'StopbandAttenuation2', {db_high}, ...
'PassbandRipple', {pass_ripple}, ...
'SampleRate', {fs})
info(bpfilt)
% Write coefficients to csv.
writematrix(bpfilt.Coefficients, '{filter_coefs_path}')
""".format(**filter_parameters)
from nnsa.matlab.utils import matlab_engine
with matlab_engine() as eng:
eng.eval(matlab_code, nargout=0)
# Load the file with coefficients and return.
return np.genfromtxt(filter_coefs_path, delimiter=',')