function [traj, infStates] = tapas_hhmm(r, p, varargin) % Estimates a hierarchical hidden Markov model (HHMM) % % This function can be called in two ways: % % (1) hhmm(r, p) % % where r is the structure generated by fitModel and p is the parameter vector in native space; % % (2) hhmm(r, ptrans, 'trans') % % where r is the structure generated by fitModel, ptrans is the parameter vector in % transformed space, and 'trans' is a flag indicating this. % % -------------------------------------------------------------------------------------------------- % Copyright (C) 2013 Christoph Mathys, TNU, UZH & ETHZ % % This file is part of the HGF toolbox, which is released under the terms of the GNU General Public % Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL % (either version 3 or, at your option, any later version). For further details, see the file % COPYING or . % Transform paramaters back to their native space if needed if ~isempty(varargin) && strcmp(varargin{1},'trans'); [p pstruct] = hhmm_transp(r, p); end % Fixed configuration elements % % Number of possible outcomes m = r.c_prc.n_outcomes; % Get model tree (N is for node) N = pstruct.N; % Check constraints if ~isempty(N{1}.V) error('tapas:hgf:hhmm:IllegEntryProbRoot', 'Illegal entry probability for root node.'); end for id = 1:length(N) if isempty(N{id}.A) == isempty(N{id}.B) error('tapas:hgf:hhmm:IllegCombOfAB', 'Illegal combination of A and B for node no. %d.', id); end if length(N{id}.children(:)) ~= size(N{id}.A,2) error('tapas:hgf:hhmm:NumOfChildIncons', 'Number of children inconsistent with A for node no. %d.', id); end if ~isempty(N{id}.A) && any(sum(N{id}.A,2)>1) error('tapas:hgf:hhmm:IllegA', 'Illegal transition matrix A for node no. %d: row sums have to be less than or equal to 1.', id); end if ~isempty(N{id}.A) for cid = N{id}.children cidx = find(N{id}.children==cid); if ~isempty(N{cid}.children) && N{id}.A(cidx,cidx) ~= 0 error('tapas:hgf:hhmm:IllegASelf', 'Illegal transition matrix A for node no. %d: only production nodes may have self-transitions.', id); end end end if ~isempty(N{id}.B) && sum(N{id}.B(:))~=1 error('tapas:hgf:hhmm:IllegB', 'Illegal outcome contingency vector B for node no. %d.', id); end if ~isempty(N{id}.children) Vsum = 0; for cid = N{id}.children Vsum = Vsum + N{cid}.V; end if Vsum ~= 1 error('tapas:hgf:hhmm:IllegV', 'Illegal vertical transition probabilities V from node no. %d.', id); end end end % Flatten the tree into one large transition matrix % ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ % Find production nodes pn = []; for id = 1:length(N) if isempty(N{id}.children) pn = [pn, id]; end end flatDim = length(pn); % Initialize outcome contingency matrix Bflat = NaN(m,flatDim); % Fill Bflat for i = 1:flatDim Bflat(:,i) = N{pn(i)}.B'; end % Check Bflat if any(~isfinite(Bflat(:))) error('tapas:hgf:hhmm:NoOutConMat', 'Could not construct outcome contingency matrix.'); end % Initialize flattened transitioned matrix Aflat = NaN(flatDim); % Fill Aflat for i = 1:flatDim for j = 1:flatDim if N{pn(i)}.parent == N{pn(j)}.parent % If the production nodes are siblings, read out % their parent's A matrix pid = N{pn(i)}.parent; idx = find(N{pid}.children==pn(i)); jdx = find(N{pid}.children==pn(j)); Aflat(i,j) = N{pid}.A(idx,jdx); else % Otherwise, determine their lowest common ancestor % and use that to calculate the transition probability caid = ca(N,pn(i),pn(j)); tp = 1; nid = pn(i); pid = N{nid}.parent; nidx = find(N{pid}.children==nid); % Move up to one node below lowest common ancestor from start node pn(i) while pid ~= caid aend = 1-sum(N{pid}.A(nidx,:)); tp = tp*aend; nid = pid; pid = N{nid}.parent; nidx = find(N{pid}.children==nid); end % Do the horizontal transition to the ancestral line of target node pn(j) ancj = anc(N,pn(j)); while ancj(1) ~= caid ancj(1) = []; end ancj(1) = []; caidxi = nidx; caidxj = find(N{caid}.children==ancj(1)); tp = tp*N{caid}.A(caidxi,caidxj); % Go down to the target node pn(j) ancj(1) = []; while ~isempty(ancj) tp = tp*N{ancj(1)}.V; ancj(1) = []; end Aflat(i,j) = tp; end end end % Check Aflat if any(~isfinite(Aflat(:))) error('tapas:hgf:hhmm:NoFlatTransMat', 'Could not flatten transition matrix.'); end % Calculate prior probabilities of production nodes pnp = NaN(1,flatDim); for i = 1:flatDim anci = anc(N,pn(i)); anci(1) = []; p = 1; while ~isempty(anci) p = p*N{anci(1)}.V; anci(1) = []; end pnp(i) = p; end % Check pnp if sum(pnp) ~= 1 error('tapas:hgf:hhmm:NoPriorProdNodes', 'Cannot calculate prior probabilities of production nodes.'); end % Input and number of trials u = r.u(:,1); n = length(u); % Initialize alpha-prime alpr = NaN(n,flatDim); % alpr(1,:) altmp = pnp.*Bflat(u(1),:); llh = sum(altmp); alpr(1,:) = altmp./llh; % Pass through alpha-prime update loop for k = 2:1:n if not(ismember(k, r.ign)) %%%%%%%%%%%%%%%%%%%%%% % Effect of input u(k) %%%%%%%%%%%%%%%%%%%%%% altmp = Bflat(u(k),:).*(alpr(k-1,:)*Aflat); llh = sum(altmp); alpr(k,:) = altmp./llh; else alpr(k,:) = alpr(k-1,:); end end % Predicted states alprhat = [pnp; alpr]; alprhat(end,:) = []; % Create result data structure traj = struct; traj.alpr = alpr; traj.alprhat = alprhat; % Create matrix needed by observation model infStates = traj.alpr; end % function hhmm % ---------------------------------------------------------------------------------------- % Find lowest common ancestor of nodes function ca = ca(N,ida,idb) % Find ancestors of ida and idb anca = anc(N,ida); ancb = anc(N,idb); % Determine lowest common ancestor ca = NaN; aa = anca(1); ab = ancb(1); while aa == ab && ~isempty(anca) && ~isempty(ancb) ca = aa; anca(1) = []; ancb(1) = []; aa = anca(1); ab = ancb(1); end if aa == ab ca = aa; end end % ---------------------------------------------------------------------------------------- % Find ancestors of a node function anc = anc(N,id) anc = id; idt = N{id}.parent; while ~isempty(idt) anc = [idt, anc]; idt = N{idt}.parent; end end