function [traj, infStates] = tapas_kf(r, p, varargin) % The scalar Kalman filter % % This function can be called in two ways: % % (1) tapas_kf(r, p) % % where r is the structure generated by tapas_fitModel and p is the parameter vector in native space; % % (2) tapas_kf(r, ptrans, 'trans') % % where r is the structure generated by tapas_fitModel, ptrans is the parameter vector in % transformed space, and 'trans' is a flag indicating this. % % -------------------------------------------------------------------------------------------------- % Copyright (C) 2016 Christoph Mathys, TNU, UZH & ETHZ % % This file is part of the HGF toolbox, which is released under the terms of the GNU General Public % Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL % (either version 3 or, at your option, any later version). For further details, see the file % COPYING or . % Transform paramaters back to their native space if needed if ~isempty(varargin) && strcmp(varargin{1},'trans'); p = tapas_kf_transp(r, p); end % Unpack parameters g_0 = p(1); % Initial gain mu_0 = p(2); % Initial hidden state mean expom = exp(p(3)); % Process variance pi_u = p(4); % Observation precision % Add dummy "zeroth" trial u = [0; r.u(:,1)]; n = length(u); % Initialize updated quantities da = NaN(n,1); % Prediction error g = NaN(n,1); % Kalman gain mu = NaN(n,1); % Hidden state mean % Priors g(1) = g_0; mu(1) = mu_0; % Pass through update loop for k = 2:1:n if not(ismember(k-1, r.ign)) %%%%%%%%%%%%%%%%%%%%%% % Effect of input u(k) %%%%%%%%%%%%%%%%%%%%%% % Prediction error da(k) = u(k)-mu(k-1); % Gain update g(k) = (g(k-1) +pi_u*expom)/(g(k-1) +pi_u*expom +1); % Hidden state mean update mu(k) = mu(k-1)+g(k)*da(k); else da(k) = 0; g(k) = g(k-1); mu(k) = mu(k-1); end end % Predicted value muhat = mu; muhat(end) = []; % Remove priors da(1) = []; g(1) = []; mu(1) = []; % Create result data structure traj = struct; traj.g = g; traj.muhat = muhat; traj.mu = mu; traj.da = da; % Create matrix (in this case: vector) needed by observation model infStates = [traj.muhat, traj.mu]; return;