Hierarchical Gaussian Filter (HGF) model of conditioned hallucinations task (Powers et al 2017)

 Download zip file 
Help downloading and running models
Accession:229278
This is an instantiation of the Hierarchical Gaussian Filter (HGF) model for use with the Conditioned Hallucinations Task.
Reference:
1 . Powers AR, Mathys C, Corlett PR (2017) Pavlovian conditioning-induced hallucinations result from overweighting of perceptual priors. Science 357:596-600 [PubMed]
Citations  Citation Browser
Model Information (Click on a link to find other models with that property)
Model Type:
Brain Region(s)/Organism:
Cell Type(s):
Channel(s):
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: MATLAB;
Model Concept(s): Hallucinations;
Implementer(s): Powers, Al [albert.powers at yale.edu]; Mathys, Chris H ;
/
HGF
analysis
hgfToolBox_condhalluc1.4
README
COPYING *
example_binary_input.txt
example_categorical_input.mat
example_usdchf.txt
Manual.pdf
tapas_autocorr.m
tapas_bayes_optimal.m
tapas_bayes_optimal_binary.m
tapas_bayes_optimal_binary_config.m
tapas_bayes_optimal_binary_transp.m
tapas_bayes_optimal_categorical.m
tapas_bayes_optimal_categorical_config.m
tapas_bayes_optimal_categorical_transp.m
tapas_bayes_optimal_config.m
tapas_bayes_optimal_transp.m
tapas_bayes_optimal_whatworld.m
tapas_bayes_optimal_whatworld_config.m
tapas_bayes_optimal_whatworld_transp.m
tapas_bayes_optimal_whichworld.m
tapas_bayes_optimal_whichworld_config.m
tapas_bayes_optimal_whichworld_transp.m
tapas_bayesian_parameter_average.m
tapas_beta_obs.m
tapas_beta_obs_config.m
tapas_beta_obs_namep.m
tapas_beta_obs_sim.m
tapas_beta_obs_transp.m
tapas_boltzmann.m
tapas_cdfgaussian_obs.m
tapas_cdfgaussian_obs_config.m
tapas_cdfgaussian_obs_transp.m
tapas_condhalluc_obs.m
tapas_condhalluc_obs_config.m
tapas_condhalluc_obs_namep.m
tapas_condhalluc_obs_sim.m
tapas_condhalluc_obs_transp.m
tapas_condhalluc_obs2.m
tapas_condhalluc_obs2_config.m
tapas_condhalluc_obs2_namep.m
tapas_condhalluc_obs2_sim.m
tapas_condhalluc_obs2_transp.m
tapas_Cov2Corr.m
tapas_datagen_categorical.m
tapas_fit_plotCorr.m
tapas_fit_plotResidualDiagnostics.m
tapas_fitModel.m
tapas_gaussian_obs.m
tapas_gaussian_obs_config.m
tapas_gaussian_obs_namep.m
tapas_gaussian_obs_sim.m
tapas_gaussian_obs_transp.m
tapas_hgf.m
tapas_hgf_ar1.m
tapas_hgf_ar1_binary.m
tapas_hgf_ar1_binary_config.m
tapas_hgf_ar1_binary_namep.m
tapas_hgf_ar1_binary_plotTraj.m
tapas_hgf_ar1_binary_transp.m
tapas_hgf_ar1_config.m
tapas_hgf_ar1_mab.m
tapas_hgf_ar1_mab_config.m
tapas_hgf_ar1_mab_plotTraj.m
tapas_hgf_ar1_mab_transp.m
tapas_hgf_ar1_namep.m
tapas_hgf_ar1_plotTraj.m
tapas_hgf_ar1_transp.m
tapas_hgf_binary.m
tapas_hgf_binary_condhalluc_plotTraj.m
tapas_hgf_binary_config.m
tapas_hgf_binary_config_startpoints.m
tapas_hgf_binary_mab.m
tapas_hgf_binary_mab_config.m
tapas_hgf_binary_mab_plotTraj.m
tapas_hgf_binary_mab_transp.m
tapas_hgf_binary_namep.m
tapas_hgf_binary_plotTraj.m
tapas_hgf_binary_pu.m
tapas_hgf_binary_pu_config.m
tapas_hgf_binary_pu_namep.m
tapas_hgf_binary_pu_tbt.m
tapas_hgf_binary_pu_tbt_config.m
tapas_hgf_binary_pu_tbt_namep.m
tapas_hgf_binary_pu_tbt_transp.m
tapas_hgf_binary_pu_transp.m
tapas_hgf_binary_transp.m
tapas_hgf_categorical.m
tapas_hgf_categorical_config.m
tapas_hgf_categorical_namep.m
tapas_hgf_categorical_norm.m
tapas_hgf_categorical_norm_config.m
tapas_hgf_categorical_norm_transp.m
tapas_hgf_categorical_plotTraj.m
tapas_hgf_categorical_transp.m
tapas_hgf_config.m
tapas_hgf_demo.m
tapas_hgf_demo_commands.m
tapas_hgf_jget.m
tapas_hgf_jget_config.m
tapas_hgf_jget_plotTraj.m
tapas_hgf_jget_transp.m
tapas_hgf_namep.m
tapas_hgf_plotTraj.m
tapas_hgf_transp.m
tapas_hgf_whatworld.m
tapas_hgf_whatworld_config.m
tapas_hgf_whatworld_namep.m
tapas_hgf_whatworld_plotTraj.m
tapas_hgf_whatworld_transp.m
tapas_hgf_whichworld.m
tapas_hgf_whichworld_config.m
tapas_hgf_whichworld_namep.m
tapas_hgf_whichworld_plotTraj.m
tapas_hgf_whichworld_transp.m
tapas_hhmm.m
tapas_hhmm_binary_displayResults.m
tapas_hhmm_config.m
tapas_hhmm_transp.m
tapas_hmm.m
tapas_hmm_binary_displayResults.m
tapas_hmm_config.m
tapas_hmm_transp.m
tapas_kf.m
tapas_kf_config.m
tapas_kf_namep.m
tapas_kf_plotTraj.m
tapas_kf_transp.m
tapas_logit.m
tapas_logrt_linear_binary.m
tapas_logrt_linear_binary_config.m
tapas_logrt_linear_binary_minimal.m
tapas_logrt_linear_binary_minimal_config.m
tapas_logrt_linear_binary_minimal_transp.m
tapas_logrt_linear_binary_namep.m
tapas_logrt_linear_binary_sim.m
tapas_logrt_linear_binary_transp.m
tapas_logrt_linear_whatworld.m
tapas_logrt_linear_whatworld_config.m
tapas_logrt_linear_whatworld_transp.m
tapas_ph_binary.m
tapas_ph_binary_config.m
tapas_ph_binary_namep.m
tapas_ph_binary_plotTraj.m
tapas_ph_binary_transp.m
tapas_quasinewton_optim.m
tapas_quasinewton_optim_config.m
tapas_riddersdiff.m
tapas_riddersdiff2.m
tapas_riddersdiffcross.m
tapas_riddersgradient.m
tapas_riddershessian.m
tapas_rs_belief.m
tapas_rs_belief_config.m
tapas_rs_precision.m
tapas_rs_precision_config.m
tapas_rs_precision_whatworld.m
tapas_rs_precision_whatworld_config.m
tapas_rs_surprise.m
tapas_rs_surprise_config.m
tapas_rs_transp.m
tapas_rs_whatworld_transp.m
tapas_rw_binary.m
tapas_rw_binary_config.m
tapas_rw_binary_dual.m
tapas_rw_binary_dual_config.m
tapas_rw_binary_dual_plotTraj.m
tapas_rw_binary_dual_transp.m
tapas_rw_binary_namep.m
tapas_rw_binary_plotTraj.m
tapas_rw_binary_transp.m
tapas_sgm.m
tapas_simModel.m
tapas_softmax.m
tapas_softmax_2beta.m
tapas_softmax_2beta_config.m
tapas_softmax_2beta_transp.m
tapas_softmax_binary.m
tapas_softmax_binary_config.m
tapas_softmax_binary_namep.m
tapas_softmax_binary_sim.m
tapas_softmax_binary_transp.m
tapas_softmax_config.m
tapas_softmax_namep.m
tapas_softmax_sim.m
tapas_softmax_transp.m
tapas_squared_pe.m
tapas_squared_pe_config.m
tapas_squared_pe_transp.m
tapas_sutton_k1_binary.m
tapas_sutton_k1_binary_config.m
tapas_sutton_k1_binary_plotTraj.m
tapas_sutton_k1_binary_transp.m
tapas_unitsq_sgm.m
tapas_unitsq_sgm_config.m
tapas_unitsq_sgm_mu3.m
tapas_unitsq_sgm_mu3_config.m
tapas_unitsq_sgm_mu3_transp.m
tapas_unitsq_sgm_namep.m
tapas_unitsq_sgm_sim.m
tapas_unitsq_sgm_transp.m
                            
