Partial wavelet coherence ========================= This script demonstrates how to do partial wavelet coherence using the WaveletCoherence class from `nnsa`. It first shows PWC with one confounding variable, and then PWC with 2 confounding variables. Author: Tim Hermans (tim-hermans@hotmail.com). Link to script: `cwt/partial_wavelet_coherence.py `_ .. code-block:: python import numpy as np import matplotlib.pyplot as plt from nnsa import WaveletCoherence from nnsa.utils.plotting import save_fig_as, scale_figsize plt.close('all') fig_width = 18 Settings. .. code-block:: python # Data settings. np.random.seed(43) N = 1024 fs = 256 t = np.arange(N)/fs # Wavelet settings. cwt_kwargs = dict( wavelet='Morlet(6)', coimode='moi', ) # Optional surrogate settings. surrogates_kwargs = dict( how='AR', n_surrogates=0, seed=43, ) Create test data. .. code-block:: python # x and y. f_all = [4, 8, 16, 32, 64] a_all = np.array([1, 1, 1, 1, 1]) phi_all = np.array([-180, -90, 0, 90, 180])/180*np.pi x = np.random.rand(N) - 0.5 y = np.random.rand(N) - 0.5 for a, f, phi in zip(a_all, f_all, phi_all): x += a * np.sin(t*2*np.pi*f + phi) y += a * np.sin(t*2*np.pi*f) # Add a confounding signal z to x and y to 'overshadow' their true interaction. z = np.cumsum(5*(np.random.rand(N)-0.5)) snr = 0.5 xz = snr*x + z + np.random.rand(N) yz = snr*y + z + np.random.rand(N) Plot signals. .. code-block:: python fig, axes = plt.subplots(1, 3, tight_layout=True, sharex='all', figsize=scale_figsize((10, 3.5), width=fig_width, unit='cm'), ) axes[0].plot(t, x, label='x') axes[0].plot(t, y, label='y') axes[1].plot(t, z, color='C3', label='z') axes[2].plot(t, xz, label='x') axes[2].plot(t, yz, label='y') axes[0].set_title('Clean x and y') axes[0].legend() axes[1].set_title('Confounding z') axes[1].legend() axes[2].set_title('x and y + confounding z') axes[2].legend() .. figure:: ../examples/cwt/figs/partial_wavelet_coherence_1.png Compute (partial) wavelet coherence with one confounding variable. .. code-block:: python # Init wavelet coherence object. wcoh = WaveletCoherence(cwt_kwargs=cwt_kwargs, surrogates=surrogates_kwargs) # Wavelet coherence between original x and y (without confounding z). exp_result = wcoh.wct(x, y, fs, labels=['x', 'y']) # Wavelet coherence between disturbed x and y (by confounding z). wct_result = wcoh.wct(xz, yz, fs, labels=['x', 'y']) # Partial wavelet coherence between disturbed x and y, controlling for confounding z. pct_result = wcoh.pct(xz, yz, z, fs, labels=['x', 'y', 'z']) Plot wavelet results. .. code-block:: python fig, axes = plt.subplots(2, 3, layout="constrained", sharex='all', sharey='all', figsize=scale_figsize((10, 5), width=fig_width, unit='cm'), ) exp_result.plot_scalogram(ax=axes[0, 0], add_colorbar=False) exp_result.plot_phase(ax=axes[1, 0], add_colorbar=False, cmap='twilight', mask_non_significant=False) wct_result.plot_scalogram(ax=axes[0, 1], add_colorbar=False) wct_result.plot_phase(ax=axes[1, 1], add_colorbar=False, cmap='twilight', mask_non_significant=False) pct_result.plot_scalogram(ax=axes[0, 2], add_colorbar=True) pct_result.plot_phase(ax=axes[1, 2], add_colorbar=True, cmap='twilight', mask_non_significant=False) for ax in np.reshape([axes[:, 1:]], -1): ax.set_ylabel('') for ax in np.reshape([axes[:-1, :]], -1): ax.set_xlabel('') for ax in axes[1, :]: ax.set_title('') axes[0, 0].set_title('Expected\n(clean x and y)') axes[0, 1].set_title('WCT\n(x and y +\nconfounding z)') axes[0, 2].set_title('PCT\n(x and y +\nconfounding z)') .. figure:: ../examples/cwt/figs/partial_wavelet_coherence_2.png To compute partial wavelet coherence with two (or even more) confounding signals, we just use a list for the `z` argument in the pct() function. .. code-block:: python # Create extra confounding signal. z2 = np.cumsum(5*(np.random.rand(N)-0.5)) xz2 = snr*x + 0.5*(z + z2) + np.random.rand(N) yz2 = snr*y + 0.5*(z + z2) + np.random.rand(N) Plot signals. .. code-block:: python fig, axes = plt.subplots(1, 3, tight_layout=True, sharex='all', figsize=scale_figsize((10, 3.5), width=fig_width, unit='cm'), ) axes[0].plot(t, x, label='x') axes[0].plot(t, y, label='y') axes[1].plot(t, z, color='C3', label='z') axes[1].plot(t, z2, color='C3', ls='--', label='z2') axes[2].plot(t, xz2, label='x') axes[2].plot(t, yz2, label='y') axes[0].set_title('Clean x and y') axes[0].legend() axes[1].set_title('Confounding z and z2') axes[1].legend() axes[2].set_title('x and y + confounding z and z2') axes[2].legend() .. figure:: ../examples/cwt/figs/partial_wavelet_coherence_3.png Compute wavelet coherences. .. code-block:: python # Compute regular wavelet coherence (WCT). wct2_result = wcoh.wct(xz2, yz2, fs, labels=['xz', 'yz']) # Compute partial wavelet coherence (PCT). pct2_result = wcoh.pct(x=xz2, y=yz2, z=[z, z2], fs=fs, labels=['xz', 'yz', 'z', 'z2']) Plot wavelet results. .. code-block:: python fig, axes = plt.subplots(2, 3, layout="constrained", sharex='all', sharey='all', figsize=scale_figsize((10, 5), width=fig_width, unit='cm'), ) exp_result.plot_scalogram(ax=axes[0, 0], add_colorbar=False) exp_result.plot_phase(ax=axes[1, 0], add_colorbar=False, cmap='twilight') wct2_result.plot_scalogram(ax=axes[0, 1], add_colorbar=False) wct2_result.plot_phase(ax=axes[1, 1], add_colorbar=False, cmap='twilight') pct2_result.plot_scalogram(ax=axes[0, 2], add_colorbar=True) pct2_result.plot_phase(ax=axes[1, 2], add_colorbar=True, cmap='twilight') for ax in np.reshape([axes[:, 1:]], -1): ax.set_ylabel('') for ax in np.reshape([axes[:-1, :]], -1): ax.set_xlabel('') for ax in axes[1, :]: ax.set_title('') axes[0, 0].set_title('Expected\n(clean x and y)') axes[0, 1].set_title('WCT\n(x and y +\nconfounding z and z2)') axes[0, 2].set_title('PCT\n(x and y +\nconfounding z and z2)') .. figure:: ../examples/cwt/figs/partial_wavelet_coherence_4.png