Optimizing hyperparameters using L-BFGS-B

This is a version of the post-hoc example, where all parameters will be tuned using a general purpose convex minimizer (L-BFGS-B) with an inner leave-one-out crossvalidation loop.

Author: Marijn van Vliet <w.m.vanvliet@gmail.com>

# Required imports
import numpy as np
from scipy.stats import zscore, norm
import mne
from posthoc import (Workbench, WorkbenchOptimizer, cov_estimators,
                     normalizers)
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.linear_model import LogisticRegression
from matplotlib import pyplot as plt

import warnings
warnings.simplefilter('ignore')

We will use the MNE sample dataset. It is an MEG recording of a participant listening to auditory beeps and looking at visual stimuli. For this example, we attempt to discriminate between auditory beeps presented to the left versus the right of the head. The following code reads in the sample dataset.

path = mne.datasets.sample.data_path()
raw = mne.io.read_raw_fif(path + '/MEG/sample/sample_audvis_raw.fif',
                          preload=True)
events = mne.find_events(raw)
event_id = dict(left=1, right=2)
raw.pick_types(meg='grad')
raw.filter(None, 20)
raw, events = raw.resample(50, events=events)

# Create epochs
epochs = mne.Epochs(raw, events, event_id, tmin=-0.2, tmax=0.5,
                    baseline=(-0.2, 0), preload=True)
n_epochs, n_channels, n_samples = epochs.get_data().shape

The data is now loaded as an mne.Epochs object. In order to use the sklearn and posthoc packages effectively, we need to shape this data into a (observations x features) matrix X and corresponding (observations x targets) y matrix.

X = epochs.get_data().reshape(len(epochs), -1)

# There is currently a bug in the intercept calculation for the
# kernel-formulated leave-one-out code path (for the paper, the slower
# 'traditional' path was used). For this example, we'll just zscore the data
# and abandon the intercept calculation altogether.
X = zscore(X, axis=0)

# Create training labels, based on the event codes during the experiment.
y = epochs.events[:, [2]]
y = y - 1.5
y *= 2
y = y.astype(int)

Now, we are ready to define a logistic regression model and apply it to the data. We split the data 50/50 into a training and test set. We present the training data to the model to learn from and test its performance on the test set.

# Split the data 50/50, but make sure the number of left/right epochs are
# balanced.
folds = StratifiedKFold(n_splits=2)
train_index, test_index = next(folds.split(X, y))
X_train, y_train = X[train_index], y[train_index]
X_test, y_test = X[test_index], y[test_index]

# The logistic regression model ignores observations that are close to the
# decision boundary. The parameter `C` controls how far away observations have
# to be in order to not be ignored. A setting of 25 means "quite far". We also
# specify the seed for the random number generator, so that this example
# replicates exactly every time.
base_model = LogisticRegression(C=25, solver='lbfgs', random_state=0,
                                fit_intercept=False)

# Train on the training data and predict the test data.
base_model.fit(X_train, y_train)
y_hat = base_model.predict(X_test)

# How many epochs did we decode correctly?
base_model_accuracy = accuracy_score(y_test, y_hat)
print('Base model accuracy:', base_model_accuracy)

Out:

Base model accuracy: 0.7534246575342466

To inspect the pattern that the model has learned, we wrap the model in a posthoc.Workbench object. After fitting, this object exposes the .pattern_ attribute.

base_model = Workbench(base_model).fit(X_train, y_train)

# Plot the pattern
plt.figure()
plt.plot(epochs.times, base_model.pattern_.reshape(n_channels, n_samples).T,
         color='black', alpha=0.2)
plt.xlabel('Time (s)')
plt.ylabel('Signal (normalized units)')
plt.title('Pattern learned by the base model')
../_images/sphx_glr_plot_optimizer_001.png

Post-hoc adaptation can be used to improve the model somewhat.

