Partial wavelet coherence

This script demonstrates how to do partial wavelet coherence using the WaveletCoherence class from nnsa. It first shows PWC with one confounding variable, and then PWC with 2 confounding variables.

Author: Tim Hermans (tim-hermans@hotmail.com).

Link to script: cwt/partial_wavelet_coherence.py

import numpy as np
import matplotlib.pyplot as plt

from nnsa import WaveletCoherence
from nnsa.utils.plotting import save_fig_as, scale_figsize

plt.close('all')
fig_width = 18

Settings.

# Data settings.
np.random.seed(43)
N = 1024
fs = 256
t = np.arange(N)/fs

# Wavelet settings.
cwt_kwargs = dict(
    wavelet='Morlet(6)',
    coimode='moi',
)

# Optional surrogate settings.
surrogates_kwargs = dict(
    how='AR',
    n_surrogates=0,
    seed=43,
)

Create test data.

# x and y.
f_all = [4, 8, 16, 32, 64]
a_all = np.array([1, 1, 1, 1, 1])
phi_all = np.array([-180, -90, 0, 90, 180])/180*np.pi
x = np.random.rand(N) - 0.5
y = np.random.rand(N) - 0.5
for a, f, phi in zip(a_all, f_all, phi_all):
    x += a * np.sin(t*2*np.pi*f + phi)
    y += a * np.sin(t*2*np.pi*f)

# Add a confounding signal z to x and y to 'overshadow' their true interaction.
z = np.cumsum(5*(np.random.rand(N)-0.5))
snr = 0.5
xz = snr*x + z + np.random.rand(N)
yz = snr*y + z + np.random.rand(N)

Plot signals.

fig, axes = plt.subplots(1, 3, tight_layout=True, sharex='all',
                         figsize=scale_figsize((10, 3.5), width=fig_width, unit='cm'),
                         )
axes[0].plot(t, x, label='x')
axes[0].plot(t, y, label='y')
axes[1].plot(t, z, color='C3', label='z')
axes[2].plot(t, xz, label='x')
axes[2].plot(t, yz, label='y')

axes[0].set_title('Clean x and y')
axes[0].legend()
axes[1].set_title('Confounding z')
axes[1].legend()
axes[2].set_title('x and y + confounding z')
axes[2].legend()
_images/partial_wavelet_coherence_1.png

Compute (partial) wavelet coherence with one confounding variable.

# Init wavelet coherence object.
wcoh = WaveletCoherence(cwt_kwargs=cwt_kwargs,
                        surrogates=surrogates_kwargs)

# Wavelet coherence between original x and y (without confounding z).
exp_result = wcoh.wct(x, y, fs, labels=['x', 'y'])

# Wavelet coherence between disturbed x and y (by confounding z).
wct_result = wcoh.wct(xz, yz, fs, labels=['x', 'y'])

# Partial wavelet coherence between disturbed x and y, controlling for confounding z.
pct_result = wcoh.pct(xz, yz, z, fs, labels=['x', 'y', 'z'])

Plot wavelet results.

fig, axes = plt.subplots(2, 3, layout="constrained", sharex='all', sharey='all',
                         figsize=scale_figsize((10, 5), width=fig_width, unit='cm'),
                         )
exp_result.plot_scalogram(ax=axes[0, 0], add_colorbar=False)
exp_result.plot_phase(ax=axes[1, 0], add_colorbar=False, cmap='twilight', mask_non_significant=False)
wct_result.plot_scalogram(ax=axes[0, 1], add_colorbar=False)
wct_result.plot_phase(ax=axes[1, 1], add_colorbar=False, cmap='twilight', mask_non_significant=False)
pct_result.plot_scalogram(ax=axes[0, 2], add_colorbar=True)
pct_result.plot_phase(ax=axes[1, 2], add_colorbar=True, cmap='twilight', mask_non_significant=False)

for ax in np.reshape([axes[:, 1:]], -1):
    ax.set_ylabel('')
for ax in np.reshape([axes[:-1, :]], -1):
    ax.set_xlabel('')
for ax in axes[1, :]:
    ax.set_title('')

axes[0, 0].set_title('Expected\n(clean x and y)')
axes[0, 1].set_title('WCT\n(x and y +\nconfounding z)')
axes[0, 2].set_title('PCT\n(x and y +\nconfounding z)')
_images/partial_wavelet_coherence_2.png

To compute partial wavelet coherence with two (or even more) confounding signals, we just use a list for the z argument in the pct() function.

# Create extra confounding signal.
z2 = np.cumsum(5*(np.random.rand(N)-0.5))
xz2 = snr*x + 0.5*(z + z2) + np.random.rand(N)
yz2 = snr*y + 0.5*(z + z2) + np.random.rand(N)

Plot signals.

fig, axes = plt.subplots(1, 3, tight_layout=True, sharex='all',
                         figsize=scale_figsize((10, 3.5), width=fig_width, unit='cm'),
                         )
axes[0].plot(t, x, label='x')
axes[0].plot(t, y, label='y')
axes[1].plot(t, z, color='C3', label='z')
axes[1].plot(t, z2, color='C3', ls='--', label='z2')
axes[2].plot(t, xz2, label='x')
axes[2].plot(t, yz2, label='y')

axes[0].set_title('Clean x and y')
axes[0].legend()
axes[1].set_title('Confounding z and z2')
axes[1].legend()
axes[2].set_title('x and y + confounding z and z2')
axes[2].legend()
_images/partial_wavelet_coherence_3.png

Compute wavelet coherences.

# Compute regular wavelet coherence (WCT).
wct2_result = wcoh.wct(xz2, yz2, fs, labels=['xz', 'yz'])

# Compute partial wavelet coherence (PCT).
pct2_result = wcoh.pct(x=xz2, y=yz2, z=[z, z2], fs=fs, labels=['xz', 'yz', 'z', 'z2'])

Plot wavelet results.

fig, axes = plt.subplots(2, 3, layout="constrained", sharex='all', sharey='all',
                         figsize=scale_figsize((10, 5), width=fig_width, unit='cm'),
                         )
exp_result.plot_scalogram(ax=axes[0, 0], add_colorbar=False)
exp_result.plot_phase(ax=axes[1, 0], add_colorbar=False, cmap='twilight')
wct2_result.plot_scalogram(ax=axes[0, 1], add_colorbar=False)
wct2_result.plot_phase(ax=axes[1, 1], add_colorbar=False, cmap='twilight')
pct2_result.plot_scalogram(ax=axes[0, 2], add_colorbar=True)
pct2_result.plot_phase(ax=axes[1, 2], add_colorbar=True, cmap='twilight')

for ax in np.reshape([axes[:, 1:]], -1):
    ax.set_ylabel('')
for ax in np.reshape([axes[:-1, :]], -1):
    ax.set_xlabel('')
for ax in axes[1, :]:
    ax.set_title('')

axes[0, 0].set_title('Expected\n(clean x and y)')
axes[0, 1].set_title('WCT\n(x and y +\nconfounding z and z2)')
axes[0, 2].set_title('PCT\n(x and y +\nconfounding z and z2)')
_images/partial_wavelet_coherence_4.png