"""
Wavlet-based feature extraction.
"""
import copy
import sys
import warnings
import h5py
import numpy as np
import pyprind
from matplotlib import cm
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.ticker import FuncFormatter
from nnsa.cwt.config import DEFAULT_WAVELET
from nnsa.cwt.mothers import Morlet, get_wavelet
from nnsa.cwt.plotting import plot_scalogram
from nnsa.cwt.transforms import compute_wavelet_coherence, cwt, compute_partial_coherence
from nnsa.containers.time_series import TimeSeries
from nnsa.cwt.utils import compute_moi, compute_coi
from nnsa.parameters.parameters import ClassWithParameters, Parameters
from nnsa.feature_extraction.result import ResultBase
from nnsa.stats.surrogates import compute_surrogate_fun, precompute_surrogate_data, compute_surrogate
from nnsa.utils.arrays import interp_nan, slice_along_axis, cummean, moving_mean
from nnsa.utils.config import HORIZONTAL_RULE
import matplotlib.pyplot as plt
__all__ = [
'CWT',
'CWTResult',
'WaveletCoherence',
'WaveletCoherenceResult',
'WaveletResult',
'convert_freq_scale',
'revert_freq_scale',
]
from nnsa.utils.event_detections import get_onsets_offsets
from nnsa.utils.plotting import maximize_figure, shade_axis
from nnsa.utils.conversions import convert_time_scale, revert_time_scale
from nnsa.utils.segmentation import get_all_segments, segment_generator
[docs]class CWT(ClassWithParameters):
"""
Class for computing continuous wavelet transorm.
Main method: compute_wavelet_coherence() or process().
Args:
see nnsa.ClassWithParameters.
Examples:
>>> np.random.seed(43)
>>> N = 1024
>>> fs = 1
>>> t = np.arange(0, N, 1/fs)
>>> f_cos = 0.1
>>> x = np.cos(2*np.pi*t*f_cos) + np.random.rand(N)
>>> cwt = CWT()
>>> result = cwt.wct(x, fs=1, verbose=0)
>>> print(type(result).__name__)
CWTResult
>>> result.W.shape
(76, 1024)
# Print most prominent frequency (should be close to f_cos).
>>> result.freqs[np.argmax(result.wps())].round(1)
0.1
"""
[docs] @staticmethod
def default_parameters():
"""
Return the default parameters.
Returns:
(nnsa.Parameters): a default set of parameters for the object.
"""
# Parameters for CWT computation, see cwt().
cwt_kwargs = dict(
wavelet=DEFAULT_WAVELET,
coimode='moi',
)
pars = {
'cwt_kwargs': cwt_kwargs,
}
return Parameters(**pars)
[docs] def cwt(self, x, fs, time_offset=0, label=None, verbose=1):
"""
Compute continuous wavelet transform of x.
Args:
x (array-like): 1D signal array.
fs (float): sample frequency of x and y in Hz.
time_offset (float, optional): time offset in seconds (i.e. time of the first sample).
Defaults to 0.
label (tuple, optional): label for the output.
verbose (int, optional): verbose level.
Defaults to 1.
Returns:
(nnsa.CWTResult): object containing the result.
"""
if verbose > 0:
print(HORIZONTAL_RULE)
print('Running cwt with parameters:')
print(self.parameters)
if label is None:
# See if we passed an object that has labels to create the label.
if hasattr(x, 'label'):
label = x.label
# As array.
x = np.asarray(x).squeeze()
# Extract some parameters.
cwt_kwargs = self.parameters['cwt_kwargs']
# Convert to numpy arrays.
x = np.asarray(x)
nan_mask_x = np.isnan(x)
# Compute cwt.
W, scales, freqs, insidecoi = cwt(
x=x, dt=1/fs, verbose=verbose, **cwt_kwargs)
result = CWTResult(W=W, scales=scales,
wavelet=str(cwt_kwargs.get('wavelet', str(DEFAULT_WAVELET))),
insidecoi=insidecoi,
freqs=freqs, fs=fs,
signal=x,
nan_mask=nan_mask_x,
label=label,
time_offset=time_offset,
algorithm_parameters=self.parameters)
return result
[docs] def process(self, *args, **kwargs):
return self.cwt(*args, **kwargs)
[docs]class CWTResult(ResultBase):
"""
High-level interface for processing continuous wavelet transform as computed by nnsa.CWT().
Args:
W (np.ndarray): 2D array with wavelet coefficients (n_freqs, n_time).
scales (np.ndarray): scales corresponding to the first axis of W.
wavelet (str): string indicating whihc wavelet was used. E.g. 'Morlet(6)'.
insidecoi (np.ndarray): boolean array with same shape as W that is True at locations in
the cone of influence (COI).
freqs (np.ndarray): frequencies in Hz corresponding to the first axis of W.
algorithm_parameters (nnsa.Parameters): see ResultBase.
signal (np.ndarray, optional): array of shape (n_time) containing the time signal on which the
CWT was computed.
Defaults to None.
nan_mask (np.ndarray, optional): boolean array of shape (n_time) containing True values
where the signal originally contained nans which where interpolated prior to
wavelet transform. If given, the nan mask can be used to exclude values at time
points that where nan originally.
Defaults to None.
label (tuple, optional): label for the signal.
Defaults to 'CWT'.
data_info (str, optional): see ResultBase.
segment_start_times (np.ndarray, optional): see ResultBase.
segment_end_times (np.ndarray, optional): see ResultBase.
fs (flaot, optional): see ResultBase.
"""
def __init__(self, W, scales, wavelet, insidecoi, freqs, *args,
signal=None, nan_mask=None, label='CWT', **kwargs):
super().__init__(*args, **kwargs)
# Check inputs.
if W.shape[0] != len(freqs):
raise ValueError('Lengths of `W` and `freqs` should be equal.')
if signal is not None:
signal = np.asarray(signal).squeeze()
if signal.shape != (W.shape[1],):
raise ValueError('`signal` should have shape {}. Got {}.'
.format((W.shape[1],), signal.shape))
if nan_mask is not None:
nan_mask = np.asarray(nan_mask).squeeze()
if nan_mask.shape != (W.shape[1],):
raise ValueError('`nan_mask` should have shape {}. Got {}.'
.format((W.shape[1],), nan_mask.shape))
# Store variables that are not already stored by the parent class (ResultBase).
self.W = W
self.scales = scales
self._wavelet = wavelet
self.insidecoi = insidecoi
self.freqs = freqs
self.signal = signal
self.nan_mask = nan_mask
self.label = label
@property
def num_segments(self):
"""
Return the number of segments.
Returns:
(int): number of segments/samples.
"""
return self.W.shape[-1]
@property
def Wp(self):
"""
Return the wavelet power.
The power is scaled by the inverse of the scale, as proposed by Liu et al. 2007.
Returns:
Wp (np.ndarray): wavelet power (same size as self.W).
"""
return 1/self.scales.reshape(-1, 1) * np.abs(self.W)**2
@property
def wavelet(self):
return get_wavelet(self._wavelet)
@property
def dj(self):
dj_all = np.diff(np.log2(self.scales))
dj = np.mean(dj_all)
if np.any(np.abs(dj_all - dj) > 1e-10):
raise ValueError('Non equidistant log2 spacing!')
return dj
@property
def dt(self):
return 1/self.fs
[docs] def wps(self):
"""
Compute wavelet power spectrum: the wavelet power averaged across time.
Returns:
wps (np.ndarray): 1D array with powers.
freqs (np.ndarray): 1D arrays with frequencies corresponding to wps.
"""
freqs = self.freqs
Wp = self.Wp
wps = np.nanmean(Wp, axis=-1)
return wps, freqs
[docs] def plot(self, signal_kwargs=None, **kwargs):
"""
General plot of the results.
"""
if signal_kwargs is None:
signal_kwargs = {}
x = self.signal
t = self.time
fig, axes = plt.subplots(2, 1, sharex='all',
constrained_layout=True)
axes = np.reshape([axes], -1)
maximize_figure()
# Plot time signal.
plt.sca(axes[0])
plt.plot(t, x, **signal_kwargs)
plt.ylabel('Signal (a.u.)')
# Plot frequency domain.
self.plot_scalogram(ax=axes[1], **kwargs)
[docs] def plot_scalogram(self, ax=None, **kwargs):
"""
Plot time-frequency scalogram.
"""
freqs = self.freqs
Wp = self.Wp
t = self.time
if ax is None:
# Current axis.
ax = plt.gca()
else:
plt.sca(ax)
plot_scalogram(t, freqs, Wp, insidecoi=self.insidecoi,
time_scale='seconds', ax=ax, **kwargs)
plt.title(self.label)
[docs] def to_wavelet_power(self):
"""
Converts to scale-normalized power and phase.
Returns:
result (WaveletResult): WaveetResult instance conatining normailized wavelet power.
"""
P = self.Wp
A = np.angle(self.W)
# Save as WaveletResult.
result = WaveletResult(
P=P, A=A, freqs=self.freqs, fs=self.fs,
wavelet=self.wavelet,
insidecoi=self.insidecoi,
name='Wavelet power',
signals=(self.signal),
labels=[self.label],
time_offset=self.time_offset,
algorithm_parameters=self.algorithm_parameters)
return result
[docs]class WaveletCoherence(ClassWithParameters):
"""
Class for computing wavelet coherence.
Main method: compute_wavelet_coherence() or process().
Args:
see nnsa.ClassWithParameters.
Examples:
>>> np.random.seed(43)
>>> N = 1024
>>> fs = 1
>>> t = np.arange(0, N, 1/fs)
>>> f_mutual = 0.1
>>> x = np.cos(2*np.pi*t*f_mutual) + np.random.rand(N)
>>> y = np.sin(2*np.pi*t*f_mutual) + np.random.rand(N)
>>> wc = WaveletCoherence(surrogates={'n_surrogates': 0})
>>> result = wc.wct(x, y, fs=1, verbose=0)
>>> print(type(result).__name__)
WaveletCoherenceResult
>>> result.P.shape
(76, 1024)
# Print most prominent coupled frequency (should be close to the mutual frequency).
>>> result.freqs[np.argmax(np.nanmean(result.P, axis=1))].round(1)
0.1
"""
[docs] @staticmethod
def default_parameters():
"""
Return the default parameters.
Returns:
(nnsa.Parameters): a default set of parameters for the object.
"""
# Parameters for compute_wavelet_coherence, see documentation of compute_wavelet_coherence().
cwt_kwargs = dict(
wavelet=DEFAULT_WAVELET,
coimode='moi',
)
# Parameters for surrogate analysis. Set n_surrogates to 0 if no surrogate computations are desired.
# See also nnsa.stats.surrogates.compute_surrogate_fun().
surrogates = Parameters(**{
# Number of surrogates to compute. Set to 0 to not compute any surrogate values:
'n_surrogates': 100,
# How to generate surrogates (see nnsa.stats.surrogates.compute_surrogate() for options):
'how': 'IAAFT',
# Seed for the random generator:
'seed': 43,
})
pars = {
'cwt_kwargs': cwt_kwargs,
'surrogates': surrogates,
}
return Parameters(**pars)
[docs] def wavelet_coherence(self, x, y, fs, time_offset=0, labels=None, verbose=1):
"""
Compute wavelet-based coherence between x and y.
Args:
x (array-like): 1D signal array.
y (array_like): 1D signal array. Must have the same length and sample frequency as `x`.
fs (float): sample frequency of x and y in Hz.
time_offset (float, optional): time offset in seconds (i.e. time of the first sample).
Defaults to 0.
labels (tuple, optional): labels for the output.
If labels is not given and TimeSeries are passed, infers the labels from these objects.
Otherwise, if labels is None, uses default labels ('x', 'y').
Defaults to None.
verbose (int, optional): verbose level.
Defaults to 1.
Returns:
(nnsa.WaveletResult): object containing the result.
"""
if verbose > 0:
print(HORIZONTAL_RULE)
print('Running wct with parameters:')
print(self.parameters)
if labels is None:
# See if we passed an object that has labels to create the label.
if hasattr(x, 'label') and hasattr(y, 'label'):
labels = (x.label, y.label)
else:
labels = ('x', 'y')
else:
if len(labels) != 2:
raise ValueError('Did not get 2 labels for the signals. Got labels={}.'
.format(labels))
# As array.
x = np.asarray(x).squeeze()
y = np.asarray(y).squeeze()
# Extract some parameters.
cwt_kwargs = self.parameters['cwt_kwargs']
surrogates_kwargs = self.parameters['surrogates']
cwt_kwargs['dt'] = 1 / fs
cwt_kwargs['verbose'] = verbose
compute_surrogates = (surrogates_kwargs is not None and surrogates_kwargs['n_surrogates'] > 0)
# Compute coherence.
P, A, freqs, insidecoi = compute_wavelet_coherence(x, y, **cwt_kwargs)
# Compute surrogate coherences.
if compute_surrogates:
cwt_kwargs['verbose'] = 0
fun = lambda xs, ys: compute_wavelet_coherence(xs, ys, **cwt_kwargs)[0]
# Compute WCT surrogates.
P_sur = compute_surrogate_fun(inputs=(x, y), fun=fun, verbose=verbose,
**dict(surrogates_kwargs, return_af_errors=False))
# Compute significance and mean surrogate coupling value.
significance = np.nanmean(P > P_sur, axis=0)
P_surrogates = np.nanmean(P_sur, axis=0)
else:
significance = None
P_surrogates = None
extra = None
# Save as WaveletResult.
result = WaveletResult(
P=P, A=A, freqs=freqs, fs=fs,
wavelet=cwt_kwargs['wavelet'],
insidecoi=insidecoi,
extra=extra,
name='Wavelet coherence',
significance=significance,
P_surrogates=P_surrogates,
signals=(x, y),
labels=labels,
time_offset=time_offset,
algorithm_parameters=self.parameters)
return result
[docs] def wct(self, *args, **kwargs):
"""
Alias for self.wavelet_coherence().
"""
return self.wavelet_coherence(*args, **kwargs)
[docs] def pct(self, x, y, z, fs, time_offset=0, labels=None, verbose=1):
"""
Compute partial wavelet coherence: coherence between x and y, while ignoring common correlation with z.
Args:
x (array-like): 1D signal array.
y (array_like): 1D signal array. Must have the same length and sample frequency as `x`.
z (array_like or list): 1D signal array. Must have the same length and sample frequency as `x`.
Can also be a list of 1D signal arrays, in which case the confounding relation from each of these
signals will be removed.
fs (float): sample frequency of x and y in Hz.
time_offset (float, optional): time offset in seconds (i.e. time of the first sample).
Defaults to 0.
labels (tuple, optional): labels for the output, see WaveletCoherence.wct().
If labels is not given and TimeSeries are passed, infers the labels from these objects.
Otherwise, if labels is None, uses default labels.
Defaults to None.
verbose (int, optional): verbose level.
Defaults to 1.
Returns:
(nnsa.WaveletResult): object containing the result.
"""
if verbose > 0:
print(HORIZONTAL_RULE)
print('Running pct with parameters:')
print(self.parameters)
if not isinstance(z, (list, tuple)):
z_all = [z]
else:
z_all = list(z)
# Labels.
sig_all = [x, y] + z_all
if labels is None:
# See if we passed an object that has labels to create the label.
labels = []
for ii, sig in enumerate(sig_all):
if hasattr(sig, 'label'):
lab = sig.label
else:
lab = ''
labels.append(lab)
if len(labels) != len(sig_all):
raise ValueError('Did not get labels for all {} signals. Got labels={}.'
.format(len(sig_all), labels))
# Add prefix x, y, z to labels.
default_labels = ['x', 'y', 'z']
for ii in range(len(labels)):
prefix = default_labels[ii] if ii < 3 else f'z{ii-1}'
labels[ii] = f'{prefix}:' + labels[ii]
# To arrays.
x = np.asarray(x).squeeze()
y = np.asarray(y).squeeze()
if isinstance(z, (tuple, list)):
# z are multiple signals, collect to list.
z_all = [np.asarray(zi).squeeze() for zi in z]
elif np.asarray(z).squeeze().shape == x.shape:
# z is a single signal.
z_all = [np.asarray(z).squeeze()]
else:
raise TypeError('`z` should be an array or a list of arrays. Got a {}.'.format(type(z)))
# Extract some parameters.
cwt_kwargs = self.parameters['cwt_kwargs']
surrogates_kwargs = self.parameters['surrogates']
cwt_kwargs['dt'] = 1 / fs
cwt_kwargs['verbose'] = verbose
compute_surrogates = (surrogates_kwargs is not None and surrogates_kwargs['n_surrogates'] > 0)
# Compute coherence.
P, A, freqs, insidecoi = compute_partial_coherence(x=x, y=y, z=z_all, **cwt_kwargs)
# Compute surrogate coherences.
if compute_surrogates:
cwt_kwargs['verbose'] = 0
fun = lambda *inp: compute_partial_coherence(inp[0], inp[1], inp[2:], **cwt_kwargs)[0]
# Compute WCT surrogates.
P_sur = compute_surrogate_fun(inputs=(x, y) + tuple(z_all), fun=fun, verbose=verbose,
**dict(surrogates_kwargs, return_af_errors=False))
# Compute significance and mean surrogate coupling value.
significance = np.nanmean(P > P_sur, axis=0)
P_surrogates = np.nanmean(P_sur, axis=0)
else:
significance = None
P_surrogates = None
extra = None
# Save as WaveletResult.
result = WaveletResult(
P=P, A=A, freqs=freqs, fs=fs,
wavelet=cwt_kwargs['wavelet'],
insidecoi=insidecoi,
extra=extra,
name='Partial wavelet coherence',
significance=significance,
P_surrogates=P_surrogates,
signals=(x, y) + tuple(z_all),
labels=labels,
time_offset=time_offset,
algorithm_parameters=self.parameters)
return result
[docs]class WaveletResult(ResultBase):
"""
High-level interface for processing wavelet maps. Can handle CWT, WCT, PCT.
Args:
P (np.ndarray): 2D array with wavelet power or squared coherence values with size (n_freqs, n_time).
E.g. for a CWT: P = 1/scales.reshape(-1, 1) * np.abs(W)**2
A (np.ndarray): 2D array with wavelet phase angles with size (n_freqs, n_time).
E.g. for a CWT: A = np.angle(W)
freqs (np.ndarray): frequencies in Hz corresponding to the first axis of P and A.
fs (float): sampling frequency (Hz).
wavelet (Mother, str): Mother object or string specifying which wavelet was used (e.g. 'Morlet(6)').
insidecoi (np.ndarray, optional): boolean array of shape (n_freqs, n_time) containing True
values at the cone/mask of influence. True pixels correspond to regions affected by adge/artefact effects.
Defaults to None.
name (str, optional): name for the result, e.g. 'WCT EEG-NIRS'.
algorithm_parameters (nnsa.Parameters, optional): see ResultBase.
significance (np.ndarray, optional): array with same shape as P containing significance values for
P between 0 - 1 (0 not significant, 1 significant).
If not specified, methods that require significance values will raise errors.
Defaults to None.
P_surrogates (np.ndarray, optional): array with same size as P, with mean surrogate values.
Defaults to None.
signals (np.ndarray, optional): array of shape (n_signals, n_time) containing an arbitrary number
of signals that were used in the computation (for visualization purposes).
Defaults to None.
extra (dict, optional): dict with optional extra arrays. E.g. to save some other surrogate values.
labels (tuple, optional): tuple of length n_signals, with a label for each element in `signals`,
e.g. ('EEG', 'NIRS').
Defaults to None.
data_info (str, optional): see ResultBase.
segment_start_times (np.ndarray, optional): see ResultBase.
segment_end_times (np.ndarray, optional): see ResultBase.
time_offset (float, optional): see ResultBase.
"""
def __init__(self, P, A, freqs, fs, wavelet, insidecoi=None, name=None,
algorithm_parameters=None,
significance=None, P_surrogates=None,
signals=None, labels=None,
extra=None,
data_info=None,
segment_start_times=None, segment_end_times=None,
time_offset=0):
super().__init__(algorithm_parameters=algorithm_parameters, data_info=data_info,
segment_start_times=segment_start_times, segment_end_times=segment_end_times,
fs=fs, time_offset=time_offset)
# Check inputs.
if P.shape != A.shape:
raise ValueError('Shapes of `P` and `freqs` should be equal. Got {} and {}.'
.format(P.shape, A.shape))
if len(P) != len(freqs):
raise ValueError('Lengths of `P` and `freqs` should be equal. Got {} and {}.'
.format(len(P), len(freqs)))
if insidecoi is not None:
insidecoi = np.asarray(insidecoi).squeeze()
if insidecoi.shape != P.shape:
raise ValueError('`insidecoi` should have shape {}. Got shape {}.'
.format(P.shape, insidecoi.shape))
if name is None:
name = ''
if significance is not None:
if significance.shape != P.shape:
raise ValueError('`significance` does not have the same shape as `P`.')
if P_surrogates is not None:
if P_surrogates.shape != P.shape:
raise ValueError('`P_surrogates` does not have the same shape as `P`.')
if signals is not None:
signals = np.asarray(signals).squeeze()
if signals.ndim == 1 and not isinstance(signals[0], np.ndarray):
signals = np.expand_dims(signals, 0)
if isinstance(labels, str):
labels = (labels,)
if labels is not None and len(labels) != len(signals):
raise ValueError('Did not get the same amount of labels for the signals. Got {} labels and {} signals.'
.format(len(labels), len(signals)))
if labels is None:
labels = ['{}'.format(i) for i in range(len(signals))]
# Store variables that are not already stored by the parent class (ResultBase).
self.P = P
self.A = A
self.freqs = freqs
self.insidecoi = insidecoi
self._wavelet = str(wavelet)
self.name = name
self.significance = significance
self.P_surrogates = P_surrogates
self.signals = signals
self.labels = labels
self.extra = extra
# Check whether the wavelet can be loaded (raises an error otherwise).
_ = self.wavelet
@property
def num_segments(self):
"""
Return the number of segments/samples.
Returns:
(int): number of segments/samples.
"""
return self.P.shape[-1]
@property
def dj(self):
"""
Return 2log-space scale spacing.
"""
dj_all = np.diff(np.log2(self.scales))
dj = np.mean(dj_all)
if np.any(np.abs(dj_all - dj) > 1e-10):
raise ValueError('Non equidistant log2 spacing!')
return dj
@property
def dt(self):
"""
Return sampling period.
"""
return 1/self.fs
@property
def W(self):
"""
Return the complex valued wavelet coefficients.
"""
return np.sqrt(self.P) * np.exp(1j*self.A)
@property
def wavelet(self):
"""
Return wavelet Mother object.
"""
return get_wavelet(self._wavelet)
@property
def scales(self):
"""
Return the wavelet scales equivalent to self.freqs.
"""
return self.wavelet.freq2scal(self.freqs)
[docs] def compute_coi(self):
"""
Return mask for the cone of influence at edges.
"""
n0 = self.P.shape[-1]
dt = self.dt
freqs = self.freqs
wavelet = self.wavelet
insidecoi = compute_coi(n0=n0, dt=dt, freqs=freqs, wavelet=wavelet)
return insidecoi
[docs] def downsample(self, ratio=2, how='decimate', which='time', inplace=False):
"""
Downsample the spectra along the frequency (axis=0) or time (axis=1) axis.
Does also downsample the signals.
Args:
ratio (int): downsampling ratio.
how (str, optional): method for downsampling. Options:
- 'decimate': take every `ratio`th sample.
- 'mean': collapse `ratio` consecutive samples into the mean of those samples.
- 'median': collapse `ratio` consecutive samples into the median of those samples.
Defaults to 'decimate'.
which (str, optional): dimension which to downsample. Choose from:
- 'time' or 't': downsample time dimension.
- 'freq' or 'f': downsample frequency dimension.
- 'both': downsample both time and frequency dimensions.
Defaults to 'time.
inplace (bool, optional): if True, operates inplace, else a new object is returned.
Defaults to False.
Returns:
wcr (WaveletCoherenceResult): new WaveletCoherenceResult object containing the filtered signal
(if inplace is False).
"""
def _downsample(x, axis):
n = x.shape[axis]
if how == 'decimate':
slices = [slice(None)] * x.ndim
slices[axis] = slice(ratio, n, ratio)
xd = x[tuple(slices)]
else:
xs = get_all_segments(x, segment_length=ratio, overlap=0, fs=1, axis=axis)
# Swap axis: the segment axis becomes the new time/frequency axis.
xs = np.swapaxes(xs, 0, axis)
if how == 'mean':
xd = np.nanmean(xs, axis=0)
elif how == 'median':
xd = np.nanmedian(xs, axis=0)
else:
raise ValueError('Invalid parameter how="{}". Choose from {}.'
.format(how, ['decimate', 'mean', 'median']))
return xd
if inplace:
wr = self
else:
wr = copy.deepcopy(self)
which = which.lower()
if which == 'both':
wr.downsample(ratio=ratio, how=how, which='time', inplace=True)
wr.downsample(ratio=ratio, how=how, which='freq', inplace=True)
else:
if which in ['t', 'time']:
axis = -1
wr._fs /= ratio
elif which in ['f', 'freq', 'frequency']:
axis = -2
wr.freqs = wr.freqs[ratio::ratio]
else:
raise ValueError('Invalid parameter which="{}". Choose from {}.'.format(
which, ['time', 'freq']
))
attrs = ['P', 'insidecoi', 'significance', 'P_surrogates']
for at in attrs:
x = getattr(wr, at)
if x is not None:
xd = _downsample(x=x, axis=axis)
setattr(wr, at, xd)
wr.insidecoi = wr.insidecoi.astype(bool)
# Angle is different. Convert to a vector in the complex plane, average the vactor and compute the angle.
Av = np.exp(1j*wr.A)
Avd = _downsample(Av, axis=axis)
wr.A = np.angle(Avd)
# if axis == -1:
# # Downsample time -> downsample signals.
# if wr.signals is not None:
# wr.signals = _downsample(x=wr.signals, axis=-1)
if not inplace:
return wr
[docs] def get_downsampled(self, which='mean', phase_center=None, phase_tol=np.pi/4,
time_start=None, time_stop=None, time_win=6*3600, time_step=None,
freq_low=None, freq_high=None, freq_scale='Hz', freq_win=None, freq_step=None,
mask_coi=True, max_nan_frac=0.5):
if which == 'frac_phase' and phase_center is None:
warnings.warn('Select a phase (specify `phase_center`) when which="frac_phase".')
time = self.time
if time_start is None:
# From beginning of data.
time_start = time[0]
if time_stop is None:
# From beginning of data.
time_stop = time[-1]
if time_step is None:
# No overlap.
time_step = time_win
if freq_win is None:
# Average per decade.
freq_win = 1 / self.dj
if freq_step is None:
# No overlap.
freq_step = freq_win
freq_step = int(freq_step)
# Extract data.
freqs, _, P, significance, A, _, insidecoi = self._extract_data(
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
mask_coi=mask_coi)
phase_mask = self.get_phase_mask(A, phase_center=phase_center, phase_tol=phase_tol)
if which == 'mean':
# Average P.
P[~phase_mask] = np.nan # Ignore other phases when focussing on one phase.
Y = P
elif which == 'frac_sig':
# Percentage of significant coherence.
sig_mask = self.get_significance_mask(significance).astype(float)
sig_mask[~phase_mask] = 0 # Significant other phases count as insignificant when focussing on one phase.
sig_mask[np.isnan(significance)] = np.nan # Ignore COI.
Y = sig_mask
elif which == 'frac_phase':
# Fraction of certain phase.
phase_mask = phase_mask.astype(float)
phase_mask[np.isnan(A)] = np.nan # Ignore COI.
Y = phase_mask
else:
raise ValueError('Invalid choice "{}" for `which` parameter.'.format(which))
all_start_times = np.arange(time_start, time_stop - time_win / 2, time_step)
all_freqs_start = np.arange(0, len(freqs) - freq_win / 2, freq_step)
# Allocate output.
out_shape = (len(all_freqs_start), len(all_start_times))
out = np.full(out_shape, fill_value=np.nan)
# Loop over time.
for i_t, t in enumerate(all_start_times):
# Find time indices.
t_idx_start = np.argmin(np.abs(time - t))
t_idx_end = np.argmin(np.abs(time - (t + time_win)))
# Skip if segment is too short.
if (t_idx_end - t_idx_start) / self.fs < 0.5 * time_win:
continue
# Loop over scales.
for j_f, f in enumerate(all_freqs_start):
f_idx_start = int(f)
f_idx_end = int(f_idx_start + freq_step)
# Extract patch and compute mean value.
patch = Y[f_idx_start: f_idx_end, t_idx_start: t_idx_end]
if np.isnan(patch).mean() > max_nan_frac:
val = np.nan
else:
val = np.nanmean(patch)
out[j_f, i_t] = val
freqs_out = revert_freq_scale(freqs[(all_freqs_start + freq_win / 2).astype(int)], freq_scale)
time_out = all_start_times + time_win / 2
return out, freqs_out, time_out
[docs] def get_cum_freq_profile(self, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
which='mean', phase_center=None, phase_tol=np.pi/4,
mask_coi=True):
"""
Return the frequency profile.
Args:
freq_low (float, optional): low frequency to include (in the unit of `freq_scale`).
If None, the minimum available frequency is used.
Defaults to None.
freq_high (float, optional): high frequency to include (in the unit of `freq_scale`).
If None, the maximum available frequency is used.
Defaults to None.
freq_scale (str, optional): frequency scale. Use 'Hz', 'mHz', 'seconds', 'minutes'.
Defaults to 'Hz'.
time_start (float, optional): start time to include (in the unit of `time_scale`).
If None, the first available timestamp is used.
Defaults to None.
time_stop (float, optional): end time to include (in the unit of `time_scale`).
If None, the last available timestamp is used.
Defaults to None.
time_scale (str, optional): time scale, see nnsa.data.plotting.convert_time_scale().
Defaults to 'seconds'.
which (str, optional): what to plot as a function of frequency. Choose from:
- 'mean': mean of P.
- 'mean_sig': mean significance.
- 'frac_sig': fraction of significant P.
- 'frac_phase': occurence fraction of the specified phase.
phase_center and phase_tol: select phase. See self.get_phase_mask. If None, include all phases.
mask_coi (bool, optional): ignore samples that are in the COI.
Defaults to True.
Returns:
freqs (np.ndarray): frequencies.
freq_profile (np.ndarray): time averaged value per frequency.
"""
# Extract data.
freqs, time, P, significance, A, P_surrogates, insidecoi = self._extract_data(
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
mask_coi=mask_coi)
phase_mask = self.get_phase_mask(A, phase_center=phase_center, phase_tol=phase_tol)
# I expect to see "RuntimeWarning: Mean of empty slice" in this block
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
if which == 'mean':
# Average P.
P[~phase_mask] = np.nan # Ignore other phases when focussing on one phase.
freq_profile = cummean(P, axis=-1)
elif which == 'mean_sig':
# Average significance.
significance[~phase_mask] = np.nan # Ignore other phases when focussing on one phase.
freq_profile = cummean(significance, axis=-1)
elif which == 'frac_sig':
# Percentage of significant P.
sig_mask = self.get_significance_mask(significance).astype(float)
sig_mask[~phase_mask] = 0 # Significant other phases count as insignificant when focussing on one phase.
sig_mask[np.isnan(significance)] = np.nan # Ignore COI.
freq_profile = cummean(sig_mask, axis=-1)
elif which == 'frac_phase':
# Fraction of certain phase.
phase_mask = phase_mask.astype(float)
phase_mask[np.isnan(A)] = np.nan # Ignore COI.
freq_profile = cummean(phase_mask, axis=-1)
else:
raise ValueError('Invalid choice "{}" for `which` parameter.'.format(which))
return freqs, time, freq_profile
[docs] def get_freq_profile(self, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
how='mean', phase_center=None, phase_tol=np.pi / 4,
mask_coi=True, max_nan=None, **kwargs):
"""
Return the frequency profile.
Args:
See self.get_profile().
Returns:
freqs (np.ndarray): frequencies in the specified scale.
freq_profile (np.ndarray): time averaged value per frequency.
"""
freqs, freq_profile = self.get_profile(which='freq',
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
how=how, phase_center=phase_center, phase_tol=phase_tol,
mask_coi=mask_coi, max_nan=max_nan, **kwargs)
return freqs, freq_profile
[docs] def get_phase_mask(self, A=None, phase_center=None, phase_tol=np.pi/4):
"""
Return a boolean mask based on the phase.
Args:
A (np.ndarray, optional): array with phases. If None, takes self.A.
phase_center (float, optional): center phase.
phase_tol (float, optional): tolerance around center phase. Phase will be included if
larger than (phase_center - phase_tol) and smaller than (phase_center + phase_tol).
Returns:
phase_mask (np.ndarray): boolean mask for `A`.
"""
if A is None:
A = self.A
if phase_center is None:
return np.full(A.shape, fill_value=True)
phase_diff = np.abs(np.angle(np.exp(1j * (A - phase_center))))
# I expect to see "RuntimeWarning: invalid value encountered in less_equal" in this block
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
phase_mask = phase_diff <= phase_tol
return phase_mask
[docs] def get_profile(self, which, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
how='mean', phase_center=None, phase_tol=np.pi / 4,
mask_coi=True, max_nan=None):
"""
Return the time or frequency profile.
Args:
which (str): 'time' or ('freq', 'frequency', 'scale'). Determines
which profile is returned.
freq_low (float, optional): low frequency to include (in the unit of `freq_scale`).
If None, the minimum available frequency is used.
Defaults to None.
freq_high (float, optional): high frequency to include (in the unit of `freq_scale`).
If None, the maximum available frequency is used.
Defaults to None.
freq_scale (str, optional): frequency scale. Use 'Hz', 'mHz', 'seconds', 'minutes'.
Defaults to 'Hz'.
time_start (float, optional): start time to include (in the unit of `time_scale`).
If None, the first available timestamp is used.
Defaults to None.
time_stop (float, optional): end time to include (in the unit of `time_scale`).
If None, the last available timestamp is used.
Defaults to None.
time_scale (str, optional): time scale, see nnsa.data.plotting.convert_time_scale().
Defaults to 'seconds'.
how (str, optional): what to plot as a function of frequency. Choose from:
- 'mean': mean of P.
- 'median': median of P.
- 'mean_sig': mean significance.
- 'median_sig': median significance.
- 'frac_sig': fraction of significant P.
- 'frac_phase': occurence fraction of the specified phase.
phase_center and phase_tol: select phase. See self.get_phase_mask. If None, include all phases.
mask_coi (bool, optional): ignore samples that are in the COI.
Defaults to True.
Returns:
x (np.ndarray): time or frequencies in the specified scale.
profile (np.ndarray): time or frequency profile.
"""
# Extract data.
freqs, time, P, significance, A, P_surrogates, insidecoi = self._extract_data(
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
mask_coi=False)
if which in ['time']:
axis = 0
x = time
elif which in ['freq', 'frequency', 'scale']:
axis = -1
x = freqs
else:
raise ValueError('Invalid parameter which="{}". Choose from: {}.'
.format(which, ['time', 'freq']))
phase_mask = self.get_phase_mask(A, phase_center=phase_center, phase_tol=phase_tol)
if mask_coi and insidecoi is not None:
P[insidecoi] = np.nan
A[insidecoi] = np.nan
if significance is not None:
significance[insidecoi] = np.nan
if P_surrogates is not None:
P_surrogates[insidecoi] = np.nan
# I expect to see "RuntimeWarning: Mean of empty slice" in this block
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
if how == 'mean':
# Average P.
P[~phase_mask] = np.nan # Ignore other phases when focussing on one phase.
profile = np.nanmean(P, axis=axis)
elif how == 'median':
# Median P.
P[~phase_mask] = np.nan # Ignore other phases when focussing on one phase.
profile = np.nanmedian(P, axis=axis)
elif how == 'mean_sig':
# Average significance.
significance[~phase_mask] = np.nan # Ignore other phases when focussing on one phase.
profile = np.nanmean(significance, axis=axis)
elif how == 'median_sig':
# Median significance.
significance[~phase_mask] = np.nan # Ignore other phases when focussing on one phase.
profile = np.nanmedian(significance, axis=axis)
elif how == 'frac_sig':
# Percentage of significant P.
sig_mask = self.get_significance_mask(significance).astype(float)
sig_mask[~phase_mask] = 0 # Significant other phases count as insignificant when focussing on one phase.
sig_mask[np.isnan(significance)] = np.nan # Ignore COI.
profile = np.nanmean(sig_mask, axis=axis)
elif how == 'frac_phase':
# Fraction of certain phase.
phase_mask = phase_mask.astype(float)
phase_mask[np.isnan(A)] = np.nan # Ignore COI.
profile = np.nanmean(phase_mask, axis=axis)
elif how == 'complex':
# Average the complex coherence. Will return a complex value with magnitude which is large in
# case of phase locking and a corresponding dominant phase.
C = np.sqrt(P) * np.exp(A * 1j)
profile = np.nanmean(C, axis=axis)
elif how == 'complex_sig':
# Average the complex coherence. Ignore non significant.
sig_mask = self.get_significance_mask(significance).astype(bool)
P[~sig_mask] = np.nan
C = np.sqrt(P) * np.exp(A * 1j)
profile = np.nanmean(C, axis=axis)
elif how == 'complex_sig_frac':
# Average complex coherence times fraction of significant coherence.
x, p_complex = self.get_profile(which=which,
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
how='complex_sig', phase_center=phase_center, phase_tol=phase_tol,
mask_coi=mask_coi)
x, p_frac = self.get_profile(which=which,
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
how='frac_sig', phase_center=phase_center, phase_tol=phase_tol,
mask_coi=mask_coi)
return x, p_complex * p_frac
else:
raise ValueError('Invalid parameter how="{}".'.format(how))
if max_nan is not None:
nan_frac = np.nanmean(insidecoi, axis=axis)
profile[nan_frac > max_nan] = np.nan
assert len(x) == len(profile)
return x, profile
[docs] def get_significance_mask(self, significance=None, alpha=0.05):
"""
Return mask where the data is significant.
"""
if significance is None:
significance = self.significance
# I expect to see "RuntimeWarning: invalid value encountered in greater_equal" in this block
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
mask = (significance >= (1 - alpha)).astype(bool)
return mask
[docs] def get_smoothed(self, window, stepsize, time_scale='seconds',
which='abs_sq', mask_coi=True, max_nan=1,
phase_center=None, phase_tol=np.pi/4):
"""
Return smoothed values. Choose from abs_sq, abs, real, imag, frac_sig.
"""
P = self.P.copy()
A = self.A.copy()
significance = self.significance.copy()
if which == 'frac_sig':
phase_mask = self.get_phase_mask(self.A, phase_center=phase_center, phase_tol=phase_tol)
x = self.get_significance_mask(significance).astype(float)
x[~phase_mask] = 0 # Significant other phases count as insignificant when focussing on one phase.
elif which == 'abs_sq':
x = P
elif which == 'abs':
x = np.sqrt(P)
elif which == 'real':
x = np.real(self.to_complex())
elif which == 'imag':
x = np.imag(self.to_complex())
elif which == 'complex':
# Average the complex coherence. Will return a complex value.
x = np.sqrt(P) * np.exp(A * 1j)
elif which == 'complex_sig':
# Average the complex coherence. Ignore non significant.
sig_mask = self.get_significance_mask(significance).astype(bool)
P[~sig_mask] = np.nan
x = np.sqrt(P) * np.exp(A * 1j)
elif which == 'complex_sig_frac':
# Average complex coherence times fraction of significant coherence.
time, freqs, cs = self.get_smoothed(
window, stepsize, time_scale=time_scale,
which='complex_sig', mask_coi=mask_coi, max_nan=max_nan,
phase_center=phase_center, phase_tol=phase_tol)
time, freqs, sf = self.get_smoothed(
window, stepsize, time_scale=time_scale,
which='frac_sig', mask_coi=mask_coi, max_nan=max_nan,
phase_center=phase_center, phase_tol=phase_tol)
return time, freqs, cs*sf
else:
raise ValueError('Invalid parameter "{}" for `which`. Choose from {}.'
.format(which, ['abs_sq', 'abs', 'real', 'imag', 'frac_sig',
'complex', 'complex_sig', 'complex_sig_frac']))
if mask_coi:
x[self.insidecoi] = np.nan # Ignore COI.
window = revert_time_scale(window, time_scale=time_scale)
stepsize = revert_time_scale(stepsize, time_scale=time_scale)
seg_gen = segment_generator(x, segment_length=window,
overlap=window - stepsize, fs=self.fs, axis=-1)
time_gen = segment_generator(self.time, segment_length=window,
overlap=window - stepsize, fs=self.fs, axis=-1)
C = []
time = []
for t_seg, x_seg in zip(time_gen, seg_gen):
ci = np.nanmean(x_seg, axis=-1)
ti = (t_seg[0] + t_seg[-1]) / 2
nan_mask = np.mean(np.isnan(x_seg), axis=-1) > max_nan
ci[nan_mask] = np.nan
C.append(ci)
time.append(ti)
freqs = self.freqs
time = np.array(time)
C = np.array(C).T
return time, freqs, C
[docs] def get_time_profile(self, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
how='mean', phase_center=None, phase_tol=np.pi / 4,
mask_coi=True, max_nan=None, **kwargs):
"""
Return the frequency profile.
Args:
See self.get_profile().
Returns:
time (np.ndarray): time in the specified scale.
time_profile (np.ndarray): frequency averaged value per time.
"""
time, time_profile = self.get_profile(which='time',
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
how=how, phase_center=phase_center, phase_tol=phase_tol,
mask_coi=mask_coi, max_nan=max_nan, **kwargs)
return time, time_profile
[docs] def match_shapes(self):
"""
Match shapes of signals (truncate longest one) (inplace).
"""
len_all = [s.shape[-1] for s in self.signals]
min_len = min(len_all)
for i, x in enumerate(self.signals):
self.signals[i] = slice_along_axis(x, axis=-1, stop=min_len)
[docs] def plot(self, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
scalogram_kwargs=None, time_profile_kwargs=None,
freq_profile_kwargs=None, signals_kwargs=None,
mask_coi=True, fig_kwargs=None):
"""
Plot the results.
Args:
freq_low (float, optional): low frequency to include (in the unit of `freq_scale`).
If None, the minimum available frequency is used.
Defaults to None.
freq_high (float, optional): high frequency to include (in the unit of `freq_scale`).
If None, the maximum available frequency is used.
Defaults to None.
freq_scale (str, optional): frequency scale. Use 'Hz', 'mHz', 'seconds', 'minutes'.
Defaults to 'Hz'.
time_start (float, optional): start time to include (in the unit of `time_scale`).
If None, the first available timestamp is used.
Defaults to None.
time_stop (float, optional): end time to include (in the unit of `time_scale`).
If None, the last available timestamp is used.
Defaults to None.
time_scale (str, optional): time scale, see nnsa.data.plotting.convert_time_scale().
Defaults to 'seconds'.
scalogram_kwargs (dict, optional): keyword arguments for self.plot_scalogram().
Defaults to None.
time_profile_kwargs (dict, optional): keyword arguments for self.plot_time_profile().
Defaults to None.
freq_profile_kwargs (dict, optional): keyword arguments for self.plot_freq_profile().
Defaults to None.
signals_kwargs (dict, optional): keyword arguments for self.plot_signals().
Defaults to None.
mask_coi (bool, optional): mask samples that were are inside the cone of influence (COI).
Defaults to True.
Returns:
axes handles foe each subplot.
"""
if scalogram_kwargs is None:
scalogram_kwargs = dict()
if time_profile_kwargs is None:
time_profile_kwargs = dict()
if freq_profile_kwargs is None:
freq_profile_kwargs = dict()
if signals_kwargs is None:
signals_kwargs = dict()
if fig_kwargs is None:
fig_kwargs = dict()
# Create figure.
fig = plt.figure(**dict(tight_layout=True, **fig_kwargs))
maximize_figure()
if self.signals is None:
gs = plt.GridSpec(3, 4, figure=fig)
ax_scal = fig.add_subplot(gs[0:-1, :-1])
ax_time = fig.add_subplot(gs[-1, :-1], sharex=ax_scal)
ax_freq = fig.add_subplot(gs[0:-1, -1], sharey=ax_scal)
ax_sigs = None
else:
gs = plt.GridSpec(4, 4, figure=fig)
ax_scal = fig.add_subplot(gs[1:-1, :-1])
ax_time = fig.add_subplot(gs[-1, :-1], sharex=ax_scal)
ax_freq = fig.add_subplot(gs[1:-1, -1], sharey=ax_scal)
ax_sigs = fig.add_subplot(gs[0, :-1], sharex=ax_scal)
plt.setp(ax_sigs.get_xticklabels(), visible=False)
# ax_sigs.tick_params(axis='both', which='both', length=0)
plt.setp(ax_scal.get_xticklabels(), visible=False)
# ax_scal.tick_params(axis='both', which='both', length=0)
ax_freq.yaxis.tick_right()
ax_freq.yaxis.set_label_position("right")
# Plot original signals.
if self.signals is not None:
self.plot_signals(time_start=time_start, time_stop=time_stop, time_scale=time_scale,
ax=ax_sigs, **signals_kwargs)
ax_sigs.set_xlabel('') # Remove xlabel.
# Plot time profile.
self.plot_time_profile(freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
ax=ax_time, mask_coi=mask_coi, **time_profile_kwargs)
# Plot frequency profile.
self.plot_freq_profile(freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
ax=ax_freq, orientation='vertical', mask_coi=mask_coi,
**freq_profile_kwargs)
# Plot scalogram.
self.plot_scalogram(freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
mask_coi=mask_coi, ax=ax_scal, **scalogram_kwargs)
ax_scal.set_xlabel('') # Remove xlabel.
return ax_scal, ax_time, ax_freq, ax_sigs
[docs] def plot_complex_freq_profile(self, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
which='amp_phase', significant_only=False,
ax=None, mask_coi=True, **kwargs):
# Set current axis.
if ax is not None:
plt.sca(ax)
else:
ax = plt.gca()
which = which.lower()
# Compute frequency profile.
freqs, freq_profile = self.get_freq_profile(
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
how='complex' + ('_sig' if significant_only else ''),
mask_coi=mask_coi
)
plot_kwargs = dict({'linewidth': 3, 'color': 'k'}, **kwargs)
if 'amp' in which or 'abs' in which:
fpn = np.abs(freq_profile)
plt.semilogx(freqs, fpn, **dict(plot_kwargs, linestyle='-', label='Mean absolute coherence'))
plt.legend(loc='upper left')
plt.ylabel('Coherence (-)')
if 'real' in which:
fpr = np.real(freq_profile)
plt.semilogx(freqs, fpr, **dict(plot_kwargs, linestyle='--', label='Real coherence'))
plt.legend(loc='upper left')
plt.ylabel('Coherence (-)')
if 'imag' in which:
fpi = np.imag(freq_profile)
plt.semilogx(freqs, fpi, **dict(plot_kwargs, linestyle=':', label='Imaginary coherence'))
plt.legend(loc='upper left')
plt.ylabel('Coherence (-)')
if 'phase' in which:
fpa = np.angle(freq_profile) * 180 / np.pi
ylim = plt.ylim()
plt.pcolormesh(freqs, np.tile(ylim, [len(freqs), 1]).T, np.tile(fpa, [2, 1]), cmap='hsv',
alpha=1, shading='nearest', vmin=-180, vmax=180)
plt.ylim(ylim)
ax2 = plt.twinx(ax)
ax2.plot(freqs, fpa, **dict(plot_kwargs, linestyle='-', label='Angle', color='w'))
plt.legend(loc='upper right')
plt.ylabel('Angle (degrees)')
# Make legend text white.
for text in ax2.get_legend().get_texts():
text.set_color('w')
[docs] def plot_cum_freq_profile(self, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
which='mean', phase_center=None, phase_tol=np.pi/4,
ax=None, mask_coi=True, **kwargs):
"""
Plot the cumulative frequency profile (is a 2d image).
Args:
freq_low (float, optional): low frequency to include (in the unit of `freq_scale`).
If None, the minimum available frequency is used.
Defaults to None.
freq_high (float, optional): high frequency to include (in the unit of `freq_scale`).
If None, the maximum available frequency is used.
Defaults to None.
freq_scale (str, optional): frequency scale. Use 'Hz', 'mHz', 'seconds', 'minutes'.
Defaults to 'Hz'.
time_start (float, optional): start time to include (in the unit of `time_scale`).
If None, the first available timestamp is used.
Defaults to None.
time_stop (float, optional): end time to include (in the unit of `time_scale`).
If None, the last available timestamp is used.
Defaults to None.
time_scale (str, optional): time scale, see nnsa.data.plotting.convert_time_scale().
Defaults to 'seconds'.
which (str, optional): what to plot as a function of frequency. Choose from:
- 'mean': mean of P.
- 'mean_sig': mean of significance.
- 'frac_sig': fraction of significant P.
- 'frac_phase': occurence fraction of the specified phase.
phase_center and phase_tol: select phase. See self.get_phase_mask. If None, include all phases.
ax (plt.Axes, optional): axes object to plot in. If None, plots in the current axes.
Defaults to None.
mask_coi (bool, optional): ignore samples that are in the COI.
Defaults to True.
**kwargs (optional): optional keyword arguments for the plt.contourf function.
"""
if which == 'frac_phase' and phase_center is None:
raise ValueError('Select a phase (specify `phase_center`) when which="frac_phase".')
# Set current axis.
if ax is not None:
plt.sca(ax)
# Compute frequency profile.
freqs, time, freq_profile = self.get_cum_freq_profile(
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
which=which, phase_center=phase_center, phase_tol=phase_tol,
mask_coi=mask_coi
)
# Default kwargs.
plot_kwargs = dict({
'levels': 21,
'cmap': 'jet',
'extend': 'both',
}, **kwargs)
# Plot.
plt.contourf(time, freqs, freq_profile,
**plot_kwargs)
plt.xlabel('Time ({})'.format(time_scale))
plt.ylabel('{} ({})'.format(
'Frequency' if 'hz' in freq_scale.lower() else 'Scale',
freq_scale))
plt.title('Cumulative {}'.format(which))
# Log2 scale for frequencies.
ax = plt.gca()
ax.semilogy(base=2)
formatter = FuncFormatter(lambda y, _: '{:.8g}'.format(y))
ax.yaxis.set_major_formatter(formatter)
# Set limits.
plt.xlim([time[0], time[-1]])
nan_freqs = np.all(np.isnan(freq_profile), axis=-1)
plt.ylim([freqs[~nan_freqs].min(), freqs[~nan_freqs].max()])
# Revert y axis if plotting as scale.
if not 'hz' in freq_scale.lower():
ax.invert_yaxis()
[docs] def plot_freq_profile(self, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
how='mean', phase_center=None, phase_tol=np.pi / 4,
ax=None, orientation='horizontal', mask_coi=True, **kwargs):
"""
Plot the frequency profile.
Args:
freq_low (float, optional): low frequency to include (in the unit of `freq_scale`).
If None, the minimum available frequency is used.
Defaults to None.
freq_high (float, optional): high frequency to include (in the unit of `freq_scale`).
If None, the maximum available frequency is used.
Defaults to None.
freq_scale (str, optional): frequency scale. Use 'Hz', 'mHz', 'seconds', 'minutes'.
Defaults to 'Hz'.
time_start (float, optional): start time to include (in the unit of `time_scale`).
If None, the first available timestamp is used.
Defaults to None.
time_stop (float, optional): end time to include (in the unit of `time_scale`).
If None, the last available timestamp is used.
Defaults to None.
time_scale (str, optional): time scale, see nnsa.data.plotting.convert_time_scale().
Defaults to 'seconds'.
how (str, optional): what to plot as a function of frequency. Choose from:
- 'mean': mean of P.
- 'median': median of P.
- 'frac_sig': fraction of significant P.
- 'frac_phase': occurence fraction of the specified phase.
phase_center and phase_tol: select phase. See self.get_phase_mask. If None, include all phases.
ax (plt.Axes, optional): axes object to plot in. If None, plots in the current axes.
Defaults to None.
orientation (str, optional): whether to plot 'horitontal' (i.e. frequency on the x-axis),
or 'vertical' (i.e. frequency on the y-axis).
Defaults to 'horizontal'.
mask_coi (bool, optional): ignore samples that are in the COI.
Defaults to True.
**kwargs (optional): optional keyword arguments for the plt.contourf function.
"""
if how == 'frac_phase' and phase_center is None:
raise ValueError('Select a phase (specify `phase_center`) when which="frac_phase".')
# Set current axis.
if ax is not None:
plt.sca(ax)
else:
ax = plt.gca()
# Compute frequency profile.
freqs, freq_profile = self.get_freq_profile(
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
how=how, phase_center=phase_center, phase_tol=phase_tol,
mask_coi=mask_coi
)
# Default kwargs.
plot_kwargs = {
}
plot_kwargs.update(kwargs)
label = '{} (a.u.)'.format(how)
freq_label = '{} ({})'.format(
'Frequency' if 'hz' in freq_scale.lower() else 'Scale',
freq_scale)
# Plot.
if orientation == 'horizontal':
plt.plot(freqs, freq_profile, **plot_kwargs)
plt.ylabel(label)
plt.xlabel(freq_label)
# Log2 scale for frequencies.
ax.semilogx(base=2)
formatter = FuncFormatter(lambda y, _: '{:.8g}'.format(y))
ax.xaxis.set_major_formatter(formatter)
elif orientation == 'vertical':
plt.plot(freq_profile, freqs, **plot_kwargs)
plt.xlabel(label)
plt.ylabel(freq_label)
# Log2 scale for frequencies.
ax.semilogy(base=2)
formatter = FuncFormatter(lambda y, _: '{:.8g}'.format(y))
ax.yaxis.set_major_formatter(formatter)
else:
raise ValueError('Invalid `orientation` "{}". Choose from "horizontal", "vertical".'
.format(orientation))
plt.title('Frequency profile')
[docs] def plot_phase(self, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
mask_non_significant=True, alpha=0.05, add_colorbar=True,
ax=None, cmap=None, **kwargs):
"""
Plot the phase.
Args:
freq_low (float, optional): low frequency to include (in the unit of `freq_scale`).
If None, the minimum available frequency is used.
Defaults to None.
freq_high (float, optional): high frequency to include (in the unit of `freq_scale`).
If None, the maximum available frequency is used.
Defaults to None.
freq_scale (str, optional): frequency scale. Use 'Hz', 'mHz', 'seconds', 'minutes'.
Defaults to 'Hz'.
time_start (float, optional): start time to include (in the unit of `time_scale`).
If None, the first available timestamp is used.
Defaults to None.
time_stop (float, optional): end time to include (in the unit of `time_scale`).
If None, the last available timestamp is used.
Defaults to None.
time_scale (str, optional): time scale, see nnsa.data.plotting.convert_time_scale().
Defaults to 'seconds'.
ax (plt.Axes, optional): axes object to plot in. If None, plots in the current axes.
Defaults to None.
cmap (list, optional): colopmap.
**kwargs (optional): optional keyword arguments for the plt.pcolormesh function.
"""
if cmap is None:
cmap = 'twilight'
# # Cyclic colormap.
# cmap_1_sided = cm.get_cmap('RdYlGn', 180)
# neg_colors = cmap_1_sided(np.linspace(0, 1, 180))
# pos_colors = neg_colors[::-1]
# new_colors = np.concatenate([neg_colors, pos_colors])
# cmap = ListedColormap(new_colors)
# Set current axis.
if ax is not None:
plt.sca(ax)
else:
ax = plt.gca()
# Extract data.
freqs, time, P, significance, A, _, insidecoi = self._extract_data(
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
mask_coi=True)
if mask_non_significant and self.significance is not None:
# Remove non significant.
sig_mask = self.get_significance_mask(significance=significance, alpha=alpha) == 1
A[~sig_mask] = np.nan
else:
sig_mask = np.ones(P.shape, dtype=bool)
# Shade insidecoi.
cmap_coi = ListedColormap([0.5 * np.ones(3)])
coi_mask = insidecoi.copy().astype(float)
coi_mask[~insidecoi] = np.nan
plt.contourf(time, freqs, insidecoi, levels=[0.5, 1.5], alpha=1, cmap=cmap_coi)
# Plot.
phase = A * 180 / np.pi
plot_kwargs = dict(dict(
vmin=-180,
vmax=180,
cmap=cmap,
), **kwargs)
plt.pcolormesh(time, freqs, phase, shading='auto', **plot_kwargs)
# plot_kwargs = dict(dict(
# levels=np.linspace(-180, 180, 21),
# cmap=cmap,
# ), **kwargs)
# plt.contourf(time, freqs, phase, **plot_kwargs)
plt.xlabel('Time ({})'.format(time_scale))
plt.ylabel('{} ({})'.format(
'Frequency' if 'hz' in freq_scale.lower() else 'Scale',
freq_scale))
plt.title('{} Phase'.format(self.name if self.name is not None else ''))
# Log2 scale for frequencies.
ax = plt.gca()
ax.semilogy(base=2)
formatter = FuncFormatter(lambda y, _: '{:.8g}'.format(y))
ax.yaxis.set_major_formatter(formatter)
# Set limits.
plt.xlim([time[0], time[-1]])
nan_freqs = np.all(np.isnan(A), axis=-1)
plt.ylim([freqs[~nan_freqs].min(), freqs[~nan_freqs].max()])
# Revert y axis if plotting as scale.
if not 'hz' in freq_scale.lower():
ax.invert_yaxis()
if add_colorbar:
cbar = plt.colorbar(ax=ax)
cbar.set_label('Phase ($\degree$)')
[docs] def plot_phase_discrete(self, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
plot_significance=True, significance_mode='contour', alpha=0.05,
ax=None, mask_coi=True, mask_mode='shade', phases=None, colors=None,
**kwargs):
"""
Plot the phase.
Args:
freq_low (float, optional): low frequency to include (in the unit of `freq_scale`).
If None, the minimum available frequency is used.
Defaults to None.
freq_high (float, optional): high frequency to include (in the unit of `freq_scale`).
If None, the maximum available frequency is used.
Defaults to None.
freq_scale (str, optional): frequency scale. Use 'Hz', 'mHz', 'seconds', 'minutes'.
Defaults to 'Hz'.
time_start (float, optional): start time to include (in the unit of `time_scale`).
If None, the first available timestamp is used.
Defaults to None.
time_stop (float, optional): end time to include (in the unit of `time_scale`).
If None, the last available timestamp is used.
Defaults to None.
time_scale (str, optional): time scale, see nnsa.data.plotting.convert_time_scale().
Defaults to 'seconds'.
plot_surrogates (bool, optional): plot mean profile of the surrogates (if surrogates available).
Defaults to False.
ax (plt.Axes, optional): axes object to plot in. If None, plots in the current axes.
Defaults to None.
orientation (str, optional): whether to plot 'horitontal' (i.e. frequency on the x-axis),
or 'vertical' (i.e. frequency on the y-axis).
Defaults to 'horizontal'.
mask_coi (bool, optional): ignore samples that are in the COI.
Defaults to True.
phases (list, optional): list with phases classes.
colors (list, optional): list with colors corresponding to phases.
colors:
**kwargs (optional): optional keyword arguments for the plt.pcolormesh function.
"""
if colors is None:
if phases is None:
# Default colors and phases.
colors = np.asarray(plt.get_cmap('Paired').colors)[[1, 0, 5, 4]]
else:
colors = ['C{}'.format(i) for i in range(len(phases))]
if phases is None:
phases = np.array([0, np.pi / 2, np.pi, -np.pi / 2])
else:
phases = np.asarray(phases)
# Set current axis.
if ax is not None:
plt.sca(ax)
# Extract data.
freqs, time, P, significance, A, _, insidecoi = self._extract_data(
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
mask_coi=True if mask_mode == 'remove' else False)
if colors in ['hsv', 'continuous']:
# Use continuous cyclic color map hsv.
phase = A*180/np.pi
plot_kwargs = dict(
vmin=-180,
vmax=180,
cmap='hsv',
)
else:
# Use discrete color map.
if len(phases) != len(colors):
raise ValueError('Length of `phases` ({}) and `colors` ({}) must be equal.'
.format(len(phases), len(colors)))
# Determine phase number.
phase_diff = np.angle(
np.exp(1j * (A[:, :, np.newaxis] - phases[np.newaxis, np.newaxis, :])))
phase = np.argmin(np.abs(phase_diff), axis=-1)
cmap = ListedColormap(colors=colors)
levels = np.arange(len(phases) + 1) - 0.5
plot_kwargs = dict(
cmap=cmap,
norm=BoundaryNorm(levels, ncolors=cmap.N, clip=True),
)
# Plot.
plt.pcolormesh(time, freqs, phase, shading='nearest', **dict(plot_kwargs, **kwargs))
plt.xlabel('Time ({})'.format(time_scale))
plt.ylabel('{} ({})'.format(
'Frequency' if 'hz' in freq_scale.lower() else 'Scale',
freq_scale))
plt.title('{} Phase'.format(self.name if self.name is not None else ''))
# Circle significant parts with black line if significance is available.
if significance is not None and plot_significance:
sig_mask = self.get_significance_mask(significance=significance, alpha=alpha).astype(int)
# I expect to see "UserWarning: No contour levels were found within the data range." in this block
if significance_mode == 'contour':
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
plt.contour(time, freqs, sig_mask, levels=[1], colors='k', linewidths=2)
elif significance_mode in ['hide', 'shade', 'contourf']:
cmap = ListedColormap([0.5 * np.ones(3)])
plt.contourf(time, freqs, sig_mask, levels=[-0.01, 0.01], alpha=1, cmap=cmap)
if mask_coi and mask_mode == 'shade' and insidecoi is not None:
coi_mask = insidecoi.copy().astype(float)
coi_mask[~insidecoi] = np.nan
cmap = ListedColormap([0.0 * np.ones(3)])
plt.contourf(time, freqs, coi_mask, levels=[0.99, 1.01], alpha=0.75, cmap=cmap)
# Log2 scale for frequencies.
ax = plt.gca()
ax.semilogy(base=2)
formatter = FuncFormatter(lambda y, _: '{:.8g}'.format(y))
ax.yaxis.set_major_formatter(formatter)
# Set limits.
plt.xlim([time[0], time[-1]])
nan_freqs = np.all(np.isnan(A), axis=-1)
plt.ylim([freqs[~nan_freqs].min(), freqs[~nan_freqs].max()])
# Revert y axis if plotting as scale.
if not 'hz' in freq_scale.lower():
ax.invert_yaxis()
[docs] def plot_scalogram(self, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
plot_significance=True, significance_mode='contour', alpha=0.05,
ax=None, mask_coi=True, mask_mode='shade', add_colorbar=False,
**kwargs):
"""
Plot the scalogram (contour plot of the wavelet coherence as a function of frequency and time).
If surrogate data is available, a significance threshold is computed using the significance level `alpha`
and regions where the coherence exceeds the threshold are encircled with black contours.
Args:
freq_low (float, optional): low frequency to include (in the unit of `freq_scale`).
If None, the minimum available frequency is used.
Defaults to None.
freq_high (float, optional): high frequency to include (in the unit of `freq_scale`).
If None, the maximum available frequency is used.
Defaults to None.
freq_scale (str, optional): frequency scale. Use 'Hz', 'mHz', 'seconds', 'minutes'.
Defaults to 'Hz'.
time_start (float, optional): start time to include (in the unit of `time_scale`).
If None, the first available timestamp is used.
Defaults to None.
time_stop (float, optional): end time to include (in the unit of `time_scale`).
If None, the last available timestamp is used.
Defaults to None.
time_scale (str, optional): time scale, see nnsa.data.plotting.convert_time_scale().
Defaults to 'seconds'.
plot_significance (bool, optional): plot contour lines alng the significant regions (True) or not (False).
Defaults to True.
significance_mode (str, optional): how to display significance. Choose from:
- 'contour': draws contour lines at the significance level.
- 'hide': hides non-significant pixels (shades them darkgrey).
Defaults to 'contour'.
alpha (float, optional): float between 0 and 1, defining the significance level when computing the
significance threshold if surrogate data is available.
Defaults to 0.05.
ax (plt.Axes, optional): axes object to plot in. If None, plots in the current axes.
Defaults to None.
mask_coi (bool, optional): mask samples in the COI.
Defaults to True.
mask_mode (str, optional): choose to 'remove' or 'shade' nans.
Defaults to 'shade'.
add_colorbar (bool): whether to add a colorbar to the plot (True) or not (False).
**kwargs (optional): optional keyword arguments for the plt.contourf function.
"""
# Set current axis.
if ax is not None:
plt.sca(ax)
else:
ax = plt.gca()
# Extract data.
freqs, time, P, significance, A, _, insidecoi = self._extract_data(
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
mask_coi=True if mask_mode == 'remove' else False)
# Default kwargs.
plot_kwargs = dict({
'levels': 21,
'cmap': 'jet',
'extend': 'both',
}, **kwargs)
# Plot.
h = plt.contourf(time, freqs, P,
**plot_kwargs)
plt.xlabel('Time ({})'.format(time_scale))
plt.ylabel('{} ({})'.format(
'Frequency' if 'hz' in freq_scale.lower() else 'Scale',
freq_scale))
plt.title('{}'.format(self.name if self.name is not None else 'Scalogram'))
if add_colorbar:
plt.colorbar(ax=ax)
# Circle significant parts with black line if significance is available.
if significance is not None and plot_significance:
sig_mask = self.get_significance_mask(significance=significance, alpha=alpha).astype(int)
# I expect to see "UserWarning: No contour levels were found within the data range." in this block
if significance_mode == 'contour':
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
plt.contour(time, freqs, sig_mask, levels=[1], colors='k', linewidths=2)
elif significance_mode in ['hide', 'shade', 'contourf']:
cmap = ListedColormap([0.5 * np.ones(3)])
plt.contourf(time, freqs, sig_mask, levels=[-0.01, 0.01], alpha=1, cmap=cmap)
if mask_coi and mask_mode == 'shade' and insidecoi is not None:
coi_mask = insidecoi.copy().astype(float)
coi_mask[~insidecoi] = np.nan
cmap = ListedColormap([0.0*np.ones(3)])
plt.contourf(time, freqs, coi_mask, levels=[0.99, 1.01], alpha=0.65, cmap=cmap)
# Log2 scale for frequencies.
ax = plt.gca()
ax.semilogy(base=2)
formatter = FuncFormatter(lambda y, _: '{:.8g}'.format(y))
ax.yaxis.set_major_formatter(formatter)
# Set limits.
plt.xlim([time[0], time[-1]])
nan_freqs = np.all(np.isnan(P), axis=-1)
plt.ylim([freqs[~nan_freqs].min(), freqs[~nan_freqs].max()])
# Revert y axis if plotting as scale.
if not 'hz' in freq_scale.lower():
ax.invert_yaxis()
return h
[docs] def plot_signals(self, time_start=None, time_stop=None, time_scale='seconds',
normalize=True, smooth_window=None, ax=None, colors=None, **kwargs):
"""
Plot the original signals as function of time.
Args:
time_start (float, optional): start time to include (in the unit of `time_scale`).
If None, the first available timestamp is used.
Defaults to None.
time_stop (float, optional): end time to include (in the unit of `time_scale`).
If None, the last available timestamp is used.
Defaults to None.
time_scale (str, optional): time scale, see nnsa.data.plotting.convert_time_scale().
Defaults to 'seconds'.
normalize (bool, optional): normalize signals (zero mean, unit SD) prior to plotting.
Defaults to True.
smooth_window (float): optional window for moving mean to smooth time signals
(in seconds).
ax (plt.Axes, optional): axes object to plot in. If None, plots in the current axes.
Defaults to None.
colors (optional): Specify one color for all signals (str, array), or one color for each
signals (list, tuple, dict).
**kwargs (optional): optional keyword arguments for the plt.plot function.
"""
if self.signals is None:
raise ValueError('Original signals are missing. Set the `signals` attribute.')
# Set current axis.
if ax is not None:
plt.sca(ax)
else:
ax = plt.gca()
# Extract data.
signals = self.signals.copy()
time = np.linspace(self.time[0], self.time[-1] + 1/self.fs, signals[0].shape[-1]) # Downsampling could have happened.
# Convert time scale.
time = convert_time_scale(time, time_scale)
# Extract time period.
if time_start is not None or time_stop is not None:
if time_start is None:
time_start = time[0]
if time_stop is None:
time_stop = time[-1]
time_mask = np.logical_and(time_start <= time, time <= time_stop)
signals = signals[:, time_mask]
time = time[time_mask]
# I expect to see "RuntimeWarning: Mean of empty slice" in this block
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
if normalize:
for i in range(len(signals)):
signals[i] = signals[i] - np.nanmean(signals[i], axis=-1, keepdims=True)
signals[i] = signals[i] / np.nanstd(signals[i], axis=-1, keepdims=True)
# Default kwargs.
plot_kwargs = {
}
plot_kwargs.update(kwargs)
# Plot.
for i, sig in enumerate(signals):
if self.labels is not None:
label = self.labels[i]
else:
label = ''
if not normalize and i != 0:
ax_i = ax.twinx()
else:
ax_i = ax
# Smooth.
if smooth_window is not None:
n_smooth = int(smooth_window * self.fs)
sig = moving_mean(sig, n_smooth)
if colors is None:
color_i = 'C{}'.format(i)
elif isinstance(colors, (list, tuple)):
color_i = colors[i]
elif isinstance(colors, dict):
color_i = colors[label]
else:
color_i = colors
ax_i.plot(time, sig, label=label, color=color_i, **plot_kwargs)
if not normalize:
ax_i.set_ylabel(label, color=color_i)
else:
ax_i.set_ylabel('Signal (a.u.)')
ax_i.tick_params(axis='y', labelcolor=color_i)
ax.set_xlabel('Time ({})'.format(time_scale))
plt.title('Input signals')
if self.labels is not None and normalize is True:
plt.legend()
[docs] def plot_time_profile(self, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
plot_surrogates=False,
ax=None, mask_coi=True, **kwargs):
"""
Plot the time profile (i.e. coherence averaged over a range of frequencies).
Args:
freq_low (float, optional): low frequency to include (in the unit of `freq_scale`).
If None, the minimum available frequency is used.
Defaults to None.
freq_high (float, optional): high frequency to include (in the unit of `freq_scale`).
If None, the maximum available frequency is used.
Defaults to None.
freq_scale (str, optional): frequency scale. Use 'Hz', 'mHz', 'seconds', 'minutes'.
Defaults to 'Hz'.
time_start (float, optional): start time to include (in the unit of `time_scale`).
If None, the first available timestamp is used.
Defaults to None.
time_stop (float, optional): end time to include (in the unit of `time_scale`).
If None, the last available timestamp is used.
Defaults to None.
time_scale (str, optional): time scale, see nnsa.data.plotting.convert_time_scale().
Defaults to 'seconds'.
plot_surrogates (bool, optional): plot mean profile of the surrogates (if surrogates available).
Defaults to False.
ax (plt.Axes, optional): axes object to plot in. If None, plots in the current axes.
Defaults to None.
mask_coi (bool, optional): ignore samples that are in the COI.
Defaults to True.
**kwargs (optional): optional keyword arguments for the plt.plot function.
"""
# Set current axis.
if ax is not None:
plt.sca(ax)
# Extract data.
freqs, time, P, significance, A, P_surrogates, _ = self._extract_data(
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale, mask_coi=mask_coi)
plot_surrogates = plot_surrogates and (P_surrogates is not None)
# Compute time profile.
# I expect to see "RuntimeWarning: Mean of empty slice" in this block
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
time_profile = np.nanmean(P, axis=0)
if plot_surrogates:
P_surrogates = np.nanmean(P_surrogates, axis=0)
# Default kwargs.
plot_kwargs = {
}
plot_kwargs.update(kwargs)
# Plot.
if plot_surrogates:
plt.plot(time, P_surrogates, **dict(
plot_kwargs, linestyle='--', color=0.5 * np.ones(3)))
plt.plot(time, time_profile, **plot_kwargs)
plt.xlabel('Time ({})'.format(time_scale))
plt.ylabel('Frequency average (a.u.)')
plt.title('Time profile')
plt.ylim(0, 1)
[docs] def to_complex(self):
"""
Return complex values C.
C = sqrt(P) * exp(Aj)
Returns:
C (np.ndarray): complex values, came shape as self.P and self.A.
"""
C = np.sqrt(self.P) * np.exp(1j * self.A)
return C
[docs] def update_coi(self, nan_mask, **kwargs):
"""
Update cone of influence by combining it with a MOI based on nan_mask (inplace).
Args:
nan_mask: see compute_moi.
**kwargs: for compute_moi().
"""
new_coi = compute_moi(nan_mask, self.scales, self.dt, self.dj, self.wavelet, **kwargs)
self.insidecoi = np.logical_or(self.insidecoi, new_coi)
def _extract_data(self, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
mask_coi=True):
"""
Extract frequency, time, P, significance and A arrays for the specified region.
If mask_coi is True, samples that are in the COI are replaced with
nans in P, A and significance arrays.
"""
freqs = convert_freq_scale(self.freqs, freq_scale).copy()
time = convert_time_scale(self.time, time_scale).copy()
P = self.P.copy()
significance = self.significance.copy() if self.significance is not None else None
A = self.A.copy()
P_surrogates = self.P_surrogates.copy() if self.P_surrogates is not None else None
insidecoi = self.insidecoi.copy() if self.insidecoi is not None else None
# Select frequencies.
if freq_low is not None or freq_high is not None:
if freq_low is None:
freq_low = np.min(freqs)
if freq_high is None:
freq_high = np.max(freqs)
freq_mask = np.logical_and(freq_low <= freqs, freqs <= freq_high)
freqs = freqs[freq_mask]
P = P[freq_mask, :]
A = A[freq_mask, :]
if significance is not None:
significance = significance[freq_mask, :]
if P_surrogates is not None:
P_surrogates = P_surrogates[freq_mask, :]
if insidecoi is not None:
insidecoi = insidecoi[freq_mask, :]
# Select times.
if time_start is not None or time_stop is not None:
if time_start is None:
time_start = time[0]
if time_stop is None:
time_stop = time[-1]
time_mask = np.logical_and(time_start <= time, time <= time_stop)
time = time[time_mask]
P = P[:, time_mask]
A = A[:, time_mask]
if significance is not None:
significance = significance[:, time_mask]
if P_surrogates is not None:
P_surrogates = P_surrogates[:, time_mask]
if insidecoi is not None:
insidecoi = insidecoi[:, time_mask]
if mask_coi and insidecoi is not None:
P[insidecoi] = np.nan
A[insidecoi] = np.nan
if significance is not None:
significance[insidecoi] = np.nan
if P_surrogates is not None:
P_surrogates[insidecoi] = np.nan
return freqs, time, P, significance, A, P_surrogates, insidecoi
[docs] def shade(self, mask, freq_low=None, freq_high=None, freq_scale='Hz',
time_start=None, time_stop=None, time_scale='seconds',
color='k', alpha=0.5, ax=None):
"""
Plot filled contour mask to shade a part of the scalogram.
Args:
mask (np.ndarray): binary mask, where True values will be shaded.
"""
if ax is None:
ax = plt.gca()
else:
plt.sca(ax)
freqs, time, _, _, _, _, _ = self._extract_data(
freq_low=freq_low, freq_high=freq_high, freq_scale=freq_scale,
time_start=time_start, time_stop=time_stop, time_scale=time_scale,
mask_coi=True)
cmap = ListedColormap([color])
plt.contourf(time, freqs, mask, levels=[0.99, 1.01], alpha=alpha, cmap=cmap)
def _merge(self, other, index=None):
"""
See ResultBase.
"""
# Check if the frequency array of self and other are the same.
if len(self.freqs) != len(other.freqs) or np.any(np.abs(self.freqs - other.freqs) > 1e-10):
raise ValueError('Cannot merge results with different frequency arrays.')
# # Add nan frequencies if freq arrays are not of same length.
# if len(self.freqs) < len(other.freqs):
# if abs(self.freqs[-1] - other.freqs[len(self.freqs) - 1]) > 1e-10:
# raise ValueError('Cannot merge results with different frequency arrays.')
# else:
# # Add nans to self.
# n_samples = self.Cxy.shape[-1]
# n_add = len(other.freqs) - len(self.freqs)
# self.Cxy = np.concatenate([self.Cxy,
# np.full((n_add, n_samples), fill_value=np.nan)], axis=0)
# self.Axy = np.concatenate([self.Axy,
# np.full((n_add, n_samples), fill_value=np.nan)], axis=0)
# if self.significance is not None:
# self.significance = np.concatenate([self.significance,
# np.full((n_add, n_samples),
# fill_value=np.nan)], axis=1)
# elif len(self.freqs) > len(other.freqs):
# if abs(other.freqs[-1] - self.freqs[len(other.freqs) - 1]) > 1e-10:
# raise ValueError('Cannot merge results with different frequency arrays.')
# else:
# # Add nans to other.
# n_samples = other.Cxy.shape[-1]
# n_add = len(self.freqs) - len(other.freqs)
# other.Cxy = np.concatenate([other.Cxy,
# np.full((n_add, n_samples), fill_value=np.nan)], axis=0)
# other.Axy = np.concatenate([other.Axy,
# np.full((n_add, n_samples), fill_value=np.nan)], axis=0)
# if other.significance is not None:
# other.significance = np.concatenate([other.significance,
# np.full((n_add, n_samples),
# fill_value=np.nan)], axis=1)
if index is not None:
n_freqs, n_samples = self.P.shape
if index < n_samples:
# Cut piece off.
msg = 'Overwriting data while merging.'
warnings.warn(msg)
self.P = self.P[:, :index]
self.A = self.A[:, :index]
if self.significance is not None:
self.significance = self.significance[:, :index]
if self.P_surrogates is not None:
self.P_surrogates = self.P_surrogates[:, :index]
if self.signals is not None:
self.signals = self.signals[:, :index]
if self.insidecoi is not None:
self.insidecoi = self.insidecoi[:, :index]
else:
# Add nans.
self.P = np.concatenate([self.P,
np.full((n_freqs, index - n_samples), fill_value=np.nan)], axis=-1)
self.A = np.concatenate([self.A,
np.full((n_freqs, index - n_samples), fill_value=np.nan)], axis=-1)
if self.significance is not None:
self.significance = np.concatenate([self.significance,
np.full((n_freqs, index - n_samples),
fill_value=np.nan)], axis=-1)
if self.P_surrogates is not None:
self.P_surrogates = np.concatenate([self.P_surrogates,
np.full((n_freqs, index - n_samples),
fill_value=np.nan)], axis=-1)
if self.signals is not None:
self.signals = np.concatenate([self.signals,
np.full((len(self.signals), index - n_samples), fill_value=np.nan)],
axis=-1)
if self.insidecoi is not None:
self.insidecoi = np.concatenate([self.insidecoi,
np.full((n_freqs, index - n_samples), fill_value=True)],
axis=-1)
self.P = np.concatenate((self.P, other.P), axis=-1)
self.A = np.concatenate((self.A, other.A), axis=-1)
if self.significance is not None and other.significance is not None:
self.significance = np.concatenate((self.significance, other.significance), axis=-1)
if self.P_surrogates is not None and other.P_surrogates is not None:
self.P_surrogates = np.concatenate((self.P_surrogates, other.P_surrogates), axis=-1)
if self.signals is not None and other.signals is not None:
self.signals = np.concatenate((self.signals, other.signals), axis=-1)
if self.insidecoi is not None and other.insidecoi is not None:
self.insidecoi = np.concatenate((self.insidecoi, other.insidecoi), axis=-1)
@staticmethod
def _read_from_hdf5(filepath):
"""
Read result from hdf5 file into a WaveletResult class.
Args:
filepath (str): see ResultBase._read_from_hdf5().
Returns:
result (nnsa.WaveletResult): instance of DynamicCouplingResult containing the
DynamicCoupling result.
"""
# Read standard hdf5 header (use the ResultBase method).
algorithm_parameters, data_info, segment_start_times, segment_end_times, fs, time_offset = \
ResultBase._read_hdf5_header(filepath)[1:]
# Re-open the file and read the rest of the file.
with h5py.File(filepath, 'r') as f:
# Read array data.
P = f['P'][:]
A = f['A'][:]
freqs = f['freqs'][:]
attrs = f['P'].attrs
if 'significance' in f:
significance = f['significance'][:]
else:
significance = None
if 'P_surrogates' in f:
P_surrogates = f['P_surrogates'][:]
else:
P_surrogates = None
if 'signals' in f:
signals = tuple()
labels = []
if isinstance(f['signals'], h5py.Dataset):
signals = f['signals'][:]
if 'labels' in attrs:
labels = [lab.decode() for lab in attrs['labels']]
else:
labels = None
else:
for key in f['signals']:
signals += (f['signals'][key][:],)
labels.append(key)
else:
signals = None
labels = None
if 'insidecoi' in f:
insidecoi = f['insidecoi'][:]
else:
insidecoi = None
if 'extra' in f:
extra = dict()
for key in f['extra']:
extra[key] = f['extra'][key][:]
else:
extra = None
# Read non-array data.
if 'name' in attrs:
name = attrs['name'].decode()
else:
name = None
if 'wavelet' in attrs:
wavelet = attrs['wavelet'].decode()
else:
wavelet = None
# Create a result object.
result = WaveletResult(P=P, A=A, freqs=freqs, fs=fs, wavelet=wavelet,
insidecoi=insidecoi, name=name,
algorithm_parameters=algorithm_parameters,
significance=significance,
P_surrogates=P_surrogates,
signals=signals, labels=labels,
extra=extra,
data_info=data_info,
segment_start_times=segment_start_times,
segment_end_times=segment_end_times,
time_offset=time_offset)
return result
def _write_to_hdf5(self, filepath):
"""
Write the contents of the object to an hdf5 file.
Args:
filepath (str): see ResultBase._write_to_hdf5().
"""
# Write standard hdf5 header (use the ResultBase method).
self._write_hdf5_header(filepath)
# Append attributes to the hdf5 file.
with h5py.File(filepath, 'a') as f:
# Write array data.
f.create_dataset('P', data=self.P)
f.create_dataset('A', data=self.A)
f.create_dataset('freqs', data=self.freqs)
if self.significance is not None:
f.create_dataset('significance', data=self.significance)
if self.P_surrogates is not None:
f.create_dataset('P_surrogates', data=self.P_surrogates)
if self.signals is not None:
for label, sig in zip(self.labels, self.signals):
f.create_dataset('signals/{}'.format(label), data=sig)
if self.insidecoi is not None:
f.create_dataset('insidecoi', data=self.insidecoi)
if self.extra is not None:
for label, sig in self.extra.items():
f.create_dataset('extra/{}'.format(label), data=sig)
# Write non-array data as attributes to the 'Cxy' group.
# Convert strings to np.string_ type as recommended for compatibility.
# if self.labels is not None:
# f['P'].attrs['labels'] = [np.string_(lab) for lab in self.labels]
if self.name is not None:
f['P'].attrs['name'] = np.string_(self.name)
if self._wavelet is not None:
f['P'].attrs['wavelet'] = np.string_(self._wavelet)
[docs]class WaveletCoherenceResult(WaveletResult):
"""
High-level interface for processing wavelet coherence as computed by nnsa.WaveletCoherence().
Alias for WaveletResult().
Args:
see WaveletResult.
"""
[docs]def convert_freq_scale(freqs, freq_scale, current='Hz'):
"""
Convert frequencies in the current scale to some other frequency or time scale.
"""
# Convert from current scale to Hz.
freqs = revert_freq_scale(freqs, current)
# Convert from Hz to the desired freq scale.
freq_scale = freq_scale.lower()
if freq_scale == 'hz':
return freqs
elif freq_scale == 'mhz':
return freqs * 1000
elif freq_scale in ['seconds', 's']:
return 1 / freqs
elif freq_scale in ['minutes', 'min']:
return 1 / freqs / 60
elif freq_scale in ['hours', 'h']:
return 1 / freqs / 60 / 60
else:
raise ValueError('Invalid freq_scale "{}".'.format(freq_scale))
def get_significance_mask(P, P_surrogates, alpha=0.05, axis=0):
"""
Determines whether the value P is greater than (1- alpha) percent in the surrogate values P_surrogates.
Args:
P:
P_surrogates:
alpha:
axis: axis in P_surrogates that corresponding to the surrogates.
Returns:
"""
# I expect to see "RuntimeWarning: invalid value encountered in greater_equal" in this block
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
# Proportion of surrogates with lower values.
significance = np.nanmean(P > P_surrogates, axis=axis)
# Where significance is greater than the alpha level.
mask = (significance >= (1 - alpha)).astype(int)
return mask
[docs]def revert_freq_scale(freqs, freq_scale):
"""
Convert frequencies in some other frequency or time scale to Hz.
"""
freq_scale = freq_scale.lower()
if freq_scale == 'hz':
return freqs
elif freq_scale == 'mhz':
return freqs / 1000
elif freq_scale in ['seconds', 's']:
return 1 / freqs
elif freq_scale in ['minutes', 'min']:
return 1 / freqs / 60
elif freq_scale in ['hours', 'h']:
return 1 / freqs / 60 / 60
else:
raise ValueError('Invalid freq_scale "{}".'.format(freq_scale))