For starters, the template is quite noisy. The main distinctive feature between the conditions should be the auditory evoked potential around 0.05 seconds. Let’s apply post-hoc adaptation to inform the model of this, by multiplying the pattern with a Gaussian kernel to restrict it to a specific time interval.

# The function that modifies the pattern takes as input the original pattern,
# the training data, and two parameters that define the center and width of the
# Gaussian kernel.
cache = dict()
def pattern_modifier(pattern, X, y, center, width):
    """Multiply the pattern with a Gaussian kernel."""
    mod_pattern = pattern.reshape(n_channels, n_samples)
    key = (center, width)
    if key in cache:
        kernel = cache[key]
    else:
        kernel = norm(center, width).pdf(np.arange(n_samples))
        kernel /= kernel.max()
        cache[key] = kernel
    mod_pattern = mod_pattern * kernel[np.newaxis, :]
    return mod_pattern.reshape(pattern.shape)

We will search for the optimal center and width parameters by using an optimization algorithm. In order to select the best parameter, we must define a scoring function. Let’s use the ROC-AUC score.

def scorer(model, X, y):
    """Our scoring function."""
    y_hat = model.predict(X)
    return roc_auc_score(y, y_hat)

Now we can assemble the post-hoc model. The covariance matrix is computed using a shrinkage estimator. Since the number of features far exceeds the number of training observations, the kernel version of the estimator is much faster. We modify the pattern using the pattern_modifier function that we defined earlier, but modifying the pattern like this will affect the scaling of the output. To obtain a result with a consistent scaling, we modify the normalizer such that the modified pattern passes through our model with unit gain.

# Define initial values for the parameters. The optimization algorithm will
# use gradient descend using these as starting point.
initial_center = np.searchsorted(epochs.times, 0.05)
initial_width = 5

# Define the allowed range for the parameters. The optimizer will not exceed
# these.
center_bounds = (5, 25)
width_bounds = (1, 50)

# Define the post-hoc model using an optimizer to fine-tune the parameters.
optimized_model = WorkbenchOptimizer(
    base_model,
    cov=cov_estimators.ShrinkageKernel(1.0),
    cov_param_bounds=[(0.9, 1.0)],
    pattern_modifier=pattern_modifier,
    pattern_param_x0=[initial_center, initial_width],
    pattern_param_bounds=[center_bounds, width_bounds],
    normalizer_modifier=normalizers.unit_gain,
    scoring=scorer,
    verbose=True,
    random_search=20,
).fit(X_train, y_train)

# Decode the test data
y_hat = optimized_model.predict(X_test).ravel()

# Assign the 'left' class to values above 0 and 'right' to values below 0
y_bin = np.zeros(len(y_hat), dtype=np.int)
y_bin[y_hat >= 0] = 1
y_bin[y_hat < 0] = -1

# How many epochs did we decode correctly?
optimized_model_accuracy = accuracy_score(y_test, y_bin)
print('Base model accuracy:', base_model_accuracy)
print('Optimized model accuracy:', optimized_model_accuracy)

Out:

