.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_optimizer.py: Optimizing hyperparameters using L-BFGS-B ========================================= This is a version of the post-hoc example, where all parameters will be tuned using a general purpose convex minimizer (L-BFGS-B) with an inner leave-one-out crossvalidation loop. Author: Marijn van Vliet .. code-block:: python # Required imports import numpy as np from scipy.stats import zscore, norm import mne from posthoc import (Workbench, WorkbenchOptimizer, cov_estimators, normalizers) from sklearn.model_selection import StratifiedKFold from sklearn.metrics import accuracy_score, roc_auc_score from sklearn.linear_model import LogisticRegression from matplotlib import pyplot as plt import warnings warnings.simplefilter('ignore') We will use the MNE sample dataset. It is an MEG recording of a participant listening to auditory beeps and looking at visual stimuli. For this example, we attempt to discriminate between auditory beeps presented to the left versus the right of the head. The following code reads in the sample dataset. .. code-block:: python path = mne.datasets.sample.data_path() raw = mne.io.read_raw_fif(path + '/MEG/sample/sample_audvis_raw.fif', preload=True) events = mne.find_events(raw) event_id = dict(left=1, right=2) raw.pick_types(meg='grad') raw.filter(None, 20) raw, events = raw.resample(50, events=events) # Create epochs epochs = mne.Epochs(raw, events, event_id, tmin=-0.2, tmax=0.5, baseline=(-0.2, 0), preload=True) n_epochs, n_channels, n_samples = epochs.get_data().shape The data is now loaded as an :class:`mne.Epochs` object. In order to use the ``sklearn`` and ``posthoc`` packages effectively, we need to shape this data into a (observations x features) matrix ``X`` and corresponding (observations x targets) ``y`` matrix. .. code-block:: python X = epochs.get_data().reshape(len(epochs), -1) # There is currently a bug in the intercept calculation for the # kernel-formulated leave-one-out code path (for the paper, the slower # 'traditional' path was used). For this example, we'll just zscore the data # and abandon the intercept calculation altogether. X = zscore(X, axis=0) # Create training labels, based on the event codes during the experiment. y = epochs.events[:, [2]] y = y - 1.5 y *= 2 y = y.astype(int) Now, we are ready to define a logistic regression model and apply it to the data. We split the data 50/50 into a training and test set. We present the training data to the model to learn from and test its performance on the test set. .. code-block:: python # Split the data 50/50, but make sure the number of left/right epochs are # balanced. folds = StratifiedKFold(n_splits=2) train_index, test_index = next(folds.split(X, y)) X_train, y_train = X[train_index], y[train_index] X_test, y_test = X[test_index], y[test_index] # The logistic regression model ignores observations that are close to the # decision boundary. The parameter `C` controls how far away observations have # to be in order to not be ignored. A setting of 25 means "quite far". We also # specify the seed for the random number generator, so that this example # replicates exactly every time. base_model = LogisticRegression(C=25, solver='lbfgs', random_state=0, fit_intercept=False) # Train on the training data and predict the test data. base_model.fit(X_train, y_train) y_hat = base_model.predict(X_test) # How many epochs did we decode correctly? base_model_accuracy = accuracy_score(y_test, y_hat) print('Base model accuracy:', base_model_accuracy) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Base model accuracy: 0.7534246575342466 To inspect the pattern that the model has learned, we wrap the model in a :class:`posthoc.Workbench` object. After fitting, this object exposes the `.pattern_` attribute. .. code-block:: python base_model = Workbench(base_model).fit(X_train, y_train) # Plot the pattern plt.figure() plt.plot(epochs.times, base_model.pattern_.reshape(n_channels, n_samples).T, color='black', alpha=0.2) plt.xlabel('Time (s)') plt.ylabel('Signal (normalized units)') plt.title('Pattern learned by the base model') .. image:: /auto_examples/images/sphx_glr_plot_optimizer_001.png :class: sphx-glr-single-img Post-hoc adaptation can be used to improve the model somewhat. For starters, the template is quite noisy. The main distinctive feature between the conditions should be the auditory evoked potential around 0.05 seconds. Let's apply post-hoc adaptation to inform the model of this, by multiplying the pattern with a Gaussian kernel to restrict it to a specific time interval. .. code-block:: python # The function that modifies the pattern takes as input the original pattern, # the training data, and two parameters that define the center and width of the # Gaussian kernel. cache = dict() def pattern_modifier(pattern, X, y, center, width): """Multiply the pattern with a Gaussian kernel.""" mod_pattern = pattern.reshape(n_channels, n_samples) key = (center, width) if key in cache: kernel = cache[key] else: kernel = norm(center, width).pdf(np.arange(n_samples)) kernel /= kernel.max() cache[key] = kernel mod_pattern = mod_pattern * kernel[np.newaxis, :] return mod_pattern.reshape(pattern.shape) We will search for the optimal ``center`` and ``width`` parameters by using an optimization algorithm. In order to select the best parameter, we must define a scoring function. Let's use the ROC-AUC score. .. code-block:: python def scorer(model, X, y): """Our scoring function.""" y_hat = model.predict(X) return roc_auc_score(y, y_hat) Now we can assemble the post-hoc model. The covariance matrix is computed using a shrinkage estimator. Since the number of features far exceeds the number of training observations, the kernel version of the estimator is much faster. We modify the pattern using the ``pattern_modifier`` function that we defined earlier, but modifying the pattern like this will affect the scaling of the output. To obtain a result with a consistent scaling, we modify the normalizer such that the modified pattern passes through our model with unit gain. .. code-block:: python # Define initial values for the parameters. The optimization algorithm will # use gradient descend using these as starting point. initial_center = np.searchsorted(epochs.times, 0.05) initial_width = 5 # Define the allowed range for the parameters. The optimizer will not exceed # these. center_bounds = (5, 25) width_bounds = (1, 50) # Define the post-hoc model using an optimizer to fine-tune the parameters. optimized_model = WorkbenchOptimizer( base_model, cov=cov_estimators.ShrinkageKernel(1.0), cov_param_bounds=[(0.9, 1.0)], pattern_modifier=pattern_modifier, pattern_param_x0=[initial_center, initial_width], pattern_param_bounds=[center_bounds, width_bounds], normalizer_modifier=normalizers.unit_gain, scoring=scorer, verbose=True, random_search=20, ).fit(X_train, y_train) # Decode the test data y_hat = optimized_model.predict(X_test).ravel() # Assign the 'left' class to values above 0 and 'right' to values below 0 y_bin = np.zeros(len(y_hat), dtype=np.int) y_bin[y_hat >= 0] = 1 y_bin[y_hat < 0] = -1 # How many epochs did we decode correctly? optimized_model_accuracy = accuracy_score(y_test, y_bin) print('Base model accuracy:', base_model_accuracy) print('Optimized model accuracy:', optimized_model_accuracy) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Computing patterns for each leave-one-out iteration... Fit intercept: no Normalize: no cov_params=(1.0,), pattern_modifier_params=(13.0, 5.0), normalizer_modifier_params=() score=0.746914 cov_params=(0.9113509276554594,), pattern_modifier_params=(10.411227010505158, 42.34336204944506), normalizer_modifier_params=() score=0.756173 cov_params=(0.9271996319006757,), pattern_modifier_params=(21.994330754488224, 49.009912028277945), normalizer_modifier_params=() score=0.794753 cov_params=(0.9032461817238148,), pattern_modifier_params=(6.9309025581067685, 17.624591304250792), normalizer_modifier_params=() score=0.495370 cov_params=(0.9607169989947569,), pattern_modifier_params=(13.09444386153178, 4.304106657324297), normalizer_modifier_params=() score=0.622685 cov_params=(0.918081454154554,), pattern_modifier_params=(11.22667199253668, 25.725414010951305), normalizer_modifier_params=() score=0.683642 cov_params=(0.9577092953237923,), pattern_modifier_params=(18.078106011871387, 12.631005521608426), normalizer_modifier_params=() score=0.763889 cov_params=(0.9036947409885082,), pattern_modifier_params=(24.76878292431204, 22.686964599753505), normalizer_modifier_params=() score=0.743056 cov_params=(0.9639215312971519,), pattern_modifier_params=(21.44192622979467, 38.74473694335619), normalizer_modifier_params=() score=0.790895 cov_params=(0.9266839014046406,), pattern_modifier_params=(6.6003461858735974, 32.851670970732584), normalizer_modifier_params=() score=0.688272 cov_params=(0.9210154569765379,), pattern_modifier_params=(19.542954092730724, 36.068781911391554), normalizer_modifier_params=() score=0.797840 cov_params=(0.9438736515591699,), pattern_modifier_params=(18.449062891227054, 40.98989486273545), normalizer_modifier_params=() score=0.793210 cov_params=(0.9265038930540532,), pattern_modifier_params=(9.431697337623397, 48.72519354209997), normalizer_modifier_params=() score=0.763117 cov_params=(0.9586610528019213,), pattern_modifier_params=(15.231230791203723, 9.075609258029225), normalizer_modifier_params=() score=0.673611 cov_params=(0.9828590828512768,), pattern_modifier_params=(5.5094149822305205, 17.12178311876502), normalizer_modifier_params=() score=0.655864 cov_params=(0.9336756802073747,), pattern_modifier_params=(17.6222612110126, 35.84392126323846), normalizer_modifier_params=() score=0.790123 cov_params=(0.9711194947975753,), pattern_modifier_params=(8.463350661678902, 49.597062335495245), normalizer_modifier_params=() score=0.773920 cov_params=(0.9402124815242155,), pattern_modifier_params=(7.587103856473012, 38.88081880420404), normalizer_modifier_params=() score=0.747685 cov_params=(0.9033086309161991,), pattern_modifier_params=(5.101200703646223, 13.635510516157682), normalizer_modifier_params=() score=0.444444 cov_params=(0.9038418955633067,), pattern_modifier_params=(13.722633118341754, 43.69803416894229), normalizer_modifier_params=() score=0.777778 cov_params=(0.9079623102391366,), pattern_modifier_params=(13.37499044026506, 23.619085416871826), normalizer_modifier_params=() score=0.696759 cov_params=(0.9210154569765379,), pattern_modifier_params=(19.542954092730724, 36.068781911391554), normalizer_modifier_params=() score=0.797840 cov_params=(0.9220154569765379,), pattern_modifier_params=(19.542954092730724, 36.068781911391554), normalizer_modifier_params=() score=0.797840 cov_params=(0.9210154569765379,), pattern_modifier_params=(19.543954092730726, 36.068781911391554), normalizer_modifier_params=() score=0.797840 cov_params=(0.9210154569765379,), pattern_modifier_params=(19.542954092730724, 36.06978191139155), normalizer_modifier_params=() score=0.797840 Base model accuracy: 0.7534246575342466 Optimized model accuracy: 0.7123287671232876 The post-hoc model performs better. Let's visualize the optimized pattern. sphinx_gallery_thumbnail_number = 2 .. code-block:: python plt.figure() plt.plot(epochs.times, optimized_model.pattern_.reshape(n_channels, n_samples).T, color='black', alpha=0.2) plt.xlabel('Time (s)') plt.ylabel('Signal (normalized units)') plt.title('Optimized pattern') .. image:: /auto_examples/images/sphx_glr_plot_optimizer_002.png :class: sphx-glr-single-img References ---------- .. [1] Marijn van Vliet and Riitta Salmelin. Post-hoc modification of linear models: combining machine learning with domain information to make solid inferences from noisy data. In preparation. **Total running time of the script:** ( 0 minutes 12.432 seconds) .. _sphx_glr_download_auto_examples_plot_optimizer.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: plot_optimizer.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: plot_optimizer.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_