Basic wavelet transforms ======================== Wavelet coherence using the continuous wavelet transform can be used to identify local correlations between two signals in time-frequency space. The `nnsa` package contains some useful code to do wavelet analysis. This script defines some test signals and computes CWT, wavelet coherence and partial coherence. Author: Tim Hermans (tim-hermans@hotmail.com). Link to script: `cwt/basic_wavelet_transforms.py `_ .. code-block:: python import numpy as np import matplotlib.pyplot as plt from nnsa import WaveletCoherence from nnsa.cwt import Morlet, plot_scalogram, cwt from nnsa.utils.plotting import scale_figsize, save_fig_as plt.close('all') fig_width = 18 Create three test signals. The two signals of interest are x and y. We also define a third signal z which is functioning as a confounding signal. It's confounding effect will be introduced by adding this signal to x and y. .. code-block:: python # Seed the random generator. rng = np.random.RandomState(40) # Some settings and precomputations. fs = 100 t_max = 10 N = fs*t_max t = np.arange(N)/fs t_rad = (t + 0*rng.normal(scale=1/fs, size=N))*2*np.pi # The confounding signal z (only active (non-zero) in the middle half). z = np.sin(t_rad * 4) * ((t > t_max/4) & (t < t_max*3/4)) # The signals x and y, which include the confounding z component. x = np.sin(t_rad * 16) + np.sin(t_rad * 2) - z y = np.sin(t_rad * 16) + np.sin(t_rad * 1 + np.pi) + z # Finally, add some random noise to each signal. noise_level = 1/10 x += rng.normal(scale=noise_level, size=N) y += rng.normal(scale=noise_level, size=N) z += rng.normal(scale=noise_level, size=N) Plot the signals. You should be able to see the confounding effect that z has on both x and y when it becomes active in the middle part. .. code-block:: python fig, ax = plt.subplots( 1, 1, tight_layout=True, figsize=scale_figsize((4, 9/3), width=fig_width, unit='cm'), sharex='all', ) spacing = 2 ax.plot(t, x, label='$x$') ax.plot(t, y - spacing, label='$y$') ax.plot(t, z - spacing*2, label='$z$') ax.legend(bbox_to_anchor=(1, 1), loc='upper left') ax.set_yticks([]) ax.set_ylabel('Signal (a.u.)') ax.set_title('Signals') .. figure:: ../examples/cwt/figs/basic_wavelet_transforms_1.png To analyze these signals with wavelets, we use the Morlet mother wavelet. .. code-block:: python # Create Morlet wavelet instance. wavelet = Morlet(6) # Plot the mother wavelet in time domain. fig = plt.figure( tight_layout=True, figsize=scale_figsize((4, 3), width=0.6*fig_width, unit='cm'), ) tau = np.linspace(-4, 4, 200) w = wavelet.psi(tau) plt.plot(tau, np.abs(w), color='C7', label='Magnitude') plt.plot(tau, np.real(w), label='Real', color='C0') plt.plot(tau, np.imag(w), linestyle='--', color='C0', label='Imaginary') plt.xlabel('Normalized time') plt.legend(loc='upper left', bbox_to_anchor=(0.65, 1)) .. figure:: ../examples/cwt/figs/basic_wavelet_transforms_2.png Define settings for the continuous wavelet transform (see the documentation of the cwt function for more options). .. code-block:: python cwt_kwargs = dict( # Mother wavelet. wavelet=Morlet(6), # Sampling interval (seconds). dt=1 / fs, # Spacing between discrete scales. Smaller values will result in better # scale resolution, but slower calculation and plot. dj=1 / 10, ) Compute CWTs. .. code-block:: python Wx, scales, freqs, insidecoi = cwt(x, **cwt_kwargs) Wy, scales, freqs, insidecoi = cwt(y, **cwt_kwargs) Wz, scales, freqs, insidecoi = cwt(z, **cwt_kwargs) # Convert wavelet coefficients to power. Note that the power is scaled by the inverse of the scale, # see Liu et al. 2007 (https://doi.org/10.1175/2007JTECHO511.1). Px = 1/scales.reshape(-1, 1) * np.abs(Wx)**2 Py = 1/scales.reshape(-1, 1) * np.abs(Wy)**2 Pz = 1/scales.reshape(-1, 1) * np.abs(Wz)**2 # Normalize. Px /= np.max(Px) Py /= np.max(Py) Pz /= np.max(Pz) Plot the wavelet power of each signal. .. code-block:: python fig, axes = plt.subplots( 4, 1, tight_layout=True, figsize=scale_figsize((4, 9), width=0.6*fig_width, unit='cm'), sharex='all', ) ax = axes[0] spacing = 2 ax.plot(t, x, label='$x$') ax.plot(t, y - spacing, label='$y$') ax.plot(t, z - spacing*2, label='$z$') ax.legend(bbox_to_anchor=(1, 1), loc='upper left') ax.set_yticks([]) ax.set_ylabel('Signal (a.u.)') ax.set_title('Signals') ax = axes[1] plot_scalogram(t, freqs, Px, insidecoi=insidecoi, ax=ax) ax.set_title('$|W_x|^2$') ax = axes[2] plot_scalogram(t, freqs, Py, insidecoi=insidecoi, ax=ax) ax.set_title('$|W_y|^2$') ax = axes[3] plot_scalogram(t, freqs, Pz, insidecoi=insidecoi, ax=ax) ax.set_title('$|W_z|^2$') # Clean up the x-axis. for ax in axes[:-1]: ax.set_xlabel('') ax.set_xticks([0, t_max]) .. figure:: ../examples/cwt/figs/basic_wavelet_transforms_3.png Next, we can compute the (partial) wavelet coherence between x and y. The WaveletCoherence class can be used for this for both regular coherence (using the .wct() function) and partial coherence (using the pct() function). .. code-block:: python # Init wavelet coherence object. wcoh = WaveletCoherence( # Parameters for compute_wavelet_coherence, these are mainly the parameters for cwt, but with some extra # options relating to computing the coherence, see documentation of compute_wavelet_coherence(). cwt_kwargs=cwt_kwargs, # Parameters for surrogate analysis. # Set n_surrogates to 0 if no surrogate computations are desired. # See also nnsa.stats.surrogates.compute_surrogate_fun(). surrogates=dict(n_surrogates=0)) # Compute regular wavelet coherence. wct_result = wcoh.wct(x, y, fs, labels=['x', 'y']) # Compute partial wavelet coherence. pct_result = wcoh.pct(x, y, z, fs, labels=['xz', 'yz', 'z']) Plot coherence. Note that the regular coherence shows a coupling due to z, whereas this coupling disappears when using partial wavelet coherence. .. code-block:: python fig, axes = plt.subplots( 2, 2, tight_layout=True, sharex='all', sharey='all', figsize=scale_figsize((9, 5), width=fig_width, unit='cm'), ) ax = axes[0, 0] wct_result.plot_scalogram(ax=ax) ax.set_title('$|R_{xy}|^2$') ax = axes[0, 1] wct_result.plot_phase(ax=ax, add_colorbar=True, cmap='twilight') ax.set_title('$arg(R_{xy})$') ax = axes[1, 0] pct_result.plot_scalogram(ax=ax) ax.set_title('$|RP_{xy,z}|^2$') ax = axes[1, 1] pct_result.plot_phase(ax=ax, add_colorbar=True, cmap='twilight') ax.set_title('$arg(RP_{xy,z})$') # Remove axis labels in inner subplots. for ax in np.reshape(axes[:-1, :], -1): ax.set_xlabel('') for ax in np.reshape(axes[:, 1:], -1): ax.set_ylabel('') .. figure:: ../examples/cwt/figs/basic_wavelet_transforms_4.png