Hierarchical Gaussian Filter (HGF) model of conditioned hallucinations task (Powers et al 2017)

 Download zip file 
Help downloading and running models
Accession:229278
This is an instantiation of the Hierarchical Gaussian Filter (HGF) model for use with the Conditioned Hallucinations Task.
Reference:
1 . Powers AR, Mathys C, Corlett PR (2017) Pavlovian conditioning-induced hallucinations result from overweighting of perceptual priors. Science 357:596-600 [PubMed]
Citations  Citation Browser
Model Information (Click on a link to find other models with that property)
Model Type:
Brain Region(s)/Organism:
Cell Type(s):
Channel(s):
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: MATLAB;
Model Concept(s): Hallucinations;
Implementer(s): Powers, Al [albert.powers at yale.edu]; Mathys, Chris H ;
/
HGF
analysis
hgfToolBox_condhalluc1.4
README
COPYING *
example_binary_input.txt
example_categorical_input.mat
example_usdchf.txt
Manual.pdf
tapas_autocorr.m
tapas_bayes_optimal.m
tapas_bayes_optimal_binary.m
tapas_bayes_optimal_binary_config.m
tapas_bayes_optimal_binary_transp.m
tapas_bayes_optimal_categorical.m
tapas_bayes_optimal_categorical_config.m
tapas_bayes_optimal_categorical_transp.m
tapas_bayes_optimal_config.m
tapas_bayes_optimal_transp.m
tapas_bayes_optimal_whatworld.m
tapas_bayes_optimal_whatworld_config.m
tapas_bayes_optimal_whatworld_transp.m
tapas_bayes_optimal_whichworld.m
tapas_bayes_optimal_whichworld_config.m
tapas_bayes_optimal_whichworld_transp.m
tapas_bayesian_parameter_average.m
tapas_beta_obs.m
tapas_beta_obs_config.m
tapas_beta_obs_namep.m
tapas_beta_obs_sim.m
tapas_beta_obs_transp.m
tapas_boltzmann.m
tapas_cdfgaussian_obs.m
tapas_cdfgaussian_obs_config.m
tapas_cdfgaussian_obs_transp.m
tapas_condhalluc_obs.m
tapas_condhalluc_obs_config.m
tapas_condhalluc_obs_namep.m
tapas_condhalluc_obs_sim.m
tapas_condhalluc_obs_transp.m
tapas_condhalluc_obs2.m
tapas_condhalluc_obs2_config.m
tapas_condhalluc_obs2_namep.m
tapas_condhalluc_obs2_sim.m
tapas_condhalluc_obs2_transp.m
tapas_Cov2Corr.m
tapas_datagen_categorical.m
tapas_fit_plotCorr.m
tapas_fit_plotResidualDiagnostics.m
tapas_fitModel.m
tapas_gaussian_obs.m
tapas_gaussian_obs_config.m
tapas_gaussian_obs_namep.m
tapas_gaussian_obs_sim.m
tapas_gaussian_obs_transp.m
tapas_hgf.m
tapas_hgf_ar1.m
tapas_hgf_ar1_binary.m
tapas_hgf_ar1_binary_config.m
tapas_hgf_ar1_binary_namep.m
tapas_hgf_ar1_binary_plotTraj.m
tapas_hgf_ar1_binary_transp.m
tapas_hgf_ar1_config.m
tapas_hgf_ar1_mab.m
tapas_hgf_ar1_mab_config.m
tapas_hgf_ar1_mab_plotTraj.m
tapas_hgf_ar1_mab_transp.m
tapas_hgf_ar1_namep.m
tapas_hgf_ar1_plotTraj.m
tapas_hgf_ar1_transp.m
tapas_hgf_binary.m
tapas_hgf_binary_condhalluc_plotTraj.m
tapas_hgf_binary_config.m
tapas_hgf_binary_config_startpoints.m
tapas_hgf_binary_mab.m
tapas_hgf_binary_mab_config.m
tapas_hgf_binary_mab_plotTraj.m
tapas_hgf_binary_mab_transp.m
tapas_hgf_binary_namep.m
tapas_hgf_binary_plotTraj.m
tapas_hgf_binary_pu.m
tapas_hgf_binary_pu_config.m
tapas_hgf_binary_pu_namep.m
tapas_hgf_binary_pu_tbt.m
tapas_hgf_binary_pu_tbt_config.m
tapas_hgf_binary_pu_tbt_namep.m
tapas_hgf_binary_pu_tbt_transp.m
tapas_hgf_binary_pu_transp.m
tapas_hgf_binary_transp.m
tapas_hgf_categorical.m
tapas_hgf_categorical_config.m
tapas_hgf_categorical_namep.m
tapas_hgf_categorical_norm.m
tapas_hgf_categorical_norm_config.m
tapas_hgf_categorical_norm_transp.m
tapas_hgf_categorical_plotTraj.m
tapas_hgf_categorical_transp.m
tapas_hgf_config.m
tapas_hgf_demo.m
tapas_hgf_demo_commands.m
tapas_hgf_jget.m
tapas_hgf_jget_config.m
tapas_hgf_jget_plotTraj.m
tapas_hgf_jget_transp.m
tapas_hgf_namep.m
tapas_hgf_plotTraj.m
tapas_hgf_transp.m
tapas_hgf_whatworld.m
tapas_hgf_whatworld_config.m
tapas_hgf_whatworld_namep.m
tapas_hgf_whatworld_plotTraj.m
tapas_hgf_whatworld_transp.m
tapas_hgf_whichworld.m
tapas_hgf_whichworld_config.m
tapas_hgf_whichworld_namep.m
tapas_hgf_whichworld_plotTraj.m
tapas_hgf_whichworld_transp.m
tapas_hhmm.m
tapas_hhmm_binary_displayResults.m
tapas_hhmm_config.m
tapas_hhmm_transp.m
tapas_hmm.m
tapas_hmm_binary_displayResults.m
tapas_hmm_config.m
tapas_hmm_transp.m
tapas_kf.m
tapas_kf_config.m
tapas_kf_namep.m
tapas_kf_plotTraj.m
tapas_kf_transp.m
tapas_logit.m
tapas_logrt_linear_binary.m
tapas_logrt_linear_binary_config.m
tapas_logrt_linear_binary_minimal.m
tapas_logrt_linear_binary_minimal_config.m
tapas_logrt_linear_binary_minimal_transp.m
tapas_logrt_linear_binary_namep.m
tapas_logrt_linear_binary_sim.m
tapas_logrt_linear_binary_transp.m
tapas_logrt_linear_whatworld.m
tapas_logrt_linear_whatworld_config.m
tapas_logrt_linear_whatworld_transp.m
tapas_ph_binary.m
tapas_ph_binary_config.m
tapas_ph_binary_namep.m
tapas_ph_binary_plotTraj.m
tapas_ph_binary_transp.m
tapas_quasinewton_optim.m
tapas_quasinewton_optim_config.m
tapas_riddersdiff.m
tapas_riddersdiff2.m
tapas_riddersdiffcross.m
tapas_riddersgradient.m
tapas_riddershessian.m
tapas_rs_belief.m
tapas_rs_belief_config.m
tapas_rs_precision.m
tapas_rs_precision_config.m
tapas_rs_precision_whatworld.m
tapas_rs_precision_whatworld_config.m
tapas_rs_surprise.m
tapas_rs_surprise_config.m
tapas_rs_transp.m
tapas_rs_whatworld_transp.m
tapas_rw_binary.m
tapas_rw_binary_config.m
tapas_rw_binary_dual.m
tapas_rw_binary_dual_config.m
tapas_rw_binary_dual_plotTraj.m
tapas_rw_binary_dual_transp.m
tapas_rw_binary_namep.m
tapas_rw_binary_plotTraj.m
tapas_rw_binary_transp.m
tapas_sgm.m
tapas_simModel.m
tapas_softmax.m
tapas_softmax_2beta.m
tapas_softmax_2beta_config.m
tapas_softmax_2beta_transp.m
tapas_softmax_binary.m
tapas_softmax_binary_config.m
tapas_softmax_binary_namep.m
tapas_softmax_binary_sim.m
tapas_softmax_binary_transp.m
tapas_softmax_config.m
tapas_softmax_namep.m
tapas_softmax_sim.m
tapas_softmax_transp.m
tapas_squared_pe.m
tapas_squared_pe_config.m
tapas_squared_pe_transp.m
tapas_sutton_k1_binary.m
tapas_sutton_k1_binary_config.m
tapas_sutton_k1_binary_plotTraj.m
tapas_sutton_k1_binary_transp.m
tapas_unitsq_sgm.m
tapas_unitsq_sgm_config.m
tapas_unitsq_sgm_mu3.m
tapas_unitsq_sgm_mu3_config.m
tapas_unitsq_sgm_mu3_transp.m
tapas_unitsq_sgm_namep.m
tapas_unitsq_sgm_sim.m
tapas_unitsq_sgm_transp.m
                            
