function [traj, infStates] = tapas_rw_binary_dual(r, p, varargin) % Calculates the trajectories of v under the Rescorla-Wagner learning model for dual updates. % % This function can be called in two ways: % % (1) tapas_rw_binary_dual(r, p) % % where r is the structure generated by tapas_fitModel and p is the parameter vector in native space; % % (2) tapas_rw_binary_dual(r, ptrans, 'trans') % % where r is the structure generated by tapas_fitModel, ptrans is the parameter vector in % transformed space, and 'trans' is a flag indicating this. % % -------------------------------------------------------------------------------------------------- % Copyright (C) 2012-2013 Christoph Mathys, TNU, UZH & ETHZ % % This file is part of the HGF toolbox, which is released under the terms of the GNU General Public % Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL % (either version 3 or, at your option, any later version). For further details, see the file % COPYING or . % Transform paramaters back to their native space if needed if ~isempty(varargin) && strcmp(varargin{1},'trans'); p = tapas_rw_binary_dual_transp(r, p); end % Unpack parameters v_0 = p(1:2); al = p(3); ka = p(4); % Add dummy "zeroth" trial u = [0; r.u(:,1)]; y = [0; r.y(:,1)]; n = length(u); % Initialize updated quantity: value v = NaN(n,2); da = NaN(n,2); % Prior v(1,:) = v_0; % Pass through value update loop for k = 2:1:n if not(ismember(k, r.ign)) %%%%%%%%%%%%%%%%%%%%%% % Effect of input u(k) %%%%%%%%%%%%%%%%%%%%%% % Prediction error if u(k)==1 da(k,y(k)) = 1 -v(k-1,y(k)); da(k,3-y(k)) = 0 -v(k-1,3-y(k)); elseif u(k)==0 da(k,y(k)) = 0 -v(k-1,y(k)); da(k,3-y(k)) = 1 -v(k-1,3-y(k)); end % Value v(k,y(k)) = v(k-1,y(k)) +al*da(k,y(k)); v(k,3-y(k)) = v(k-1,3-y(k)) +ka*al*da(k,3-y(k)); else da(k,:) = [0, 0]; v(k,:) = v(k-1,:); end end % Predicted value vhat = v; vhat(end,:) = []; % Remove representation priors v(1,:) = []; da(1,:) = []; % Create result data structure traj = struct; traj.v = v; traj.vhat = vhat; traj.da = da; % Create matrix (in this case: vector) needed by observation model infstates = NaN(n-1,1,2,1,1); infStates(:,1,1,1,1) = vhat(:,1); infStates(:,1,2,1,1) = vhat(:,2); return;