Hierarchical Gaussian Filter (HGF) model of conditioned hallucinations task (Powers et al 2017)

 Download zip file 
Help downloading and running models
Accession:229278
This is an instantiation of the Hierarchical Gaussian Filter (HGF) model for use with the Conditioned Hallucinations Task.
Reference:
1 . Powers AR, Mathys C, Corlett PR (2017) Pavlovian conditioning-induced hallucinations result from overweighting of perceptual priors. Science 357:596-600 [PubMed]
Citations  Citation Browser
Model Information (Click on a link to find other models with that property)
Model Type:
Brain Region(s)/Organism:
Cell Type(s):
Channel(s):
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: MATLAB;
Model Concept(s): Hallucinations;
Implementer(s): Powers, Al [albert.powers at yale.edu]; Mathys, Chris H ;
/
HGF
analysis
hgfToolBox_condhalluc1.4
README
COPYING *
example_binary_input.txt
example_categorical_input.mat
example_usdchf.txt
Manual.pdf
tapas_autocorr.m
tapas_bayes_optimal.m
tapas_bayes_optimal_binary.m
tapas_bayes_optimal_binary_config.m
tapas_bayes_optimal_binary_transp.m
tapas_bayes_optimal_categorical.m
tapas_bayes_optimal_categorical_config.m
tapas_bayes_optimal_categorical_transp.m
tapas_bayes_optimal_config.m
tapas_bayes_optimal_transp.m
tapas_bayes_optimal_whatworld.m
tapas_bayes_optimal_whatworld_config.m
tapas_bayes_optimal_whatworld_transp.m
tapas_bayes_optimal_whichworld.m
tapas_bayes_optimal_whichworld_config.m
tapas_bayes_optimal_whichworld_transp.m
tapas_bayesian_parameter_average.m
tapas_beta_obs.m
tapas_beta_obs_config.m
tapas_beta_obs_namep.m
tapas_beta_obs_sim.m
tapas_beta_obs_transp.m
tapas_boltzmann.m
tapas_cdfgaussian_obs.m
tapas_cdfgaussian_obs_config.m
tapas_cdfgaussian_obs_transp.m
tapas_condhalluc_obs.m
tapas_condhalluc_obs_config.m
tapas_condhalluc_obs_namep.m
tapas_condhalluc_obs_sim.m
tapas_condhalluc_obs_transp.m
tapas_condhalluc_obs2.m
tapas_condhalluc_obs2_config.m
tapas_condhalluc_obs2_namep.m
tapas_condhalluc_obs2_sim.m
tapas_condhalluc_obs2_transp.m
tapas_Cov2Corr.m
tapas_datagen_categorical.m
tapas_fit_plotCorr.m
tapas_fit_plotResidualDiagnostics.m
tapas_fitModel.m
tapas_gaussian_obs.m
tapas_gaussian_obs_config.m
tapas_gaussian_obs_namep.m
tapas_gaussian_obs_sim.m
tapas_gaussian_obs_transp.m
tapas_hgf.m
tapas_hgf_ar1.m
tapas_hgf_ar1_binary.m
tapas_hgf_ar1_binary_config.m
tapas_hgf_ar1_binary_namep.m
tapas_hgf_ar1_binary_plotTraj.m
tapas_hgf_ar1_binary_transp.m
tapas_hgf_ar1_config.m
tapas_hgf_ar1_mab.m
tapas_hgf_ar1_mab_config.m
tapas_hgf_ar1_mab_plotTraj.m
tapas_hgf_ar1_mab_transp.m
tapas_hgf_ar1_namep.m
tapas_hgf_ar1_plotTraj.m
tapas_hgf_ar1_transp.m
tapas_hgf_binary.m
tapas_hgf_binary_condhalluc_plotTraj.m
tapas_hgf_binary_config.m
tapas_hgf_binary_config_startpoints.m
tapas_hgf_binary_mab.m
tapas_hgf_binary_mab_config.m
tapas_hgf_binary_mab_plotTraj.m
tapas_hgf_binary_mab_transp.m
tapas_hgf_binary_namep.m
tapas_hgf_binary_plotTraj.m
tapas_hgf_binary_pu.m
tapas_hgf_binary_pu_config.m
tapas_hgf_binary_pu_namep.m
tapas_hgf_binary_pu_tbt.m
tapas_hgf_binary_pu_tbt_config.m
tapas_hgf_binary_pu_tbt_namep.m
tapas_hgf_binary_pu_tbt_transp.m
tapas_hgf_binary_pu_transp.m
tapas_hgf_binary_transp.m
tapas_hgf_categorical.m
tapas_hgf_categorical_config.m
tapas_hgf_categorical_namep.m
tapas_hgf_categorical_norm.m
tapas_hgf_categorical_norm_config.m
tapas_hgf_categorical_norm_transp.m
tapas_hgf_categorical_plotTraj.m
tapas_hgf_categorical_transp.m
tapas_hgf_config.m
tapas_hgf_demo.m
tapas_hgf_demo_commands.m
tapas_hgf_jget.m
tapas_hgf_jget_config.m
tapas_hgf_jget_plotTraj.m
tapas_hgf_jget_transp.m
tapas_hgf_namep.m
tapas_hgf_plotTraj.m
tapas_hgf_transp.m
tapas_hgf_whatworld.m
tapas_hgf_whatworld_config.m
tapas_hgf_whatworld_namep.m
tapas_hgf_whatworld_plotTraj.m
tapas_hgf_whatworld_transp.m
tapas_hgf_whichworld.m
tapas_hgf_whichworld_config.m
tapas_hgf_whichworld_namep.m
tapas_hgf_whichworld_plotTraj.m
tapas_hgf_whichworld_transp.m
tapas_hhmm.m
tapas_hhmm_binary_displayResults.m
tapas_hhmm_config.m
tapas_hhmm_transp.m
tapas_hmm.m
tapas_hmm_binary_displayResults.m
tapas_hmm_config.m
tapas_hmm_transp.m
tapas_kf.m
tapas_kf_config.m
tapas_kf_namep.m
tapas_kf_plotTraj.m
tapas_kf_transp.m
tapas_logit.m
tapas_logrt_linear_binary.m
tapas_logrt_linear_binary_config.m
tapas_logrt_linear_binary_minimal.m
tapas_logrt_linear_binary_minimal_config.m
tapas_logrt_linear_binary_minimal_transp.m
tapas_logrt_linear_binary_namep.m
tapas_logrt_linear_binary_sim.m
tapas_logrt_linear_binary_transp.m
tapas_logrt_linear_whatworld.m
tapas_logrt_linear_whatworld_config.m
tapas_logrt_linear_whatworld_transp.m
tapas_ph_binary.m
tapas_ph_binary_config.m
tapas_ph_binary_namep.m
tapas_ph_binary_plotTraj.m
tapas_ph_binary_transp.m
tapas_quasinewton_optim.m
tapas_quasinewton_optim_config.m
tapas_riddersdiff.m
tapas_riddersdiff2.m
tapas_riddersdiffcross.m
tapas_riddersgradient.m
tapas_riddershessian.m
tapas_rs_belief.m
tapas_rs_belief_config.m
tapas_rs_precision.m
tapas_rs_precision_config.m
tapas_rs_precision_whatworld.m
tapas_rs_precision_whatworld_config.m
tapas_rs_surprise.m
tapas_rs_surprise_config.m
tapas_rs_transp.m
tapas_rs_whatworld_transp.m
tapas_rw_binary.m
tapas_rw_binary_config.m
tapas_rw_binary_dual.m
tapas_rw_binary_dual_config.m
tapas_rw_binary_dual_plotTraj.m
tapas_rw_binary_dual_transp.m
tapas_rw_binary_namep.m
tapas_rw_binary_plotTraj.m
tapas_rw_binary_transp.m
tapas_sgm.m
tapas_simModel.m
tapas_softmax.m
tapas_softmax_2beta.m
tapas_softmax_2beta_config.m
tapas_softmax_2beta_transp.m
tapas_softmax_binary.m
tapas_softmax_binary_config.m
tapas_softmax_binary_namep.m
tapas_softmax_binary_sim.m
tapas_softmax_binary_transp.m
tapas_softmax_config.m
tapas_softmax_namep.m
tapas_softmax_sim.m
tapas_softmax_transp.m
tapas_squared_pe.m
tapas_squared_pe_config.m
tapas_squared_pe_transp.m
tapas_sutton_k1_binary.m
tapas_sutton_k1_binary_config.m
tapas_sutton_k1_binary_plotTraj.m
tapas_sutton_k1_binary_transp.m
tapas_unitsq_sgm.m
tapas_unitsq_sgm_config.m
tapas_unitsq_sgm_mu3.m
tapas_unitsq_sgm_mu3_config.m
tapas_unitsq_sgm_mu3_transp.m
tapas_unitsq_sgm_namep.m
tapas_unitsq_sgm_sim.m
tapas_unitsq_sgm_transp.m
                            
