Reward modulated STDP (Legenstein et al. 2008)

 Download zip file 
Help downloading and running models
Accession:116837
"... This article provides tools for an analytic treatment of reward-modulated STDP, which allows us to predict under which conditions reward-modulated STDP will achieve a desired learning effect. These analytical results imply that neurons can learn through reward-modulated STDP to classify not only spatial but also temporal firing patterns of presynaptic neurons. They also can learn to respond to specific presynaptic firing patterns with particular spike patterns. Finally, the resulting learning theory predicts that even difficult credit-assignment problems, where it is very hard to tell which synaptic weights should be modified in order to increase the global reward for the system, can be solved in a self-organizing manner through reward-modulated STDP. This yields an explanation for a fundamental experimental result on biofeedback in monkeys by Fetz and Baker. In this experiment monkeys were rewarded for increasing the firing rate of a particular neuron in the cortex and were able to solve this extremely difficult credit assignment problem. ... In addition our model demonstrates that reward-modulated STDP can be applied to all synapses in a large recurrent neural network without endangering the stability of the network dynamics."
Reference:
1 . Legenstein R, Pecevski D, Maass W (2008) A learning theory for reward-modulated spike-timing-dependent plasticity with application to biofeedback. PLoS Comput Biol 4:e1000180 [PubMed]
Citations  Citation Browser
Model Information (Click on a link to find other models with that property)
Model Type: Realistic Network;
Brain Region(s)/Organism: Neocortex;
Cell Type(s):
Channel(s):
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: Python; PCSIM;
Model Concept(s): Pattern Recognition; Spatio-temporal Activity Patterns; Reinforcement Learning; STDP; Biofeedback; Reward-modulated STDP;
Implementer(s):
from numpy import *

def KappaFunction(t,KappaApos, KappaAneg, KappaTaupos, KappaTauneg, KappaTe, KappaTaupos2  = 3e-3, KappaTauneg2 = 3e-3, KernelType = 'square'):
    if KernelType == 'DblExp':
        if t + KappaTe > 0:
            return KappaApos * ( exp(- (t + KappaTe) / KappaTaupos ) - exp( - (t + KappaTe)/KappaTaupos2) )
        else:
            return KappaAneg * ( exp(- (-t - KappaTe) / KappaTauneg ) - exp( - (- t - KappaTe)/KappaTauneg2) )
    else:
        if t > KappaTaupos - KappaTe:
            return 0
        elif t > - KappaTe:
            return KappaApos
        elif t > - KappaTauneg - KappaTe:
            return KappaAneg
        return 0
    return 0


def optimal_Te_value(maxValue,TeDT,KappaApos,KappaAneg,KappaTaupos,KappaTauneg,KappaTaupos2,KappaTauneg2,KernelType,synTau):
    DT = 1e-5    
    LeftMargin = -0.3
    RightMargin = 0.3 + DT
    EpsilonArray = array( [ 0 for t in arange(LeftMargin, 0, DT)] + [ 1/synTau * exp(-t/synTau) for t in arange(0, RightMargin, DT) ])        
    EpsilonArray = EpsilonArray[::-1]
    KappaArray = array([ KappaFunction(t, KappaApos, KappaAneg, KappaTaupos, KappaTauneg, 0, KappaTaupos2, KappaTauneg2, KernelType) for t in arange(LeftMargin, RightMargin+maxValue, DT) ])
    prevSign = '1'
    for KappaTe in arange(0,maxValue,TeDT):
        Sign = sum(EpsilonArray * KappaArray[int(KappaTe/DT):int(KappaTe/DT)+len(EpsilonArray)]) * DT > 0
        if not prevSign == '1':
            if not prevSign == Sign:
                return KappaTe
        prevSign = Sign
    

def WStdpFunction(t,stdpApos,stdpAneg, stdpTaupos, stdpTauneg):    
    if t >= 0:
        return stdpApos * exp(-t/stdpTaupos)
    else:
        return stdpAneg * exp(t/stdpTauneg)


