function [traj, infStates] = tapas_hgf_categorical_norm(r, p, varargin) % Calculates the trajectories of the agent's representations under the HGF for categorical inputs. % % This function can be called in two ways: % % (1) tapas_hgf_categorical(r, p) % % where r is the structure generated by tapas_fitModel and p is the parameter vector in native space; % % (2) tapas_hgf_categorical(r, ptrans, 'trans') % % where r is the structure generated by tapas_fitModel, ptrans is the parameter vector in % transformed space, and 'trans' is a flag indicating this. % % -------------------------------------------------------------------------------------------------- % Copyright (C) 2013 Christoph Mathys, TNU, UZH & ETHZ % % This file is part of the HGF toolbox, which is released under the terms of the GNU General Public % Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL % (either version 3 or, at your option, any later version). For further details, see the file % COPYING or . % Check whether we have a configuration structure if ~isfield(r,'c_prc') error('tapas:hgf:ConfigRequired', 'Configuration required: before calling tapas_hgf_categorical, tapas_hgf_categorical_config has to be called.'); end % Transform paramaters back to their native space if needed if ~isempty(varargin) && strcmp(varargin{1},'trans'); p = tapas_hgf_categorical_transp(r, p); end % Number of states whose contingencies have to be learned no = r.c_prc.n_outcomes; % Unpack parameters mu2_0 = p(1:no); sa2_0 = p(no+1:2*no); mu3_0 = p(2*no+1); sa3_0 = p(2*no+2); ka = p(2*no+3); om = p(2*no+4); th = p(2*no+5); % Add dummy "zeroth" trial u = [1; r.u(:,1)]; % Number of trials (including prior) n = length(u); % Initialize updated quantities % Representations mu1 = NaN(n,no); pi1 = NaN(n,no); mu2 = NaN(n,no); pi2 = NaN(n,no); mu3 = NaN(n,1); pi3 = NaN(n,1); % Other quantities mu1hat = NaN(n,no); pi1hat = NaN(n,no); mu2hat = NaN(n,no); pi2hat = NaN(n,no); mu3hat = NaN(n,1); pi3hat = NaN(n,1); v2 = NaN(n,1); w2 = NaN(n,no); da1 = NaN(n,no); da2 = NaN(n,no); % Representation priors % Note: first entries of the other quantities remain % NaN because they are undefined and are thrown away % at the end; their presence simply leads to consistent % trial indices. mu1(1,:) = tapas_sgm(mu2_0, 1); pi1(1,:) = 1./(mu1(1,:).*(1-mu1(1,:))); mu2(1,:) = mu2_0; pi2(1,:) = 1./sa2_0; mu3(1) = mu3_0; pi3(1) = 1/sa3_0; % Pass through representation update loop for k = 2:1:n if not(ismember(k-1, r.ign)) %%%%%%%%%%%%%%%%%%%%%% % Effect of input u(k) %%%%%%%%%%%%%%%%%%%%%% % 1st level % ~~~~~~~~~ % Normalized predictions mu1hat(k,:) = tapas_boltzmann(mu2(k-1,:), 1); % Precisions of predictions pi1hat(k,:) = 1./(mu1hat(k,:).*(1 -mu1hat(k,:))); % Posterior for each possible transition mu1(k,:) = 0; mu1(k,u(k)) = 1; % Precision of posterior for each possible % transition is infinite owing to absence % of noise pi1(k,:) = Inf; % Prediction errors da1(k,:) = mu1(k,:) -mu1hat(k,:); % 2nd level % ~~~~~~~~~ % Predictions mu2hat(k,:) = mu2(k-1,:); % Precisions of predictions pi2hat(k,:) = 1./(1./pi2(k-1,:) +exp(ka *mu3(k-1) +om)); % Updates pi2(k,:) = pi2hat(k,:) +1./pi1hat(k,:); mu2(k,:) = mu2hat(k,:) +1./pi2(k,:) .*da1(k,:); % Volatility prediction errors da2(k,:) = (1./pi2(k,:) +(mu2(k,:) -mu2hat(k,:)).^2) .*pi2hat(k,:) -1; % 3rd level % ~~~~~~~~~ % Prediction mu3hat(k) = mu3(k-1); % Precision of prediction pi3hat(k) = 1/(1/pi3(k-1) +th); % Weighting factors v2(k) = exp(ka *mu3(k-1) +om); w2(k,:) = v2(k) *pi2hat(k,:); % Updates pi3(k) = pi3hat(k) +sum(1/2 *ka^2 *w2(k,:) .*(w2(k,:) +(2 *w2(k,:) -1) .*da2(k,:))); if pi3(k) <= 0 error('tapas:hgf:NegPostPrec', 'Negative posterior precision. Parameters are in a region where model assumptions are violated.'); end mu3(k) = mu3hat(k) +sum(1/2 *1/pi3(k) *ka *w2(k,:) .*da2(k,:)); else mu1(k,:) = mu1(k-1,:); pi1(k,:) = pi1(k-1,:); mu2(k,:) = mu2(k-1,:); pi2(k,:) = pi2(k-1,:); mu3(k) = mu3(k-1); pi3(k) = pi3(k-1); mu1hat(k,:) = mu1hat(k-1,:); pi1hat(k,:) = pi1hat(k-1,:); mu2hat(k,:) = mu2hat(k-1,:); pi2hat(k,:) = pi2hat(k-1,:); mu3hat(k) = mu3hat(k-1); pi3hat(k) = pi3hat(k-1); v2(k) = v2(k-1); w2(k,:) = w2(k-1,:); da1(k,:) = da1(k-1,:); da2(k,:) = da2(k-1,:); end end % Remove representation priors mu1(1,:) = []; pi1(1,:) = []; mu2(1,:) = []; pi2(1,:) = []; mu3(1) = []; pi3(1) = []; % Remove other dummy initial values mu1hat(1,:) = []; pi1hat(1,:) = []; mu2hat(1,:) = []; pi2hat(1,:) = []; mu3hat(1) = []; pi3hat(1) = []; v2(1) = []; w2(1,:) = []; da1(1,:) = []; da2(1,:) = []; % Create result data structure traj = struct; traj.mu = NaN(n-1,3,no); traj.mu(:,1,:) = mu1; traj.mu(:,2,:) = mu2; traj.mu(:,3,1) = mu3; traj.sa = NaN(n-1,3,no); traj.sa(:,1,:) = 1./pi1; traj.sa(:,2,:) = 1./pi2; traj.sa(:,3,1) = 1./pi3; traj.muhat = NaN(n-1,3,no); traj.muhat(:,1,:) = mu1hat; traj.muhat(:,2,:) = mu2hat; traj.muhat(:,3,1) = mu3hat; traj.sahat = NaN(n-1,3,no); traj.sahat(:,1,:) = 1./pi1hat; traj.sahat(:,2,:) = 1./pi2hat; traj.sahat(:,3,1) = 1./pi3hat; traj.v = v2; traj.w = w2; traj.da = NaN(n-1,2,no); traj.da(:,1,:) = da1; traj.da(:,2,:) = da2; % Updates with respect to prediction traj.ud = traj.mu -traj.muhat; % Psi (precision weights on prediction errors) psi = NaN(n-1,3,no); for k = 1:n-1 psi(k,2,:) = 1./pi2(k,:); psi(k,3,:) = pi2hat(k,:)./pi3(k); end traj.psi = psi; % Epsilons (precision-weighted prediction errors) epsi = NaN(n-1,3,no); for k = 1:n-1 epsi(k,2,:) = squeeze(psi(k,2,:))' .*squeeze(da1(k,:)); epsi(k,3,:) = squeeze(psi(k,3,:))' .*squeeze(da2(k,:)); end traj.epsi = epsi; % Implied learning rates at the first level lr1 = NaN(n-1,no); for k = 1:n-1 upd1 = tapas_sgm(mu2(k,:), 1) -mu1hat(k,:); lr1(k,:) = upd1./da1(k,:); end % Full learning rate (full weights on prediction errors) wt = NaN(n-1,3,no); wt(:,1,:) = lr1; wt(:,2,:) = psi(:,2,:); v2psi = NaN(n-1,no); for k = 1:n-1 v2psi(k,:) = v2(k)*psi(k,3,:); end wt(:,3,:) = 1/2 *ka *v2psi; traj.wt = wt; % Create matrices for use by the observation model infStates = NaN(n-1,3,no,4); infStates(:,:,:,1) = traj.muhat; infStates(:,:,:,2) = traj.sahat; infStates(:,:,:,3) = traj.mu; infStates(:,:,:,4) = traj.sa; return;