function [traj, infStates] = tapas_hgf_ar1_mab(r, p, varargin)
% Calculates the trajectories of the agent's representations under the AR(1)-HGF in a multi-armed
% bandit task
%
% This function can be called in two ways:
% 
% (1) tapas_hgf_ar1_mab(r, p)
%   
%     where r is the structure generated by tapas_fitModel and p is the parameter vector in native space;
%
% (2) tapas_hgf_ar1_mab(r, ptrans, 'trans')
% 
%     where r is the structure generated by tapas_fitModel, ptrans is the parameter vector in
%     transformed space, and 'trans' is a flag indicating this.
%
% --------------------------------------------------------------------------------------------------
% Copyright (C) 2013 Christoph Mathys, TNU, UZH & ETHZ
%
% This file is part of the HGF toolbox, which is released under the terms of the GNU General Public
% Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL
% (either version 3 or, at your option, any later version). For further details, see the file
% COPYING or <http://www.gnu.org/licenses/>.


% Transform paramaters back to their native space if needed
if ~isempty(varargin) && strcmp(varargin{1},'trans');
    p = tapas_hgf_ar1_mab_transp(r, p);
end

% Number of levels
try
    l = r.c_prc.n_levels;
catch
    l = length(p)/6;
    
    if l ~= floor(l)
        error('tapas:hgf:UndetNumLevels', 'Cannot determine number of levels');
    end