def checkConstraints(synTau, NumSyn, ratioStrong, Wmax, inputRate, Rbase, 
                     stdpApos, stdpAneg, stdpTaupos, stdpTauneg, 
                     DAStdpRate, DATraceDelay, DATraceTau, DATraceShape, rewardDelay, 
                     KappaApos, KappaAneg, KappaTaupos, KappaTauneg, KappaTe, numAdditionalTargetSynapses, KappaTaupos2 = 3e-3, KappaTauneg2 = 3e-3, KernelType = 'square'):
    if stdpTaupos == synTau:
        synTau = synTau * 1.0001
    if stdpTauneg == synTau:
        synTau = synTau * 1.0001

    DT = 1e-5
    
    LeftMargin = -0.3
    RightMargin = 0.3 + DT
    
    KappaArray = array([ KappaFunction(t, KappaApos, KappaAneg, KappaTaupos, KappaTauneg, KappaTe, KappaTaupos2, KappaTauneg2, KernelType) for t in arange(LeftMargin, RightMargin, DT) ])
    
    EpsilonArray = array( [ 0 for t in arange(LeftMargin, 0, DT)] + [ 1/synTau * exp(-t/synTau) for t in arange(0, RightMargin, DT) ])
    
    WstdpArray = array([ WStdpFunction(t,stdpApos,stdpAneg, stdpTaupos, stdpTauneg) for t in arange(LeftMargin, RightMargin, DT) ])
    
    EpsilonKappaArray = convolve(KappaArray, EpsilonArray, 'same') * DT
    
    IntWstdpEpsKappaNumeric = sum(WstdpArray * EpsilonKappaArray) * DT
    
    print "IntWstdpEpsKappaNumeric = ", IntWstdpEpsKappaNumeric
    
    IntWstdpEpsilonEpsKappaNumeric = sum(WstdpArray * EpsilonArray * EpsilonKappaArray) * DT
    
    print "IntWstdpEpsilonEpsKappaNumeric = ", IntWstdpEpsilonEpsKappaNumeric
    
    IntEpsilonEpsilonKappaNumeric = sum(EpsilonArray * EpsilonKappaArray) * DT
    
    print "IntEpsilonEpsilonKappaNumeric = ", IntEpsilonEpsilonKappaNumeric
    
    IntWstdpEpsilonNumeric = sum(WstdpArray * EpsilonArray) * DT
    
    print "IntWstdpEpsilonNumeric = ", IntWstdpEpsilonNumeric
    
    IntKappaNumeric = sum(KappaArray) * DT
    
    print "IntKappaNumeric = ", IntKappaNumeric
    
    IntWstdpNumeric = sum(WstdpArray) * DT
    
    print "IntWstdpNumeric = ", IntWstdpNumeric
    
    print "Rbase = ", Rbase
    print "Wmax = ", Wmax
    print "NumSyn = ", NumSyn
    print "inputRate = ", inputRate
    
    
    Rpost = Rbase + Wmax * NumSyn * inputRate * 1.0 / 2.0
    
    print "Rpost = ", Rpost
    
    Rstar = Wmax * ratioStrong * (NumSyn + numAdditionalTargetSynapses) * inputRate
    
    print "Rstar = ", Rstar
    
    Weight = Wmax / 2
    
    VarianceFactor = Rstar * Rpost * KappaTaupos * stdpApos * 1e5
    
    
    if DATraceShape == 'exp':
        FcTr = exp(- (rewardDelay - DATraceDelay) / DATraceTau )
    elif DATraceShape == 'alpha':
        FcTr = (rewardDelay - DATraceDelay) / DATraceTau * exp( - (rewardDelay - DATraceDelay) / DATraceTau )    
    
    IntFc = DATraceTau    
    
    
    Term1 = IntKappaNumeric * ( Rpost * IntWstdpNumeric + Weight * IntWstdpEpsilonNumeric ) * ( Rpost * Rstar * IntFc +  FcTr * (Rstar + Rstar * Weight +  Rpost * Wmax))
    
    Term2 = FcTr * Wmax * Rpost * IntWstdpEpsKappaNumeric
    
    Term3 = FcTr * Wmax * Weight * IntWstdpEpsilonEpsKappaNumeric
    
    Term4 = FcTr * Wmax * Weight * Rpost * IntWstdpNumeric * IntEpsilonEpsilonKappaNumeric
    
    Term5 = FcTr * Wmax * Weight * Weight * IntWstdpEpsilonNumeric * IntEpsilonEpsilonKappaNumeric
    
    print "Term1 = ", Term1
    print "Term2 = ", Term2
    print "Term3 = ", Term3
    print "Term4 = ", Term4
    print "Term5 = ", Term5
    
    
    DW_over_DT_strong = Term1 + Term2 + Term3 + Term4 + Term5
    
    DW_over_DT_strong = DAStdpRate * inputRate * DW_over_DT_strong
    
    
    print "********************************************************"    
    print "DW_over_DT_strong = ", DW_over_DT_strong
    
    Tconv_strong = Wmax / 2.0 / DW_over_DT_strong
    
    print "Tconv strong = ", Tconv_strong
    
    print "Tconv_strong in hours = ", Tconv_strong / 3600
    
    
    DW_over_DT_weak = DAStdpRate * IntKappaNumeric * IntFc * Rstar * inputRate * Rpost * ( Rpost * IntWstdpNumeric + Weight * IntWstdpEpsilonNumeric ) +\
                      IntKappaNumeric * FcTr * inputRate * ( Rpost * IntWstdpNumeric + Weight * IntWstdpEpsilonNumeric ) * ( Rstar + Rstar * Weight )
                      
    
    print "DW_over_DT_weak = ", DW_over_DT_weak    
    
    Tconv_weak = Wmax / 2 / -DW_over_DT_weak
    
    print "Tconv_weak = ", Tconv_weak
    
    print "Tconv_weak in hours = ", Tconv_weak / 3600
                                                
    IntWEpsKappa1 = (-stdpAneg) * (-KappaAneg) * stdpTauneg * exp(-KappaTe * stdpTauneg ) 
    IntWEpsKappa2 = (-stdpAneg) * KappaApos * (  stdpTauneg * synTau / (stdpTauneg - synTau) * (exp(KappaTe * (stdpTauneg - synTau) / (stdpTauneg * synTau) ) - 1 ) +\
                     stdpTauneg * (exp(-KappaTe / stdpTauneg ) - 1) )
    IntWEpsKappa3 = stdpApos * KappaApos * ( stdpTaupos - stdpTaupos * synTau / (stdpTaupos + synTau)  )
    
    IntWEpsKappa = IntWEpsKappa1 + IntWEpsKappa2 + IntWEpsKappa3
                                                
    print "IntWepsKappa = ", IntWEpsKappa                                            
                                                
    IntWstdp = abs(stdpApos * stdpTaupos + stdpAneg * stdpTauneg)
    
    print "IntWstdp = ", IntWstdp
    
    IntWstdpEpsil = stdpApos * stdpTaupos / (synTau + stdpTaupos)
    
    print "IntWstdpEpsil", IntWstdpEpsil
    
    IntKappa = KappaApos * KappaTaupos + KappaAneg * KappaTauneg
    
    print "IntKappa = ", IntKappa
    
    if DATraceShape == 'exp':
        IntFcOverFcTr = DATraceTau / exp(- (rewardDelay - DATraceDelay) / DATraceTau )
    elif DATraceShape == 'alpha':
        IntFcOverFcTr = DATraceTau / (rewardDelay - DATraceDelay) * exp( - (rewardDelay - DATraceDelay) / DATraceTau )
        
    
    RatePostMax = Rbase + NumSyn * Wmax * inputRate
    
    minLearningNrnRate = Rbase + Wmax * NumSyn * (1.0 / 4.0) * inputRate
    
    targetNrnRate = Wmax * NumSyn * ratioStrong * inputRate
    
    
    print "********************************************************"
    
    
    LeftSide_weak = minLearningNrnRate * abs(IntWstdpNumeric)
    RightSide_weak = IntWstdpEpsil * Wmax
    
    
    LeftSide_strong = IntWstdpEpsKappaNumeric
    RightSide_strong = abs(IntWstdpNumeric) * IntKappaNumeric * targetNrnRate / Wmax * ( Rpost * IntFcOverFcTr  + 1)
    
    
    return [ DW_over_DT_strong , DW_over_DT_weak, Tconv_strong, Tconv_weak, LeftSide_weak, RightSide_weak, LeftSide_strong, RightSide_strong ]