function [traj, infStates] = tapas_hgf_jget(r, p, varargin)
% Calculates the trajectories of the agent's representations under the HGF model of the jumping
% Gaussian estimation task (JGET)
%
% This function can be called in two ways:
% 
% (1) tapas_hgf_jget(r, p)
%   
%     where r is the structure generated by fitModel and p is the parameter vector in native space;
%
% (2) tapas_hgf_jget(r, ptrans, 'trans')
% 
%     where r is the structure generated by fitModel, ptrans is the parameter vector in
%     transformed space, and 'trans' is a flag indicating this.
%
% --------------------------------------------------------------------------------------------------
% Copyright (C) 2013 Christoph Mathys, TNU, UZH & ETHZ
%
% This file is part of the HGF toolbox, which is released under the terms of the GNU General Public
% Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL
% (either version 3 or, at your option, any later version). For further details, see the file
% COPYING or <http://www.gnu.org/licenses/>.


% Transform paramaters back to their native space if needed
if ~isempty(varargin) && strcmp(varargin{1},'trans');
    p = tapas_hgf_jget_transp(r, p);
end

% Number of levels
try
    l = r.c_prc.n_levels;
catch
    l = length(p)/8;
    
    if l ~= floor(l)
        error('tapas:hgf:UndetNumLevels', 'Cannot determine number of levels');
    end