Computing patterns for each leave-one-out iteration...
Fit intercept: no
Normalize: no
cov_params=(1.0,), pattern_modifier_params=(13.0, 5.0), normalizer_modifier_params=() score=0.746914
cov_params=(0.9113509276554594,), pattern_modifier_params=(10.411227010505158, 42.34336204944506), normalizer_modifier_params=() score=0.756173
cov_params=(0.9271996319006757,), pattern_modifier_params=(21.994330754488224, 49.009912028277945), normalizer_modifier_params=() score=0.794753
cov_params=(0.9032461817238148,), pattern_modifier_params=(6.9309025581067685, 17.624591304250792), normalizer_modifier_params=() score=0.495370
cov_params=(0.9607169989947569,), pattern_modifier_params=(13.09444386153178, 4.304106657324297), normalizer_modifier_params=() score=0.622685
cov_params=(0.918081454154554,), pattern_modifier_params=(11.22667199253668, 25.725414010951305), normalizer_modifier_params=() score=0.683642
cov_params=(0.9577092953237923,), pattern_modifier_params=(18.078106011871387, 12.631005521608426), normalizer_modifier_params=() score=0.763889
cov_params=(0.9036947409885082,), pattern_modifier_params=(24.76878292431204, 22.686964599753505), normalizer_modifier_params=() score=0.743056
cov_params=(0.9639215312971519,), pattern_modifier_params=(21.44192622979467, 38.74473694335619), normalizer_modifier_params=() score=0.790895
cov_params=(0.9266839014046406,), pattern_modifier_params=(6.6003461858735974, 32.851670970732584), normalizer_modifier_params=() score=0.688272
cov_params=(0.9210154569765379,), pattern_modifier_params=(19.542954092730724, 36.068781911391554), normalizer_modifier_params=() score=0.797840
cov_params=(0.9438736515591699,), pattern_modifier_params=(18.449062891227054, 40.98989486273545), normalizer_modifier_params=() score=0.793210
cov_params=(0.9265038930540532,), pattern_modifier_params=(9.431697337623397, 48.72519354209997), normalizer_modifier_params=() score=0.763117
cov_params=(0.9586610528019213,), pattern_modifier_params=(15.231230791203723, 9.075609258029225), normalizer_modifier_params=() score=0.673611
cov_params=(0.9828590828512768,), pattern_modifier_params=(5.5094149822305205, 17.12178311876502), normalizer_modifier_params=() score=0.655864
cov_params=(0.9336756802073747,), pattern_modifier_params=(17.6222612110126, 35.84392126323846), normalizer_modifier_params=() score=0.790123
cov_params=(0.9711194947975753,), pattern_modifier_params=(8.463350661678902, 49.597062335495245), normalizer_modifier_params=() score=0.773920
cov_params=(0.9402124815242155,), pattern_modifier_params=(7.587103856473012, 38.88081880420404), normalizer_modifier_params=() score=0.747685
cov_params=(0.9033086309161991,), pattern_modifier_params=(5.101200703646223, 13.635510516157682), normalizer_modifier_params=() score=0.444444
cov_params=(0.9038418955633067,), pattern_modifier_params=(13.722633118341754, 43.69803416894229), normalizer_modifier_params=() score=0.777778
cov_params=(0.9079623102391366,), pattern_modifier_params=(13.37499044026506, 23.619085416871826), normalizer_modifier_params=() score=0.696759
cov_params=(0.9210154569765379,), pattern_modifier_params=(19.542954092730724, 36.068781911391554), normalizer_modifier_params=() score=0.797840
cov_params=(0.9220154569765379,), pattern_modifier_params=(19.542954092730724, 36.068781911391554), normalizer_modifier_params=() score=0.797840
cov_params=(0.9210154569765379,), pattern_modifier_params=(19.543954092730726, 36.068781911391554), normalizer_modifier_params=() score=0.797840
cov_params=(0.9210154569765379,), pattern_modifier_params=(19.542954092730724, 36.06978191139155), normalizer_modifier_params=() score=0.797840
Base model accuracy: 0.7534246575342466
Optimized model accuracy: 0.7123287671232876

The post-hoc model performs better. Let’s visualize the optimized pattern. sphinx_gallery_thumbnail_number = 2

plt.figure()
plt.plot(epochs.times,
         optimized_model.pattern_.reshape(n_channels, n_samples).T,
         color='black', alpha=0.2)
plt.xlabel('Time (s)')
plt.ylabel('Signal (normalized units)')
plt.title('Optimized pattern')
../_images/sphx_glr_plot_optimizer_002.png

References

[1]Marijn van Vliet and Riitta Salmelin. Post-hoc modification of linear models: combining machine learning with domain information to make solid inferences from noisy data. In preparation.

Total running time of the script: ( 0 minutes 12.432 seconds)

Gallery generated by Sphinx-Gallery