end

% Number of bandits
try
    b = r.c_prc.n_bandits;
catch
    error('tapas:hgf:NumOfBanditsConfig', 'Number of bandits has to be configured in r.c_prc.n_bandits.');
end

% Unpack parameters
mu_0 = p(1:l);
sa_0 = p(l+1:2*l);
phi  = p(2*l+1:3*l);
m    = p(3*l+1:4*l);
ka   = p(4*l+1:5*l-1);
om   = p(5*l:6*l-2);
th   = exp(p(6*l-1));
al   = p(6*l);

% Add dummy "zeroth" trial
u = [0; r.u(:,1)];
y = [0; r.u(:,2)];

% Number of trials (including prior)
n = size(u,1);

% Construct time axis
if r.c_prc.irregular_intervals
    if size(u,2) > 1
        t = [0; r.u(:,end)];
    else
        error('tapas:hgf:InputSingleColumn', 'Input matrix must contain more than one column if irregular_intervals is set to true.');
    end
else
    t = ones(n,1);
end

% Initialize updated quantities

% Representations
mu = NaN(n,l,b);
pi = NaN(n,l,b);

% Other quantities
muhat = NaN(n,l,b);
pihat = NaN(n,l,b);
v     = NaN(n,l);
w     = NaN(n,l-1);
da    = NaN(n,l);
dau   = NaN(n,1);