function [traj, infStates] = tapas_hgf_whatworld(r, p, varargin)
% Calculates the trajectories of the agent's representations under the HGF
%
% This function can be called in two ways:
% 
% (1) tapas_hgf_whatworld(r, p)
%   
%     where r is the structure generated by tapas_fitModel and p is the parameter vector in native space;
%
% (2) tapas_hgf_whatworld(r, ptrans, 'trans')
% 
%     where r is the structure generated by tapas_fitModel, ptrans is the parameter vector in
%     transformed space, and 'trans' is a flag indicating this.
%
% --------------------------------------------------------------------------------------------------
% Copyright (C) 2013 Christoph Mathys, TNU, UZH & ETHZ
%
% This file is part of the HGF toolbox, which is released under the terms of the GNU General Public
% Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL
% (either version 3 or, at your option, any later version). For further details, see the file
% COPYING or <http://www.gnu.org/licenses/>.

% Check whether we have a configuration structure
if ~isfield(r,'c_prc')
    error('tapas:hgf:ConfigRequired', 'Configuration required: before calling tapas_hgf_whatworld, tapas_hgf_whatworld_config has to be called.');
end

% Transform paramaters back to their native space if needed
if ~isempty(varargin) && strcmp(varargin{1},'trans');
    p = tapas_hgf_whatworld_transp(r, p);
