"""
Author: Tim Hermans (tim-hermans@hotmail.com).
"""
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.ticker import FuncFormatter
from nnsa.cwt.transforms import cwt
from nnsa.utils.conversions import convert_time_scale
__all__ = [
'plot_cwt_scalogram',
'plot_scalogram',
'plot_tf_map',
]
[docs]def plot_cwt_scalogram(x, fs=1, ax=None, cwt_kwargs=None, plot_kwargs=None):
"""
Compute and plot CWT power in dB.
"""
W, scales, freqs, insidecoi = cwt(x, dt=1/fs, **cwt_kwargs)
C = 20*np.log(1/scales.reshape(-1, 1)*np.abs(W)) # Multiply by 1/s to remove bias to low frequencies (Liu et al. 2007).
C[insidecoi] = np.nan
time = np.arange(W.shape[-1])/fs
plot_scalogram(time, freqs, C, ax=ax, **plot_kwargs)
[docs]def plot_scalogram(time, freqs, C, insidecoi=None, alpha=0.75,
prange=None, time_scale='seconds', db=False,
ax=None, colorbar=False, **kwargs):
"""
Plot a scalogram with freqs in logscale.
"""
if ax is None:
ax = plt.gca()
else:
plt.sca(ax)
if time is None:
time = np.arange(C.shape[1])
if freqs is None:
freqs = np.arange(C.shape[0])
if prange is not None:
if len(prange) != 2:
raise ValueError('`prange` must have length 2. Got length {}.'.format(len(prange)))
C = np.clip(C, np.nanpercentile(C, prange[0]), np.nanpercentile(C, prange[1]))
# Default kwargs.
plot_kwargs = dict({
'levels': 21,
'cmap': 'jet',
'extend': 'both',
}, **kwargs)
# Convert to decibels.
if db:
C = 10*np.log10(C)
# Plot.
t = convert_time_scale(time, time_scale)
plt.contourf(t, freqs, C,
**plot_kwargs)
plt.xlabel('Time ({})'.format(time_scale))
plt.ylabel('Frequency (Hz)')
if colorbar:
plt.colorbar()
if insidecoi is not None:
mask = insidecoi.copy().astype(float)
mask[~insidecoi] = np.nan
cmap = ListedColormap([np.ones(3)])
plt.contourf(t, freqs, mask, levels=[0.99, 1.01], alpha=alpha, cmap=cmap)
# Log2 scale for frequencies.
ax = plt.gca()
ax.semilogy(base=2)
formatter = FuncFormatter(lambda y, _: '{:.8g}'.format(y))
ax.yaxis.set_major_formatter(formatter)
[docs]def plot_tf_map(C, time=None, freqs=None, time_scale='seconds',
freq_scale='Hz', colorbar=True, ax=None, **kwargs):
"""
Plot time frequency map using plt.pcolormesh.
"""
from nnsa.feature_extraction.wavelets import convert_freq_scale
if ax is None:
ax = plt.gca()
else:
plt.sca(ax)
time = convert_time_scale(time, time_scale)
freqs = convert_freq_scale(freqs, freq_scale)
plot_kwargs = dict({
'cmap': 'jet'
}, **kwargs)
plt.pcolormesh(time,
freqs,
C, shading='auto',
**plot_kwargs)
plt.xlabel('Time ({})'.format(time_scale))
plt.ylabel('{} ({})'.format(
'Frequency' if 'hz' in freq_scale.lower() else 'Scale',
freq_scale))
# Log2 scale for frequencies.
ax = plt.gca()
ax.semilogy(base=2)
formatter = FuncFormatter(lambda y, _: '{:.8g}'.format(y))
ax.yaxis.set_major_formatter(formatter)
plt.ylim([freqs[0], freqs[-1]])
# Revert y axis if plotting as scale.
if not 'hz' in freq_scale.lower():
ax.invert_yaxis()
if colorbar:
plt.colorbar()