clc clear all close all % ************************************************************************* % This script reproduces Figure 2B-G of the manuscript. % Florian Raudies, 01/30/2014, Boston University. % This script will run for about 2 minutes. % ************************************************************************* LABEL_SIZE = 16; % Set seed for random number generator to be able to replicate data. rng(2); LetterLabel = {'A','B','C','D'}; NumberLabel = {'1','2','3','4'}; % Initialize the four quadrant learner. dcl = DoubleContextLearnerDBNaLP(LetterLabel,NumberLabel,40,3); dcl.learn(800,{'A1','B1'}); err = dcl.testError; [A Wc] = dcl.getDBNActivationSortedByWeights(); [nLayer nHidden nState] = size(A); PlotIndex = [1 5 2 6; 9 13 10 14; 3 7 4 8; 11 15 12 16]; Asum = cumsum(repmat(Wc,[nState 1])'.*squeeze(A(end,:,:))); Data = sum(repmat(Wc,[nState 1])'.*squeeze(A(end,:,:))); A = (A-min(A(:)))/(max(A(:))-min(A(:))); Asum = (Asum-min(Asum(:)))/(max(Asum(:))-min(Asum(:))); % ************************************************************************* % Figure of activations within the network. % ************************************************************************* figure('Position',[50 50 1200 600],'PaperPosition',[2 2 12 5],'Name','B'); for iLayer = 1:nLayer, for iHidden = 1:16, iPlot = sub2ind([6 8],2*(iLayer-1)+ceil(iHidden/8),mod(iHidden-1,8)+1); subplot(6,8,iPlot); imshow(reshape(A(iLayer,iHidden,PlotIndex),[4 4]),[0 1]); end end % ************************************************************************* % Figure of activations of the network output. % ************************************************************************* figure('Position',[50 50 1200 200],'PaperPosition',[2 2 12 2],'Name','C'); for iHidden = 1:16, iPlot = sub2ind([2 8],ceil(iHidden/8),mod(iHidden-1,8)+1); subplot(2,8,iPlot); imshow(reshape(Asum(iHidden,PlotIndex),[4 4]),[0 1]); end % ************************************************************************* % Figure of the resclaed 16th sum activation. % ************************************************************************* figure('Name','D'); imshow(reshape(Asum(16,PlotIndex),[4 4]),[],'InitialMagnification',3*10^3); Label = dcl.getLabelBlock(); [~, Label] = max(Label,[],2); Label = Label - 1; % Determine class labels for display purposes. CY = Label > 0; CX = ~CY; % Retrieve parameters and visualize the result. w = dcl.lp.getWeight; theta = dcl.lp.getThreshold; X = -10:2; Y = -w*X+theta; % ************************************************************************* % Figure of the linear perceptron hyperplane and data. % ************************************************************************* figure('Name','E'); h1 = plot(Data(CX),0,'k^','MarkerSize',12); hold on; h2 = plot(Data(CY),0,'ko','MarkerSize',12); h3 = plot(X,Y,'-k','LineWidth',1.5); hold off; legend([h1(1) h2(2) h3],'X','Y','Hyperplane','Location','SouthEast'); axis equal; axis([-10 2 -.5 .5]); xlabel('activation','FontSize',LABEL_SIZE); ylabel('auxiliary','FontSize',LABEL_SIZE); set(gca,'FontSize',LABEL_SIZE); % ************************************************************************* % Train with all stimuli. % ************************************************************************* rng(2); % Initialize the four quadrant learner. dcl = DoubleContextLearnerDBNaLP(LetterLabel,NumberLabel,40,3); dcl.learn(800,{}); err = dcl.testError; [A Wc] = dcl.getDBNActivationSortedByWeights(); [nLayer nHidden nState] = size(A); Asum = cumsum(repmat(Wc,[nState 1])'.*squeeze(A(end,:,:))); Data = sum(repmat(Wc,[nState 1])'.*squeeze(A(end,:,:))); % ************************************************************************* % Figure of the resclaed 16th sum activation. % ************************************************************************* figure('Name','F'); imshow(reshape(Asum(16,PlotIndex),[4 4]),[],'InitialMagnification',3*10^3); Label = dcl.getLabelBlock(); [~, Label] = max(Label,[],2); Label = Label - 1; % Determine class labels for display purposes. CY = Label > 0; CX = ~CY; % Retrieve parameters and visualize the result. w = dcl.lp.getWeight; theta = dcl.lp.getThreshold; X = -10:2; Y = -w*X+theta; % ************************************************************************* % Figure of the linear perceptron hyperplane and data. % ************************************************************************* figure('Name','G'); h1 = plot(Data(CX),0,'k^','MarkerSize',12); hold on; h2 = plot(Data(CY),0,'ko','MarkerSize',12); h3 = plot(X,Y,'-k','LineWidth',1.5); hold off; legend([h1(1) h2(2) h3],'X','Y','Hyperplane','Location','SouthEast'); axis equal; axis([-10 2 -.5 .5]); xlabel('activation','FontSize',LABEL_SIZE); ylabel('auxiliary','FontSize',LABEL_SIZE); set(gca,'FontSize',LABEL_SIZE);