"""Functions related to visualization of RDMs."""
import types
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.cm import get_cmap
from mne.viz import Brain
from mne.viz.topo import _iter_topography
from scipy.spatial import distance
[docs]
def plot_rdms(rdms, names=None, items=None, n_rows=1, cmap="viridis", title=None):
"""Plot one or more RDMs.
Parameters
----------
rdms : ndarray | list of ndarray
The RDM or list of RDMs to plot. The RDMs can either be two-dimensional (n_items
x n_items) matrices or be in condensed form.
names : str | list of str | None
For each given RDM, a name to show above it. Defaults to no names.
items : list of str | None
The each item (row/col) in the RDM, a string description. This will be displayed
along the axes. Defaults to None which means the items will be numbered.
n_rows : int
Number of rows to use when plotting multiple RDMs at once. Defaults to 1.
cmap : str
Matplotlib colormap to use. See
https://matplotlib.org/gallery/color/colormap_reference.html
for all possibilities. Defaults to 'viridis'.
title : str | None
Title for the entire figure. Defaults to no title.
Returns
-------
fig : matplotlib figure
The figure produced by matplotlib
"""
if not isinstance(rdms, list):
rdms = [rdms]
if isinstance(names, str):
names = [names]
if names is not None and len(names) != len(rdms):
raise ValueError(
f"Number of given names ({len(names)}) does not "
f"match the number of RDMs ({len(rdms)})"
)
n_cols = int(np.ceil(len(rdms) / n_rows))
fig = plt.figure(figsize=(2 * n_cols, 2 * n_rows))
ax = fig.subplots(n_rows, n_cols, sharex=True, sharey=True, squeeze=False)
for row in range(n_rows):
for col in range(n_cols):
i = row * n_cols + col % n_cols
if i < len(rdms):
rdm = rdms[i]
if rdm.ndim == 1:
rdm = distance.squareform(rdm)
elif rdm.ndim > 2:
raise ValueError(f"Invalid shape {rdm.shape} for RDM")
im = ax[row, col].imshow(rdm, cmap=cmap)
if names is not None:
name = names[i]
ax[row, col].set_title(name)
if items is not None:
ax[row, col].set_xticks(np.arange(len(items)))
ax[row, col].set_xticklabels(items)
ax[row, col].set_yticks(np.arange(len(items)))
ax[row, col].set_yticklabels(items)
else:
ax[row, col].set_visible(False)
plt.colorbar(im, ax=ax)
if title is not None:
plt.suptitle(title)
return fig
def _click_func(ax, ch_idx, rdms, cmap):
"""Plot a single RDM interactively.
Parameters
----------
ax: matplotlib.Axes.axes
Axes.axes object on which a new single RDM is plotted.
ch_idx: int
Index of a channel.
rdms: ndarray, shape (n_sensors, n_rdm_datapoint)
RDMs of MEG recordings; there's one RDM for each sensor.
cmap: str
Colormap used for plotting RDMs. Check ``matplotlib.pyplot.imshow`` for details.
"""
rdm = rdms[ch_idx]
rdm = distance.squareform(rdm)
ax.imshow(rdm, cmap=cmap)
def _plot_rdms_topo_timepoint(
rdms,
info,
layout=None,
fig=None,
title=None,
axis_facecolor="w",
axis_spinecolor="w",
fig_facecolor="w",
figsize=(6.4, 4.8),
cmap="viridis",
show=False,
):
"""Plot RDMs on 2D MEG topography.
Parameters
----------
rdms: ndarray, shape (n_sensors, n_rdm_datapoints) | generator
RDMs of MEG recordings; one RDM for each sensor. Can also be a generator of RDMs
as produced by the :func:`rdm_epochs`,
:func:`rdm_evokeds` or :func:`rdm_array` functions.
info: mne.io.meas_info.Info
Info object that contains meta data of MEG recordings.
layout: mne.channels.layout.Layout | None
Layout objects containing sensor layout info. The default (``None``) will figure
out layout based on info.
fig: matplotlib.pyplot.Figure | None
Figure object on which RDMs on 2D MEG topography are plotted. The default
(``None``) creates a new Figure object.
title: str | None
Title of the plot, used only when ``fig=None``. The default (``None``) puts no
title in the figure.
axis_facecolor: str
Face color of the each RDM. Defaults to 'w', white.
axis_spinecolor: str
Spine color of each RDM. Defaults to 'w', white.
fig_facecolor: str
Face color of the entire topography. Defaults to 'w', white.
figsize: tuple of float
Figure size. The first element specify width and the second height.
Defaults to (6.4, 4.8).
cmap: str
Colormap used for plotting RDMs. Defaults to 'viridis'.
Check :func:`matplotlib.pyplot.imshow` for details.
show: bool
Whether to display the generated figure. Defaults to False.
Returns
-------
fig: matplotlib.pyplot.Figure
Figure object in which RDMs are plotted on 2D MEG topography.
"""
on_pick = partial(_click_func, rdms=rdms, cmap=cmap)
if fig is None:
fig = plt.figure(figsize=figsize)
if title is not None:
fig.suptitle(title, x=0.98, horizontalalignment="right")
else:
fig = plt.figure(fig.number)
my_topo_plot = _iter_topography(
info=info,
layout=layout,
on_pick=on_pick,
fig=fig,
axis_facecolor=axis_facecolor,
axis_spinecolor=axis_spinecolor,
fig_facecolor=fig_facecolor,
unified=False,
)
for i, (ax, _) in enumerate(my_topo_plot):
rdms_i = rdms[i]
rdms_i = distance.squareform(rdms_i)
ax.imshow(rdms_i, cmap=cmap)
if show:
fig.show()
return fig
def plot_rdms_topo(
rdms,
info,
time=None,
layout=None,
fig=None,
axis_facecolor="w",
axis_spinecolor="w",
fig_facecolor="w",
figsize=(6.4, 4.8),
cmap="viridis",
show=True,
):
"""Plot RDMs on 2D sensor topography.
Parameters
----------
rdms: ndarray | numpy.memmap, shape (n_sensors,[ n_times,] n_rdm_datapts)
RDMs of MEG/EEG recordings; one RDM for each sensor and time point.
info: mne.io.meas_info.Info
Info object that contains meta data of MEG/EEG recordings.
time: int | [int, int] | None
A time point (int) or time window ([int, int]) for which RDMs are plotted. When
a time window is given, averge RDMs for the window are plotted. The default
(``None``) plots the average RDMs of all the time points. Start of the time
window is inclusive, while the end is exclusive.
layout: mne.channels.layout.Layout, optional
Layout objects containing sensor layout info. The default, ``layout=None``, will
figure out layout based on info.
fig: matplotlib.pyplot.Figure | None, optional
Figure object on which RDMs on 2D sensor topography are plotted. The default
(``None``) creates a new Figure object with a title based on time parameter.
axis_facecolor: str, optional
Face color of the each RDM. Defaults to 'w', white.
axis_spinecolor: str, optional
Spine color of each RDM. Defaults to 'w', white.
fig_facecolor: str, optional
Face color of the entire topography. Defaults to 'w', white.
figsize: tuple of float, optional
Figure size. The first element specify width and the second height.
Defaults to (6.4, 4.8).
cmap: str, optional
Colormap used for plotting RDMs. Defaults to 'viridis'.
Check :func:`matplotlib.pyplot.imshow` for details.
show: bool, optional
Whether to display the generated figure. Defaults to ``True``.
Returns
-------
fig: matplotlib.pyplot.Figure
Figure object in which RDMs are plotted on 2D sensor topography.
"""
if isinstance(rdms, types.GeneratorType):
rdms = np.array(list(rdms))
if rdms.ndim != 2 and rdms.ndim != 3:
raise ValueError(
"rdms have to be a 2D or 3D ndarray or numpy.memmap, "
"[n_sensors,[ n_times,] n_rdm_datapoints]"
)
if len(rdms.shape) == 2:
rdms = rdms[:, np.newaxis, :]
if time is None:
time = [0, rdms.shape[1]]
if isinstance(time, int):
time = [time, time + 1]
if not isinstance(time, list):
raise TypeError("time has to be int, list of [int, int] or None.")
if (not all(isinstance(i, int) for i in time)) or (len(time) != 2):
raise TypeError("time has to be int, list of [int, int] or None.")
if time[0] >= time[1]:
raise ValueError(
"The start of the time window has to be smaller "
"than the end of the time window."
)
if time[0] < 0 or time[1] > rdms.shape[1]:
raise ValueError(
"The time window is out of range. The minimum is 0 "
f"and the maximum is {rdms.shape[1]}"
)
if (fig is not None) and (not isinstance(fig, plt.Figure)):
raise TypeError("fig has to be matplotlib.pyplot.Figure or None.")
rdms_cropped = rdms[:, time[0] : time[1], :]
rdms_avg = rdms_cropped.mean(axis=1)
# set title to time window
if time[0] + 1 != time[1]:
title = f"From {time[0]} (inclusive) to {time[1]} (exclusive)"
else:
title = f"Time point: {time[0]}"
fig = _plot_rdms_topo_timepoint(
rdms_avg,
info,
fig=fig,
layout=layout,
title=title,
axis_facecolor=axis_facecolor,
axis_spinecolor=axis_spinecolor,
fig_facecolor=fig_facecolor,
figsize=figsize,
cmap=cmap,
show=show,
)
return fig
def plot_roi_map(values, rois, subject, subjects_dir, cmap="plasma", alpha=1.0):
"""Plot ROI values on a FreeSurfer brain.
Parameters
----------
values : array-like, shape (n_rois,)
The values to plot. One value per ROI.
rois : list of mne.Label
The labels corrsponding to the ROIs.
subject : str
The name of the FreeSurfer subject to plot the brain for.
subjects_dir : str
The folder in which the FreeSurfer subject data is kept. Inside this folder
should be a folder with the same name as the `subject` parameter.
cmap : str
The name of the matplotlib colormap to use. Defaults to 'plasma'.
alpha : float
The alpha (opacity, 1.0 is fully opaque, 0.0 is fully transparant) of the data
being plotted on top of the brain.
Returns
-------
brain : mne.viz.Brain
The MNE-Python brain plotting object that was created and currently
being shown. You can use this to modify the plot.
"""
cmap = get_cmap(cmap)
max_val = np.max(values)
brain = Brain(
subject=subject, subjects_dir=subjects_dir, surf="inflated", hemi="both"
)
labels_lh = np.zeros(len(brain.geo["lh"].coords), dtype=int)
labels_rh = np.zeros(len(brain.geo["rh"].coords), dtype=int)
ctab_lh = list()
ctab_rh = list()
for i, (roi, value) in enumerate(zip(rois, values), 1):
if roi.hemi == "lh":
labels = labels_lh
ctab = ctab_lh
else:
labels = labels_rh
ctab = ctab_rh
labels[roi.vertices] = i
ctab.append([int(x * 255) for x in cmap(value / max_val)[:4]] + [i])
ctab_lh = np.array(ctab_lh)
ctab_rh = np.array(ctab_rh)
brain.add_annotation(
[(labels_lh, ctab_lh), (labels_rh, ctab_rh)], borders=False, alpha=alpha
)
return brain