end

% Unpack parameters
mux_0 = p(1:l);
sax_0 = p(l+1:2*l);
mua_0 = p(2*l+1:3*l);
saa_0 = p(3*l+1:4*l);
kau   = p(4*l+1);
kax   = p(4*l+2:5*l);
kaa   = p(5*l+1:6*l-1);
omu   = p(6*l);
omx   = p(6*l+1:7*l-1);
oma   = p(7*l+1:8*l-1);
thx   = exp(p(7*l));
tha   = exp(p(8*l));

% Add dummy "zeroth" trial
u = [0; r.u(:,1)];

% Number of trials (including prior)
n = length(u);

% Construct time axis
if r.c_prc.irregular_intervals
    if size(u,2) > 1
        t = [0; r.u(:,end)];
    else
        error('tapas:hgf:InputSingleColumn', 'Input matrix must contain more than one column if irregular_intervals is set to true.');
    end
else
    t = ones(n,1);
end

% Initialize updated quantities

% Representations
mux = NaN(n,l);
pix = NaN(n,l);
mua = NaN(n,l);
pia = NaN(n,l);

% Other quantities
muuhat = NaN(n,1);
piuhat = NaN(n,1);
muxhat = NaN(n,l);
pixhat = NaN(n,l);
muahat = NaN(n,l);
piahat = NaN(n,l);
daux   = NaN(n,1);
daua   = NaN(n,1);
wx     = NaN(n,l-1);
dax    = NaN(n,l-1);
wa     = NaN(n,l-1);
daa    = NaN(n,l-1);

% Representation priors
% Note: first entries of the other quantities remain
% NaN because they are undefined and are thrown away
% at the end; their presence simply leads to consistent
% trial indices.
mux(1,:) = mux_0;
pix(1,:) = 1./sax_0;
mua(1,:) = mua_0;
pia(1,:) = 1./saa_0;

