# -*- coding: utf-8 -*-
"""Connectivity analysis using Dynamic Imaging of Coherent Sources (DICS).
Authors: Susanna Aro <susanna.aro@aalto.fi>
Marijn van Vliet <w.m.vanvliet@gmail.com>
"""
import copy
import types
import numpy as np
from h5io import read_hdf5, write_hdf5
from mne import BiHemiLabel, Forward, Label, SourceSpaces, pick_channels_forward
from mne.parallel import parallel_func
from mne.source_estimate import _make_stc
from mne.source_space._source_space import (
_ensure_src,
_ensure_src_subject,
_get_morph_src_reordering,
)
from mne.time_frequency import pick_channels_csd
from mne.utils import copy_function_doc_to_method_doc, logger, verbose
from scipy import sparse
from scipy.spatial.distance import cdist, pdist
from .forward import forward_to_tangential
from .utils import reg_pinv
from .viz import plot_connectivity
class BaseConnectivity(object):
"""Base class for connectivity objects.
Contains implementation of methods that are defined for all connectivity
objects.
Parameters
----------
data : ndarray, shape (n_pairs,)
For each connectivity source pair, a value describing the connection.
For example, this can be the strength of the connection between the
sources.
pairs : ndarray, shape (n_pairs, 2)
The sources involved in the from-to connectivity pair. The sources
are listed as indices of the list given as the ``sources`` parameter.
n_sources : int
The number of sources between which connectivity is defined.
source_degree : tuple of lists (out_degree, in_degree) | None
For each source, the total number of possible connections from and to
the source. This information is needed to perform weighting on the
number of connections during visualization and statistics. If ``None``,
it is assumed that all possible connections are defined in the
``pairs`` parameter and the out- and in-degree of each source is
computed.
subject : str | None
The subject-id. Defaults to ``None``.
directed : bool
Whether the connectivity is directed (from->to != to->from). Defaults
to False.
Attributes
----------
n_connections : int
The number of connections.
"""
def __init__(
self, data, pairs, n_sources, source_degree=None, subject=None, directed=False
):
self.data = np.asarray(data)
pairs = np.asarray(pairs)
if pairs.shape[1] != len(data):
raise ValueError(
"The number of pairs does not match the number "
"of items in the data list."
)
if pairs.shape[1] > 0 and n_sources < pairs.max():
raise ValueError(
"Pairs are defined between non-existent sources "
"(n_sources=%d)." % n_sources
)
self.pairs = pairs
self.n_sources = n_sources
self.subject = subject
self.directed = directed
if source_degree is not None:
source_degree = np.asarray(source_degree)
if source_degree.shape[1] != n_sources:
raise ValueError(
"The length of the source_degree list does "
"not match the number of sources."
)
self.source_degree = source_degree
else:
self.source_degree = np.asarray(_compute_degree(pairs, n_sources))
def __repr__(self):
return "<{} | n_sources={}, n_conns={}, subject={}>".format(
self.__class__.__name__, self.n_sources, self.n_connections, self.subject
)
@property
def n_connections(self):
"""The number of connections."""
return len(self.data)
def copy(self):
"""Return copy of the Connectivity object."""
return copy.deepcopy(self)
def __setstate__(self, state): # noqa: D105
self.data = state["data"]
self.pairs = state["pairs"]
self.n_sources = state["n_sources"]
if "source_degree" in state:
self.source_degree = state["source_degree"]
else:
self.source_degree = _compute_degree(self.pairs, self.n_sources)
self.subject = state["subject"]
self.directed = state["directed"]
def __getstate__(self): # noqa: D105
return dict(
data=self.data,
pairs=self.pairs,
subject=self.subject,
n_sources=self.n_sources,
source_degree=self.source_degree,
directed=self.directed,
)
def save(self, fname):
"""Save the connectivity object to an HDF5 file.
Parameters
----------
fname : str
The name of the file to save the connectivity to. The extension
'.h5' will be appended if the given filename doesn't have it
already.
See Also
--------
read_connectivity : For reading connectivity objects from a file.
"""
if not fname.endswith(".h5"):
fname += ".h5"
write_hdf5(fname, self.__getstate__(), overwrite=True, title="conpy")
def get_adjacency(self):
"""Get a source-to-source adjacency matrix.
Each non-zero element in the matrix indicates a connection exists
between the sources. The value of the element is the strength of the
connection.
"""
A = sparse.csr_matrix(
(self.data, self.pairs),
shape=(self.n_sources, self.n_sources),
)
if self.directed:
return A
else:
return A + A.T
def threshold(self, thresh, crit=None, direction="above", copy=False):
"""Threshold the connectivity.
Only retain the connections which exceed a given threshold.
Parameters
----------
thresh : float
threshold limit
crit : None | ndarray, shape (n_connections,)
An array containing for each connection, a value which must pass
the threshold for the connection to be retained. By default, this
is the data value of the connection. Common uses for this parameter
include thresholding connections based on t-values or p-values.
direction: 'above' | 'below'
Defines whether the `thres_data` must be above or below the given
threshold in order for the vertex-pair to be retained. Defaults to
'above'.
copy : bool
Whether to operate in place (``False``, the default) or on a copy
(``True``).
Returns
-------
thresholded_con : instance of Connectivity
The thresholded version of the connectivity.
"""
if crit is None:
crit = self.data
elif len(crit) != self.n_connections:
raise ValueError(
"The number of items in `crit` does not match "
"the number of connections."
)
# Convert crit into a binary mask
if direction == "above":
mask = crit > thresh
elif direction == "below":
mask = crit < thresh
else:
raise ValueError(
'The direction parameter must be either "above" ' 'or "below".'
)
if copy:
thresholded_con = self.copy()
else:
thresholded_con = self
thresholded_con.data = self.data[mask]
thresholded_con.pairs = self.pairs[:, mask]
return thresholded_con
def __getitem__(self, index):
"""Select connections without making a deep copy."""
# Create an "empty" connection object
con = self.__class__.__new__(self.__class__)
# Construct the fields for the newconnection object.
state = self.__getstate__()
state["data"] = self.data[index]
state["pairs"] = self.pairs[:, index]
# Set the fields of the new connection object
con.__setstate__(state)
return con
def is_compatible(self, other):
"""Check compatibility with another connectivity object.
Two connectivity objects are compatible if they define the same
connectivity pairs.
Returns
-------
is_compatible : bool
Whether the given connectivity object is compatible with this one.
"""
return (
isinstance(other, BaseConnectivity)
and other.n_sources == self.n_sources
and np.array_equal(other.pairs, self.pairs)
)
def __iadd__(self, other): # noqa: D105
if self.is_compatible(other):
self.data += other.data
return self
def __add__(self, other): # noqa: D105
return self.copy().__iadd__(other)
def __isub__(self, other): # noqa: D105
if self.is_compatible(other):
self.data -= other.data
return self
def __sub__(self, other): # noqa: D105
return self.copy().__isub__(other)
def __idiv__(self, other): # noqa: D105
if self.is_compatible(other):
self.data /= other.data
return self
def __div__(self, other): # noqa: D105
con = self.copy()
# Always use floating point for division
con.data = con.data.astype("float")
return con.__idiv__(other)
def __truediv__(self, other): # noqa: D105
return self.__div__(other)
def __itruediv__(self, other): # noqa: D105
con = self.copy()
# Always use floating point for division
con.data = con.data.astype("float")
return con.__idiv__(other)
def __imul__(self, other): # noqa: D105
if self.is_compatible(other):
self.data *= other.data
return self
def __mul__(self, other): # noqa: D105
return self.copy().__imul__(other)
def __ipow__(self, other): # noqa: D105
if self.is_compatible(other):
self.data **= other.data
return self
def __pow__(self, other): # noqa: D105
return self.copy().__ipow__(other)
def __neg__(self): # noqa: D105
self.data *= -1
return self
def __radd__(self, other): # noqa: D105
return self + other
def __rsub__(self, other): # noqa: D105
return self - other
def __rmul__(self, other): # noqa: D105
return self * other
def __rdiv__(self, other): # noqa: D105
return self / other
def _compute_degree(pairs, n_sources):
"""Compute out- and in- degree of each source.
Computes for each source, the number of connections from and to the source.
Parameters
----------
pairs : ndarray, shape (n_pairs, 2)
The indices of the sources involved in the from-to connectivity pair.
n_sources : int
The total number of sources.
Returns
-------
out_degree : ndarray, shape (n_sources,)
The number of outgoing connections for each source.
in_degree : ndarray, shape (n_sources,)
The number of incoming connections for each source.
"""
out_degree = np.zeros(n_sources, dtype=int)
ind, degree = np.unique(pairs[0], return_counts=True)
out_degree[ind] = degree
in_degree = np.zeros(n_sources, dtype=int)
ind, degree = np.unique(pairs[1], return_counts=True)
in_degree[ind] = degree
return out_degree, in_degree
[docs]
class VertexConnectivity(BaseConnectivity):
"""Estimation of connectivity between vertices.
Parameters
----------
data : ndarray, shape (n_pairs,)
For each connectivity source pair, a value describing the connection.
For example, this can be the strength of the connection between the
sources.
pairs : ndarray, shape (n_pairs, 2)
The vertices involved in the from-to connectivity pair. The vertices
are listed as "vertex indices" in the array:
``np.hstack((vertices[0], (vertices[1] + len(vertices[0]))))``
vertices : list of two arrays of shape (n_vertices,)
For each hemisphere, the vertex numbers of sources defined in the
corresponding source space.
vertex_degree : tuple of lists (out_degree, in_degree) | None
For each vertex, the total number of possible connections from and to
the vertex. This information is needed to perform weighting on the
number of connections during visualization and statistics. If ``None``,
it is assumed that all possible connections are defined in the
``pairs`` parameter and the out- and in-degree of each vertex is
computed.
subject : str | None
The subject-id.
directed : bool
Whether the connectivity is directed (from->to != to->from). Defaults
to False.
Attributes
----------
n_connections : int
The number of connections.
n_sources : int
The number of sources between possible connections were computed.
"""
def __init__(
self, data, pairs, vertices, vertex_degree=None, subject=None, directed=False
):
if len(vertices) != 2:
raise ValueError(
"The `vertices` parameter should be a list of " "two arrays."
)
self.vertices = [np.asarray(v) for v in vertices]
n_vertices = len(self.vertices[0]) + len(self.vertices[1])
super().__init__(
data=data,
pairs=pairs,
n_sources=n_vertices,
source_degree=vertex_degree,
subject=subject,
directed=directed,
)
def make_stc(self, summary="sum", weight_by_degree=True):
"""Obtain a summary of the connectivity as a SourceEstimate object.
Parameters
----------
summary : 'sum' | 'degree' | 'absmax'
How to summarize the adjacency data:
'sum' : sum the strenghts of both the incoming and outgoing connections
for each source.
'degree': count the number of incoming and outgoing connections for each
source.
'absmax' : show the strongest coherence across both incoming and outgoing
connections at each source. In this setting, the
``weight_by_degree`` parameter is ignored.
Defaults to ``'sum'``.
weight_by_degree : bool
Whether to weight the summary by the number of possible
connections. Defaults to ``True``.
Returns
-------
stc : instance of SourceEstimate
The summary of the connectivity.
"""
if self.vertices is None:
raise ValueError("Stc needs vertices!")
if summary == "degree":
vert_inds, data = np.unique(self.pairs, return_counts=True)
n_vert_lh = len(self.vertices[0])
lh_inds = vert_inds < n_vert_lh
vertices = [
self.vertices[0][vert_inds[lh_inds]],
self.vertices[1][vert_inds[~lh_inds] - n_vert_lh],
]
elif summary == "sum":
A = self.get_adjacency()
data = A.sum(axis=0).T + A.sum(axis=1)
vertices = self.vertices
# These are needed later in order to weight by degree
vert_inds = np.arange(len(self.vertices[0]) + len(self.vertices[1]))
# For undirected connectivity objects, all connections have been
# counted twice.
if not self.directed:
data = data / 2.0
elif summary == "absmax":
A = self.get_adjacency()
in_max = A.max(axis=0).toarray().ravel()
out_max = A.max(axis=1).toarray().ravel()
data = np.maximum(in_max, out_max)
vertices = self.vertices
else:
raise ValueError(
'The summary parameter must be "degree", or ' '"sum", or "absmax".'
)
data = np.asarray(data, dtype="float").ravel()
if weight_by_degree and summary != "absmax":
degree = self.source_degree[:, vert_inds].sum(axis=0)
# Prevent division by zero
zero_mask = degree == 0
data[~zero_mask] /= degree[~zero_mask]
data[zero_mask] = 0
return _make_stc(
data[:, np.newaxis],
vertices=vertices,
tmin=0,
tstep=1,
subject=self.subject,
)
@verbose
def parcellate(self, labels, summary="sum", weight_by_degree=True, verbose=None):
"""Get the connectivity parcellated according to the given labels.
The coherence of all connections within a label are averaged.
Parameters
----------
labels : list of (Label | BiHemiLabel)
The labels to use to parcellate the connectivity.
summary : 'sum' | 'degree' | 'absmax' | function
How to summarize the connectivity within a label. Either the
summation of the connection values ('sum'), the number of
connections from and to the label is used ('degree'), the absolute
maximum value of the connections ('absmax'), or a function can be
specified, which is called for each label with the following
signature:
>>> def summary(adjacency, vert_from, vert_to):
... '''Summarize the connections within a label.
...
... Parameters
... ----------
... adjacency : sparse matrix, shape (n_sources, n_sources)
... The adjacency matrix that defines the connection
... between the sources.
... src_from : list of int
... Indices of sources that are outside of the label.
... src_to : list of int
... Indices of sources that are inside the label.
...
... Returns
... -------
... coh : float
... Summarized coherence of the parcel.
weight_by_degree : bool
Whether to weight the summary of each label by the number of
possible connections from and to that label. Defaults to ``True``.
verbose : bool | str | int | None
If not None, override default verbose level (see
:func:`mne.verbose` and :ref:`Logging documentation <tut_logging>`
for more).
Returns
-------
coh_parc : LabelConnectivity
The parcellated connectivity.
See Also
--------
mne.read_labels_from_annot : To read a list of labels from a FreeSurfer
annotation.
"""
if not isinstance(labels, list):
raise ValueError("labels must be a list of labels")
# Make sure labels and connectivity are compatible
if (
labels[0].subject is not None
and self.subject is not None
and labels[0].subject != self.subject
):
raise RuntimeError(
"label and connectivity must have same subject names, "
'currently "%s" and "%s"' % (labels[0].subject, self.subject)
)
if summary == "degree":
def summary(c, f, t):
return float(c[f, :][:, t].nnz)
elif summary == "sum":
def summary(c, f, t):
return c[f, :][:, t].sum()
elif summary == "absmax":
def summary(c, f, t):
if len(f) == 0 or len(t) == 0:
return 0.0
else:
return np.abs(c[f, :][:, t]).max()
elif not isinstance(summary, types.FunctionType):
raise ValueError(
'The summary parameter must be "degree", "sum" '
'"absmax" or a function.'
)
logger.info("Computing out- and in-degree for each label...")
n_labels = len(labels)
label_degree = np.zeros((2, n_labels), dtype=int)
for i, label in enumerate(labels):
vert_ind = _get_vert_ind_from_label(self.vertices, label)
label_degree[:, i] = self.source_degree[:, vert_ind].sum(axis=1)
logger.info("Summarizing connectivity...")
adjacency = self.get_adjacency()
pairs = np.triu_indices(n_labels, k=1)
n_pairs = len(pairs[0])
summary_parc = np.zeros(n_pairs)
prev_from = -1
for pair_i, (lab_from, lab_to) in enumerate(zip(*pairs)):
if lab_from != prev_from:
logger.info(" in %s" % labels[lab_from].name)
prev_from = lab_from
vert_from = _get_vert_ind_from_label(self.vertices, labels[lab_from])
vert_to = _get_vert_ind_from_label(self.vertices, labels[lab_to])
val = summary(adjacency, vert_from, vert_to)
if weight_by_degree:
degree = label_degree[0, lab_from] + label_degree[1, lab_to]
if degree == 0:
# Prevent division by 0
val = 0
else:
val /= degree
summary_parc[pair_i] = val
# Drop connections with a value of zero. We take this to mean that no
# connection exists.
nonzero_inds = np.flatnonzero(summary_parc)
pairs = np.array(pairs)[:, nonzero_inds]
summary_parc = summary_parc[nonzero_inds]
logger.info("[done]")
return LabelConnectivity(
data=summary_parc,
pairs=pairs,
labels=labels,
label_degree=label_degree,
subject=self.subject,
)
def __setstate__(self, state): # noqa: D105
super().__setstate__(state)
self.vertices = state["vertices"]
def __getstate__(self): # noqa: D105
state = super().__getstate__()
state.update(
type="all-to-all",
vertices=self.vertices,
)
return state
def is_compatible(self, other):
"""Check compatibility with another connectivity object.
Two connectivity objects are compatible if they define the same
connectivity pairs.
Returns
-------
is_compatible : bool
Whether the given connectivity object is compatible with this one.
"""
return (
isinstance(other, VertexConnectivity)
and np.array_equal(other.vertices[0], self.vertices[0])
and np.array_equal(other.vertices[1], self.vertices[1])
and np.array_equal(other.pairs, self.pairs)
)
def to_original_src(
self, src_orig, subject_orig=None, subjects_dir=None, verbose=None
):
"""Get the connectivity from a morphed source to the original subject.
Parameters
----------
src_orig : instance of SourceSpaces
The original source spaces that were morphed to the current
subject.
subject_orig : str | None
The original subject. For most source spaces this shouldn't need
to be provided, since it is stored in the source space itself.
subjects_dir : string, or None
Path to SUBJECTS_DIR if it is not set in the environment.
verbose : bool | str | int | None
If not None, override default verbose level (see
:func:`mne.verbose` and :ref:`Logging documentation <tut_logging>`
for more).
Returns
-------
con : instance of VertexConnectivity
The transformed connectivity.
See Also
--------
mne.morph_source_spaces
"""
if self.subject is None:
raise ValueError("con.subject must be set")
src_orig = _ensure_src(src_orig, kind="surface")
subject_orig = _ensure_src_subject(src_orig, subject_orig)
data_idx, vertices = _get_morph_src_reordering(
vertices=self.vertices,
src_from=src_orig,
subject_from=subject_orig,
subject_to=self.subject,
subjects_dir=subjects_dir,
verbose=verbose,
)
# Map the pairs to new vertices
mapping = np.argsort(data_idx)
pairs = [[mapping[p_] for p_ in p] for p in self.pairs]
vertex_degree = self.source_degree[:, data_idx]
return VertexConnectivity(
data=self.data,
pairs=pairs,
vertices=vertices,
vertex_degree=vertex_degree,
subject=subject_orig,
)
[docs]
class LabelConnectivity(BaseConnectivity):
"""Estimation of all-to-all connectivity, parcellated into labels.
Parameters
----------
data : ndarray, shape (n_pairs,)
For each connectivity source pair, a value describing the connection.
For example, this can be the strength of the connection between the
sources.
pairs : ndarray, shape (n_pairs, 2)
The index of the labels involved in the from-to connectivity pair.
labels : list of instance of Label
The labels between which connectivity has been computed.
label_degree : tuple of lists (out_degree, in_degree) | None
For each label, the total number of possible connections from and to
the label. This information is needed to perform weighting on the
number of connections during visualization and statistics. If ``None``,
it is assumed that all possible connections are defined in the
``pairs`` parameter and the out- and in-degree of each label is
computed.
subject : str | None
The subject-id.
Attributes
----------
n_connections : int
The number of connections.
"""
def __init__(self, data, pairs, labels, label_degree=None, subject=None):
if not isinstance(labels, list):
raise ValueError("labels must be a list of labels")
super().__init__(
data=data,
pairs=pairs,
n_sources=len(labels),
source_degree=label_degree,
subject=subject,
)
self.labels = labels
@copy_function_doc_to_method_doc(plot_connectivity)
def plot( # noqa
self,
n_lines=None,
node_angles=None,
node_width=None,
node_colors=None,
facecolor="black",
textcolor="white",
node_edgecolor="black",
linewidth=1.5,
colormap="hot",
vmin=None,
vmax=None,
colorbar=True,
title=None,
colorbar_size=0.2,
colorbar_pos=(-0.3, 0.1),
fontsize_title=12,
fontsize_names=8,
fontsize_colorbar=8,
padding=6.0,
fig=None,
subplot=111,
interactive=True,
node_linewidth=2.0,
show=True,
):
return plot_connectivity(
self,
n_lines=n_lines,
node_angles=node_angles,
node_width=node_width,
node_colors=node_colors,
facecolor=facecolor,
textcolor=textcolor,
node_edgecolor=node_edgecolor,
linewidth=linewidth,
colormap=colormap,
vmin=vmin,
vmax=vmax,
colorbar=colorbar,
title=title,
colorbar_size=colorbar_size,
colorbar_pos=colorbar_pos,
fontsize_title=fontsize_title,
fontsize_names=fontsize_names,
fontsize_colorbar=fontsize_colorbar,
padding=padding,
fig=fig,
subplot=subplot,
interactive=interactive,
node_linewidth=node_linewidth,
show=show,
)
def is_compatible(self, other):
"""Check compatibility with another connectivity object.
Two connectivity objects are compatible if they define the same
connectivity pairs.
Returns
-------
is_compatible : bool
Whether the given connectivity object is compatible with this one.
"""
return (
isinstance(other, LabelConnectivity)
and np.array_equal(other.pairs, self.pairs)
and np.all(
[
np.array_equal(l1.vertices, l2.vertices)
for l1, l2 in zip(other.labels, self.labels)
]
)
and np.all(
[
np.array_equal(l1.values, l2.values)
for l1, l2 in zip(other.labels, self.labels)
]
)
)
def __setstate__(self, state): # noqa: D105
super(LabelConnectivity, self).__setstate__(state)
self.labels = [Label(*label) for label in state["labels"]]
def __getstate__(self): # noqa: D105
state = super(LabelConnectivity, self).__getstate__()
state.update(
type="label",
labels=[label.__getstate__() for label in self.labels],
)
return state
def _get_vert_ind_from_label(vertices, label):
"""Get the indices of the vertices that fall within a given label.
Parameters
----------
vertices : list of ndarray
For each hemisphere, the vertex numbers.
label : instance of Label | BiHemiLabel
The label for which to get the vertex indices.
Returns
-------
vertex_ind : ndarray
The indices of the vertices that fall within the given label.
"""
if not isinstance(label, Label) and not isinstance(label, BiHemiLabel):
raise TypeError("Expected Label or BiHemiLabel; got %r" % label)
if label.hemi == "both":
vertex_ind_lh = _get_vert_ind_from_label(vertices, label.lh)
vertex_ind_rh = _get_vert_ind_from_label(vertices, label.rh)
return np.hstack((vertex_ind_lh, vertex_ind_rh))
elif label.hemi == "lh":
verts_present = np.intersect1d(vertices[0], label.vertices)
return np.searchsorted(vertices[0], verts_present)
elif label.hemi == "rh":
verts_present = np.intersect1d(vertices[1], label.vertices)
return np.searchsorted(vertices[1], verts_present) + len(vertices[0])
[docs]
def read_connectivity(fname):
"""Read a Connectivity object from an HDF5 file.
Parameters
----------
fname : str
The name of the file to read the connectivity from. The extension '.h5'
will be appended if the given filename doesn't have it already.
Returns
-------
connectivity : instance of Connectivity
The Connectivity object that was stored in the file.
See Also
--------
Connectivity.save : For saving connectivity objects
"""
if not fname.endswith(".h5"):
fname += ".h5"
con_dict = read_hdf5(fname, title="conpy")
con_type = con_dict["type"]
del con_dict["type"]
if con_type == "all-to-all":
return VertexConnectivity(
data=con_dict["data"],
pairs=con_dict["pairs"],
vertices=con_dict["vertices"],
vertex_degree=con_dict["source_degree"],
subject=con_dict["subject"],
)
elif con_type == "label":
labels = [Label(**label) for label in con_dict["labels"]]
return LabelConnectivity(
data=con_dict["data"],
pairs=con_dict["pairs"],
labels=labels,
label_degree=con_dict["source_degree"],
subject=con_dict["subject"],
)
[docs]
def all_to_all_connectivity_pairs(src_or_fwd, min_dist=0.04):
"""Obtain pairs of vertices to compute all-to-all connectivity for.
This is needed for all-to-all connectivity. Calculates all the pairs of
vertices that are further away from each other than the selected distance
limit.
Parameters
----------
src_or_fwd : instance of SourceSpaces | instance of Forwxard
The source space or forward model to obtain vertex pairs for.
min_dist: float
The minimum distance between vertices (in meters). Defaults to 0.04.
Returns
-------
vert_from : ndarray, shape (n_pairs,)
For each pair, the index of the first vertex.
vert_to : ndarray, shape (n_pairs,)
For each pair, the index of the second vertex.
See Also
--------
one_to_all_connectivity_pairs : Obtain pairs for one-to-all connectivity.
"""
# Get coordinates of the vertices
if isinstance(src_or_fwd, SourceSpaces):
vertno_lh = src_or_fwd[0]["vertno"]
vertno_rh = src_or_fwd[1]["vertno"]
grid_points = np.vstack(
(src_or_fwd[0]["rr"][vertno_lh], src_or_fwd[1]["rr"][vertno_rh])
)
elif isinstance(src_or_fwd, Forward):
grid_points = src_or_fwd["source_rr"]
else:
raise ValueError(
"Source must be instance of Forward or a list", "of SourceSpaces"
)
n_sources = len(grid_points)
# Compute indices of all pairs
vert_from, vert_to = np.triu_indices(n_sources, k=1)
# Select the pairs that are further away than the distance limit
selection = pdist(grid_points) >= min_dist
# Converting this to a list of tuples is very slow, so let's keep it like
# this for now.
return vert_from[selection], vert_to[selection]
[docs]
def one_to_all_connectivity_pairs(src_or_fwd, ref_point, min_dist=0):
"""Obtain pairs of vertices to compute one-to-all connectivity for.
This is needed for one-to-all connectivity. Calculates all the pairs where
the vertex is further away from reference point than the selected distance
limit.
Parameters
----------
src_or_fwd : instance of SourceSpaces | instance of Forward
The source space or forward model to obtain vertex pairs for.
ref_point: int
Index of the vertex that will serve as reference point.
min_dist: float
The minimum distance between vertices (in meters). Defaults to 0.
Returns
-------
vert_from : ndarray, shape (n_pairs,)
For each pair, the index of the first vertex. This is always the index
of the refence point.
vert_to : ndarray, shape (n_pairs,)
For each pair, the index of the second vertex.
See Also
--------
all_to_all_connectivity_pairs : Obtain pairs for all-to-all connectivity.
"""
# Get coordinates of the vertices
if isinstance(src_or_fwd, SourceSpaces):
vertno_lh = src_or_fwd[0]["vertno"]
vertno_rh = src_or_fwd[1]["vertno"]
grid_points = np.vstack(
(src_or_fwd[0]["rr"][vertno_lh], src_or_fwd[1]["rr"][vertno_rh])
)
elif isinstance(src_or_fwd, Forward):
grid_points = src_or_fwd["source_rr"]
else:
raise ValueError(
"Source must be instance of Forward or a list", "of SourceSpaces"
)
# Select the pairs that are further away than the distance limit
dist = cdist(grid_points[ref_point][np.newaxis], grid_points)
vert_to = np.flatnonzero(dist >= min_dist)
n_pairs = len(vert_to)
vert_from = np.asarray(ref_point).repeat(n_pairs)
return vert_from, vert_to
try:
import numba as nb
@nb.jit(nb.complex128[:, :, :](nb.complex128[:, :, :], nb.complex128[:, :]))
def _compute_opt1(x, y):
r = np.zeros((x.shape[0], y.shape[1], y.shape[1]), dtype=nb.complex128)
for i in range(len(x)):
r[i, :, :] = np.dot(np.dot(y.T, x[i, :, :]), y)
return r
@nb.jit(
nb.complex128[:, :, :](
nb.complex128[:, :, :], nb.complex128[:, :, :], nb.int64[:], nb.int64[:]
)
)
def _compute_power_cross_inv(x, y, x_ind, y_ind):
r = np.zeros((x_ind.shape[0], x.shape[0], y.shape[2]), dtype=nb.complex128)
i = 0
for x_i, y_i in zip(x_ind, y_ind):
r[i, :, :] = np.dot(x[:, x_i, :], y[:, y_i, :])
i += 1
return r
@nb.jit(nb.complex128[:, :, :](nb.complex128[:, :, :], nb.complex128[:, :, :]))
def _compute_power_cross_inv2(x, y):
r = np.zeros((x.shape[1], x.shape[0], y.shape[2]), dtype=nb.complex128)
for i in range(x.shape[1]):
r[i, :, :] = np.dot(x[:, i, :], y[:, i, :])
return r
numba_enabled = True
except Exception:
numba_enabled = False
def _compute_dics_coherence(
W,
G,
vert_ind_from,
vert_ind_to,
spec_power_inv,
orientations,
coh_metric="absolute",
):
"""Compute the coherence between two sources using a DICS beamformer.
Computes the coherence between two dipoles for different angles and returns
the maximum value.
Parameters
----------
W : ndarray, shape (n_orient, n_sources, n_sensors)
The beamformer filter weights.
G : ndarray, shape (n_sensors, n_sources, n_orient)
The leadfield.
vert_ind_from : ndarray, shape (n_pairs,)
For each vertex-pair to compute the connectivity for, the index of the
first vertex.
vert_ind_to : ndarray, shape (n_pairs,)
For each vertex-pair to compute the connectivity for, the index of the
second vertex.
spec_power_inv : ndarray, shape (n_sources, n_orient, n_orient)
Inverse of cross-spectral power between the dipoles at each source
location.
orientations : ndarray, shape (n_orient, n_angles)
For each angle to try, a unit vector pointing in the direction of the
angle.
coh_metric : 'absolute' | 'imaginary'
The coherence metric to use. Either the square of absolute coherence
('absolute') or the square of the imaginary part of the coherence
('imaginary'). Defaults to 'absolute'.
Returns
-------
coherence : ndarray, shape (n_pairs,)
For each vertex-pair, the coherence in the direction of maximum
coherence.
"""
power_from_inv = spec_power_inv[vert_ind_from]
power_to_inv = spec_power_inv[vert_ind_to]
if numba_enabled:
power_cross_inv = _compute_power_cross_inv(
W, G.astype("complex"), vert_ind_from, vert_ind_to
)
opt1 = _compute_opt1(power_cross_inv, orientations.astype("complex"))
else:
# Computes W @ G
power_cross_inv = np.einsum(
"ijk,kjl->jil", W[:, vert_ind_from, :], G[:, vert_ind_to, :]
)
# Computes orientations.T @ power_cross_inv @ orientations
opt1 = power_cross_inv.dot(orientations)
opt1 = opt1.transpose(0, 2, 1).dot(orientations).transpose(0, 2, 1)
if coh_metric == "absolute":
opt1 = np.abs(opt1)
elif coh_metric == "imaginary":
opt1 = np.imag(opt1)
# Computes np.diag(orientations.T @ power_from_inv @ orientations)
opt2 = np.sum(orientations * power_from_inv.dot(orientations), axis=1)
# Computes np.diag(orientations.T @ power_to_inv @ orientations)
opt3 = np.sum(orientations * power_to_inv.dot(orientations), axis=1)
# Compute coherence for each orientation
opt = (opt1**2) / (opt2[:, :, np.newaxis] * opt3[:, np.newaxis, :])
# Pick the best orientation as the final coherence value
return np.real(np.max(opt, axis=(1, 2)))
[docs]
@verbose
def dics_connectivity(
vertex_pairs,
fwd,
data_csd,
reg=0.05,
coh_metric="absolute",
n_angles=50,
block_size=10000,
n_jobs=1,
verbose=None,
):
"""Compute spectral connectivity using a DICS beamformer.
Calculates the connectivity between the given vertex pairs using a DICS
beamformer [1]_ [2]_. Connectivity is defined in terms of coherence:
C = Sxy^2 [Sxx * Syy]^-1
Where Sxy is the cross-spectral density (CSD) between dipoles x and y, Sxx
is the power spectral density (PSD) at dipole x and Syy is the PSD at
dipole y.
Parameters
----------
vertex_pairs : pair of lists (vert_from_idx, vert_to_idx)
Vertex pairs between which connectivity is calculated. The pairs are
specified using two lists: the first list contains, for each pair, the
index of the first vertex. The second list contains, for each pair, the
index of the second vertex.
fwd : instance of Forward
Subject's forward solution, possibly restricted to only include
vertices that are close to the sensors. For 'canonical' mode, the
orientation needs to be tangential or free.
data_csd : instance of CrossSpectralDensity
The cross spectral density of the data.
reg : float
Tikhonov regularization parameter to control for trade-off between
spatial resolution and noise sensitivity. Defaults to 0.05.
coh_metric : 'absolute' | 'imaginary'
The coherence metric to use. Either the square of absolute coherence
('absolute') or the square of the imaginary part of the coherence
('imaginary'). Defaults to 'absolute'.
n_angles : int
Number of angles to try when optimizing dipole orientations. Defaults
to 50.
block_size : int
Number of pairs to process in a single batch. Beware of memory
requirements, which are ``n_jobs * block_size``. Defaults to 10000.
n_jobs : int
Number of blocks to process simultaneously. Defaults to 1.
verbose : bool | str | int | None
If not None, override default verbose level (see :func:`mne.verbose`
and :ref:`Logging documentation <tut_logging>` for more).
Returns
-------
connectivity : instance of Connectivity
The adjacency matrix.
See Also
--------
all_to_all_connectivity_pairs : Obtain pairs for all-to-all connectivity.
one_to_all_connectivity_pairs : Obtain pairs for one-to-all connectivity.
References
----------
.. [1] Gross, J., Kujala, J., Hamalainen, M., Timmermann, L., Schnitzler,
A., & Salmelin, R. (2001). Dynamic imaging of coherent sources:
Studying neural interactions in the human brain. Proceedings of the
National Academy of Sciences, 98(2), 694–699.
.. [2] Kujala, J., Gross, J., & Salmelin, R. (2008). Localization of
correlated network activity at the cortical level with MEG.
NeuroImage, 39(4), 1706–1720.
"""
fwd = pick_channels_forward(fwd, data_csd.ch_names)
data_csd = pick_channels_csd(data_csd, fwd["info"]["ch_names"])
vertex_from, vertex_to = vertex_pairs
if len(vertex_from) != len(vertex_to):
raise ValueError("Lengths of the two lists of vertices do not match.")
n_pairs = len(vertex_from)
G = fwd["sol"]["data"].copy()
n_orient = G.shape[1] // fwd["nsource"]
if n_orient == 1:
raise ValueError(
"A forward operator with free or tangential " "orientation must be used."
)
elif n_orient == 3:
# Convert forward to tangential orientation for more speed.
fwd = forward_to_tangential(fwd)
G = fwd["sol"]["data"]
n_orient = 2
G = G.reshape(G.shape[0], fwd["nsource"], n_orient)
# Normalize the lead field
G /= np.linalg.norm(G, axis=0)
Cm = data_csd.get_data()
Cm_inv, alpha, _ = reg_pinv(Cm, reg)
del Cm
W = np.dot(G.T, Cm_inv)
# Pre-compute spectral power at each unique vertex
unique_verts, vertex_map = np.unique(
np.r_[vertex_from, vertex_to], return_inverse=True
)
spec_power_inv = np.array(
[np.dot(W[:, vert, :], G[:, vert, :]) for vert in unique_verts]
)
# Map vertex indices to unique indices, so the pre-computed spectral power
# can be retrieved
vertex_from_map = vertex_map[: len(vertex_from)]
vertex_to_map = vertex_map[len(vertex_from) :]
coherence = np.zeros((len(vertex_from)))
# Define a search space for dipole orientations
angles = np.arange(n_angles) * np.pi / n_angles
orientations = np.vstack((np.sin(angles), np.cos(angles)))
# Create chunks of pairs to evaluate at once
n_blocks = int(np.ceil(n_pairs / float(block_size)))
blocks = [
slice(i * block_size, min((i + 1) * block_size, n_pairs))
for i in range(n_blocks)
]
parallel, my_compute_dics_coherence, _ = parallel_func(
_compute_dics_coherence, n_jobs, verbose
)
logger.info(
"Computing coherence between %d source pairs in %d blocks..."
% (n_pairs, n_blocks)
)
if numba_enabled:
logger.info("Using numba optimized code path.")
coherence = np.hstack(
parallel(
my_compute_dics_coherence(
W,
G,
vertex_from_map[block],
vertex_to_map[block],
spec_power_inv,
orientations,
coh_metric,
)
for block in blocks
)
)
logger.info("[done]")
return VertexConnectivity(
data=coherence,
pairs=[v[: len(coherence)] for v in vertex_pairs],
vertices=[s["vertno"] for s in fwd["src"]],
vertex_degree=None, # Compute this in the constructor
subject=fwd["src"][0]["subject_his_id"],
)