% Representation priors
% Note: first entries of the other quantities remain
% NaN because they are undefined and are thrown away
% at the end; their presence simply leads to consistent
% trial indices.
mu(1,:,:) = repmat(mu_0,[1 1 b]);
pi(1,:,:) = repmat(1./sa_0,[1 1 b]);

% Representation update loop
% Pass through trials 
for k = 2:1:n
    if not(ismember(k-1, r.ign))
        
        %%%%%%%%%%%%%%%%%%%%%%
        % Effect of input u(k)
        %%%%%%%%%%%%%%%%%%%%%%
        
        % 1st level
        % ~~~~~~~~~
        % Prediction
        muhat(k,1,:) = mu(k-1,1,:) +t(k) *phi(1) *(m(1) -mu(k-1,1,:));
        
        % Precision of prediction
        pihat(k,1,:) = 1/(1/pi(k-1,1,:) +t(k) *exp(ka(1) *mu(k-1,2,:) +om(1)));
        
        % Input prediction error
        dau(k) = u(k) -muhat(k,1,y(k));
        
        % Updates
        pi(k,1,:) = pihat(k,1,:);
        pi(k,1,y(k)) = pi(k,1,y(k)) +1/al;
        
        mu(k,1,:) = muhat(k,1,:);
        mu(k,1,y(k)) = mu(k,1,y(k)) +1/pihat(k,1,y(k)) *1/(1/pihat(k,1,y(k)) +al) *dau(k);

        % Volatility prediction error
        da(k,1) = (1/pi(k,1,y(k)) +(mu(k,1,y(k)) -muhat(k,1,y(k)))^2) *pihat(k,1,y(k)) -1;
        
        if l > 2
            % Pass through higher levels
            % ~~~~~~~~~~~~~~~~~~~~~~~~~~
            for j = 2:l-1
                % Prediction
                muhat(k,j,:) = mu(k-1,j,:) +t(k) *phi(j) *(m(j) -mu(k-1,j,:));
                
                % Precision of prediction
                pihat(k,j,:) = 1/(1/pi(k-1,j,:) +t(k) *exp(ka(j) *mu(k-1,j+1,:) +om(j)));

                % Weighting factor
                v(k,j-1) = t(k) *exp(ka(j-1) *mu(k-1,j,y(k)) +om(j-1));
                w(k,j-1) = v(k,j-1) *pihat(k,j-1,y(k));

                % Updates
                pi(k,j,:) = pihat(k,j,:) +1/2 *ka(j-1)^2 *w(k,j-1) *(w(k,j-1) +(2 *w(k,j-1) -1) *da(k,j-1));

                if pi(k,j,1) <= 0
                    error('tapas:hgf:NegPostPrec', 'Negative posterior precision. Parameters are in a region where model assumptions are violated.');
                end

                mu(k,j,:) = muhat(k,j,:) +1/2 *1/pi(k,j,:) *ka(j-1) *w(k,j-1) *da(k,j-1);
    
                % Volatility prediction error
                da(k,j) = (1/pi(k,j,y(k)) +(mu(k,j,y(k)) -muhat(k,j,y(k)))^2) *pihat(k,j,y(k)) -1;
            end
        end

        % Last level
        % ~~~~~~~~~~
        % Prediction
        muhat(k,l,:) = mu(k-1,l,:) +t(k) *phi(l) *(m(l) -mu(k-1,l,:));
        
        % Precision of prediction
        pihat(k,l,:) = 1/(1/pi(k-1,l,:) +t(k) *th);

        % Weighting factor
        v(k,l)   = t(k) *th;
        v(k,l-1) = t(k) *exp(ka(l-1) *mu(k-1,l,y(k)) +om(l-1));
        w(k,l-1) = v(k,l-1) *pihat(k,l-1,y(k));
        
        % Updates
        pi(k,l,:) = pihat(k,l,:) +1/2 *ka(l-1)^2 *w(k,l-1) *(w(k,l-1) +(2 *w(k,l-1) -1) *da(k,l-1));

        if pi(k,l,1) <= 0
            error('tapas:hgf:NegPostPrec', 'Negative posterior precision. Parameters are in a region where model assumptions are violated.');
        end

        mu(k,l,:) = muhat(k,l,:) +1/2 *1/pi(k,l,:) *ka(l-1) *w(k,l-1) *da(k,l-1);
    
        % Volatility prediction error
        da(k,l) = (1/pi(k,l,y(k)) +(mu(k,l,y(k)) -muhat(k,l,y(k)))^2) *pihat(k,l,y(k)) -1;
    else

        mu(k,:,:) = mu(k-1,:,:);
        pi(k,:,:) = pi(k-1,:,:);

        muhat(k,:,:) = muhat(k-1,:,:);
        pihat(k,:,:) = pihat(k-1,:,:);
        
        v(k,:)  = v(k-1,:);
        w(k,:)  = w(k-1,:);
        da(k,:) = da(k-1,:);
        
    end