end

% Number of states whose contingencies have to be learned
ns = r.c_prc.n_states;

% Number of elements of the transition matrix
ntr = ns^2;

% Unpack parameters
mu2_0 = reshape(p(1:ntr)',ns,ns);
sa2_0 = reshape(p(ntr+1:2*ntr)',ns,ns);
mu3_0 = p(2*ntr+1);
sa3_0 = p(2*ntr+2);
ka    = p(2*ntr+3);
om    = p(2*ntr+4);
th    = p(2*ntr+5);

% Add dummy "zeroth" trial
u = [1; r.u(:,1)];

% Number of trials (including prior)
n = length(u);

% Initialize updated quantities

% Representations
mu1 = NaN(n,ns,ns);
pi1 = NaN(n,ns,ns);
mu2 = NaN(n,ns,ns);
pi2 = NaN(n,ns,ns);
mu3 = NaN(n,1);
pi3 = NaN(n,1);

% Other quantities
mu1hat = NaN(n,ns,ns);
pi1hat = NaN(n,ns,ns);
mu2hat = NaN(n,ns,ns);
pi2hat = NaN(n,ns,ns);
mu3hat = NaN(n,1);
pi3hat = NaN(n,1);
v2     = NaN(n,1);
w2     = NaN(n,ns,ns);
da1    = NaN(n,ns,ns);
da2    = NaN(n,ns,ns);

% Representation priors
% Note: first entries of the other quantities remain
% NaN because they are undefined and are thrown away
% at the end; their presence simply leads to consistent
% trial indices.
mu1(1,:,:) = tapas_sgm(mu2_0, 1);
pi1(1,:,:) = 1./(mu1(1,:,:).*(1-mu1(1,:,:)));
mu2(1,:,:) = mu2_0;
pi2(1,:,:) = 1./sa2_0;
mu3(1)   = mu3_0;
pi3(1)   = 1/sa3_0;

% Pass through representation update loop
for k = 2:1:n
    if not(ismember(k-1, r.ign))
        
        %%%%%%%%%%%%%%%%%%%%%%
        % Effect of input u(k)
        %%%%%%%%%%%%%%%%%%%%%%
        
        % Note: there is only an effect on the column of 
        % the transition matrix that corresponds to the
        % previous outcome u(k-1)
        
        % 1st level
        % ~~~~~~~~~
        % Unnormalized predictions
        mu1hat(k,:,:) = squeeze(tapas_sgm(mu2(k-1,:,:), 1));
        
        % Precisions of predictions
        pi1hat(k,:,:) = 1./(mu1hat(k,:,:).*(1 -mu1hat(k,:,:)));

        % Posterior for each possible transition
        mu1(k,:,u(k-1)) = 0;
        mu1(k,u(k),u(k-1)) = 1;

        % Precision of posterior for each possible
        % transition is infinite owing to absence
        % of noise
        pi1(k,:,u(k-1)) = Inf;
        
        % Prediction errors
        da1(k,:,u(k-1)) = mu1(k,:,u(k-1)) -mu1hat(k,:,u(k-1));

        % 2nd level
        % ~~~~~~~~~
        % Predictions
        mu2hat(k,:,:) = mu2(k-1,:,:);
        
        % Precisions of predictions
        pi2hat(k,:,:) = 1./(1./pi2(k-1,:,:) +exp(ka *mu3(k-1) +om));

        % Updates
        % Without observation, pi2 is equal to pi2hat
        pi2(k,:,:) = pi2hat(k,:,:);
        % However, where we have an observation, the usual update applies
        pi2(k,:,u(k-1)) = pi2hat(k,:,u(k-1)) +1./pi1hat(k,:,u(k-1));

        % By default, carry means of predictions forward
        mu2(k,:,:) = mu2hat(k,:,:);
        % However, where we have a prediction error, perform an update
        mu2(k,:,u(k-1)) = mu2hat(k,:,u(k-1)) +1./pi2(k,:,u(k-1)) .*da1(k,:,u(k-1));

        % Volatility prediction errors
        da2(k,:,u(k-1)) = (1./pi2(k,:,u(k-1)) +(mu2(k,:,u(k-1)) -mu2hat(k,:,u(k-1))).^2) .*pi2hat(k,:,u(k-1)) -1;


        % 3rd level
        % ~~~~~~~~~
        % Prediction
        mu3hat(k) = mu3(k-1);
        
        % Precision of prediction
        pi3hat(k) = 1/(1/pi3(k-1) +th);

        % Weighting factors
        v2(k)          = exp(ka *mu3(k-1) +om);
        w2(k,:,u(k-1)) = v2(k) *pi2hat(k,:,u(k-1));

        % Updates
        pi3(k) = pi3hat(k) +sum(1/2 *ka^2 *w2(k,:,u(k-1)) .*(w2(k,:,u(k-1)) +(2 *w2(k,:,u(k-1)) -1) .*da2(k,:,u(k-1))));

        if pi3(k) <= 0
            error('tapas:hgf:NegPostPrec', 'Negative posterior precision. Parameters are in a region where model assumptions are violated.');
        end

        mu3(k) = mu3hat(k) +sum(1/2 *1/pi3(k) *ka *w2(k,:,u(k-1)) .*da2(k,:,u(k-1)));
    
    else
        mu1(k,:,:) = mu1(k-1,:,:);
        pi1(k,:,:) = pi1(k-1,:,:);
        mu2(k,:,:) = mu2(k-1,:,:);
        pi2(k,:,:) = pi2(k-1,:,:);
        mu3(k)     = mu3(k-1);
        pi3(k)     = pi3(k-1);

        mu1hat(k,:,:) = mu1hat(k-1,:,:);
        pi1hat(k,:,:) = pi1hat(k-1,:,:);
        mu2hat(k,:,:) = mu2hat(k-1,:,:);
        pi2hat(k,:,:) = pi2hat(k-1,:,:);
        mu3hat(k)     = mu3hat(k-1);
        pi3hat(k)     = pi3hat(k-1);
        v2(k)         = v2(k-1);
        w2(k,:,:)     = w2(k-1,:,:);
        da1(k,:,:)    = da1(k-1,:,:);
        da2(k,:,:)    = da2(k-1,:,:);
    end
end

% Remove representation priors
mu1(1,:,:)  = [];
pi1(1,:,:)  = [];
mu2(1,:,:)  = [];
pi2(1,:,:)  = [];
mu3(1)      = [];
pi3(1)      = [];

% Remove other dummy initial values
mu1hat(1,:,:) = [];
pi1hat(1,:,:) = [];
mu2hat(1,:,:) = [];
pi2hat(1,:,:) = [];
mu3hat(1)     = [];
pi3hat(1)     = [];
v2(1)         = [];
w2(1,:,:)     = [];
da1(1,:,:)    = [];
da2(1,:,:)    = [];

% Create result data structure
traj = struct;

traj.mu = NaN(n-1,3,ns,ns);
traj.mu(:,1,:,:) = mu1;
traj.mu(:,2,:,:) = mu2;
traj.mu(:,3,1,1) = mu3;

traj.sa = NaN(n-1,3,ns,ns);
traj.sa(:,1,:,:) = 1./pi1;
traj.sa(:,2,:,:) = 1./pi2;
traj.sa(:,3,1,1) = 1./pi3;

traj.muhat = NaN(n-1,3,ns,ns);
traj.muhat(:,1,:,:) = mu1hat;
traj.muhat(:,2,:,:) = mu2hat;
traj.muhat(:,3,1,1) = mu3hat;

traj.sahat = NaN(n-1,3,ns,ns);
traj.sahat(:,1,:,:) = 1./pi1hat;
traj.sahat(:,2,:,:) = 1./pi2hat;
traj.sahat(:,3,1,1) = 1./pi3hat;

traj.v       = v2;
traj.w       = w2;

traj.da = NaN(n-1,2,ns,ns);
traj.da(:,1,:,:) = da1;
traj.da(:,2,:,:) = da2;

% Updates with respect to prediction
traj.ud = traj.mu -traj.muhat;

% Psi (precision weights on prediction errors)
psi = NaN(n-1,3,ns,ns);
for k = 1:n-1
    psi(k,2,:,:) = 1./pi2(k,:,:);
    psi(k,3,:,:) = pi2hat(k,:,:)./pi3(k);
end
traj.psi = psi;

% Epsilons (precision-weighted prediction errors)
epsi = NaN(n-1,3,ns,ns);
for k = 1:n-1
    epsi(k,2,:,:) = squeeze(psi(k,2,:,:)) .*squeeze(da1(k,:,:));
    epsi(k,3,:,:) = squeeze(psi(k,3,:,:)) .*squeeze(da2(k,:,:));
end
traj.epsi = epsi;

% Implied learning rates at the first level
lr1 = NaN(n-1,ns,ns);
for k = 1:n-1
    upd1       = tapas_sgm(mu2(k,:,:), 1) -mu1hat(k,:,:);
    lr1(k,:,:) = upd1./da1(k,:,:);
end

% Full learning rate (full weights on prediction errors)
wt          = NaN(n-1,3,ns,ns);
wt(:,1,:,:) = lr1;
wt(:,2,:,:) = psi(:,2,:,:);
v2psi = NaN(n-1,ns,ns);
for k = 1:n-1
    v2psi(k,:,:) = v2(k)*psi(k,3,:,:);
end
wt(:,3,:,:) = 1/2 *ka *v2psi;
traj.wt     = wt;

% Create matrices for use by the observation model
infStates = NaN(n-1,3,ns,ns,4);
infStates(:,:,:,:,1) = traj.muhat;
infStates(:,:,:,:,2) = traj.sahat;
infStates(:,:,:,:,3) = traj.mu;
infStates(:,:,:,:,4) = traj.sa;

return;