Source code for nnsa.visualization.eeg

import numpy as np
import matplotlib.pyplot as plt

__all__ = [
    'get_default_positions',
    'topoplot',
]


[docs]def get_default_positions(): """ Return a dict with default locations of EEG channels. Keys are all lower-case. """ default_pos = { 'fp1': (-0.024, 0.075), 'fp2': (0.024, 0.075), 't3': (-0.078, 0), 't4': (0.078, 0), 'o1': (-0.024, -0.075), 'o2': (0.024, -0.075), 'c3': (-0.046, 0), 'c4': (0.046, 0), 'cz': (0, 0), } return default_pos
[docs]def topoplot(data, channels=None, positions=None, ax=None, **kwargs): """ Topoplot of EEG data. Args: data (np.ndarray): values at electrodes. channels (list): list of channel labels. positions (list, dict): list of positions of the electrodes. If not specified, will determine the positions from default locations and `channels`. ax (plt.Axes): axes to plot is. **kwargs (dict): for mne.viz.plot_topomap(). """ from mne.viz import plot_topomap # Default inputs. if positions is None: positions = get_default_positions() if isinstance(positions, dict): if channels is None: raise ValueError('`channels` must be specified.') # Make keys lowercase. positions = dict((k.lower(), v) for k, v in positions.items()) # Collect positions. pos = [] for chan in channels: if chan.lower() not in positions: raise ValueError('Position of channel "{}" unknown. Specify this using the `positions` argument' .format(chan)) pos.append(positions[chan.lower()]) elif isinstance(positions, (list, tuple)): # Check length. if len(positions) != len(data): raise ValueError('Length of `data` ({}) does not equal length of `positions` ({}).' .format(len(data), len(positions))) pos = positions else: raise ValueError('Invalid input for `positions` ({}).'.format(positions)) if ax is None: # Plot in current axes. ax = plt.gca() plot_kwargs = dict( cmap=plt.cm.jet, sensors=True, res=64, axes=ax, show_names=False, mask=None, mask_params=None, outlines='head', contours=0, image_interp='bilinear', show=False, ch_type='eeg') plot_kwargs.update(**kwargs) plot_topomap(np.asarray(data), np.asarray(pos), names=channels, **plot_kwargs)