Basic wavelet transforms

Wavelet coherence using the continuous wavelet transform can be used to identify local correlations between two signals in time-frequency space. The nnsa package contains some useful code to do wavelet analysis. This script defines some test signals and computes CWT, wavelet coherence and partial coherence.

Author: Tim Hermans (tim-hermans@hotmail.com).

Link to script: cwt/basic_wavelet_transforms.py

import numpy as np
import matplotlib.pyplot as plt

from nnsa import WaveletCoherence
from nnsa.cwt import Morlet, plot_scalogram, cwt
from nnsa.utils.plotting import scale_figsize, save_fig_as

plt.close('all')
fig_width = 18

Create three test signals. The two signals of interest are x and y. We also define a third signal z which is functioning as a confounding signal. It’s confounding effect will be introduced by adding this signal to x and y.

# Seed the random generator.
rng = np.random.RandomState(40)

# Some settings and precomputations.
fs = 100
t_max = 10
N = fs*t_max
t = np.arange(N)/fs
t_rad = (t + 0*rng.normal(scale=1/fs, size=N))*2*np.pi

# The confounding signal z (only active (non-zero) in the middle half).
z = np.sin(t_rad * 4) * ((t > t_max/4) & (t < t_max*3/4))

# The signals x and y, which include the confounding z component.
x = np.sin(t_rad * 16) + np.sin(t_rad * 2) - z
y = np.sin(t_rad * 16) + np.sin(t_rad * 1 + np.pi) + z

# Finally, add some random noise to each signal.
noise_level = 1/10
x += rng.normal(scale=noise_level, size=N)
y += rng.normal(scale=noise_level, size=N)
z += rng.normal(scale=noise_level, size=N)

Plot the signals. You should be able to see the confounding effect that z has on both x and y when it becomes active in the middle part.

fig, ax = plt.subplots(
    1, 1, tight_layout=True,
    figsize=scale_figsize((4, 9/3), width=fig_width, unit='cm'),
    sharex='all',
)
spacing = 2
ax.plot(t, x, label='$x$')
ax.plot(t, y - spacing, label='$y$')
ax.plot(t, z - spacing*2, label='$z$')
ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
ax.set_yticks([])
ax.set_ylabel('Signal (a.u.)')
ax.set_title('Signals')
_images/basic_wavelet_transforms_1.png

To analyze these signals with wavelets, we use the Morlet mother wavelet.

# Create Morlet wavelet instance.
wavelet = Morlet(6)

# Plot the mother wavelet in time domain.
fig = plt.figure(
    tight_layout=True,
    figsize=scale_figsize((4, 3), width=0.6*fig_width, unit='cm'),
)
tau = np.linspace(-4, 4, 200)
w = wavelet.psi(tau)
plt.plot(tau, np.abs(w), color='C7', label='Magnitude')
plt.plot(tau, np.real(w), label='Real', color='C0')
plt.plot(tau, np.imag(w), linestyle='--', color='C0', label='Imaginary')
plt.xlabel('Normalized time')
plt.legend(loc='upper left', bbox_to_anchor=(0.65, 1))
_images/basic_wavelet_transforms_2.png

Define settings for the continuous wavelet transform (see the documentation of the cwt function for more options).

cwt_kwargs = dict(
    # Mother wavelet.
    wavelet=Morlet(6),
    # Sampling interval (seconds).
    dt=1 / fs,
    # Spacing between discrete scales. Smaller values will result in better
    # scale resolution, but slower calculation and plot.
    dj=1 / 10,
)

Compute CWTs.

Wx, scales, freqs, insidecoi = cwt(x, **cwt_kwargs)
Wy, scales, freqs, insidecoi = cwt(y, **cwt_kwargs)
Wz, scales, freqs, insidecoi = cwt(z, **cwt_kwargs)

# Convert wavelet coefficients to power. Note that the power is scaled by the inverse of the scale,
# see Liu et al. 2007 (https://doi.org/10.1175/2007JTECHO511.1).
Px = 1/scales.reshape(-1, 1) * np.abs(Wx)**2
Py = 1/scales.reshape(-1, 1) * np.abs(Wy)**2
Pz = 1/scales.reshape(-1, 1) * np.abs(Wz)**2

# Normalize.
Px /= np.max(Px)
Py /= np.max(Py)
Pz /= np.max(Pz)

Plot the wavelet power of each signal.

fig, axes = plt.subplots(
    4, 1, tight_layout=True,
    figsize=scale_figsize((4, 9), width=0.6*fig_width, unit='cm'),
    sharex='all',
)
ax = axes[0]
spacing = 2

ax.plot(t, x, label='$x$')
ax.plot(t, y - spacing, label='$y$')
ax.plot(t, z - spacing*2, label='$z$')
ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
ax.set_yticks([])
ax.set_ylabel('Signal (a.u.)')
ax.set_title('Signals')

ax = axes[1]
plot_scalogram(t, freqs, Px, insidecoi=insidecoi, ax=ax)
ax.set_title('$|W_x|^2$')

ax = axes[2]
plot_scalogram(t, freqs, Py, insidecoi=insidecoi, ax=ax)
ax.set_title('$|W_y|^2$')

ax = axes[3]
plot_scalogram(t, freqs, Pz, insidecoi=insidecoi, ax=ax)
ax.set_title('$|W_z|^2$')

# Clean up the x-axis.
for ax in axes[:-1]:
    ax.set_xlabel('')
ax.set_xticks([0, t_max])
_images/basic_wavelet_transforms_3.png

Next, we can compute the (partial) wavelet coherence between x and y. The WaveletCoherence class can be used for this for both regular coherence (using the .wct() function) and partial coherence (using the pct() function).

# Init wavelet coherence object.
wcoh = WaveletCoherence(
    # Parameters for compute_wavelet_coherence, these are mainly the parameters for cwt, but with some extra
    # options relating to computing the coherence, see documentation of compute_wavelet_coherence().
    cwt_kwargs=cwt_kwargs,

    # Parameters for surrogate analysis.
    # Set n_surrogates to 0 if no surrogate computations are desired.
    # See also nnsa.stats.surrogates.compute_surrogate_fun().
    surrogates=dict(n_surrogates=0))

# Compute regular wavelet coherence.
wct_result = wcoh.wct(x, y, fs, labels=['x', 'y'])

# Compute partial wavelet coherence.
pct_result = wcoh.pct(x, y, z, fs, labels=['xz', 'yz', 'z'])

Plot coherence. Note that the regular coherence shows a coupling due to z, whereas this coupling disappears when using partial wavelet coherence.

fig, axes = plt.subplots(
    2, 2, tight_layout=True, sharex='all', sharey='all',
    figsize=scale_figsize((9, 5), width=fig_width, unit='cm'),
)
ax = axes[0, 0]
wct_result.plot_scalogram(ax=ax)
ax.set_title('$|R_{xy}|^2$')

ax = axes[0, 1]
wct_result.plot_phase(ax=ax, add_colorbar=True, cmap='twilight')
ax.set_title('$arg(R_{xy})$')

ax = axes[1, 0]
pct_result.plot_scalogram(ax=ax)
ax.set_title('$|RP_{xy,z}|^2$')

ax = axes[1, 1]
pct_result.plot_phase(ax=ax, add_colorbar=True, cmap='twilight')
ax.set_title('$arg(RP_{xy,z})$')

# Remove axis labels in inner subplots.
for ax in np.reshape(axes[:-1, :], -1):
    ax.set_xlabel('')
for ax in np.reshape(axes[:, 1:], -1):
    ax.set_ylabel('')
_images/basic_wavelet_transforms_4.png