classdef DoubleContextTask < handle % DoubleContextTask % The double context task requires the assocation of 16 stimulus % (A,B,C,D) - context (1,2,3,4) pairs with one of the two responses X % or Y. % % The task is as follows. % ---------------- ----------- % | A1 B1 | A2 B2 | | X X | Y Y | % | C1 D1 | C2 D2 | associate | Y Y | X X | % ---------------- ---------> ----------- % | A3 B3 | A4 B4 | | Y Y | X X | % | C3 D3 | C4 D4 | | X X | Y Y | % ---------------- ----------- % % % Florian Raudies, 01/30/2014, Boston University. properties (SetAccess = protected) LetterLabel NumberLabel StateName DataBlock LabelBlock blockTrain % Train with ordered blocks. end methods % For the double-conext task call with % LetterLabel = {'A','B','C','D'} and % NumberLabel = {'1','2','3','4'} function obj = DoubleContextTask(LetterLabel,NumberLabel) obj.LetterLabel = LetterLabel; obj.NumberLabel = NumberLabel; nLetter = length(obj.LetterLabel); nNumber = length(obj.NumberLabel); obj.StateName = cell(nLetter * nNumber, 1); obj.DataBlock = zeros(nLetter*nNumber,nLetter+nNumber); LabelIndex = zeros(nLetter*nNumber,1); for iLetter = 1:nLetter, letter = obj.LetterLabel{iLetter}; for iNumber = 1:nNumber, iData = sub2ind([nNumber nLetter],iNumber,iLetter); obj.StateName{iData} = [letter, ... obj.NumberLabel{iNumber}]; if iLetter <= nLetter/2, LabelIndex(iData) = iNumber==2 || iNumber==3; else LabelIndex(iData) = ~(iNumber==2 || iNumber==3); end obj.DataBlock(iData,iLetter) = 1; obj.DataBlock(iData,nLetter+iNumber) = 1; end end LabelIndex = 1 + double(LabelIndex); LabelIndex = sub2ind([nLetter*nNumber 2],... (1:nLetter*nNumber)',LabelIndex); obj.LabelBlock = zeros(nLetter*nNumber,2); obj.LabelBlock(LabelIndex) = 1; obj.blockTrain = 0; end function [Data Label] = generateData(obj, nBlock, ExcludeState) [~, Exclude] = ismember(ExcludeState, obj.StateName); Include = setdiff(1:size(obj.DataBlock,1), Exclude); Data = obj.DataBlock(Include,:); Label = obj.LabelBlock(Include,:); if ~obj.blockTrain Data = repmat(Data,nBlock,1); Label = repmat(Label,nBlock,1); Index = randperm(length(Label)); Data = Data(Index,:); Label = Label(Index,:); else Index = arrangeBlocks(size(Data,1),nBlock,1); Data = Data(Index,:); Label = Label(Index,:); end end function Data = getDataBlock(obj) Data = obj.DataBlock; end function Label = getLabelBlock(obj) Label = obj.LabelBlock; end function Data = getDataBlockExclude(obj, ExcludeState) [~, Exclude] = ismember(ExcludeState, obj.StateName); Include = setdiff(1:size(obj.DataBlock,1), Exclude); Data = obj.DataBlock(Include,:); end function Label = getLabelBlockExclude(obj, ExcludeState) [~, Exclude] = ismember(ExcludeState, obj.StateName); Include = setdiff(1:size(obj.DataBlock,1), Exclude); Label = obj.LabelBlock(Include,:); end end end