%% MODELING BIRTH FREQUENCIES IN THE US 1969-1988
%
%  Description:
%
%    Demonstration of analysis of birthday frequencies in USA 1969-1988
%    using Gaussian process with several components by converting the
%    Gaussian process regression problem into a state space model [1].
%
%    This example follows the presentation by Aki Vehtari (example 
%    included in GPstuff) and the study is explained in more detail in
%    the book Bayesian Data Analysis [2].
%
%    The codes are provided as proof-of-concept, and are not optimized
%    for speed. The computational burden still scales linearly, however.
%
%  References:
%
%    [1] Arno Solin and Simo Sarkka (2014). Explicit link between periodic 
%        covariance functions and state space models. Accepted for 
%        publication in Proceedings of the Seventeenth International 
%        Conference on Artifcial Intelligence and Statistics (AISTATS 2014).
%
%    [2] Andrew Gelman, John B. Carlin, Hal S. Stern, David B. Dunson, 
%        Aki Vehtari, and Donald B. Rubin. Bayesian Data Analysis. 
%        Chapman & Hall/CRC Press, third edition, 2013.
%
% Copyright:
%
%    2013-2014 Arno Solin and Simo Sarkka
%
% This software is distributed under the GNU General Public
% License (version 3 or later); please refer to the file
% License.txt, included with the software, for details.
%

%% Add path to core functions

  addpath ../core/


