Deep belief network learns context dependent behavior (Raudies, Zilli, Hasselmo 2014)

 Download zip file 
Help downloading and running models
Accession:194883
We tested a rule generalization capability with a Deep Belief Network (DBN), Multi-Layer Perceptron network, and the combination of a DBN with a linear perceptron (LP). Overall, the combination of the DBN and LP had the highest success rate for generalization.
Reference:
1 . Raudies F, Zilli EA, Hasselmo ME (2014) Deep belief networks learn context dependent behavior. PLoS One 9:e93250 [PubMed]
Model Information (Click on a link to find other models with that property)
Model Type: Connectionist Network;
Brain Region(s)/Organism:
Cell Type(s):
Channel(s):
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: MATLAB;
Model Concept(s):
Implementer(s): Raudies, Florian [florian.raudies at gmail.com];
/
Matlab
screenshots
README.html
DeepBeliefNetwork.m
DoubleContextLearner.m
DoubleContextLearnerDBN.m
DoubleContextLearnerDBNaLP.m
DoubleContextLearnerMLP.m
DoubleContextTask.m
Figure2.m
Figure3A.m
Figure3B.m
Figure3C.m
Figure3D.m
Figure3E.m
Figure3F.m
Figure3G.m
Figure3H.m
Figure4B.m
Figure4C.m
Figure4D.m
gpl-3.0.txt *
LinearPerceptron.m
logistic.m
MultiLayerPerceptronNetwork.m
num2cellstr.m
RestrictedBoltzmannMachine.m
rotateXLabels.m *
                            
classdef DoubleContextTask < handle
    % DoubleContextTask
    % The double context task requires the assocation of 16 stimulus 
    % (A,B,C,D) - context (1,2,3,4) pairs with one of the two responses X
    % or Y.
    %
    % The task is as follows.
    %   ----------------                 -----------
    %   | A1 B1 | A2 B2 |               | X X | Y Y |
    %   | C1 D1 | C2 D2 |  associate    | Y Y | X X |
    %   ----------------   --------->    -----------
    %   | A3 B3 | A4 B4 |               | Y Y | X X |
    %   | C3 D3 | C4 D4 |               | X X | Y Y |
    %   ----------------                 -----------
    %
    %
    %   Florian Raudies, 01/30/2014, Boston University.
    properties (SetAccess = protected)
        LetterLabel 
        NumberLabel
        StateName
        DataBlock
        LabelBlock
        blockTrain  % Train with ordered blocks.
    end    
    methods
        % For the double-conext task call with
        % LetterLabel = {'A','B','C','D'} and 
        % NumberLabel = {'1','2','3','4'}
        function obj = DoubleContextTask(LetterLabel,NumberLabel)
            obj.LetterLabel = LetterLabel;
            obj.NumberLabel = NumberLabel;
            nLetter = length(obj.LetterLabel);
            nNumber = length(obj.NumberLabel);
            obj.StateName   = cell(nLetter * nNumber, 1);
            obj.DataBlock   = zeros(nLetter*nNumber,nLetter+nNumber);
            LabelIndex  = zeros(nLetter*nNumber,1);
            for iLetter = 1:nLetter,
                letter = obj.LetterLabel{iLetter};
                for iNumber = 1:nNumber,
                    iData = sub2ind([nNumber nLetter],iNumber,iLetter);
                    obj.StateName{iData} = [letter, ...
                                            obj.NumberLabel{iNumber}];
                    if iLetter <= nLetter/2,
                        LabelIndex(iData) = iNumber==2 || iNumber==3;
                    else
                        LabelIndex(iData) = ~(iNumber==2 || iNumber==3);
                    end
                    obj.DataBlock(iData,iLetter) = 1;
                    obj.DataBlock(iData,nLetter+iNumber) = 1;
                end
            end
            LabelIndex      = 1 + double(LabelIndex);
            LabelIndex      = sub2ind([nLetter*nNumber 2],...
                                      (1:nLetter*nNumber)',LabelIndex);
            obj.LabelBlock  = zeros(nLetter*nNumber,2);
            obj.LabelBlock(LabelIndex) = 1;
            obj.blockTrain  = 0;
        end
        function [Data Label] = generateData(obj, nBlock, ExcludeState)
            [~, Exclude] = ismember(ExcludeState, obj.StateName);
            Include = setdiff(1:size(obj.DataBlock,1), Exclude);
            Data    = obj.DataBlock(Include,:);
            Label   = obj.LabelBlock(Include,:);
            if ~obj.blockTrain
                Data    = repmat(Data,nBlock,1);
                Label   = repmat(Label,nBlock,1);
                Index   = randperm(length(Label));
                Data    = Data(Index,:);
                Label   = Label(Index,:);
            else
                Index   = arrangeBlocks(size(Data,1),nBlock,1);
                Data    = Data(Index,:);
                Label   = Label(Index,:);
            end
        end
        function Data = getDataBlock(obj)
            Data    = obj.DataBlock;
        end
        function Label = getLabelBlock(obj)
            Label   = obj.LabelBlock;
        end
        function Data = getDataBlockExclude(obj, ExcludeState)
            [~, Exclude] = ismember(ExcludeState, obj.StateName);
            Include = setdiff(1:size(obj.DataBlock,1), Exclude);
            Data    = obj.DataBlock(Include,:);
        end
        function Label = getLabelBlockExclude(obj, ExcludeState)
            [~, Exclude] = ismember(ExcludeState, obj.StateName);
            Include = setdiff(1:size(obj.DataBlock,1), Exclude);
            Label   = obj.LabelBlock(Include,:);
        end
    end
end

Loading data, please wait...