end

% Remove representation priors
mu(1,:,:)  = [];
pi(1,:,:)  = [];

% Check validity of trajectories
if any(isnan(mu(:))) || any(isnan(pi(:)))
    error('tapas:hgf:VarApproxInvalid', 'Variational approximation invalid. Parameters are in a region where model assumptions are violated.');
else
    % Check for implausible jumps in trajectories
    dmu = diff(mu);
    dpi = diff(pi);
    rmdmu = repmat(sqrt(mean(dmu.^2)),length(dmu),1);
    rmdpi = repmat(sqrt(mean(dpi.^2)),length(dpi),1);

    jumpTol = 256;
    if any(abs(dmu(:)) > jumpTol*rmdmu(:)) || any(abs(dpi(:)) > jumpTol*rmdpi(:))
        error('tapas:hgf:VarApproxInvalid', 'Variational approximation invalid. Parameters are in a region where model assumptions are violated.');
    end
end

% Remove other dummy initial values
muhat(1,:,:) = [];
pihat(1,:,:) = [];
v(1,:)       = [];
w(1,:)       = [];
da(1,:)      = [];
dau(1)       = [];

% Create result data structure
traj = struct;

traj.mu     = mu;
traj.sa     = 1./pi;

traj.muhat  = muhat;
traj.sahat  = 1./pihat;

traj.v      = v;
traj.w      = w;
traj.da     = da;
traj.dau    = dau;

% Updates with respect to prediction
traj.ud = mu -muhat;

% Psi (precision weights on prediction errors)
psi      = NaN(n-1,l);
pi1      = squeeze(pi(:,1,:));
pi1obs   = pi1(sub2ind(size(pi1), (1:size(pi1,1))', y));
psi(:,1) = 1./(al*pi1obs);
for i=2:l
    pihati    = squeeze(pihat(:,i-1,:));
    pihatiobs = pihati(sub2ind(size(pihati), (1:size(pihati,1))', y));
    pii       = squeeze(pi(:,i,:));
    piiobs    = pii(sub2ind(size(pii), (1:size(pii,1))', y));
    psi(:,i)  = pihatiobs./piiobs;
end
traj.psi = psi;

% Epsilons (precision-weighted prediction errors)
epsi        = NaN(n-1,l);
epsi(:,1)   = psi(:,1) .*dau;
epsi(:,2:l) = psi(:,2:l) .*da(:,1:l-1);
traj.epsi   = epsi;

% Full learning rate (full weights on prediction errors)
wt        = NaN(n-1,l);
wt(:,1)   = psi(:,1);
wt(:,2:l) = 1/2 *(v(:,1:l-1) *diag(ka(1:l-1))) .*psi(:,2:l);
traj.wt   = wt;

% Create matrices for use by the observation model
infStates = NaN(n-1,l,b,4);
infStates(:,:,:,1) = traj.muhat;
infStates(:,:,:,2) = traj.sahat;
infStates(:,:,:,3) = traj.mu;
infStates(:,:,:,4) = traj.sa;

return;