Basic wavelet transforms
Wavelet coherence using the continuous wavelet transform can be used to identify local correlations between two signals in time-frequency space. The nnsa package contains some useful code to do wavelet analysis. This script defines some test signals and computes CWT, wavelet coherence and partial coherence.
Author: Tim Hermans (tim-hermans@hotmail.com).
Link to script: cwt/basic_wavelet_transforms.py
import numpy as np
import matplotlib.pyplot as plt
from nnsa import WaveletCoherence
from nnsa.cwt import Morlet, plot_scalogram, cwt
from nnsa.utils.plotting import scale_figsize, save_fig_as
plt.close('all')
fig_width = 18
Create three test signals. The two signals of interest are x and y. We also define a third signal z which is functioning as a confounding signal. It’s confounding effect will be introduced by adding this signal to x and y.
# Seed the random generator.
rng = np.random.RandomState(40)
# Some settings and precomputations.
fs = 100
t_max = 10
N = fs*t_max
t = np.arange(N)/fs
t_rad = (t + 0*rng.normal(scale=1/fs, size=N))*2*np.pi
# The confounding signal z (only active (non-zero) in the middle half).
z = np.sin(t_rad * 4) * ((t > t_max/4) & (t < t_max*3/4))
# The signals x and y, which include the confounding z component.
x = np.sin(t_rad * 16) + np.sin(t_rad * 2) - z
y = np.sin(t_rad * 16) + np.sin(t_rad * 1 + np.pi) + z
# Finally, add some random noise to each signal.
noise_level = 1/10
x += rng.normal(scale=noise_level, size=N)
y += rng.normal(scale=noise_level, size=N)
z += rng.normal(scale=noise_level, size=N)
Plot the signals. You should be able to see the confounding effect that z has on both x and y when it becomes active in the middle part.
fig, ax = plt.subplots(
1, 1, tight_layout=True,
figsize=scale_figsize((4, 9/3), width=fig_width, unit='cm'),
sharex='all',
)
spacing = 2
ax.plot(t, x, label='$x$')
ax.plot(t, y - spacing, label='$y$')
ax.plot(t, z - spacing*2, label='$z$')
ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
ax.set_yticks([])
ax.set_ylabel('Signal (a.u.)')
ax.set_title('Signals')
To analyze these signals with wavelets, we use the Morlet mother wavelet.
# Create Morlet wavelet instance.
wavelet = Morlet(6)
# Plot the mother wavelet in time domain.
fig = plt.figure(
tight_layout=True,
figsize=scale_figsize((4, 3), width=0.6*fig_width, unit='cm'),
)
tau = np.linspace(-4, 4, 200)
w = wavelet.psi(tau)
plt.plot(tau, np.abs(w), color='C7', label='Magnitude')
plt.plot(tau, np.real(w), label='Real', color='C0')
plt.plot(tau, np.imag(w), linestyle='--', color='C0', label='Imaginary')
plt.xlabel('Normalized time')
plt.legend(loc='upper left', bbox_to_anchor=(0.65, 1))
Define settings for the continuous wavelet transform (see the documentation of the cwt function for more options).
cwt_kwargs = dict(
# Mother wavelet.
wavelet=Morlet(6),
# Sampling interval (seconds).
dt=1 / fs,
# Spacing between discrete scales. Smaller values will result in better
# scale resolution, but slower calculation and plot.
dj=1 / 10,
)
Compute CWTs.
Wx, scales, freqs, insidecoi = cwt(x, **cwt_kwargs)
Wy, scales, freqs, insidecoi = cwt(y, **cwt_kwargs)
Wz, scales, freqs, insidecoi = cwt(z, **cwt_kwargs)
# Convert wavelet coefficients to power. Note that the power is scaled by the inverse of the scale,
# see Liu et al. 2007 (https://doi.org/10.1175/2007JTECHO511.1).
Px = 1/scales.reshape(-1, 1) * np.abs(Wx)**2
Py = 1/scales.reshape(-1, 1) * np.abs(Wy)**2
Pz = 1/scales.reshape(-1, 1) * np.abs(Wz)**2
# Normalize.
Px /= np.max(Px)
Py /= np.max(Py)
Pz /= np.max(Pz)
Plot the wavelet power of each signal.
fig, axes = plt.subplots(
4, 1, tight_layout=True,
figsize=scale_figsize((4, 9), width=0.6*fig_width, unit='cm'),
sharex='all',
)
ax = axes[0]
spacing = 2
ax.plot(t, x, label='$x$')
ax.plot(t, y - spacing, label='$y$')
ax.plot(t, z - spacing*2, label='$z$')
ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
ax.set_yticks([])
ax.set_ylabel('Signal (a.u.)')
ax.set_title('Signals')
ax = axes[1]
plot_scalogram(t, freqs, Px, insidecoi=insidecoi, ax=ax)
ax.set_title('$|W_x|^2$')
ax = axes[2]
plot_scalogram(t, freqs, Py, insidecoi=insidecoi, ax=ax)
ax.set_title('$|W_y|^2$')
ax = axes[3]
plot_scalogram(t, freqs, Pz, insidecoi=insidecoi, ax=ax)
ax.set_title('$|W_z|^2$')
# Clean up the x-axis.
for ax in axes[:-1]:
ax.set_xlabel('')
ax.set_xticks([0, t_max])
Next, we can compute the (partial) wavelet coherence between x and y. The WaveletCoherence class can be used for this for both regular coherence (using the .wct() function) and partial coherence (using the pct() function).
# Init wavelet coherence object.
wcoh = WaveletCoherence(
# Parameters for compute_wavelet_coherence, these are mainly the parameters for cwt, but with some extra
# options relating to computing the coherence, see documentation of compute_wavelet_coherence().
cwt_kwargs=cwt_kwargs,
# Parameters for surrogate analysis.
# Set n_surrogates to 0 if no surrogate computations are desired.
# See also nnsa.stats.surrogates.compute_surrogate_fun().
surrogates=dict(n_surrogates=0))
# Compute regular wavelet coherence.
wct_result = wcoh.wct(x, y, fs, labels=['x', 'y'])
# Compute partial wavelet coherence.
pct_result = wcoh.pct(x, y, z, fs, labels=['xz', 'yz', 'z'])
Plot coherence. Note that the regular coherence shows a coupling due to z, whereas this coupling disappears when using partial wavelet coherence.
fig, axes = plt.subplots(
2, 2, tight_layout=True, sharex='all', sharey='all',
figsize=scale_figsize((9, 5), width=fig_width, unit='cm'),
)
ax = axes[0, 0]
wct_result.plot_scalogram(ax=ax)
ax.set_title('$|R_{xy}|^2$')
ax = axes[0, 1]
wct_result.plot_phase(ax=ax, add_colorbar=True, cmap='twilight')
ax.set_title('$arg(R_{xy})$')
ax = axes[1, 0]
pct_result.plot_scalogram(ax=ax)
ax.set_title('$|RP_{xy,z}|^2$')
ax = axes[1, 1]
pct_result.plot_phase(ax=ax, add_colorbar=True, cmap='twilight')
ax.set_title('$arg(RP_{xy,z})$')
# Remove axis labels in inner subplots.
for ax in np.reshape(axes[:-1, :], -1):
ax.set_xlabel('')
for ax in np.reshape(axes[:, 1:], -1):
ax.set_ylabel('')