%% Load data

  % We handle the data as was done in Aki Vehtari's example in GPstuff.
  % Data source: National Vital Statistics System natality data, as
  % provided by Google BigQuery and exported to cvs by Robert Kern
  % (http://www.mechanicalkern.com/static/birthdates-1968-1988.csv).

  % Load data
  d = dataset('File','birthdates-1968-1988.csv','Delimiter',',');
  y = d.births;
  n = size(d,1);

  % Make date vector
  t = datenum(d.year,d.month,d.day);

  % Normalize data (this is a GPStuff function)
  [yn,ymean,ystd] = normdata(y);
  
  
%% Show

  figure(1); clf
    plot(t,yn)
    datetick('x',10)


%% Set up model
  
  clear model

  % Set data
  model.x = t(:)';
  model.y = yn(:)';

  % Gaussian likelihood model
  model.sigma2 = 0.01;
  model.opt = true;

  % Slow bias
  model.ss{1}.make_ss      = @cf_se_to_ss; % matern52 works
  model.ss{1}.lengthScale  = 560;
  model.ss{1}.magnSigma2   = 0.15;  
  model.ss{1}.opt          = {'magnSigma2','lengthScale'};
    
  % Yearly oscillation
  model.ss{2}.make_ss      = @cf_quasiperiodic_to_ss;
  model.ss{2}.lengthScale  = 1;
  model.ss{2}.magnSigma2   = 0.2;
  model.ss{2}.period       = 365.25; % one year
  model.ss{2}.N            = 8; 
  model.ss{2}.nu           = 3/2;
  model.ss{2}.mlengthScale = 1000;
  model.ss{2}.opt          = {'magnSigma2','lengthScale','mlengthScale'};
  
  % Weekly oscillation
  model.ss{3}.make_ss      = @cf_quasiperiodic_to_ss;
  model.ss{3}.lengthScale  = 0.5;
  model.ss{3}.magnSigma2   = 0.75;
  model.ss{3}.period       = 7; % one week
  model.ss{3}.N            = 8; 
  model.ss{3}.nu           = 3/2;
  model.ss{3}.mlengthScale = 5000;
  model.ss{3}.opt          = {'magnSigma2','lengthScale','mlengthScale'};

  % Faster variations
  model.ss{4}.make_ss      = @cf_matern32_to_ss;
  model.ss{4}.lengthScale  = 61;
  model.ss{4}.magnSigma2   = 0.02;  
  model.ss{4}.opt          = {'magnSigma2','lengthScale'};  
  

%% Optimize hyperparameters
  
  % Options
  options = optimset('GradObj','on');
  options = optimset(options,'TolX', 1e-3,'TolFun',1e-3);
  options = optimset(options,'LargeScale', 'off');
  options = optimset(options,'Display', 'iter');
  options = optimset(options,'DerivativeCheck', 'off');
  
  % Find hyperparameters
  % Here we use 'fminunc' from the optimization toolbox. You could 
  % also use custom routines (such as some bfgs implementation).
  tic
  [model,lik] = ss_optimize(model,'optimizer',@fminunc,'options',options);
  toc
  
  % Show likelihood
  fprintf('Model marginal likelihood: %.2f \n',lik)
  
  
%% ... or load the pre-calculated results
%
%  load model.mat
%  
%  
%% Predict values

  % Predict everything
  xt = t;
  [meanf,Varf] = ss_predict(model,'xt',xt,'components',true);
  

%% Make figure
  
  % The trends over the whole time period
  trend_mean = meanf([1 4],:);
  trend_var  = squeeze(Varf([1 4],[1 4],:));

  % Day of week effect (1972,1980,1988)
  ind_1972  = (d.year==1972);
    ind_1972 = find((d.day_of_week==1) & ind_1972,1,'first')+(0:6);
  ind_1980  = (d.year==1980);
    ind_1980 = find((d.day_of_week==1) & ind_1980,1,'first')+(0:6);
  ind_1988  = (d.year==1988);
    ind_1988 = find((d.day_of_week==1) & ind_1988,1,'first')+(0:6);
  week_mean = [meanf(3,ind_1972); meanf(3,ind_1980); meanf(3,ind_1988)];
  week_var  = [squeeze(Varf(3,3,ind_1972)); ...
               squeeze(Varf(3,3,ind_1980)); ...
               squeeze(Varf(3,3,ind_1988))];
  
  % Time of year effect (1972,1980,1988)
  ind_1972  = (d.year==1972);
  ind_1980  = (d.year==1980);
  ind_1988  = (d.year==1988);
  year_mean = [meanf(2,ind_1972); meanf(2,ind_1980); meanf(2,ind_1988)];
  year_var  = [squeeze(Varf(2,2,ind_1972)); ...
               squeeze(Varf(2,2,ind_1980)); ...
               squeeze(Varf(2,2,ind_1988))];
  
  % Axis limits
  lims = [78 115];
  
  % Make new figure window
  figure(2); clf
  
  % Trends
  subplot(311); hold on
  
     % Plot (take every fifth point to save tex memory)
     plot(t(1:5:end),denormdata(trend_mean(1,1:5:end),ymean,ystd)/ymean*100,'--k', ...
          t(1:5:end),denormdata(trend_mean(2,1:5:end),ymean,ystd)/ymean*100,'-k','LineWidth',.7)
     
     % Baseline
     plot([t(1) t(end)],[100 100],'-k','LineWidth',.2)
      
     % Legend
     legend('Slow trend','Fast non-periodic component', ...
            'Location','SE')
     
     % Axis options
     years = 1970:3:1988;
     set(gca,'XGrid','on', ...
             'YTick',80:10:110, ...
             'XTick',datenum(years,1,1),...
             'XTickLabel',years)
     ylabel('Trends')
     xlim([datenum(1969,1,1) datenum(1988,12,31)]); ylim(lims)
     box on
     
  % Yearly effects
  subplot(312); hold on
  
     % Plot
     plot(1:366,denormdata(year_mean(1,:),ymean,ystd)/ymean*100,'-.k', ...
          1:366,denormdata(year_mean(2,:),ymean,ystd)/ymean*100,'--k', ...
          1:366,denormdata(year_mean(3,:),ymean,ystd)/ymean*100,'-k','LineWidth',.7)     
     
     % Baseline
     plot([1 366],[100 100],'-k','LineWidth',.2)      

     % Legend
     legend('1972','1980','1988','Location','SW')
     
     % Axis options
     months = datenum(1972,1:12,1)-datenum(1972,1,1)+1;
     set(gca,'XTick',[months 366], ...
             'XTickLabel',{'Jan','Feb','Mar','Apr','May','Jun', ...
                           'Jul','Aug','Sep','Oct','Nov','Dec',''}, ...
             'XGrid','on')

     ylabel('Seasonal effect')
     xlim([1 366]); ylim(lims)
     box on
     
  % Weekly effects
  subplot(313); hold on

     % Plot
     plot(1:7,denormdata(week_mean(1,:),ymean,ystd)/ymean*100,'-.k', ...
          1:7,denormdata(week_mean(2,:),ymean,ystd)/ymean*100,'--k', ...
          1:7,denormdata(week_mean(3,:),ymean,ystd)/ymean*100,'-k','LineWidth',.7)
           
     % Baseline
     plot([0 8],[100 100],'-k','LineWidth',.2)
      
     % Legend
     legend('1972','1980','1988','Location','SW')
     
     % Axis options
     set(gca,'XTick',1:7, ...
             'XTickLabel',{'Mon','Tue','Wed','Thu','Fri','Sat','Sun'}, ...
             'YTick',80:10:110, ...
             'XGrid','on')  
     ylabel('Day of week effect')
     xlim([.5 7.5]); ylim(lims)
     box on
     
  % Figure window options
  set(gcf,'Color','w')
  
  