% Representation update loop
% Pass through trials 
for k = 2:1:n
    if not(ismember(k-1, r.ign))
        
        %%%%%%%%%%%%%%%%%%%%%%
        % Effect of input u(k)
        %%%%%%%%%%%%%%%%%%%%%%
        
        % Input level
        % ~~~~~~~~~~~
        % Prediction (same as prediction of x_1, see below)
        muuhat(k) = mux(k-1,1);
        
        % Precision of prediction
        piuhat(k) = 1/exp(kau *mua(k-1,1) +omu);
        
        % Mean prediction error
        daux(k) = u(k) -muuhat(k);
        
        % 1st level
        % ~~~~~~~~~
        % Predictions
        muxhat(k,1) = mux(k-1,1);
        muahat(k,1) = mua(k-1,1);
        
        % Precisions of predictions
        pixhat(k,1) = 1/(1/pix(k-1,1) +t(k) *exp(kax(1) *mux(k-1,2) +omx(1)));
        piahat(k,1) = 1/(1/pia(k-1,1) +t(k) *exp(kaa(1) *mua(k-1,2) +oma(1)));
        
        % x-updates
        pix(k,1) = pixhat(k,1) +piuhat(k);
        mux(k,1) = muxhat(k,1) +piuhat(k)/pix(k,1) *daux(k);

        % Prediction error of input precision
        daua(k) = (1/pix(k,1) +(mux(k,1) -u(k))^2) *piuhat(k) -1;

        % alpha-updates
        pia(k,1) = piahat(k,1) +1/2 *kau^2 *(1 +daua(k));
        mua(k,1) = muahat(k,1) +1/2 *1/pia(k,1) *kau *daua(k);

        % Volatility prediction errors
        dax(k,1) = (1/pix(k,1) +(mux(k,1) -muxhat(k,1))^2) *pixhat(k,1) -1;
        daa(k,1) = (1/pia(k,1) +(mua(k,1) -muahat(k,1))^2) *piahat(k,1) -1;
        
        if l > 2
            % Pass through higher levels
            % ~~~~~~~~~~~~~~~~~~~~~~~~~~
            for j = 2:l-1
                % Predictions
                muxhat(k,j) = mux(k-1,j);
                muahat(k,j) = mua(k-1,j);
                
                % Precisions of predictions
                pixhat(k,j) = 1/(1/pix(k-1,j) +t(k) *exp(kax(j) *mux(k-1,j+1) +omx(j)));
                piahat(k,j) = 1/(1/pia(k-1,j) +t(k) *exp(kaa(j) *mua(k-1,j+1) +oma(j)));

                % Weighting factors
                wx(k,j-1) = t(k) *exp(kax(j-1) *mux(k-1,j) +omx(j-1)) *pixhat(k,j-1);
                wa(k,j-1) = t(k) *exp(kaa(j-1) *mua(k-1,j) +oma(j-1)) *piahat(k,j-1);

                % Updates
                pix(k,j) = pixhat(k,j) +1/2 *kax(j-1)^2 *wx(k,j-1) *(wx(k,j-1) +(2 *wx(k,j-1) -1) *dax(k,j-1));
                pia(k,j) = piahat(k,j) +1/2 *kaa(j-1)^2 *wa(k,j-1) *(wa(k,j-1) +(2 *wa(k,j-1) -1) *daa(k,j-1));

                if pix(k,j) <= 0 || pia(k,j) <= 0
                    error('tapas:hgf:NegPostPrec', 'Negative posterior precision. Parameters are in a region where model assumptions are violated.');
                end

                mux(k,j) = muxhat(k,j) +1/2 *1/pix(k,j) *kax(j-1) *wx(k,j-1) *dax(k,j-1);
                mua(k,j) = muahat(k,j) +1/2 *1/pia(k,j) *kaa(j-1) *wa(k,j-1) *daa(k,j-1);
    
                % Volatility prediction errors
                dax(k,j) = (1/pix(k,j) +(mux(k,j) -muxhat(k,j))^2) *pixhat(k,j) -1;
                daa(k,j) = (1/pia(k,j) +(mua(k,j) -muahat(k,j))^2) *piahat(k,j) -1;
            end
        end

        % Last level
        % ~~~~~~~~~~
        % Predictions
        muxhat(k,l) = mux(k-1,l);
        muahat(k,l) = mua(k-1,l);
        
        % Precision of prediction
        pixhat(k,l) = 1/(1/pix(k-1,l) +t(k) *thx);
        piahat(k,l) = 1/(1/pia(k-1,l) +t(k) *tha);

        % Weighting factor
        wx(k,l-1) = t(k) *exp(kax(l-1) *mux(k-1,l) +omx(l-1)) *pixhat(k,l-1);
        wa(k,l-1) = t(k) *exp(kaa(l-1) *mua(k-1,l) +oma(l-1)) *piahat(k,l-1);
        
        % Updates
        pix(k,l) = pixhat(k,l) +1/2 *kax(l-1)^2 *wx(k,l-1) *(wx(k,l-1) +(2 *wx(k,l-1) -1) *dax(k,l-1));
        pia(k,l) = piahat(k,l) +1/2 *kaa(l-1)^2 *wa(k,l-1) *(wa(k,l-1) +(2 *wa(k,l-1) -1) *daa(k,l-1));

        if pix(k,l) <= 0 || pia(k,l) <= 0
            error('tapas:hgf:NegPostPrec', 'Negative posterior precision. Parameters are in a region where model assumptions are violated.');
        end

        mux(k,l) = muxhat(k,l) +1/2 *1/pix(k,l) *kax(l-1) *wx(k,l-1) *dax(k,l-1);
        mua(k,l) = muahat(k,l) +1/2 *1/pia(k,l) *kaa(l-1) *wa(k,l-1) *daa(k,l-1);
    
        % Volatility prediction error
        dax(k,l) = (1/pix(k,l) +(mux(k,l) -muxhat(k,l))^2) *pixhat(k,l) -1;
        daa(k,l) = (1/pia(k,l) +(mua(k,l) -muahat(k,l))^2) *piahat(k,l) -1;
    else

        mux(k,:) = mux(k-1,:); 
        mua(k,:) = mua(k-1,:); 
        pix(k,:) = pix(k-1,:);
        pia(k,:) = pia(k-1,:);

        muuhat(k) = muuhat(k-1);
        piuhat(k) = piuhat(k-1);

        muxhat(k,:) = muxhat(k-1,:);
        muahat(k,:) = muahat(k-1,:);
        pixhat(k,:) = pixhat(k-1,:);
        piahat(k,:) = piahat(k-1,:);

        daux(k) = daux(k-1);
        daua(k) = daua(k-1);

        wx(k,:)  = wx(k-1,:);
        wa(k,:)  = wa(k-1,:);
        dax(k,:) = dax(k-1,:);
        daa(k,:) = daa(k-1,:);
        
    end
