Source code for nnsa.cwt.plotting

"""
Author: Tim Hermans (tim-hermans@hotmail.com).
"""
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.ticker import FuncFormatter

from nnsa.cwt.transforms import cwt
from nnsa.utils.conversions import convert_time_scale


__all__ = [
    'plot_cwt_scalogram',
    'plot_scalogram',
    'plot_tf_map',
]


[docs]def plot_cwt_scalogram(x, fs=1, ax=None, cwt_kwargs=None, plot_kwargs=None): """ Compute and plot CWT power in dB. """ W, scales, freqs, insidecoi = cwt(x, dt=1/fs, **cwt_kwargs) C = 20*np.log(1/scales.reshape(-1, 1)*np.abs(W)) # Multiply by 1/s to remove bias to low frequencies (Liu et al. 2007). C[insidecoi] = np.nan time = np.arange(W.shape[-1])/fs plot_scalogram(time, freqs, C, ax=ax, **plot_kwargs)
[docs]def plot_scalogram(time, freqs, C, insidecoi=None, alpha=0.75, prange=None, time_scale='seconds', db=False, ax=None, colorbar=False, **kwargs): """ Plot a scalogram with freqs in logscale. """ if ax is None: ax = plt.gca() else: plt.sca(ax) if time is None: time = np.arange(C.shape[1]) if freqs is None: freqs = np.arange(C.shape[0]) if prange is not None: if len(prange) != 2: raise ValueError('`prange` must have length 2. Got length {}.'.format(len(prange))) C = np.clip(C, np.nanpercentile(C, prange[0]), np.nanpercentile(C, prange[1])) # Default kwargs. plot_kwargs = dict({ 'levels': 21, 'cmap': 'jet', 'extend': 'both', }, **kwargs) # Convert to decibels. if db: C = 10*np.log10(C) # Plot. t = convert_time_scale(time, time_scale) plt.contourf(t, freqs, C, **plot_kwargs) plt.xlabel('Time ({})'.format(time_scale)) plt.ylabel('Frequency (Hz)') if colorbar: plt.colorbar() if insidecoi is not None: mask = insidecoi.copy().astype(float) mask[~insidecoi] = np.nan cmap = ListedColormap([np.ones(3)]) plt.contourf(t, freqs, mask, levels=[0.99, 1.01], alpha=alpha, cmap=cmap) # Log2 scale for frequencies. ax = plt.gca() ax.semilogy(base=2) formatter = FuncFormatter(lambda y, _: '{:.8g}'.format(y)) ax.yaxis.set_major_formatter(formatter)
[docs]def plot_tf_map(C, time=None, freqs=None, time_scale='seconds', freq_scale='Hz', colorbar=True, ax=None, **kwargs): """ Plot time frequency map using plt.pcolormesh. """ from nnsa.feature_extraction.wavelets import convert_freq_scale if ax is None: ax = plt.gca() else: plt.sca(ax) time = convert_time_scale(time, time_scale) freqs = convert_freq_scale(freqs, freq_scale) plot_kwargs = dict({ 'cmap': 'jet' }, **kwargs) plt.pcolormesh(time, freqs, C, shading='auto', **plot_kwargs) plt.xlabel('Time ({})'.format(time_scale)) plt.ylabel('{} ({})'.format( 'Frequency' if 'hz' in freq_scale.lower() else 'Scale', freq_scale)) # Log2 scale for frequencies. ax = plt.gca() ax.semilogy(base=2) formatter = FuncFormatter(lambda y, _: '{:.8g}'.format(y)) ax.yaxis.set_major_formatter(formatter) plt.ylim([freqs[0], freqs[-1]]) # Revert y axis if plotting as scale. if not 'hz' in freq_scale.lower(): ax.invert_yaxis() if colorbar: plt.colorbar()