Source code for nnsa.cwt.transforms

import sys
from functools import partial

import numpy as np
import pycwt
from pycwt.wavelet import _check_parameter_wavelet

from nnsa.cwt.config import DEFAULT_WAVELET
from nnsa.cwt.mothers import Morlet, get_wavelet
from nnsa.preprocessing.detrending import detrend_poly
from nnsa.utils.arrays import interp_nan, slice_along_axis

__all__ = [
    'compute_partial_coherence',
    'compute_wavelet_coherence',
    'cwt',
    'cwt_2d',
    'icwt',
    'pct',
    'wct',
    'xwt_smooth',
]


[docs]def compute_partial_coherence(x, y, z, dt, dj=1/10, s0=-1, J=-1, wavelet=None, detrend_order=2, normalize=True, remove_outliers=False, check_reconstruction=False, nan_policy='zeros', coimode='moi', moi_threshold=0.07864960, verbose=1, **kwargs): """ Compute partial wavelet coherence (coherence between x and y, while removing influence from z). Args: x (array-like): 1D signal array with length N. y (array_like): 1D signal array. Must have the same length and sample frequency as `x`. z (array_like or list): 1D signal array. Must have the same length and sample frequency as `x`. Can also be a list of 1D signal arrays, in which case the confounding relation from each of these signals will be removed. dt (float): sample period of x, y and z in seconds (1/fs). dj (float, optional): 2log spacing between discrete scales. Smaller values will result in better scale resolution, but slower calculation. See pycwt.cwt. Defaults to 1/10. s0 (float, optional): smallest scale (highest frequency). See pycwt.cwt. J (int, optional): number of scales - 1. See pycwt.cwt. Determines lowest frequency together with s0 and ds. wavelet (Mother, optional): mother wavelet (from pycwt.mothers). If None, defaults to DEFAULT_WAVELET. detrend_order (int, optional): if not None, detrend the signals with polynomial of specified order. normalize (bool, optional): bool specifying whether to normalize (zero mean, unit variance) signals `x`, `y` and `z` before computation or not. Defaults to True. ...: see cwt() for other parameters. **kwargs: optional parameters for pct(). Returns: Cxy (np.ma.MaskedArray): 2D masked array with wavelet based magnitude squared coherence, where values inside the cone of influence (COI) are masked. Shape corresponds to (freqs, time). Axy (np.ma.MaskedArray): 2D masked array with phase angles of the cross-spectrum, where values inside the cone of influence (COI) are masked. Shape corresponds to (freqs, time). freqs (np.ndarray): frequencies in Hz corresponding to the first axis of Cxy and Axy. insidecoi (np.ndarray): masks for the pixels in Cxy affected by edge effects/artefacts (dependent on coimode). Dimensions are (n_freqs, n_time). """ x = np.asarray(x).squeeze() y = np.asarray(y).squeeze() if isinstance(z, (tuple, list)): # z are multiple signals, collect to list. z_all = [np.asarray(zi).squeeze() for zi in z] elif np.asarray(z).squeeze().shape == x.shape: # z is a single signal. z_all = [np.asarray(z).squeeze()] else: raise TypeError('`z` should be an array or a list of arrays. Got a {}.'.format(type(z))) # All signals. sig_all = [x, y] + z_all # Check dimensions. if any([sig.ndim != 1 for sig in sig_all]): raise ValueError('Input signals should be 1-dimensional. Got signals with shapes {}.'.format( [sig.shape for sig in sig_all])) # Length of all signals. n_all = [sig.shape[-1] for sig in sig_all] mean_n = np.mean(n_all) min_n = np.min(n_all) # Check inputs. if np.max(np.abs(n_all - mean_n)/mean_n) > 0.001: raise ValueError('All signals must have the same length. Got signals with lengths {}.'.format( n_all)) # Match shapes. sig_all = [slice_along_axis(sig, axis=-1, stop=min_n) for sig in sig_all] # Collect CWT kwargs. wavelet = DEFAULT_WAVELET if wavelet is None else get_wavelet(wavelet) cwt_kwargs = dict(dt=dt, dj=dj, wavelet=wavelet, s0=s0, J=J, normalize=normalize, remove_outliers=remove_outliers, detrend_order=detrend_order, check_reconstruction=check_reconstruction, nan_policy=nan_policy, coimode=coimode, moi_threshold=moi_threshold, verbose=verbose) # Compute CWTs of all signals. W_all = [] coi_all = [] for sig in sig_all: Wi, scales, freqs, coii = cwt(sig, **cwt_kwargs) W_all.append(Wi) coi_all.append(coii) # Compute partial coherence and phase. Cxy, Axy = pct(Wx=W_all[0], Wy=W_all[1], Wz=W_all[2:], scales=scales, dt=dt, dj=dj, wavelet=wavelet, **kwargs) # Determine insidecoi. insidecoi = np.any(np.dstack(coi_all), axis=2) return Cxy, Axy, freqs, insidecoi
[docs]def compute_wavelet_coherence( x, y, dt, dj=1 / 10, s0=-1, J=-1, wavelet=None, detrend_order=2, normalize=True, remove_outliers=False, check_reconstruction=False, nan_policy='zeros', coimode='moi', moi_threshold=0.07864960, verbose=1, **kwargs): """ Compute wavelet-based coherence. Adopted from the wct function of the pycwt package (https://pypi.org/project/pycwt/). Args: x (array-like): signal array with length N. y (array_like): signal array. Must have the same length and sample frequency as `x`. dt (float): sample period of x and y in seconds (1/fs). dj (float, optional): 2log spacing between discrete scales. Smaller values will result in better scale resolution, but slower calculation. See pycwt.cwt. Defaults to 1/12. s0 (float, optional): smallest scale (highest frequency). See pycwt.cwt. J (int, optional): number of scales - 1. See pycwt.cwt. Determines lowest frequency together with s0 and ds. wavelet (Mother, optional): mother wavelet (from pycwt.mothers). If None, defaults to DEFAULT_WAVELET. detrend_order (int, optional): if not None, detrend the signals with polynomial of specified order. normalize (bool, optional): bool specifying whether to normalize (zero mean, unit variance) signals `x` and `y` before computation or not. Defaults to True. ...: see cwt() for other parameters. **kwargs: optional parameters for wct(). Returns: Cxy (np.ma.MaskedArray): 2D masked array with wavelet based magnitude squared coherence, where values inside the cone of influence (COI) are masked. Shape corresponds to (freqs, time). Axy (np.ma.MaskedArray): 2D masked array with phase angles of the cross-spectrum, where values inside the cone of influence (COI) are masked. Shape corresponds to (freqs, time). freqs (np.ndarray): frequencies in Hz corresponding to the first axis of Cxy and Axy. insidecoi (np.ndarray): masks for the pixels in Cxy affected by edge effects/artefacts (dependent on coimode). Dimensions are (n_freqs, n_time). """ x = np.asarray(x).squeeze() y = np.asarray(y).squeeze() nx = x.shape[-1] ny = y.shape[-1] # Check inputs. if np.abs(nx - ny)/nx > 0.001: raise ValueError('`x` and `y` must have the same time dimension. Got shapes {} and {}.'.format( x.shape, y.shape)) # Match shapes. min_len = min([nx, ny]) x = slice_along_axis(x, axis=-1, stop=min_len) y = slice_along_axis(y, axis=-1, stop=min_len) # Collect CWT kwargs. wavelet = DEFAULT_WAVELET if wavelet is None else get_wavelet(wavelet) cwt_kwargs = dict(dt=dt, dj=dj, wavelet=wavelet, s0=s0, J=J, normalize=normalize, remove_outliers=remove_outliers, detrend_order=detrend_order, check_reconstruction=check_reconstruction, nan_policy=nan_policy, coimode=coimode, moi_threshold=moi_threshold, verbose=verbose) # Compute CWTs of x and y. Wx, scales, freqs, coix = cwt(x, **cwt_kwargs) Wy, scales, freqs, coiy = cwt(y, **cwt_kwargs) # Compute coherence and phase. Cxy, Axy = wct(Wx, Wy, scales=scales, dt=dt, dj=dj, wavelet=wavelet, **kwargs) # Determine insidecoi. insidecoi = coix | coiy return Cxy, Axy, freqs, insidecoi
[docs]def cwt(x, dt, dj=1/10, s0=-1, J=-1, wavelet=None, detrend_order=2, normalize=True, remove_outliers=False, check_reconstruction=False, nan_policy='zeros', coimode='moi', moi_threshold=0.07864960, verbose=1): """ Wrapper of pycwt.cwt, which adds interpolation of nans, and optional detrending and normalization prior to CWT. Args: x (np.ndarray): 1D array. dt (float): sampling interval (seconds). dj (float): spacing between discrete scales. Smaller values will result in better scale resolution, but slower calculation and plot. s0 (float): smallest scale of the wavelet. Default value is 2*dt. J (float): Number of scales less one. Scales range from s0 up to s0 * 2**(J * dj), which gives a total of (J + 1) scales. Default is J = (log2(N * dt / so)) / dj. wavelet (str or instance of a wavelet class): mother wavelet.default is Morlet wavelet. detrend_order (int): order for polynomial detrending. If None, does not perform any detrending. normalize (bool): normalize (center mean, unit std) prior to CWT computation. remove_outliers (bool): if normalize is True, removes outliers (mean +- 4*SD) and then computes the mean and SD for normalization. check_reconstruction (bool): check reconstruction error. If above 5%, raises an error. nan_policy (str): how to handle nans. Choose from: - 'zeros': replace with zeros. - 'interp': interpolate (linearly). - 'error': raise error if there are any nans in the input. coimode (str, optional): the mask in which the output may be considered significantly affected by edge effects or artefacts. Choose from: - 'coi': mask where edge effects are important (Cone Of Influence as defined by Torrence et al. 1998). - 'moi': mask where edge effects or artefacts are important (Mask Of Influence, see thesis Tim Hermans section 4.2.3). moi_threshold (float, optional): if coimode == 'moi', this can be used to define the threshold for MOI (see compute_moi()). verbose (int): verbosity level. Return: Wx (np.ndarray): array with dimensions (scales, time) containing the wavelet coefficients. scales (np.ndarray): scales corresponding to Wx. freqs (np.ndarray): Fourier frequencies corresponding to the scales. insidecoi (np.ndarray): boolean array with same size as Wx, containing True if the value is inside the cone of influence (COI). """ # Check inputs. x = np.asarray(x) if x.ndim != 1: raise ValueError('`x` must be 1-dimensional. Got shape {}.'.format(x.shape)) if wavelet is None: wavelet = DEFAULT_WAVELET # Preprocess: detrend and normalize. x_old = x.copy() if detrend_order is not None: if verbose > 0: print('Detrending with a polynomial of order {}...'.format(detrend_order)) x = detrend_poly(x, order=detrend_order) poly = x_old - x # For reconstruction. else: poly = 0 if normalize: if verbose > 0: print('Normalization...') x_mean = np.nanmean(x) x_std = np.nanstd(x) if remove_outliers: std_cut_off = 5 x_n = x[np.logical_and(x >= (x_mean - std_cut_off*x_std), x <= (x_mean + std_cut_off*x_std))] print('Removing {} % outliers for normalization...'.format((len(x) - len(x_n))/len(x)*100)) x_mean = np.nanmean(x_n) x_std = np.nanstd(x_n) x = (x - x_mean) / x_std else: x_mean, x_std = None, None # Replace nans or raise error. nan_mask = np.isnan(x) if nan_policy in ['interp', 'interpolate']: x = interp_nan(x) elif nan_policy == 'zeros': x[nan_mask] = 0 elif nan_policy == 'error': if np.any(nan_mask): raise ValueError('Nans in input while `nan_policy` is "error".') else: raise ValueError('Invalid value for `nan_policy`.') if s0 == -1: s0 = wavelet.get_min_scale(dt) if J == -1: fmin = wavelet.get_min_freq(dt, x.size) J = wavelet.get_J(s0, dj, fmin) # CWT computation. if verbose > 0: print('CWT computation...') Wx, scales, freqs, coi, _, _ = pycwt.cwt(x, dt, dj=dj, s0=s0, J=J, wavelet=wavelet) if check_reconstruction: from nnsa.cwt.utils import autoscale_rec x_rec = icwt(Wx, scales, dt, dj, wavelet) x_rec = autoscale_rec(x_rec, x) # Find best scaling. # Undo normalization and detrending. if x_mean is not None: x_rec *= x_std x_rec += x_mean x_rec += poly # Add back polynomial. error = ((x_old - x_rec) ** 2).mean() / x_old.var() * 100 if verbose > 0: print('Reconstruction error = {:.2f}%'.format(error)) if error > 5: if check_reconstruction: raise AssertionError('Reconstruction error too large ({:.2f} %)! ' 'Set check_reconstruction to False or improve the wavelet parameters.'.format(error)) coimode = coimode.lower() if coimode is not None else None if coimode is None: # Do not compute mask. insidecoi = np.full(Wx.shape, fill_value=False) elif coimode == 'coi': # Create mask for Cone Of Influence, i.e. values inside the cone of influence suffering edge effects . n0 = len(x) period = np.ones([1, n0]) / freqs[:, None] coi = np.ones([len(scales), 1]) * coi[None, :] insidecoi = (period > coi) elif 'moi' in coimode: from nnsa.cwt.utils import compute_moi # Create Mask Of Influence, i.e. values suffering edge or artefact effects. ignore_short_nans = not 'full' in coimode insidecoi = compute_moi(nan_mask=nan_mask, scales=scales, dt=dt, dj=dj, wavelet=wavelet, threshold=moi_threshold, ignore_short_nans=ignore_short_nans) if 'cor' in coimode and nan_policy == 'zeros': # Compute moi (without edge effects). moi = wavelet.smooth(nan_mask.reshape(1, -1), dt=dt, dj=dj, scales=scales, time=True, scale=False, time_window=1) # Correct amplitude (amplify where zeros were inserted, proportionally). amplification = 1 / (1 - moi) # Limit amplification/correction factor to 2. amplification = np.clip(amplification, 1, 2) # DO not amplify anything in MOI. # amplification[insidecoi] = 1 Wx *= amplification else: raise ValueError('Invalid value for `coimode` "{}". Choose from {}.'.format( coimode, ['coi', 'moi', None])) return Wx, scales, freqs, insidecoi
[docs]def cwt_2d(x, **kwargs): """ Compute cwt for 1D or 2D array. Args: x (np.ndarray): either 1D or 2D array. If 2D, should have shape (channels, time). Returns: W (np.ndarray): either 2D (if x is 1D) or 3D (if x is 2D) """ if x.ndim == 1: W, scales, freqs, insidecoi = cwt(x, **kwargs) elif x.ndim == 2 and len(x) > 0: W = [] insidecoi = [] for xi in x: Wi, scales, freqs, insidecoii = cwt(xi, **kwargs) W.append(Wi) insidecoi.append(insidecoii) W = np.dstack(W).squeeze() insidecoi = np.all(np.dstack(insidecoi), axis=-1) # If all channels are inside COI, return True. else: raise ValueError('Input must be a 1 or 2 dimensional array. Got shape {}.'.format(x.shape)) return W, scales, freqs, insidecoi
[docs]def icwt(W, scales, dt, dj, wavelet, cdelta=None): """ Inverse wavelet transform. Note that the implementation in pycwt is NOT correct! See Torrence and Compo (1998), eq. (11). This function implements it correctly. Note that possible effects of detrending of the signal prior to CWT are not reconstructed. """ wavelet = _check_parameter_wavelet(wavelet) if cdelta is None: cdelta = wavelet.cdelta # As of Torrence and Compo (1998), eq. (11) iW = dj * np.sqrt(dt) / (cdelta * wavelet.psi(0)) \ * (np.real(W) / np.sqrt(scales).reshape(-1, 1)).sum(axis=0) return np.real(iW)
[docs]def pct(Wx, Wy, Wz, scales, dt, dj, wavelet, smooth_kwargs=None): """ Compute partial coherence between Wx and Wy, while controlling for influence of variable(s) Wz. References: X. Meng, “The time-frequency dependence of unemployment on real input prices: a wavelet coherency and partial coherency approach,” Applied Economics, vol. 52, no. 10, pp. 1124–1140, Sep. 2019. Args: Wx (np.ndarray): 2D array with wavelet coefficients for x (scales, time). Wy (np.ndarray): wavelet coefficients for y (scales, time). Wz (np.ndarray or list): wavelet coefficients of all the confounding variable (scales, time). Can also be a list when there is more than 1 confounding variables. Each element in the list should then contain the wavelet coefficients of one of the confounding variables. scales, dt, dj, wavelet: cwt parameters/output. smooth_kwargs: dict with optional parameters for wavelet.smooth(). Returns: Cxyz (np.ndarray): partial coherence values (between 0 and 1). Axyz (np.ndarray): phase of complex partial coherence. """ if isinstance(Wz, np.ndarray): # Only one confounding variable. Wz = [Wz] elif isinstance(Wz, tuple): # Multiple confounding variables (passed as tuple, convert to list). Wz = list(tuple) # Verify Wz is a list. if not isinstance(Wz, list): raise TypeError('`Wz` should be a numpy array or a list. Got a {}.'.format(type(Wz))) # Collect common kwargs. pct_kwargs = dict(scales=scales, dt=dt, dj=dj, wavelet=wavelet, smooth_kwargs=smooth_kwargs) # Run the corresponding function depending on the number of confounding variables. if len(Wz) == 1: # Only one confounding variable. Cxyz, Axyz = _pct_1(Wx=Wx, Wy=Wy, Wz=Wz[0], **pct_kwargs) elif len(Wz) > 1: # 2 or more. Cxyz, Axyz = _pct_n(Wx=Wx, Wy=Wy, Wz_all=Wz, **pct_kwargs) else: raise ValueError('`Wz` is empty.') return Cxyz, Axyz
[docs]def wct(Wx, Wy, scales, dt, dj, wavelet, smooth_kwargs=None): """ Compute the wavelet coherence transform from the raw CWTs of x and y. Args: Wx, Wy: 2D array with wavelet coefficients. scales, dt, dj, wavelet: cwt parameters/output. smooth_kwargs: dict with optional parameters for wavelet.smooth(). Returns: Cxy (np.ndarray): coherence values (between 0 and 1). Axy (np.ndarray): cross-wavelet phase between Wx and Wy. """ Sxy, Sx, Sy, Wxy = xwt_smooth(Wx, Wy, scales=scales, dt=dt, dj=dj, wavelet=wavelet, **(smooth_kwargs if smooth_kwargs is not None else dict())) # Coherence magnitude. Cxy = np.abs(Sxy) ** 2 / (Sx * Sy + sys.float_info.epsilon) # Phase (use cross spectrum to estimate phase spectrum, Maraun et al. 2004). Axy = np.angle(Wxy) return Cxy.astype(np.float32), Axy.astype(np.float32)
[docs]def xwt_smooth(Wx, Wy, scales, dt, dj, wavelet, **kwargs): """ Compute smoothed cross-wavelet transform and auto-transforms of Wx and Wy. Used for (partial) wavelet coherence. Args: Wx, Wy: 2D array with wavelet coefficients. scales, dt, dj, wavelet: cwt parameters/output. **kwargs: optional parameters for wavelet.smooth(). Returns: Sxy, Sx, Sy: smoothed cross and auto spectra. Wxy: non-smoothed (but possible averaged) cross-spectrum. """ s = scales.reshape(-1, 1) Wxy = 1 / s * Wx * Wy.conj() Wxp = 1 / s * np.abs(Wx) ** 2 # == 1 / s * Wx * Wx.conj() Wyp = 1 / s * np.abs(Wy) ** 2 # Smooth spectra in time and scale. Sm = partial(wavelet.smooth, dt=dt, dj=dj, scales=scales, **kwargs) Sxy = Sm(Wxy) Sx = Sm(Wxp) Sy = Sm(Wyp) return Sxy, Sx, Sy, Wxy
def _pct_1(Wx, Wy, Wz, scales, dt, dj, wavelet, smooth_kwargs=None): """ Compute partial coherence between Wx and Wy, while controlling for influence of a third variable Wz. This is equivalent to _pct_n(), but more memory efficient. References: H. Mihanović, M. Orlic, and Z. Pasarić, “Diurnal thermocline oscillations driven by tidal flow around an island in the Middle Adriatic,” (2009). X. Meng, “The time-frequency dependence of unemployment on real input prices: a wavelet coherency and partial coherency approach,” Applied Economics, vol. 52, no. 10, pp. 1124–1140, Sep. 2019. Args: Wx (np.ndarray): 2D array with wavelet coefficients for x (scales, time). Wy (np.ndarray): wavelet coefficients for y (scales, time). Wz (np.ndarray): wavelet coefficients for z (scales, time). scales, dt, dj, wavelet: cwt parameters/output. smooth_kwargs: dict with optional parameters for wavelet.smooth(). Returns: Cxyz (np.ndarray): partial coherence values (between 0 and 1). Axyz (np.ndarray): phase of complex partial coherence. """ def compute_R(W1, W2): Sxy, Sx, Sy, Wxy = xwt_smooth(W1, W2, scales=scales, dt=dt, dj=dj, wavelet=wavelet, **(smooth_kwargs if smooth_kwargs is not None else dict())) Rxy = Sxy/np.sqrt(Sx * Sy + sys.float_info.epsilon) # Prevent division by zero. # Phase (use cross spectrum to estimate phase spectrum, Maraun et al. 2004). Axy = np.angle(Wxy) return Rxy, Axy # Compute partial coherence from pairwise coherences (Mihanovic et al. 2009). Rxy, Axy = compute_R(Wx, Wy) Rzy, _ = compute_R(Wz, Wy) Rxz, _ = compute_R(Wx, Wz) # Do not square it yet to keep the phase information. num = Rxy - Rzy * Rxz den = np.sqrt((1 - np.abs(Rzy)**2) * (1 - np.abs(Rxz)**2)) # Complex partial coherence (equivalent to Eq. 18 in Meng 2019). Rxyz = num/den # Squared magnitude coherence and angle of complex partial coherence. Cxyz = np.abs(Rxyz)**2 Axyz = np.angle(Rxyz) return Cxyz.astype(np.float32), Axyz.astype(np.float32) def _pct_n(Wx, Wy, Wz_all, scales, dt, dj, wavelet, smooth_kwargs=None): """ Compute partial coherence between Wx and Wy, while controlling for influence of multiple variables in Wz_all. References: X. Meng, “The time-frequency dependence of unemployment on real input prices: a wavelet coherency and partial coherency approach,” Applied Economics, vol. 52, no. 10, pp. 1124–1140, Sep. 2019. Args: Wx (np.ndarray): 2D array with wavelet coefficients for x (scales, time). Wy (np.ndarray): wavelet coefficients for y (scales, time). Wz_all (lsit): list with wavelet coefficients of all confounding variables (scales, time). scales, dt, dj, wavelet: cwt parameters/output. smooth_kwargs: dict with optional parameters for wavelet.smooth(). Returns: Cxyz (np.ndarray): partial coherence values (between 0 and 1). Axyz (np.ndarray): phase of complex partial coherence. """ if isinstance(Wz_all, tuple): Wz_all = list(tuple) if not isinstance(Wz_all, list): raise TypeError('`Wz_all` should be a list. Got a {}.'.format(type(Wz_all))) # First compute all smoothed cross-wavelet transforms to compute the matrix in Eq. 16. Sm = partial(wavelet.smooth, dt=dt, dj=dj, scales=scales, **smooth_kwargs if smooth_kwargs is not None else dict()) s = scales.reshape(-1, 1) W_all = [Wx, Wy] + Wz_all S_all = [] for i in range(len(W_all)): Si = [] for j in range(len(W_all)): if j < i: # Use the fact that the S matrix is conjugate symmetric. Sij = S_all[j][i].conj() else: Wij = 1 / s * W_all[i] * W_all[j].conj() Sij = Sm(Wij) Si.append(Sij) S_all.append(Si) # To array (n_sig, n_sig, scales, time). S = np.asarray(S_all) # Compute the partial correlation between x and y while controlling for all others. # Use formula 18 in Meng 2019 with j=2. if len(Wz_all) == 2: # Write out the cofactors to make the more memory efficient. CSyx = -1 * ( S[0, 1] * (S[2, 2] * S[3, 3] - S[2, 3] * S[3, 2]) - S[0, 2] * (S[2, 1] * S[3, 3] - S[2, 3] * S[3, 1]) + S[0, 3] * (S[2, 1] * S[3, 2] - S[2, 2] * S[3, 1])) CSxx = S[1, 1] * (S[2, 2] * S[3, 3] - S[2, 3] * S[3, 2]) - S[1, 2] * (S[2, 1] * S[3, 3] - S[2, 3] * S[3, 1]) + \ S[1, 3] * (S[2, 1] * S[3, 2] - S[2, 2] * S[3, 1]) CSyy = S[0, 0] * (S[2, 2] * S[3, 3] - S[2, 3] * S[3, 2]) - S[0, 2] * (S[2, 0] * S[3, 3] - S[2, 3] * S[3, 0]) + \ S[0, 3] * (S[2, 0] * S[3, 2] - S[2, 2] * S[3, 0]) else: # Reshape to (scales, time, n_sig, n_sig). S = S.swapaxes(0, 2).swapaxes(1, 3) # Minor matrices (this makes copies of the (potentially large) array S). Myx = np.delete(np.delete(S, 1, axis=2), 0, axis=3) Mxx = np.delete(np.delete(S, 0, axis=2), 0, axis=3) Myy = np.delete(np.delete(S, 1, axis=2), 1, axis=3) # Cofactors. CSyx = -1 * np.linalg.det(Myx) CSxx = np.linalg.det(Mxx) CSyy = np.linalg.det(Myy) # Eq. 18 in Meng 2019. Rxyz = -CSyx / np.sqrt(CSxx * CSyy) # Squared magnitude coherence and angle of complex partial coherence. Cxyz = np.abs(Rxyz)**2 Axyz = np.angle(Rxyz) return Cxyz.astype(np.float32), Axyz.astype(np.float32)