end

% Remove representation priors
mux(1,:)  = [];
mua(1,:)  = [];
pix(1,:)  = [];
pia(1,:)  = [];

% Check validity of trajectories
if any(isnan(mux(:))) || any(isnan(pix(:))) || any(isnan(mua(:))) || any(isnan(pia(:)))
    error('tapas:hgf:VarApproxInvalid', 'Variational approximation invalid. Parameters are in a region where model assumptions are violated.');
else
    % Check for implausible jumps in trajectories
    dmux = diff(mux);
    dmua = diff(mua);
    dpix = diff(pix);
    dpia = diff(pia);
    rmdmux = repmat(sqrt(mean(dmux.^2)),length(dmux),1);
    rmdmua = repmat(sqrt(mean(dmua.^2)),length(dmua),1);
    rmdpix = repmat(sqrt(mean(dpix.^2)),length(dpix),1);
    rmdpia = repmat(sqrt(mean(dpia.^2)),length(dpia),1);

    jumpTol = 256;
    if any(abs(dmux(:)) > jumpTol*rmdmux(:)) || any(abs(dmua(:)) > jumpTol*rmdmua(:)) || any(abs(dpix(:)) > jumpTol*rmdpix(:)) || any(abs(dpia(:)) > jumpTol*rmdpia(:))
        error('tapas:hgf:VarApproxInvalid', 'Variational approximation invalid. Parameters are in a region where model assumptions are violated.');
    end
end

% Remove other dummy initial values
muuhat(1)   = [];
piuhat(1)   = [];
muxhat(1,:) = [];
muahat(1,:) = [];
pixhat(1,:) = [];
piahat(1,:) = [];
wx(1,:)     = [];
wa(1,:)     = [];
daux(1)     = [];
daua(1)     = [];
dax(1,:)    = [];
daa(1,:)    = [];

% Extract learning rates
lrx = NaN(n-1,l);
lra = NaN(n-1,l);

lrx(:,1) = piuhat./pix(:,1);
lrx(:,2:end) = kax./2 *wx./pix(:,2:end);
lra(:,1) = 1/2 *kau./pia(:,1);
lra(:,2:end) = kaa./2 *wa./pia(:,2:end);

% Create result data structure
traj = struct;

traj.mux     = mux;
traj.mua     = mua;
traj.sax     = 1./pix;
traj.saa     = 1./pia;

traj.muuhat  = muuhat;
traj.muxhat  = muxhat;
traj.muahat  = muahat;
traj.sauhat  = 1./piuhat;
traj.saxhat  = 1./pixhat;
traj.saahat  = 1./piahat;

traj.wx      = wx;
traj.wa      = wa;

traj.daux    = daux;
traj.daua    = daua;
traj.dax     = dax;
traj.daa     = daa;

traj.lrx     = lrx;
traj.lra     = lra;

% Create matrices for use by the observation model
infStates = NaN(n-1,1,10);
infStates(:,1,1)  = traj.muuhat;
infStates(:,1,2)  = traj.sauhat;
infStates(:,1,3)  = traj.muxhat(:,1);
infStates(:,1,4)  = traj.saxhat(:,1);
infStates(:,1,5)  = traj.muahat(:,1);
infStates(:,1,6)  = traj.saahat(:,1);
infStates(:,1,7)  = traj.mux(:,1);
infStates(:,1,8)  = traj.sax(:,1);
infStates(:,1,9)  = traj.mua(:,1);
infStates(:,1,10) = traj.saa(:,1);

return;