.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples\plot_optimization.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_optimization.py: Automatic post-hoc optimization of linear models ================================================= This example will demonstrate how to define custom modifications to a linear model that introduce new hyperparameters. We will then use post-hoc's optimizer to find the optimal values for these hyperparameters. We will start with ordinary linear regression as a base model. Then, we will modify the covariance matrix by applying shrinkage, modify the pattern with a Gaussian kernel and modify the normalizer to be "unit noise gain", meaning the weights all sum to 1. Author: Marijn van Vliet .. GENERATED FROM PYTHON SOURCE LINES 16-27 .. code-block:: default # Required imports from matplotlib import pyplot as plt from posthoc import Workbench, WorkbenchOptimizer, cov_estimators, normalizers from scipy.stats import norm, pearsonr from sklearn.linear_model import LinearRegression from sklearn.model_selection import cross_val_predict from sklearn.preprocessing import normalize from functools import partial import mne import numpy as np .. GENERATED FROM PYTHON SOURCE LINES 28-43 We will use some data from the original publication [1]_. A participant was silently reading word-pairs. In these pairs, the two words had a varying forward association strength between them. For example: ``locomotiv -> train`` has a high association strength, and ``dog -> submarine`` has not. In the case of word-pairs with high association strength, the brain will process second word is faster, since it has been semantically primed by the first word. We are going to deduce the memory priming effect from epochs of EEG data and use that to predict what the forward association strength was for a given word-pair. Let's first load the data and plot a contrast between word-pairs with a high versus low association strength, so we can observe how the memory priming effect manifests in the EEG data. .. GENERATED FROM PYTHON SOURCE LINES 43-50 .. code-block:: default epochs = mne.read_epochs('subject04-epo.fif') related = epochs['FAS > 0.2'].average() related.comment = 'related' unrelated = epochs['FAS < 0.2'].average() unrelated.comment = 'unrelated' mne.viz.plot_evoked_topo([related, unrelated]) .. image:: /auto_examples/images/sphx_glr_plot_optimization_001.png :alt: plot optimization :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none
.. GENERATED FROM PYTHON SOURCE LINES 51-59 Around 400ms after the presentation of the second word, there is a negative peak named the N400 potential. We can clearly observe the semantic priming effect as the N400 is more prominent in cases where the words have a low forward associative strength. A naive approach to deduce the forward association strength from a word pair is to take the average signal around 400ms at some sensors that show the N400 well: .. GENERATED FROM PYTHON SOURCE LINES 59-67 .. code-block:: default ROI = epochs.copy() ROI.pick_channels(['P3', 'Pz', 'P4']) ROI.crop(0.3, 0.47) FAS_pred = ROI.get_data().mean(axis=(1, 2)) perf_naive, _ = pearsonr(epochs.metadata['FAS'], FAS_pred) print(f'Performance: {perf_naive:.2f}') .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Performance: 0.30 .. GENERATED FROM PYTHON SOURCE LINES 68-69 Let's try ordinary linear regression next, using 10-fold cross-validation. .. GENERATED FROM PYTHON SOURCE LINES 69-76 .. code-block:: default X = normalize(epochs.get_data().reshape(200, 32 * 60)) y = epochs.metadata['FAS'].values ols = LinearRegression() FAS_pred = cross_val_predict(ols, X, y, cv=10) perf_ols, _ = pearsonr(epochs.metadata['FAS'], FAS_pred) print(f'Performance: {perf_ols:.2f} (to beat: {perf_naive:.2f})') .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Performance: 0.21 (to beat: 0.30) .. GENERATED FROM PYTHON SOURCE LINES 77-91 Feeding all data into a linear regression model performs worse than taking the average signal in a well chosen sensors. That is because the model is overfitting. We could restrict the data going into the model to the same sensors and time window as we did when averaging the signal, but we can do so much better. Let's use the post-hoc framework to modify the linear regression model and incorporate some information about the nature of the data and the N400 potential. First, let's try to reduce overfitting by applying some shrinkage to the covariance matrix. The data consists of 32 EEG electrodes, each recording 60 samples of data. This causes a clear pattern to appear in the covariance matrix: .. GENERATED FROM PYTHON SOURCE LINES 91-94 .. code-block:: default plt.figure() plt.matshow(np.cov(X.T), cmap='magma') .. rst-class:: sphx-glr-horizontal * .. image:: /auto_examples/images/sphx_glr_plot_optimization_002.png :alt: plot optimization :class: sphx-glr-multi-img * .. image:: /auto_examples/images/sphx_glr_plot_optimization_003.png :alt: plot optimization :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 95-99 The covariance matrix is build up from 32x32 squares, each square being 60x60. The ``KroneckerShrinkage`` class can make use of this information and apply different amounts of shrinkage to the diagonal of each square and the covariance matrix overall. .. GENERATED FROM PYTHON SOURCE LINES 99-101 .. code-block:: default cov = cov_estimators.KroneckerKernel(outer_size=32, inner_size=60) .. GENERATED FROM PYTHON SOURCE LINES 102-106 To use the Kronecker shrinkage determine the optimal amount of shrinkage to apply, we can wrap our linear regression model in the ``WorkbenchOptimizer`` class. By default, this uses heavily optimized leave-one-out cross-validation with a gradient descent algorithm to find the best values. .. GENERATED FROM PYTHON SOURCE LINES 106-118 .. code-block:: default # We're optimizing for correlation between model prediction and true FAS def scorer(model, X, y): return pearsonr(model.predict(X), y)[0] # Construct the post-hoc workbench, tell it to modify the model by applying # Kronecker shrinkage. model = WorkbenchOptimizer(ols, cov=cov, scoring=scorer).fit(X, y) shrinkage_params = model.cov_params_ print('Optimal shrinkage parameters:', shrinkage_params) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Computing patterns for each leave-one-out iteration... Fit intercept: yes Normalize: no Choosing optimized code-path for LinearRegression() model. Using kernel path. cov_params=(0.5, 0.5), pattern_modifier_params=(), normalizer_modifier_params=() score=0.128617 cov_params=(0.5, 0.5), pattern_modifier_params=(), normalizer_modifier_params=() score=0.128617 cov_params=(0.501, 0.5), pattern_modifier_params=(), normalizer_modifier_params=() score=0.128665 cov_params=(0.5, 0.501), pattern_modifier_params=(), normalizer_modifier_params=() score=0.128489 cov_params=(0.5475143042386605, 0.37168759154748676), pattern_modifier_params=(), normalizer_modifier_params=() score=0.151458 cov_params=(0.5485143042386605, 0.37168759154748676), pattern_modifier_params=(), normalizer_modifier_params=() score=0.151526 cov_params=(0.5475143042386605, 0.37268759154748676), pattern_modifier_params=(), normalizer_modifier_params=() score=0.151266 cov_params=(0.6152884717698088, 0.1794167207357365), pattern_modifier_params=(), normalizer_modifier_params=() score=0.213952 cov_params=(0.6162884717698088, 0.1794167207357365), pattern_modifier_params=(), normalizer_modifier_params=() score=0.214079 cov_params=(0.6152884717698088, 0.1804167207357365), pattern_modifier_params=(), normalizer_modifier_params=() score=0.213513 cov_params=(0.6785316392745661, 0.0), pattern_modifier_params=(), normalizer_modifier_params=() score=0.303706 cov_params=(0.6795316392745661, 0.0), pattern_modifier_params=(), normalizer_modifier_params=() score=0.303890 cov_params=(0.6785316392745661, 0.001), pattern_modifier_params=(), normalizer_modifier_params=() score=0.312336 cov_params=(0.61866807939433, 0.1698289954321378), pattern_modifier_params=(), normalizer_modifier_params=() score=0.218710 cov_params=(0.61966807939433, 0.1698289954321378), pattern_modifier_params=(), normalizer_modifier_params=() score=0.218841 cov_params=(0.61866807939433, 0.1708289954321378), pattern_modifier_params=(), normalizer_modifier_params=() score=0.218249 cov_params=(0.6601594401656801, 0.05212072460748246), pattern_modifier_params=(), normalizer_modifier_params=() score=0.300725 cov_params=(0.6611594401656801, 0.05212072460748246), pattern_modifier_params=(), normalizer_modifier_params=() score=0.300878 cov_params=(0.6601594401656801, 0.05312072460748246), pattern_modifier_params=(), normalizer_modifier_params=() score=0.299882 cov_params=(0.6721658286859352, 0.018059387372573676), pattern_modifier_params=(), normalizer_modifier_params=() score=0.329133 cov_params=(0.6731658286859352, 0.018059387372573676), pattern_modifier_params=(), normalizer_modifier_params=() score=0.329258 cov_params=(0.6721658286859352, 0.019059387372573677), pattern_modifier_params=(), normalizer_modifier_params=() score=0.328614 cov_params=(0.7971864936933489, 0.0), pattern_modifier_params=(), normalizer_modifier_params=() score=0.327041 cov_params=(0.7981864936933489, 0.0), pattern_modifier_params=(), normalizer_modifier_params=() score=0.327251 cov_params=(0.7971864936933489, 0.001), pattern_modifier_params=(), normalizer_modifier_params=() score=0.331873 cov_params=(0.7362515282264148, 0.008802129998560695), pattern_modifier_params=(), normalizer_modifier_params=() score=0.337934 cov_params=(0.7372515282264148, 0.008802129998560695), pattern_modifier_params=(), normalizer_modifier_params=() score=0.338059 cov_params=(0.7362515282264148, 0.009802129998560696), pattern_modifier_params=(), normalizer_modifier_params=() score=0.338368 cov_params=(0.8249149452241371, 0.0045867872725880885), pattern_modifier_params=(), normalizer_modifier_params=() score=0.345517 cov_params=(0.8259149452241371, 0.0045867872725880885), pattern_modifier_params=(), normalizer_modifier_params=() score=0.345663 cov_params=(0.8249149452241371, 0.0055867872725880886), pattern_modifier_params=(), normalizer_modifier_params=() score=0.346892 cov_params=(1.0, 0.0), pattern_modifier_params=(), normalizer_modifier_params=() score=0.297086 cov_params=(0.999, 0.0), pattern_modifier_params=(), normalizer_modifier_params=() score=0.312730 cov_params=(1.0, 0.001), pattern_modifier_params=(), normalizer_modifier_params=() score=0.297086 cov_params=(0.8945405407430417, 0.002762772105894922), pattern_modifier_params=(), normalizer_modifier_params=() score=0.353331 cov_params=(0.8955405407430417, 0.002762772105894922), pattern_modifier_params=(), normalizer_modifier_params=() score=0.353489 cov_params=(0.8945405407430417, 0.003762772105894922), pattern_modifier_params=(), normalizer_modifier_params=() score=0.354723 cov_params=(0.962123177308625, 0.0009922773161266168), pattern_modifier_params=(), normalizer_modifier_params=() score=0.359635 cov_params=(0.963123177308625, 0.0009922773161266168), pattern_modifier_params=(), normalizer_modifier_params=() score=0.359641 cov_params=(0.962123177308625, 0.001992277316126617), pattern_modifier_params=(), normalizer_modifier_params=() score=0.360245 cov_params=(0.9423239917892359, 0.002780662328775404), pattern_modifier_params=(), normalizer_modifier_params=() score=0.359813 cov_params=(0.9433239917892359, 0.002780662328775404), pattern_modifier_params=(), normalizer_modifier_params=() score=0.359902 cov_params=(0.9423239917892359, 0.003780662328775404), pattern_modifier_params=(), normalizer_modifier_params=() score=0.360618 cov_params=(0.9352708588718786, 0.007470138094729195), pattern_modifier_params=(), normalizer_modifier_params=() score=0.362489 cov_params=(0.9362708588718786, 0.007470138094729195), pattern_modifier_params=(), normalizer_modifier_params=() score=0.362556 cov_params=(0.9352708588718786, 0.008470138094729195), pattern_modifier_params=(), normalizer_modifier_params=() score=0.362981 cov_params=(0.9348100949822981, 0.013358888211589011), pattern_modifier_params=(), normalizer_modifier_params=() score=0.364529 cov_params=(0.9358100949822981, 0.013358888211589011), pattern_modifier_params=(), normalizer_modifier_params=() score=0.364588 cov_params=(0.9348100949822981, 0.01435888821158901), pattern_modifier_params=(), normalizer_modifier_params=() score=0.364712 cov_params=(0.9516666877061215, 0.016280071434003486), pattern_modifier_params=(), normalizer_modifier_params=() score=0.365456 cov_params=(0.9526666877061215, 0.016280071434003486), pattern_modifier_params=(), normalizer_modifier_params=() score=0.365438 cov_params=(0.9516666877061215, 0.017280071434003487), pattern_modifier_params=(), normalizer_modifier_params=() score=0.365567 cov_params=(0.9489996860582463, 0.0193851650562842), pattern_modifier_params=(), normalizer_modifier_params=() score=0.365721 cov_params=(0.9499996860582463, 0.0193851650562842), pattern_modifier_params=(), normalizer_modifier_params=() score=0.365726 cov_params=(0.9489996860582463, 0.0203851650562842), pattern_modifier_params=(), normalizer_modifier_params=() score=0.365753 Optimal shrinkage parameters: [0.9489996860582463, 0.0193851650562842] .. GENERATED FROM PYTHON SOURCE LINES 119-120 Let's inspect the pattern that the model has learned: .. GENERATED FROM PYTHON SOURCE LINES 120-126 .. code-block:: default plt.figure() plt.plot(epochs.times, model.pattern_.reshape(32, 60).T, color='black', alpha=0.2) plt.xlabel('Time (s)') plt.ylabel('Signal (normalized units)') plt.title('Pattern learned by the model using Kronecker shrinkage') .. image:: /auto_examples/images/sphx_glr_plot_optimization_004.png :alt: Pattern learned by the model using Kronecker shrinkage :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Text(0.5, 1.0, 'Pattern learned by the model using Kronecker shrinkage') .. GENERATED FROM PYTHON SOURCE LINES 127-130 We can clearly see that the model is picking up on the N400. Let's fine-tune the pattern a bit by multiplying it with a Guassian kernel, centered around 400 ms. .. GENERATED FROM PYTHON SOURCE LINES 130-139 .. code-block:: default def pattern_modifier(pattern, X, y, mean, std): """Multiply the pattern with a Gaussian kernel.""" n_channels, n_samples = 32, 60 kernel = norm(mean, std).pdf(np.arange(n_samples)) kernel /= kernel.max() mod_pattern = pattern.reshape(n_channels, n_samples) mod_pattern = mod_pattern * kernel[np.newaxis, :] return mod_pattern.reshape(pattern.shape) .. GENERATED FROM PYTHON SOURCE LINES 140-142 Now the optimizer has four hyperparameters to tune: two shrinkage values and two values dictating the shape of the Gaussian kernel. .. GENERATED FROM PYTHON SOURCE LINES 142-153 .. code-block:: default model_opt = WorkbenchOptimizer( ols, cov=cov, pattern_modifier=pattern_modifier, pattern_param_x0=[30, 5], # Initial guess for decent kernel shape pattern_param_bounds=[(0, 60), (2, None)], # Boundaries for what values to try normalizer_modifier=normalizers.unit_gain, scoring=scorer, ).fit(X, y) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Computing patterns for each leave-one-out iteration... Fit intercept: yes Normalize: no Choosing optimized code-path for LinearRegression() model. Using kernel path. cov_params=(0.5, 0.5), pattern_modifier_params=(30.0, 5.0), normalizer_modifier_params=() score=0.238878 cov_params=(0.5, 0.5), pattern_modifier_params=(30.0, 5.0), normalizer_modifier_params=() score=0.238878 cov_params=(0.501, 0.5), pattern_modifier_params=(30.0, 5.0), normalizer_modifier_params=() score=0.238909 cov_params=(0.5, 0.501), pattern_modifier_params=(30.0, 5.0), normalizer_modifier_params=() score=0.238779 cov_params=(0.5, 0.5), pattern_modifier_params=(30.001, 5.0), normalizer_modifier_params=() score=0.238870 cov_params=(0.5, 0.5), pattern_modifier_params=(30.0, 5.001), normalizer_modifier_params=() score=0.238855 cov_params=(0.5310971479546562, 0.4016941829457773), pattern_modifier_params=(29.99198616260383, 4.977396261185516), normalizer_modifier_params=() score=0.251690 cov_params=(0.5320971479546562, 0.4016941829457773), pattern_modifier_params=(29.99198616260383, 4.977396261185516), normalizer_modifier_params=() score=0.251735 cov_params=(0.5310971479546562, 0.4026941829457773), pattern_modifier_params=(29.99198616260383, 4.977396261185516), normalizer_modifier_params=() score=0.251561 cov_params=(0.5310971479546562, 0.4016941829457773), pattern_modifier_params=(29.99298616260383, 4.977396261185516), normalizer_modifier_params=() score=0.251682 cov_params=(0.5310971479546562, 0.4016941829457773), pattern_modifier_params=(29.99198616260383, 4.978396261185516), normalizer_modifier_params=() score=0.251667 cov_params=(0.5765024048484941, 0.27285167506202446), pattern_modifier_params=(29.98404622431943, 4.955346354637884), normalizer_modifier_params=() score=0.275617 cov_params=(0.5775024048484941, 0.27285167506202446), pattern_modifier_params=(29.98404622431943, 4.955346354637884), normalizer_modifier_params=() score=0.275692 cov_params=(0.5765024048484941, 0.27385167506202446), pattern_modifier_params=(29.98404622431943, 4.955346354637884), normalizer_modifier_params=() score=0.275416 cov_params=(0.5765024048484941, 0.27285167506202446), pattern_modifier_params=(29.985046224319433, 4.955346354637884), normalizer_modifier_params=() score=0.275609 cov_params=(0.5765024048484941, 0.27285167506202446), pattern_modifier_params=(29.98404622431943, 4.956346354637884), normalizer_modifier_params=() score=0.275596 cov_params=(0.6726577855810936, 0.0), pattern_modifier_params=(29.9672317001661, 4.908650943043517), normalizer_modifier_params=() score=0.343913 cov_params=(0.6736577855810936, 0.0), pattern_modifier_params=(29.9672317001661, 4.908650943043517), normalizer_modifier_params=() score=0.344053 cov_params=(0.6726577855810936, 0.001), pattern_modifier_params=(29.9672317001661, 4.908650943043517), normalizer_modifier_params=() score=0.348688 cov_params=(0.6726577855810936, 0.0), pattern_modifier_params=(29.9682317001661, 4.908650943043517), normalizer_modifier_params=() score=0.343906 cov_params=(0.6726577855810936, 0.0), pattern_modifier_params=(29.9672317001661, 4.909650943043517), normalizer_modifier_params=() score=0.343914 cov_params=(0.5809782829125139, 0.26015086915978014), pattern_modifier_params=(29.983263535308776, 4.953172758455817), normalizer_modifier_params=() score=0.278623 cov_params=(0.5819782829125139, 0.26015086915978014), pattern_modifier_params=(29.983263535308776, 4.953172758455817), normalizer_modifier_params=() score=0.278701 cov_params=(0.5809782829125139, 0.26115086915978014), pattern_modifier_params=(29.983263535308776, 4.953172758455817), normalizer_modifier_params=() score=0.278412 cov_params=(0.5809782829125139, 0.26015086915978014), pattern_modifier_params=(29.984263535308777, 4.953172758455817), normalizer_modifier_params=() score=0.278616 cov_params=(0.5809782829125139, 0.26015086915978014), pattern_modifier_params=(29.983263535308776, 4.954172758455817), normalizer_modifier_params=() score=0.278603 cov_params=(0.6443532335067832, 0.08031734039754651), pattern_modifier_params=(29.97218126794771, 4.922396327847094), normalizer_modifier_params=() score=0.343618 cov_params=(0.6453532335067832, 0.08031734039754651), pattern_modifier_params=(29.97218126794771, 4.922396327847094), normalizer_modifier_params=() score=0.343742 cov_params=(0.6443532335067832, 0.08131734039754651), pattern_modifier_params=(29.97218126794771, 4.922396327847094), normalizer_modifier_params=() score=0.343154 cov_params=(0.6443532335067832, 0.08031734039754651), pattern_modifier_params=(29.973181267947712, 4.922396327847094), normalizer_modifier_params=() score=0.343614 cov_params=(0.6443532335067832, 0.08031734039754651), pattern_modifier_params=(29.97218126794771, 4.923396327847095), normalizer_modifier_params=() score=0.343605 cov_params=(0.6627163536499598, 0.028209927871517604), pattern_modifier_params=(29.968970141073424, 4.913478746254578), normalizer_modifier_params=() score=0.368872 cov_params=(0.6637163536499598, 0.028209927871517604), pattern_modifier_params=(29.968970141073424, 4.913478746254578), normalizer_modifier_params=() score=0.368976 cov_params=(0.6627163536499598, 0.029209927871517605), pattern_modifier_params=(29.968970141073424, 4.913478746254578), normalizer_modifier_params=() score=0.368606 cov_params=(0.6627163536499598, 0.028209927871517604), pattern_modifier_params=(29.969970141073425, 4.913478746254578), normalizer_modifier_params=() score=0.368869 cov_params=(0.6627163536499598, 0.028209927871517604), pattern_modifier_params=(29.968970141073424, 4.914478746254578), normalizer_modifier_params=() score=0.368864 cov_params=(0.7670465209893657, 0.0), pattern_modifier_params=(29.96589839861781, 4.9058939152727365), normalizer_modifier_params=() score=0.357881 cov_params=(0.7680465209893657, 0.0), pattern_modifier_params=(29.96589839861781, 4.9058939152727365), normalizer_modifier_params=() score=0.358038 cov_params=(0.7670465209893657, 0.001), pattern_modifier_params=(29.96589839861781, 4.9058939152727365), normalizer_modifier_params=() score=0.361368 cov_params=(0.7670465209893657, 0.0), pattern_modifier_params=(29.96689839861781, 4.9058939152727365), normalizer_modifier_params=() score=0.357875 cov_params=(0.7670465209893657, 0.0), pattern_modifier_params=(29.96589839861781, 4.906893915272737), normalizer_modifier_params=() score=0.357883 cov_params=(0.7048322843401036, 0.016822163456784937), pattern_modifier_params=(29.967730142177057, 4.910416906864553), normalizer_modifier_params=() score=0.373740 cov_params=(0.7058322843401036, 0.016822163456784937), pattern_modifier_params=(29.967730142177057, 4.910416906864553), normalizer_modifier_params=() score=0.373836 cov_params=(0.7048322843401036, 0.017822163456784938), pattern_modifier_params=(29.967730142177057, 4.910416906864553), normalizer_modifier_params=() score=0.373924 cov_params=(0.7048322843401036, 0.016822163456784937), pattern_modifier_params=(29.96873014217706, 4.910416906864553), normalizer_modifier_params=() score=0.373737 cov_params=(0.7048322843401036, 0.016822163456784937), pattern_modifier_params=(29.967730142177057, 4.911416906864553), normalizer_modifier_params=() score=0.373736 cov_params=(0.7540905730693692, 0.013499243030878894), pattern_modifier_params=(29.96626904920268, 4.906920126090405), normalizer_modifier_params=() score=0.377343 cov_params=(0.7550905730693692, 0.013499243030878894), pattern_modifier_params=(29.96626904920268, 4.906920126090405), normalizer_modifier_params=() score=0.377439 cov_params=(0.7540905730693692, 0.014499243030878895), pattern_modifier_params=(29.96626904920268, 4.906920126090405), normalizer_modifier_params=() score=0.377763 cov_params=(0.7540905730693692, 0.013499243030878894), pattern_modifier_params=(29.96726904920268, 4.906920126090405), normalizer_modifier_params=() score=0.377340 cov_params=(0.7540905730693692, 0.013499243030878894), pattern_modifier_params=(29.96626904920268, 4.9079201260904055), normalizer_modifier_params=() score=0.377340 cov_params=(1.0, 0.0), pattern_modifier_params=(29.950171128477027, 4.8686416206840235), normalizer_modifier_params=() score=0.332510 cov_params=(0.999, 0.0), pattern_modifier_params=(29.950171128477027, 4.8686416206840235), normalizer_modifier_params=() score=0.348507 cov_params=(1.0, 0.001), pattern_modifier_params=(29.950171128477027, 4.8686416206840235), normalizer_modifier_params=() score=0.332510 cov_params=(1.0, 0.0), pattern_modifier_params=(29.95117112847703, 4.8686416206840235), normalizer_modifier_params=() score=0.332511 cov_params=(1.0, 0.0), pattern_modifier_params=(29.950171128477027, 4.869641620684024), normalizer_modifier_params=() score=0.332506 cov_params=(0.8526990211647191, 0.008086114212062043), pattern_modifier_params=(29.959813863648666, 4.8915706372129515), normalizer_modifier_params=() score=0.383847 cov_params=(0.8536990211647191, 0.008086114212062043), pattern_modifier_params=(29.959813863648666, 4.8915706372129515), normalizer_modifier_params=() score=0.383957 cov_params=(0.8526990211647191, 0.009086114212062044), pattern_modifier_params=(29.959813863648666, 4.8915706372129515), normalizer_modifier_params=() score=0.384617 cov_params=(0.8526990211647191, 0.008086114212062043), pattern_modifier_params=(29.960813863648667, 4.8915706372129515), normalizer_modifier_params=() score=0.383844 cov_params=(0.8526990211647191, 0.008086114212062043), pattern_modifier_params=(29.959813863648666, 4.892570637212952), normalizer_modifier_params=() score=0.383847 cov_params=(0.9499176671960045, 0.0027492788321010937), pattern_modifier_params=(29.953449658435385, 4.876437486303859), normalizer_modifier_params=() score=0.391218 cov_params=(0.9509176671960045, 0.0027492788321010937), pattern_modifier_params=(29.953449658435385, 4.876437486303859), normalizer_modifier_params=() score=0.391328 cov_params=(0.9499176671960045, 0.0037492788321010937), pattern_modifier_params=(29.953449658435385, 4.876437486303859), normalizer_modifier_params=() score=0.391733 cov_params=(0.9499176671960045, 0.0027492788321010937), pattern_modifier_params=(29.954449658435387, 4.876437486303859), normalizer_modifier_params=() score=0.391214 cov_params=(0.9499176671960045, 0.0027492788321010937), pattern_modifier_params=(29.953449658435385, 4.877437486303859), normalizer_modifier_params=() score=0.391218 cov_params=(0.9803858688101427, 0.0010767213240918862), pattern_modifier_params=(29.951455124511412, 4.871694775807695), normalizer_modifier_params=() score=0.391335 cov_params=(0.9813858688101427, 0.0010767213240918862), pattern_modifier_params=(29.951455124511412, 4.871694775807695), normalizer_modifier_params=() score=0.391090 cov_params=(0.9803858688101427, 0.0020767213240918863), pattern_modifier_params=(29.951455124511412, 4.871694775807695), normalizer_modifier_params=() score=0.391431 cov_params=(0.9803858688101427, 0.0010767213240918862), pattern_modifier_params=(29.952455124511413, 4.871694775807695), normalizer_modifier_params=() score=0.391333 cov_params=(0.9803858688101427, 0.0010767213240918862), pattern_modifier_params=(29.951455124511412, 4.8726947758076955), normalizer_modifier_params=() score=0.391333 cov_params=(0.9574185809282916, 0.002337514798937927), pattern_modifier_params=(29.95295862758659, 4.8752698866293045), normalizer_modifier_params=() score=0.391789 cov_params=(0.9584185809282916, 0.002337514798937927), pattern_modifier_params=(29.95295862758659, 4.8752698866293045), normalizer_modifier_params=() score=0.391879 cov_params=(0.9574185809282916, 0.003337514798937927), pattern_modifier_params=(29.95295862758659, 4.8752698866293045), normalizer_modifier_params=() score=0.392216 cov_params=(0.9574185809282916, 0.002337514798937927), pattern_modifier_params=(29.95395862758659, 4.8752698866293045), normalizer_modifier_params=() score=0.391785 cov_params=(0.9574185809282916, 0.002337514798937927), pattern_modifier_params=(29.95295862758659, 4.876269886629305), normalizer_modifier_params=() score=0.391788 cov_params=(0.9699932528467375, 0.0016472258808617798), pattern_modifier_params=(29.952135454297867, 4.873312500720917), normalizer_modifier_params=() score=0.392297 cov_params=(0.9709932528467375, 0.0016472258808617798), pattern_modifier_params=(29.952135454297867, 4.873312500720917), normalizer_modifier_params=() score=0.392300 cov_params=(0.9699932528467375, 0.00264722588086178), pattern_modifier_params=(29.952135454297867, 4.873312500720917), normalizer_modifier_params=() score=0.392549 cov_params=(0.9699932528467375, 0.0016472258808617798), pattern_modifier_params=(29.95313545429787, 4.873312500720917), normalizer_modifier_params=() score=0.392293 cov_params=(0.9699932528467375, 0.0016472258808617798), pattern_modifier_params=(29.952135454297867, 4.874312500720917), normalizer_modifier_params=() score=0.392295 cov_params=(0.9225231238046329, 0.015211881150587817), pattern_modifier_params=(29.960910155832725, 4.896086572506774), normalizer_modifier_params=() score=0.393809 cov_params=(0.9235231238046329, 0.015211881150587817), pattern_modifier_params=(29.960910155832725, 4.896086572506774), normalizer_modifier_params=() score=0.393880 cov_params=(0.9225231238046329, 0.016211881150587817), pattern_modifier_params=(29.960910155832725, 4.896086572506774), normalizer_modifier_params=() score=0.394013 cov_params=(0.9225231238046329, 0.015211881150587817), pattern_modifier_params=(29.961910155832726, 4.896086572506774), normalizer_modifier_params=() score=0.393807 cov_params=(0.9225231238046329, 0.015211881150587817), pattern_modifier_params=(29.960910155832725, 4.897086572506774), normalizer_modifier_params=() score=0.393808 cov_params=(0.9492204678121738, 0.039092974004306724), pattern_modifier_params=(29.971528400079535, 4.92782219525526), normalizer_modifier_params=() score=0.394942 cov_params=(0.9502204678121738, 0.039092974004306724), pattern_modifier_params=(29.971528400079535, 4.92782219525526), normalizer_modifier_params=() score=0.394960 cov_params=(0.9492204678121738, 0.040092974004306725), pattern_modifier_params=(29.971528400079535, 4.92782219525526), normalizer_modifier_params=() score=0.394828 cov_params=(0.9492204678121738, 0.039092974004306724), pattern_modifier_params=(29.972528400079536, 4.92782219525526), normalizer_modifier_params=() score=0.394941 cov_params=(0.9492204678121738, 0.039092974004306724), pattern_modifier_params=(29.971528400079535, 4.92882219525526), normalizer_modifier_params=() score=0.394939 cov_params=(0.9700451920044341, 0.02829741724845628), pattern_modifier_params=(29.965043913355956, 4.910663919854389), normalizer_modifier_params=() score=0.394539 cov_params=(0.9710451920044341, 0.02829741724845628), pattern_modifier_params=(29.965043913355956, 4.910663919854389), normalizer_modifier_params=() score=0.394370 cov_params=(0.9700451920044341, 0.029297417248456283), pattern_modifier_params=(29.965043913355956, 4.910663919854389), normalizer_modifier_params=() score=0.394510 cov_params=(0.9700451920044341, 0.02829741724845628), pattern_modifier_params=(29.966043913355957, 4.910663919854389), normalizer_modifier_params=() score=0.394538 cov_params=(0.9700451920044341, 0.02829741724845628), pattern_modifier_params=(29.965043913355956, 4.911663919854389), normalizer_modifier_params=() score=0.394537 cov_params=(0.9581116939510036, 0.034483753948152875), pattern_modifier_params=(29.968759814175524, 4.920496378820285), normalizer_modifier_params=() score=0.395324 cov_params=(0.9591116939510036, 0.034483753948152875), pattern_modifier_params=(29.968759814175524, 4.920496378820285), normalizer_modifier_params=() score=0.395284 cov_params=(0.9581116939510036, 0.035483753948152875), pattern_modifier_params=(29.968759814175524, 4.920496378820285), normalizer_modifier_params=() score=0.395252 cov_params=(0.9581116939510036, 0.034483753948152875), pattern_modifier_params=(29.969759814175525, 4.920496378820285), normalizer_modifier_params=() score=0.395323 cov_params=(0.9581116939510036, 0.034483753948152875), pattern_modifier_params=(29.968759814175524, 4.921496378820286), normalizer_modifier_params=() score=0.395321 cov_params=(0.9480406724865438, 0.02056353932381044), pattern_modifier_params=(29.961869988891713, 4.900891609949009), normalizer_modifier_params=() score=0.395698 cov_params=(0.9490406724865438, 0.02056353932381044), pattern_modifier_params=(29.961869988891713, 4.900891609949009), normalizer_modifier_params=() score=0.395708 cov_params=(0.9480406724865438, 0.02156353932381044), pattern_modifier_params=(29.961869988891713, 4.900891609949009), normalizer_modifier_params=() score=0.395756 cov_params=(0.9480406724865438, 0.02056353932381044), pattern_modifier_params=(29.962869988891715, 4.900891609949009), normalizer_modifier_params=() score=0.395696 cov_params=(0.9480406724865438, 0.02056353932381044), pattern_modifier_params=(29.961869988891713, 4.9018916099490095), normalizer_modifier_params=() score=0.395696 cov_params=(0.9503945927129629, 0.026202708718212262), pattern_modifier_params=(29.964469693063545, 4.908585304151802), normalizer_modifier_params=() score=0.395840 cov_params=(0.9513945927129629, 0.026202708718212262), pattern_modifier_params=(29.964469693063545, 4.908585304151802), normalizer_modifier_params=() score=0.395838 cov_params=(0.9503945927129629, 0.027202708718212263), pattern_modifier_params=(29.964469693063545, 4.908585304151802), normalizer_modifier_params=() score=0.395828 cov_params=(0.9503945927129629, 0.026202708718212262), pattern_modifier_params=(29.965469693063547, 4.908585304151802), normalizer_modifier_params=() score=0.395839 cov_params=(0.9503945927129629, 0.026202708718212262), pattern_modifier_params=(29.964469693063545, 4.909585304151802), normalizer_modifier_params=() score=0.395838 .. GENERATED FROM PYTHON SOURCE LINES 154-155 Let's take a look at the optimal parameters: .. GENERATED FROM PYTHON SOURCE LINES 155-160 .. code-block:: default shrinkage_params = model_opt.cov_params_ pattern_params = model_opt.pattern_modifier_params_ print('Optimal shrinkage parameters:', shrinkage_params) print('Optimal pattern parameters:', pattern_params) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Optimal shrinkage parameters: [0.9503945927129629, 0.026202708718212262] Optimal pattern parameters: [29.964469693063545, 4.908585304151802] .. GENERATED FROM PYTHON SOURCE LINES 161-168 To evaluate the performance of the new model, you can pass the :class:`WorkbenchOptimizer` object into :func:`cross_val_predict`. This would cause the optimization procedure to be run during every iteration of the cross-validation loop. To save time in this example, we are going to do freeze the parameters before entering the model into the cross-validation loop. So take this result with a grain of salt, as the hyperparameters have been tuned using all data, not just the training set! .. GENERATED FROM PYTHON SOURCE LINES 168-178 .. code-block:: default model = Workbench( ols, cov=cov_estimators.ShrinkageKernel(alpha=shrinkage_params[0]), pattern_modifier=partial(pattern_modifier, mean=pattern_params[0], std=pattern_params[1]), normalizer_modifier=normalizers.unit_gain, ) FAS_pred = cross_val_predict(model, X, y, cv=10) perf_opt, _ = pearsonr(epochs.metadata['FAS'], FAS_pred) print(f'Performance: {perf_opt:.2f} (to beat: {perf_naive:.2f})') .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Performance: 0.37 (to beat: 0.30) .. GENERATED FROM PYTHON SOURCE LINES 179-180 Here is the final pattern: .. GENERATED FROM PYTHON SOURCE LINES 180-187 .. code-block:: default model.fit(X, y) plt.figure() plt.plot(epochs.times, model.pattern_.reshape(32, 60).T, color='black', alpha=0.2) plt.xlabel('Time (s)') plt.ylabel('Signal (normalized units)') plt.title('Pattern learned by the post-hoc model') .. image:: /auto_examples/images/sphx_glr_plot_optimization_005.png :alt: Pattern learned by the post-hoc model :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Text(0.5, 1.0, 'Pattern learned by the post-hoc model') .. GENERATED FROM PYTHON SOURCE LINES 188-196 References ---------- .. [1] Marijn van Vliet and Riitta Salmelin (2020). Post-hoc modification of linear models: combining machine learning with domain information to make solid inferences from noisy data. Neuroimage, 204, 116221. https://doi.org/10.1016/j.neuroimage.2019.116221 sphinx_gallery_thumbnail_number = 5 .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 5 minutes 0.946 seconds) .. _sphx_glr_download_auto_examples_plot_optimization.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_optimization.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_optimization.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_