"""
This module contains functions dealing matplotlib plots.
"""
import datetime
import os
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import math
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.ticker import NullFormatter, FuncFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.stats import spearmanr, mannwhitneyu
from sklearn import manifold
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from nnsa.utils.paths import check_directory_exists
from nnsa.utils.arrays import asnumber, get_bin_edges
__all__ = [
'DEFAULT_COLORS',
'IndexTracker',
'color_background',
'compute_linewidth',
'enumerate_axes',
'format_time_axis',
'maximize_figure',
'mm2inch',
'pieplot',
'PointPicker',
'remove_ticks',
'save_fig_as',
'scale_figsize',
'set_plot_style',
'shade_axis',
'stripboxplot',
'subplot_rows_columns'
]
# Specify a set of default colors.
DEFAULT_COLORS = ['#6494aa',
'#F1B663',
'#90A959',
'#BD4C50',
'#904CA9',
'#774940',
'#E56BA4',
'#474747']
[docs]def color_background(x, c, ylim=None, ax=None, **kwargs):
"""
Color the background of the plot according to `c`.
Args:
x (np.ndarray): x-locations of the levels in c.
c (np.ndarray): array with color intensities (same shape as x).
ylim (np.ndarray, optional): optional lower and upper y limit to color.
ax (plt.axes, optional): axes to color.
plot_kwargs (dict, optional): keyword arguments for plt.contourf.
"""
if ax is not None:
plt.sca(ax)
if ylim is None:
ylim = plt.ylim()
plot_kwargs = dict({
'cmap': 'jet',
'levels': 20,
}, **kwargs)
# Use contourf to color the background.
z = np.vstack((c, c))
plt.contourf(x, ylim, z, **plot_kwargs)
[docs]def compute_linewidth(y):
"""
Compute an approriate linewidth for a noisy signal (e.g. EEG).
Args:
y (np.ndarray): data array that is plotted.
Returns:
linewidth (float): an appropriate linewidth for plotting the data y.
"""
y = y[~np.isnan(y)]
if len(y) == 0:
return 1
# Fitting points.
x_fit = np.array([[1, 0.09236669, 2500],
[1, 1.68553142e-02, 1.50000000e+04],
[1, 1.32140071e-02, 7.50000000e+04],
[1, 1.89917847e-03, 1.48500000e+06]])
y_fit = np.array([[1],
[0.5],
[0.7],
[0.3]])
# Compute coefficients for linear regression: y = x*theta.
theta = np.linalg.lstsq(x_fit, y_fit, rcond=None)[0]
# Compute feature X of current y: linewidth scales with the length of y and with the high frequency content of y
# (captured by np.diff(y)).
x = np.array([[1, np.mean(np.abs(np.diff(y))/(np.max(y) - np.min(y))), len(y)]])
# Predict linewidth.
linewidth = x @ theta
return linewidth
[docs]def enumerate_axes(axes, xloc=-0.1, yloc=1.05, style='alphabet', capitalize=False, postfix='', **kwargs):
"""
Add enumeration to axes. E.g. a, b, c.
Args:
axes (list, tuple, np.ndarray): list of axes to enumerate.
xloc (float): x-coordinate for the text. By default this is in normalized axis coordinates.
Specify `transform` (as kwargs) to use a different coordinate system.
yloc (float): y-coordinate for the text. By default this is in normalized axis coordinates.
Specify `transform` (as kwargs) to use a different coordinate system.\
style (str): specify which enumration style to use. Choose from:
'alphabet'.
capitalize (bool): whether to capitalize the enumeration.
postfix (str): optional postfix to add to the enumeration. E.g. ')'.
**kwargs (dict, optional): for ax.text().
"""
style = style.lower()
# Create list or array of characters for enumaration.
if 'alpha' in style:
chars = 'abcdefghijklmnop'
else:
raise ValueError('Invalid style="{}".'.format(style))
if len(chars) < len(axes):
raise NotImplementedError('Not implemented for {} axes. Not enough characters.'.format(len(axes)))
# Loop over axes and add the text.
for ii, ax in enumerate(axes):
# Create text.
text = chars[ii]
if capitalize:
text = text.capitalize() + postfix
# Update default text kwargs with user-specified kwargs.
text_kwargs = dict(dict(
horizontalalignment='center',
verticalalignment='center',
transform=ax.transAxes,
weight='bold',
), **kwargs)
# Add text.
ax.text(xloc, yloc, text, **text_kwargs)
def fillplot(x, y, data, x_order=None, y_order=None, ax=None, **kwargs):
"""
Line plot that is filled between zero and y.
Args:
x (np.ndarray or str): array with x-data, or column name for DataFrame `data`.
y (np.ndarray or str): array with y-data, or column name for DataFrame `data`.
data (pd.DataFrame or None): optional container for the data.
x_order (list, None): optional list with the order of the x-values (if non-numeric).
y_order (list, None): optional list with the order of the y-values (if non-numeric).
ax (plt.Axes, None): axis to plot in.
**kwargs: for plt.fill_between(), e.g. `step`.
"""
# Defaults.
if ax is None:
ax = plt.gca()
if data is None and (isinstance(x, str) or isinstance(y, str)):
raise ValueError('Argument `data` missing while providing column names for `x` and/or `y`.')
# Check if x and y are column names for `data`.
if isinstance(x, str):
xlabel = x
x = data[x].values
else:
xlabel = ''
if isinstance(y, str):
ylabel = y
y = data[y].values
else:
ylabel = ''
# If needed, convert non-numeric data to numbers.
if not np.issubdtype(x.dtype, np.number):
if x_order is None:
# Default order.
x_order = np.unique(x)
xlabels = x_order
xticks = np.arange(1, len(x_order) + 1) / len(x_order)
x = asnumber(x, order=x_order)
else:
xlabels = xticks = None
if not np.issubdtype(y.dtype, np.number):
if y_order is None:
# Default order.
y_order = np.unique(y)
ylabels = y_order
yticks = np.arange(1, len(y_order) + 1) / len(y_order)
y = asnumber(y, order=y_order)
else:
ylabels = yticks = None
# Plot.
ax.fill_between(
x=x, y1=y, **kwargs)
# Update ticks.
if xticks is not None:
ax.set_xticks(xticks)
ax.set_xticklabels(xlabels)
if yticks is not None:
ax.set_yticks(yticks)
ax.set_yticklabels(ylabels)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
def get_bin_idx(a, bin_edges):
bin_mask = (a >= bin_edges[:-1]) & (a < bin_edges[1:])
if np.sum(bin_mask) != 1:
raise ValueError(f'Value {a} cannot be placed inside a bin with bin edges:\n{bin_edges}')
idx_bin = np.where(bin_mask)[0][0]
return idx_bin
def heatmap(x, y, color=None, palette=None, color_range=None, size=None, size_range=None,
size_scale=500, x_order=None, y_order=None, annotations=None, text_kwargs=None,
ax=None, **kwargs):
"""
Adopted from https://github.com/dylan-profiler/heatmaps.
"""
if color is None:
color = [1]*len(x)
if palette is None:
# palette = sns.color_palette("Blues", 256)
palette = sns.diverging_palette(20, 220, n=256)
# Range of values that will be mapped to the palette, i.e. min and max possible correlation
if color_range is None:
color_min, color_max = min(color), max(color)
else:
color_min, color_max = color_range
if size is None:
size = [1]*len(x)
if size_range is None:
size_min, size_max = min(size), max(size)
else:
size_min, size_max = size_range
if x_order is None:
x_order = [t for t in sorted(set([v for v in x]))]
if y_order is None:
y_order = [t for t in sorted(set([v for v in y]))]
if ax is None:
ax = plt.gca()
n_colors = len(palette)
def value_to_size(val):
if np.isnan(val):
s = 0
if size_min == size_max:
s = 1 * size_scale
else:
val_position = (val - size_min) * 0.99 / (size_max - size_min) + 0.01 # position of value in the input range, relative to the length of the input range
val_position = min(max(val_position, 0), 1) # bound the position betwen 0 and 1
s = val_position * size_scale
return s
x_to_num = {p[1]: p[0] for p in enumerate(x_order)}
y_to_num = {p[1]: p[0] for p in enumerate(y_order)}
marker = kwargs.get('marker', 's')
kwargs_pass_on = {k:v for k,v in kwargs.items() if k not in [
'color', 'palette', 'color_range', 'size', 'size_range', 'size_scale', 'marker', 'x_order', 'y_order', 'xlabel', 'ylabel'
]}
x_scatter = [x_to_num[v] for v in x]
y_scatter = [y_to_num[v] for v in y]
s_scatter = [value_to_size(v) for v in size]
c_scatter = [value_to_color(v, color_min=color_min, color_max=color_max, palette=palette)
for v in color]
ax.scatter(
x=x_scatter, y=y_scatter, marker=marker,
s=s_scatter, c=c_scatter, **kwargs_pass_on
)
ax.set_xticks([v for k,v in x_to_num.items()])
ax.set_xticklabels([k for k in x_to_num], rotation=45, horizontalalignment='right')
ax.set_yticks([v for k,v in y_to_num.items()])
ax.set_yticklabels([k for k in y_to_num])
ax.grid(False, 'major')
ax.grid(True, 'minor')
ax.set_xticks([t + 0.5 for t in ax.get_xticks()], minor=True)
ax.set_yticks([t + 0.5 for t in ax.get_yticks()], minor=True)
ax.set_facecolor('#F1F1F1')
# Add color legend on the right side of the plot
if color_min < color_max:
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.05)
col_x = [0]*len(palette) # Fixed x coordinate for the bars
bar_y = np.linspace(color_min, color_max, n_colors) # y coordinates for each of the n_colors bars
bar_height = bar_y[1] - bar_y[0]
cax.barh(
y=bar_y,
width=[5]*len(palette), # Make bars 5 units wide
left=col_x, # Make bars start at 0
height=bar_height,
color=palette,
linewidth=0
)
cax.set_xlim(1, 2) # Bars are going from 0 to 5, so lets crop the plot somewhere in the middle
cax.grid(False) # Hide grid
cax.set_facecolor('white') # Make background white
cax.set_xticks([]) # Remove horizontal ticks
cax.set_yticks(np.linspace(min(bar_y), max(bar_y), 3)) # Show vertical ticks for min, middle and max
cax.yaxis.tick_right() # Show vertical ticks on the right
cax.spines['bottom'].set_visible(False)
cax.spines['left'].set_visible(False)
if annotations is not None:
text_kwargs = dict(
dict(va='center', ha='center', alpha=0.7),
**(text_kwargs if text_kwargs is not None else dict()))
for xi, yi, annot in zip(x_scatter, y_scatter, annotations):
annot = str(annot)
if len(annot) > 0:
ax.text(xi, yi, annot, **text_kwargs)
ax.axis('equal')
ax.set_xlim([-0.5, max([v for v in x_to_num.values()]) + 0.5])
ax.set_ylim([-0.5, max([v for v in y_to_num.values()]) + 0.5])
[docs]def mm2inch(mm):
return mm*0.0393700787
[docs]def pieplot(x, weight=None, data=None, ax=None, add_legend=True,
order=None, palette=None, normalize=True, add_labels=True,
legend_kwargs=None, **kwargs):
"""
Make a pie chart for the data x using similar syntax as seaborn.
"""
if ax is None:
ax = plt.gca()
if isinstance(x, str):
x = data[x]
x = np.asarray(x).astype(str)
if weight is None:
weight = np.ones(len(x))
if isinstance(weight, str):
weight = data[weight]
# Defaults.
if order is None:
order = sorted(np.unique(x))
if palette is None:
palette = dict([(order[i], f'C{i}') for i in range(len(order))])
# Collect sizes and colors.
sizes = []
colors = []
order_new = []
labels = []
tot_size = np.sum(weight)
for lab in order:
size_i = np.sum(weight[x == lab])/tot_size
if size_i == 0:
continue
order_new.append(lab)
sizes.append(size_i)
colors.append(palette[lab])
labels.append(lab.split(' ')[-1][:3] if size_i > 0.05 else '')
# Plot.
ax.pie(sizes, labels=labels if add_labels else None,
colors=colors, normalize=normalize, **dict(dict(counterclock=False), **kwargs))
ax.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle.
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
if add_legend:
legend_opts = dict(bbox_to_anchor=(1, 1))
if legend_kwargs is not None:
legend_opts.update(legend_kwargs)
legend_elements = []
for lab, col in zip(order_new, colors):
legend_elements.append(mpatches.Patch(color=col, label=lab))
ax.legend(handles=legend_elements, **legend_opts)
def plot_correlations(x, y_all, data, axes=None, **kwargs):
# Plot correlations.
if axes is None:
nrows, ncols = subplot_rows_columns(len(y_all))
if len(y_all) == 8:
nrows, ncols = 2, 4
fig, axes = plt.subplots(nrows, ncols, sharex='all',
tight_layout=True, figsize=1.8 * np.array([7, 5]))
axes = np.reshape([axes], -1)
df = data
for ii, (y, ax) in enumerate(zip(y_all, axes)):
sns.regplot(x=x, y=y, data=df, ax=ax, **kwargs)
df_i = df[[x, y]].dropna()
r, p = spearmanr(df_i[x], df_i[y])
if p < 0.05:
weight = 'bold'
else:
weight = None
ax.set_title(f'{y}\nr={r:.2f} (p={p:.3f})', weight=weight)
ax.grid()
for ax in axes[len(y_all):]:
ax.remove()
def plot_boxplots(x, y_all, data, axes=None, **kwargs):
if axes is None:
nrows, ncols = subplot_rows_columns(len(y_all))
fig, axes = plt.subplots(nrows, ncols, sharex='all',
tight_layout=True, figsize=1.8 * np.array([5, 5]))
axes = np.reshape([axes], -1)
df = data
for ii, (y, ax) in enumerate(zip(y_all, axes)):
stripboxplot(x=y, y=x, data=df, ax=ax, **kwargs)
df_i = df[[x, y]].dropna()
y_unique = df_i[y].unique()
if len(y_unique) == 2:
y1 = df_i[df_i[y] == y_unique[0]][x]
y2 = df_i[df_i[y] == y_unique[1]][x]
r, p = mannwhitneyu(y1, y2)
else:
p = np.nan
if p < 0.05:
weight = 'bold'
else:
weight = None
ax.set_title(f'{y}\n(p={p:.3f})', weight=weight)
ax.grid()
for ax in axes[len(y_all):]:
ax.remove()
def plot_pca(X, y, n_components=2, order=None, normalize=True, ax=None, sample_size=None, random_state=None, palette=None, **kwargs):
"""
Helper function to quicly plot a 2D PCA visualization of high-dimensional data.
"""
if n_components not in [2, 3]:
raise NotImplementedError(f'Not implemented for n_components={n_components}.')
if ax is None:
ax = plt.gca()
if y.ndim != 1:
raise ValueError('`y` should be 1-dimensional. Got shape {}.'.format(y.shape))
if len(X) != len(y):
raise ValueError('`X` and `y` should have same lengths. Got shapes {} and {}.'
.format(X.shape, y.shape))
if sample_size is not None and sample_size < len(X):
rng = np.random.default_rng(random_state)
idx = rng.choice(len(X), size=sample_size, replace=False)
X = X[idx]
y = y[idx]
if normalize:
X = StandardScaler().fit_transform(X)
# Do PCA.
pca = PCA(n_components=n_components)
T = pca.fit_transform(X)
# Determine if y is continuous or not.
y_unique = np.unique(y)
if len(y_unique) > len(y)/10:
# Continuous
ax.scatter(*T.T, c=y)
else:
# Plot with categories.
if order is None:
order = y_unique
for v in order:
mask = y == v
ax.scatter(*T[mask].T, c=palette[v] if palette is not None else None)
ax.set_xlabel('Principal component 1')
ax.set_ylabel('Principal component 2')
def plot_per_bin(x, y, data, n_bins=10, levels=None, aggregate_fun='mean',
plot_error_bars=True, plot_hist=True,
bar_kwargs=None, plot_kwargs=None, err_kwargs=None, ax=None):
"""
Plot the mean or median of y per vbin of x.
"""
df = data.copy()
if ax is None:
ax = plt.gca()
if levels is not None:
# Convert to levels with equal sample sizes.
xvals = df[x].copy()
df[x] = levels
for lev in range(levels, 0, -1):
df.loc[xvals <= np.nanquantile(xvals, q=lev / levels), x] = lev
n_bins = levels
# Assign each row to a bin based on x.
bin_edges = get_bin_edges(df[x], n_bins)
df['bin'] = df[x].apply(lambda a: get_bin_idx(a, bin_edges))
bins_sorted = np.arange(len(bin_edges) - 1)
# Get bin centres.
xi = np.convolve([0.5, 0.5], bin_edges, mode='valid')
binwidth = bin_edges[1] - bin_edges[0]
# Compute the mean or median y per bin.
y_ser = df[y]
if aggregate_fun == 'mean':
# Compute mean and sd.
mid = np.array([y_ser[df.bin == b].mean() for b in bins_sorted])
sd = np.array([y_ser[df.bin == b].std() for b in bins_sorted])
low = mid - sd
high = mid + sd
elif aggregate_fun == 'median':
# Compute quantiles.
low = np.array([y_ser[df.bin == b].quantile(0.25) for b in bins_sorted])
mid = np.array([y_ser[df.bin == b].quantile(0.50) for b in bins_sorted])
high = np.array([y_ser[df.bin == b].quantile(0.75) for b in bins_sorted])
else:
raise NotImplementedError(aggregate_fun)
if plot_hist:
# Number of segments in each bin.
counts = [np.sum(df['bin'] == b) for b in bins_sorted]
# Plot the sizes of the bins as a histogram behind the plot.
ax2 = ax.twinx()
bar_kwargs = dict(dict(edgecolor='C7', color='#F7DAB0'), **(bar_kwargs if bar_kwargs is not None else {}))
ax2.bar(x=xi, height=counts, width=binwidth, **bar_kwargs)
ax2.axes.get_yaxis().set_visible(False)
ax.patch.set_visible(False)
ax.set_zorder(ax2.get_zorder() + 1) # Move ax in front.
# Plot the bin vs mean y.
yi = mid
yerr = np.abs(np.array([low, high]).reshape(2, -1) - yi.reshape(1, -1))
plot_kwargs = dict(dict(color='k', marker='o'), **(plot_kwargs if plot_kwargs is not None else {}))
ax.plot(xi, yi, **plot_kwargs)
if plot_error_bars:
err_kwargs = dict(dict(capsize=3, color=plot_kwargs.get('color', 'k')),
**(err_kwargs if err_kwargs is not None else {}))
ax.errorbar(xi, yi, yerr=yerr, **err_kwargs)
ax.set_xlabel(x)
ax.set_ylabel(y)
def plot_tsne(X, y, n_components=2, order=None, ax=None, sample_size=None, random_state=None, palette=None, **kwargs):
"""
Helper function to quicly plot a 2D t-SNE visualization of high-dimensional data.
"""
if n_components not in [2, 3]:
raise NotImplementedError(f'Not implemented for n_components={n_components}.')
if ax is None:
ax = plt.gca()
else:
plt.sca(ax)
if y.ndim != 1:
raise ValueError('`y` should be 1-dimensional. Got shape {}.'.format(y.shape))
if len(X) != len(y):
raise ValueError('`X` and `y` should have same lengths. Got shapes {} and {}.'
.format(X.shape, y.shape))
if sample_size is not None and sample_size < len(X):
rng = np.random.default_rng(random_state)
idx = rng.choice(len(X), size=sample_size, replace=False)
X = X[idx]
y = y[idx]
# Default tSNE options.
tsne_kwargs = dict(
init="pca",
learning_rate="auto",
n_iter=300,
)
# Update with user-specified options.
tsne_kwargs.update(**kwargs)
tsne = manifold.TSNE(n_components=n_components, random_state=random_state, **tsne_kwargs)
T = tsne.fit_transform(X)
# Determine if y is continuous or not.
y_unique = np.unique(y)
if len(y_unique) > len(y)/10:
# Continuous
ax.scatter(*T.T, c=y)
else:
# Plot with categories.
if order is None:
order = y_unique
for v in order:
mask = y == v
ax.scatter(*T[mask].T, c=palette[v] if palette is not None else None)
ax.xaxis.set_major_formatter(NullFormatter())
ax.yaxis.set_major_formatter(NullFormatter())
[docs]class PointPicker(object):
"""
Select and deselect points in a matplotlib plot.
Collect selected points in a list (self.points).
Highlights selected points in the plot.
Selected points can be save to an Excel file using self.save_points('filename.xls').
Saved points in Excel (with `x` and `y` columns) can be loaded using self.load_points('filename.xlsx').
Args:
points (list): optional list with tuples of (x, y) coordinates that should already
be included in the selection.
ax (plt.Axes): axes in which to highlight the specified points in `points`.
**kwargs (optional): kwargs for ax.scatter() to control how the slected points
are highlighted (e.g. c='y', s=100).
Examples:
>>> fig, ax = plt.subplots()
>>> ax.set_title('click on points')
Text(0.5, 1.0, 'click on points')
>>> line, = ax.plot(np.random.rand(100), 'o',
... picker=True, pickradius=5) # 5 points tolerance
>>> picker = PointPicker(fig)
"""
def __init__(self, fig, points=None, ax=None, **kwargs):
if ax is None:
ax = plt.gca()
scatter_kwargs = dict({'c': 'r', 'zorder': 100}, **kwargs)
self._fig = fig
self._points = [] # List of selected points.
self._point_scatters = [] # Scatter objects that highlight selected points.
self._ax = ax # Axes in which to plot.
self._scatter_kwargs = scatter_kwargs
if points is not None:
for p in points:
self._add_point(p)
# Link figure to the onpick function.
fig.canvas.mpl_connect('pick_event', self.onpick)
@property
def points(self):
return self._points
@property
def xpoints(self):
"""
Return list of x coordinates of selected points.
"""
return [p[0] for p in self.points]
@property
def ypoints(self):
"""
Return list of y coordinates of selected points.
"""
return [p[1] for p in self.points]
[docs] def onpick(self, event):
# Set current axes.
self._ax = event.artist.axes
# Get x, y coordinates.
thisline = event.artist
xdata = thisline.get_xdata()
ydata = thisline.get_ydata()
idx_in_tol = event.ind
x_click = event.mouseevent.xdata
y_click = event.mouseevent.ydata
# If any points are already in the collection, choose that one.
dist_to_click = []
ind = None
for idx in idx_in_tol:
# Point.
p_i = (xdata[idx], ydata[idx])
# Add dist of point to click.
d_i = (p_i[0] - x_click)**2 + (p_i[1] - y_click)**2
dist_to_click.append(d_i)
# Check if point in collection.
if p_i in self.points:
ind = idx
break
if ind is None:
# If not in collection already, take closest one.
ind = idx_in_tol[np.argmin(dist_to_click)]
# Get point.
p = (xdata[ind], ydata[ind])
# Add/remove point.
self._toggle_point(p)
[docs] def load_points(self, filepath):
"""
Load points from an Excel file with columns x and y.
Hint: save as .xls to be able to have the Excel file open in Excel while loading in Python.
"""
df = pd.read_excel(filepath)
for _, row in df.iterrows():
p = (row.x, row.y)
self._add_point(p)
[docs] def save_points(self, filepath, **kwargs):
"""
Save point coordinates to an Excel file.
"""
df = pd.DataFrame(data=dict(x=self.xpoints, y=self.ypoints))
df.to_excel(filepath, index=False, **kwargs)
def _toggle_point(self, p):
if p in self.points:
# Remove.
self._remove_point(p)
else:
# Add.
self._add_point(p)
def _add_point(self, p):
# Only add if not already in collection.
if p not in self.points:
# Highlight point.
s = self._ax.scatter(p[0], p[1], **self._scatter_kwargs)
plt.autoscale(False)
plt.draw()
self._points.append(p)
self._point_scatters.append(s)
def _remove_point(self, p):
# Find index.
idx = self.points.index(p)
# Remove highlight.
self._point_scatters[idx].remove()
plt.autoscale(False)
plt.draw()
# Remove from lists.
del self._points[idx]
del self._point_scatters[idx]
[docs]def remove_ticks(axis, **kwargs):
"""
Remove axis ticks and label of specified axis.
Args:
axis (str or tuple): the axis to remove the ticks and label from. Either 'x', 'y', 'xy' or ('x', 'y').
"""
default_kwargs = {'which': 'both', # both major and minor ticks are affected
'bottom': False, # ticks along the bottom edge are off
'top': False, # ticks along the top edge are off
'left': False,
'right': False,
'labelbottom': False, # labels along the bottom edge are off
'labelleft': False
}
default_kwargs.update(kwargs)
for a in axis:
if 'x' in a:
# x-ticks.
plt.tick_params(
axis='x', # changes apply to the x-axis
**default_kwargs)
# Remove x-label.
plt.xlabel('')
if 'y' in a:
# y-ticks.
plt.tick_params(
axis='y', # changes apply to the y-axis
**default_kwargs)
# Remove y-label.
plt.ylabel('')
[docs]def save_fig_as(figname=None, directory='', filepath=None, info=None,
formats=None, verbose=1, **kwargs):
"""
Save a figure to several different output formats.
Args:
figname (str): name of the figure, will be the filename.
directory (str): directory in which to save.
filepath (str): instead of specifying figname and directory, you can specify filepath.
filepath will be os.path.join(directory, figname).
info (str): info which will be written to a .txt file with the same name
(e.g. the path to the script creating the figure). If None, no .txt file will be created.
formats (tuple, list): list with formats to save to, e.g. ("eps", "tiff", "png", "pdf", "svg").
If specified, this overrides any extension that was specified in `figname` or `filepath`.
verbose (int): if 1, prints a message on success.
**kwargs: for plt.savefig().
"""
if filepath is None:
# Determine filepath from figname and directory.
if figname is None:
raise ValueError('No filepath specified. Either specify `figname` and `directory` or `filepath`.')
filepath = os.path.join(directory, figname)
# Check if directory exist.
check_directory_exists(filepath=filepath)
# Separate extension (if given).
filepath, ext = os.path.splitext(filepath)
# Determine formats if not specified.
if formats is None:
# If extension in filename or filepath, use default formats.
if ext:
formats = (ext[1:],) # Skip the dot.
else:
formats = ('pdf', 'png')
# Save figures in several formats.
save_kwargs = dict(**kwargs)
for form in formats:
if 'eps' in form:
plt.savefig(filepath + '.eps', format='eps', **save_kwargs)
elif 'tiff' in form:
plt.savefig(filepath + '.tiff', **save_kwargs)
elif 'png' in form:
plt.savefig(filepath + '.png', **save_kwargs)
elif 'pdf' in form:
plt.savefig(filepath + '.pdf', **save_kwargs)
elif 'svg' in form:
plt.savefig(filepath + '.svg', format='svg', transparent=True)
else:
raise ValueError(f'Invalid format "{form}". Choose from: "eps", "tiff", "png", "pdf", "svg".')
if info is not None:
# Add .txt file with same name containing the info text.
with open(filepath + '.txt', 'w') as f:
current_time = datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")
f.write(info + f'\nCreate at {current_time}')
if verbose:
print(f'Saved figure to {filepath}!')
[docs]def scale_figsize(figsize, width, unit='cm'):
"""
Scale figsize based on predefined width..
Args:
figsize (list, tuple): (width, height) ratio.
width (float): desired width of the figure (in inches by default).
unit (str): 'inch' or 'cm'. The unit of `width`.
Returns:
new_figsize (np.ndarray): rescaled figsize in inches, with figsize[0] eqaul to `width`.
"""
figsize = np.asarray(figsize)
init_width = figsize[0]
if unit == 'cm':
# To inch.
width = width*0.3937
elif 'inch' in unit:
pass
else:
raise ValueError(f'Invalid unit="{unit}". Choose from: "inch", "cm".')
scale_factor = width / init_width
return scale_factor * figsize
[docs]def set_plot_style(backend=None):
"""
Standard plot style.
Args:
backend (str, optional): matplotlib backend to use. If None, the current backend is used.
Defaults to None.
"""
if backend:
matplotlib.use(backend)
sns.set_style('whitegrid')
[docs]def shade_axis(onsets, durations, labels=None, color=None, alpha=0.4, orientation='horizontal',
add_legend=True, legend_kwargs=None, ax=None):
"""
Apply vertical shading of an axis background for epochs defined by onsets and durations.
Args:
onsets (iterable): onsets for the epochs to shade (in the dimension of the x-axis).
durations (iterable): durations of the epochs to shade.
labels (iterable, optional): labels corresponding to the epochs to shade.
color (dict or str, optional): dict mapping a label to a color or one color for all.
alpha (float, optional): the transparity level of the shading.
Defaults to 0.5.
orientation (str, optional): whether the plot is 'horizontal' (onsets are on the x-axis),
or 'vertical' (onsets are on the y-axis).
Defaults to 'horizontal'.
add_legend (bool, optional): if True, add a legend explaining the shading colours (only if `color` is a dict).
If False, adds not legend.
Defaults to True.
legend_kwargs (dict, optional): if add_legend is True, legend_kwargs are passed to the plt.legend()
function as optional keyword arguments.
Defaults to None.
ax (plt.Axes, optional): matplotlib axis to shade. If None, the current axis will be used.
Defaults to None.
Returns:
h (list): list with handles for the shading spans.
"""
if labels is None:
labels = ['']*len(onsets)
add_legend = False
elif isinstance(labels, str):
# Single same label for all shades.
labels = [labels]*len(onsets)
if color is None:
# Default mapping.
unique_labels = np.unique(labels)
color = dict(zip(unique_labels, ('C{}'.format(i) for i in range(len(unique_labels)))))
if ax is not None:
plt.sca(ax)
h = []
for on, dur, lab in zip(onsets, durations, labels):
if isinstance(color, dict) and lab not in color:
# Skip labels that are not in labels_mapping.
continue
else:
if isinstance(color, dict):
col = color[lab]
else:
col = color
if orientation == 'horizontal':
span = plt.axvspan(on, on + dur, facecolor=col, alpha=alpha)
elif orientation == 'vertical':
span = plt.axhspan(on, on + dur, facecolor=col, alpha=alpha)
else:
raise ValueError('Invalid `orientation` "{}". Choose from "horizontal", "vertical".'
.format(orientation))
h.append(span)
if add_legend and isinstance(color, dict):
# Get handle to existing legend.
old_legend = plt.gca().get_legend()
# Create new legend.
default_legend_kwargs = {'loc': 'center right'}
if legend_kwargs is not None:
default_legend_kwargs.update(legend_kwargs)
plt.legend(handles=[mpatches.Patch(color=col, label=lab, alpha=alpha)
for lab, col in color.items() if lab in labels], **default_legend_kwargs)
# Re-add the old existing legend.
if old_legend is not None:
plt.gca().add_artist(old_legend)
return h
[docs]def stripboxplot(x=None, y=None, hue=None, style=None, data=None,
order=None, hue_order=None, orient=None,
color=None, palette=None, markers=None, ax=None, mediansize=0.75,
boxkwargs=None, stripkwargs=None, legendkwargs=None):
"""
Plot a stripplot and overlay the box of a boxplot.
Args:
Most inputs: Common inputs to seaborn's boxplot and stripplot functions.
See seaborn.boxplot and/or seaborn.stripplot.
style: column in data for which to use different markers in the stripplot.
markers: list or dict specifying the markers to be used for each unique marker category.
mediansize: length of the median line in the boxplot, as a fraction of the boxplot width.
boxkwargs: kwargs for seaborn's boxplot function.
stripkwargs: kwargs for seaborn's stripplot function.
legendkwargs: kwargs for matplotlib's legend.
Returns:
ax: axes handle.
"""
# By default, plot in the current axis.
if ax is None:
ax = plt.gca()
if boxkwargs is None:
boxkwargs = {}
if stripkwargs is None:
stripkwargs = {}
if legendkwargs is None:
legendkwargs = {}
# Parse some properties for the plots, override specified defaults with the kwargs.
commonkwargs = {'x': x, 'y': y, 'hue': hue, 'data': data,
'order': order, 'hue_order': hue_order, 'orient': orient,
'color': color, 'palette': palette, 'ax': ax}
boxplot_kwargs = dict({
'boxprops': {'edgecolor': 'k', 'linewidth': 0.5, 'alpha': 0.5},
# The boxplot kwargs get passed to matplotlib's boxplot function.
'medianprops': {'color': 0.15 * np.ones(3)},
'whiskerprops': {'color': 0.15 * np.ones(3), 'linewidth': 1},
'capprops': {'color': 0.3 * np.ones(3), 'linewidth': 0}, # Do not show caps (at end of whiskers).
'width': 0.8, # Width of the boxplot
'saturation': 1, # Saturation of the color
'whis': False # Disable whiskers.
}, **boxkwargs)
stripplot_kwargs = dict({
'linewidth': 0.6,
'size': 6, 'alpha': 1,
'jitter': True,
}, **stripkwargs)
legendkwargs = dict({
'loc':'upper left',
'bbox_to_anchor': (1, 1) , # Plot next to plot (places `loc` ar this position).
'title': hue,
}, **legendkwargs)
# Add common arguments.
stripplot_kwargs.update(commonkwargs)
boxplot_kwargs.update(commonkwargs)
# Plot boxplots.
sns.boxplot(fliersize=0, # Do not plot outliers.
**boxplot_kwargs)
# Change the width of the median lines.
if mediansize != 1:
for line in ax.get_lines():
x = line.get_xdata()
xn = (x - (x.sum() / 2)) * mediansize + (x.sum() / 2)
line.set_xdata(xn)
# Get the labels and handles.
handles, labels = ax.get_legend_handles_labels()
# Strip plot.
if style is None:
sns.stripplot(dodge=True, # Split on the hue value.
zorder=0.5, # Make it appear behind the boxplot.
**stripplot_kwargs)
else:
# Default markers.
if markers is None:
markers = ['o', 'X', 'P', 's', '*', 'v', '^', '>', '<']
unique_mark = data[style].unique()
if len(markers) < len(unique_mark):
msg = 'Too many different levels for `mark_by`. Specify `markers` with at least ' \
'the number of unique marker values.'
raise ValueError(msg)
sp_kwargs = stripplot_kwargs.copy()
for i_mark, mark in enumerate(unique_mark):
if isinstance(markers, list):
marker = markers[i_mark]
elif isinstance(markers, dict):
marker = markers[mark]
else:
raise ValueError('Invalid input for `markers`. Should be a list or dict. Got a {}.'
.format(type(markers)))
# Update plot input.
sp_kwargs['data'] = data[data[style] == mark]
sp_kwargs['marker'] = marker
# Plot.
sns.stripplot(dodge=True, # Split on the hue value.
zorder=0.5, # Make it appear behind the boxplot.
**sp_kwargs)
# Add legend.
if hue is not None:
ax.legend(handles, labels, **legendkwargs)
# Fit figure to window.
plt.tight_layout()
return ax
[docs]def subplot_rows_columns(n, minimize='rows'):
"""
Return a suitable number of rows and columns for a subplot figure with n plots.
Args:
n (int): number of plots in the subplot figure.
minimize (str, optional): if 'rows', the number of rows will be <= number of columns.
If 'columns', the number of columns will be <= number of rows.
Returns:
nrows (int): number of rows for the subplot figure.
ncols (int): number of columns for the subplot figure.
"""
# Compute square root of n.
sqrtn = n**(1/2)
# Rows and columns depend on which one we want to keep to a minimum.
if minimize == 'rows':
if n == 3:
nrows = 1
ncols = 3
else:
nrows = round(sqrtn)
ncols = math.ceil(sqrtn)
elif minimize == 'columns':
nrows = math.ceil(sqrtn)
ncols = round(sqrtn)
else:
raise ValueError('Invalid argument minimize="{}". Choose from "rows", "columns".'.format(minimize))
return nrows, ncols
def trendline(x, y, ax=None, **kwargs):
"""
Plot a linear trendline for (x, y).
Args:
x (np.ndarray): x-points.
y (np.ndarray): y-points.
ax (plt.Axes): matplotlib axes to plot in.
If not specified, plots in the current axes.
**kwargs: for ax.plot().
"""
if ax is None:
ax = plt.gca()
# Compute trendline.
z = np.polyfit(x, y, 1)
p = np.poly1d(z)
# Plot.
ax.plot(x, p(x), **kwargs)
def value_to_color(val, color_min, color_max, palette):
"""
Convert a value (`val`) to a color according to a color map (`palette`).
Args:
val (float): value to convert.
color_min (float): minimum value in the color scale. Values below this will get the first color in `palette`.
color_max (float): maximum value in the color scale. Values above this will get the last color in `palette`.
palette (list, np.ndarray): list with the colors in the color map (typically is a list with 256 colors).
Returns:
color: the color for the `val`.
"""
n_colors = len(palette)
if color_min == color_max:
color = palette[-1]
elif np.isnan(val):
color = palette[0]
else:
val_position = float((val - color_min)) / (color_max - color_min) # position of value in the input range, relative to the length of the input range
val_position = min(max(val_position, 0), 1) # bound the position betwen 0 and 1
ind = int(round(val_position * (n_colors - 1))) # target index in the color palette
color = palette[ind]
return color