Reward modulated STDP (Legenstein et al. 2008)

 Download zip file 
Help downloading and running models
Accession:116837
"... This article provides tools for an analytic treatment of reward-modulated STDP, which allows us to predict under which conditions reward-modulated STDP will achieve a desired learning effect. These analytical results imply that neurons can learn through reward-modulated STDP to classify not only spatial but also temporal firing patterns of presynaptic neurons. They also can learn to respond to specific presynaptic firing patterns with particular spike patterns. Finally, the resulting learning theory predicts that even difficult credit-assignment problems, where it is very hard to tell which synaptic weights should be modified in order to increase the global reward for the system, can be solved in a self-organizing manner through reward-modulated STDP. This yields an explanation for a fundamental experimental result on biofeedback in monkeys by Fetz and Baker. In this experiment monkeys were rewarded for increasing the firing rate of a particular neuron in the cortex and were able to solve this extremely difficult credit assignment problem. ... In addition our model demonstrates that reward-modulated STDP can be applied to all synapses in a large recurrent neural network without endangering the stability of the network dynamics."
Reference:
1 . Legenstein R, Pecevski D, Maass W (2008) A learning theory for reward-modulated spike-timing-dependent plasticity with application to biofeedback. PLoS Comput Biol 4:e1000180 [PubMed]
Model Information (Click on a link to find other models with that property)
Model Type: Realistic Network;
Brain Region(s)/Organism: Neocortex;
Cell Type(s):
Channel(s):
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: Python; PCSIM;
Model Concept(s): Pattern Recognition; Spatio-temporal Activity Patterns; Reinforcement Learning; STDP; Biofeedback; Reward-modulated STDP;
Implementer(s):
import sys
import os
import re
sys.path.append('../packages')
from numpy import *
import random, getopt
from datetime import datetime
from math import *
from pylab import *
from tables import *
from math import exp 
from BeforeAfterExperiment import *
from frame import FrameAxes

from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Computer Modern Sans serif']})
rc('text', usetex=True)



dummy_net = SingleThreadNetwork()

 

def display_spike_trains(f, idx, XBeforeMin, XBeforeMax, ax):
    r = constructRecordingsFromH5File(f)
        
    errorbar( (r.before_learning_spikes[idx] - XBeforeMin), 7*ones(len(r.before_learning_spikes[idx])), 0.8 * ones(len(r.before_learning_spikes[idx])), capsize = 0, visible = False, color = 'k')
    errorbar( (r.target_nrn_spikes[idx] - XBeforeMin), 5*ones(len(r.target_nrn_spikes[idx])), 0.8 * ones(len(r.target_nrn_spikes[idx])), capsize = 0, visible = False, color = 'k')
    errorbar( (r.realiz_target_nrn_spikes[idx] - XBeforeMin), 3*ones(len(r.realiz_target_nrn_spikes[idx])), 0.8 * ones(len(r.realiz_target_nrn_spikes[idx])), capsize = 0, visible = False, color = 'k')
    errorbar( (r.after_learning_spikes[idx] - XBeforeMin), ones(len(r.after_learning_spikes[idx])), 0.8 * ones(len(r.after_learning_spikes[idx])), capsize = 0, visible = False, color = 'k')
    
    xlim(0, XBeforeMax - XBeforeMin + 0.04)     
    xlabel("time [sec]")    
    ylim(0,8)
    yticks([])    
    
    text(-0.27, 0.871,'before learning', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)
    
    text(-0.27, 0.715,'target $S^*$', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)
    text(-0.27, 0.625,'(= rewarded', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)
    text(-0.27, 0.535,'spike times)', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)
    
    text(-0.27, 0.385,'realizable part', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)
    
    text(-0.27, 0.298,'of target $S^*$', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)
    
    text(-0.27, 0.125,'after learning', horizontalalignment='center',
         verticalalignment='center', fontsize = 13, transform = ax.transAxes)


def norm_vec(v):
     return sqrt(dot(v,v))
 
def rectangle_kernel(x, width):
    if x < width or x > -width:
        return 1
    return 0
    
def calculate_corr_coeff(spikes, target_spikes, start, end):        
    sigma = 2e-3
    kernel = lambda x:  rectangle_kernel( x, sigma)
         
    kernel_spikes = convolve_spikes(spikes, kernel, 1e-3, start, end, -sigma*10, sigma*10) 
    target_kernel = convolve_spikes(target_spikes, kernel, 1e-3, start, end, -sigma*10, sigma*10)
      
    corr_coeff = dot(kernel_spikes, target_kernel) / ( norm_vec(kernel_spikes) * norm_vec(target_kernel) )
    
    
    return corr_coeff

def generate_spike_corr(h5file, XBeforeMin, XBeforeMax):
    r = constructRecordingsFromH5File(h5file)
    corr = []
    for i in range(len(r.before_learning_spikes)):
        s = clip_window(r.before_learning_spikes[i], XBeforeMin, XBeforeMax, shift = True)
        target_s = clip_window(r.realiz_target_nrn_spikes[i], XBeforeMin, XBeforeMax, shift = True)
        corr.append(calculate_corr_coeff(s, target_s, 0, XBeforeMax - XBeforeMin ))
    after_learn_s = clip_window(r.after_learning_spikes[0], XBeforeMin, XBeforeMax, shift = True)
    target_s =  clip_window(r.realiz_target_nrn_spikes[0], XBeforeMin, XBeforeMax, shift = True)   
    corr.append(calculate_corr_coeff(after_learn_s, target_s, 0, XBeforeMax - XBeforeMin ))
    return  corr
    
def multi_run_and_save_beforeAfter(results_file, h5file, numRuns, Tsim):
    new_rec = Recordings(dummy_net)
    new_rec.before_learning_spikes = []
    new_rec.after_learning_spikes = []
    new_rec.target_nrn_spikes = []
    new_rec.realiz_target_nrn_spikes = []
    for i in range(numRuns):
        sampleIdx = i * int(200/numRuns)
        exper = BeforeAfterExperiment('beforeAfter', experParams = {"Tsim" : Tsim}, modelParams = {"biofeed": {"sampleIdx":sampleIdx, "h5filename" : results_file}})         
        exper.run("longrun")
        r = constructRecordingsFromH5File(exper.data_filename).biofeed
        os.remove(exper.data_filename)
        if i == 0:
            first_run_rec = r
        new_rec.before_learning_spikes.append(array(r.before_learning_nrn_spikes))
        new_rec.after_learning_spikes.append(array(r.after_learning_nrn_spikes))
        new_rec.realiz_target_nrn_spikes.append(array(r.realiz_target_nrn_spikes))
        new_rec.target_nrn_spikes.append(array(r.target_nrn_spikes))
    new_rec.src_filename = results_file    
    new_rec.saveInOneH5File(h5file)
    return new_rec, first_run_rec
    
def plot_spike_corr(corr, p):
    ep = p.experiment
    plot(corr, 'k-')
    plot(corr, 'kd', markersize = 5)
    xticks( arange(0,len(corr)) , [ "%d" % i for i in arange(0, int(ep.Tsim/60), int(ep.Tsim/60/6) ) ] )
    xlim(0,6.1)
    xlabel('time [min]') 
    ylabel('spike correlation')
    yticks( arange(0.50,0.95,0.1), [ '%.2f' % (x,) for x in arange(0.50,0.95,0.1) ] )
    ylim(0.50,0.91)
    
def plot_weightvec_angle(p, r):
    ep = p.experiment
    p = p.biofeed
    target_w = hstack((ones(p.numStrongTargetSynapses)*p.Wmax, zeros(p.numWeakTargetSynapses)*p.Wmax))
    norm_target_w = target_w / sqrt(inner(target_w , target_w))
    normed_weights = r.weights.copy()
    for i in range(normed_weights.shape[1]):
        normed_weights[:,i] /= sqrt(inner(normed_weights[:,i], normed_weights[:,i]))    
    angle = arccos(dot(norm_target_w, normed_weights))    
    plot(arange(0,len(angle)*ep.DTsim*p.samplingTime, ep.DTsim * p.samplingTime), angle, 'k-')
    xlim(0,ep.Tsim+1)        
    xticks(arange(0, ep.Tsim + 1, ep.Tsim/4.0), [ "%d" % i for i in arange(0, float(ep.Tsim+10)/60, int(ep.Tsim/60.0)/4.0 ) ] )
    xlabel('time [min]')
    ylabel('angular error [rad]')
    yticks(arange(0.0,1.01,0.2), [ "%.1f" % x for x in arange(0.0,1.01,0.2) ] )
    ylim(0.0,0.9)
    
def plot_weight_evolution(p, r, ax):
    ep=p.experiment
    
    box()
    xticks([])
    yticks([])
    
    ax_length = 0.8
    ax_gap = 0.08
    
    leg_width = 0.07
    
    ax_pos = ax.get_position().get_points().flatten()
    ax_pos[2] -= ax_pos[0]
    ax_pos[3] -= ax_pos[1]    
    
    leg_ax_pos = list(ax_pos)
    leg_ax_pos[0]  = leg_ax_pos[0] + leg_ax_pos[2]*(ax_length + ax_gap) 
    leg_ax_pos[2] = leg_width * leg_ax_pos[2]
    leg_ax = axes(leg_ax_pos)
    
    
    arr = arange(1,0,-0.01)
    arr.resize(100,1)    
    imshow(arr, aspect = 0.098)
    xticks([])
     
    yticks( arange(0,101,50), ['0', '0.5', '1'] )
    text(1.17, 0.5, '$w/w_{max}$', horizontalalignment = 'center', verticalalignment = 'center', rotation = 90, transform = ax.transAxes)
    
    leg_ax.yaxis.tick_right()
    
    
    new_ax_pos = list(ax_pos)
    new_ax_pos[2] = new_ax_pos[2] * ax_length   
    im_ax = axes(new_ax_pos)
    
    jet()
    imshow(r.weights, aspect = 1.7, interpolation = 'nearest')
    
    yticks( arange(0,101,50) , [ '%d' % (x,) for x in arange(100,-1,-50) ])
    xticks(arange(0, 201, 50), [ "%d" % i for i in arange(0, float(ep.Tsim+10)/60, int(ep.Tsim/60.0)/4.0 ) ] )
    xlabel('time [min]')        
    ylabel('synapse \#')
    jet()
    
    jet()    
    pass
    
    

def plot_multi_run_wstar(directory):    
    p = re.compile('biofeed.*\.h5')
    entries = os.listdir(directory)
    files = [ x for x in entries if p.match(x) ]    
    files.sort()
    print files
    plot_colors = [ 'b', 'r', 'g', 'm', 'k']
    col_n = 0
    for fname in files:
        h5file = openFile(os.path.join(directory,fname), mode = "r")

        all_p = constructParametersFromH5File(h5file)
        all_r = constructRecordingsFromH5File(h5file)
        
        h5file.close();
        
        p = all_p.biofeed
        ep = all_p.experiment
        
        r = all_r.biofeed
        
        strong_syn_avg = average(r.weights[:p.numStrongTargetSynapses], 0)
        strong_syn_std = std(r.weights[:p.numStrongTargetSynapses], 0)
        weak_syn_avg = average(r.weights[p.numStrongTargetSynapses:], 0)
        weak_syn_std = std(r.weights[p.numStrongTargetSynapses:], 0)

        plot( arange(0,(len(strong_syn_avg)-.5)*ep.DTsim*p.samplingTime, ep.DTsim * p.samplingTime), strong_syn_avg, plot_colors[col_n] + '-' )        
        plot( arange(0,(len(weak_syn_avg) -.5)*ep.DTsim*p.samplingTime, ep.DTsim * p.samplingTime), weak_syn_avg, plot_colors[col_n] + '--' )
        col_n += 1
        
    xlim(0,ep.Tsim+1)
    print "range is ", arange(0, ep.Tsim + 1, ep.Tsim/4.0)    
    xticks(arange(0, ep.Tsim + 1, ep.Tsim/4.0), [ "%d" % i for i in arange(0, float(ep.Tsim+10)/60, int(ep.Tsim/60.0)/4.0 ) ] )
    xlabel('time [min]')
    ylim(0,p.Wmax)
    yticks(arange(0,p.Wmax*1.001, p.Wmax/5.0), [ "%.1f" % i for i in arange(0,1.01,0.2) ])
    
    ylabel('avg. weights $(w/w_{max})$')
    
    
def plot_weight_change_fig(r, p):    
    last_weights = []
    initial_weights = []
    for w in r.weights:
        last_weights.append(mean(w[-10:-1]))
        initial_weights.append(w[0])
    
        failed_strong = len(find(last_weights[0:50] < p.Wmax/2))
        failed_weak = len(find(last_weights[50:100] > p.Wmax/2))
    
    plot(arange(100), hstack((p.Wmax * ones(50), 0 * zeros(50))), 'k:')
    plot(arange(100), p.Wmax/2 * ones(100), 'k--')
    plot(arange(100), initial_weights, 'k x', markersize = 3.4)
    plot(arange(100), last_weights, 'k o',markersize = 3.4)
    vlines(arange(100), initial_weights, last_weights)
    xlabel('synapse \#')
    ylim(0, p.Wmax)
    yticks(arange(0,p.Wmax*1.001, p.Wmax/5.0), [ "%.1f" % i for i in arange(0,1.01,0.2) ])
    ylabel('syn. weight $(w/w_{max})$')
    xticks( arange(0,101,50), [ '%d' % (x,) for x in arange(0,101,50)] )    
    
   
if __name__ == "__main__":
    mode = "just_corr"
    mode = 'complete'
    
    XBeforeMin, XBeforeMax = (5,35)   
    if mode == 'complete':
        if len(sys.argv) > 1:
            sim_dir = sys.argv[1]
        else:
            sim_dir = last_created_dir('biofeed.*')
        sim_file = os.path.join(sim_dir, last_file('biofeed.*er18.*', sim_dir))
        
        print " loading simulation filename : ", sim_file
        
        output_name = 'noname'
        if len(sys.argv) > 2:
            output_name = sys.argv[2]
        spikes_h5file = open_experiment_h5file("spikes_corr", output_name)
        
        new_rec, first_run_rec = multi_run_and_save_beforeAfter(sim_file, spikes_h5file, 6, XBeforeMax)        
    else:
        if len(sys.argv) > 1:
            sim_dir = sys.argv[1]
        else:
            sim_dir = last_created_dir('biofeed.*')
        sim_file = os.path.join(sim_dir, last_file('biofeed.*', sim_dir))            
        if len(sys.argv) > 1:
            spikes_h5file = sys.argv[1]
        else:
            spikes_h5file = last_file('spikes_corr.*\.h5$')        
        print " loading h5 filename : ", spikes_h5file
        sim_file = constructRecordingsFromH5File(spikes_h5file).src_filename

        print "loading sim h5 filename : " , sim_file
    
    
        
    sim_r = constructRecordingsFromH5File(sim_file).biofeed
    sim_p = constructParametersFromH5File(sim_file)
    
    pp = sim_p.biofeed
    
    A_plus_kappa_theory = pp.DAStdpRate * pp.stdpApos * pp.KappaApos / (0.01 * pp.Wmax)
    A_minus_kappa_theory = pp.DAStdpRate * pp.stdpApos * pp.KappaAneg / (0.01 * pp.Wmax)
    
    print "A_plus_kappa_theory = ", A_plus_kappa_theory
    print "A_minus_kappa_theory = ", A_minus_kappa_theory
    print " ratio of A_plus and A_minus kappas = ", A_plus_kappa_theory / A_minus_kappa_theory
    
        
    f = figure(1,figsize=(8,9), facecolor = 'w')
    
    f.subplots_adjust(top= 0.93, left = 0.11, bottom = 0.06, right = 0.93, hspace = 0.55, wspace = 0.55)
    clf()
    
    print sim_p
    
    ax = subplot(3, 2, 1, projection = 'frameaxes')
    text(-0.25, 1.13, 'A', fontsize = 'x-large', transform = ax.transAxes )
    plot_multi_run_wstar(sim_dir)
    
    
    
    ax = subplot(3, 2, 2, projection = 'frameaxes')
    text(-0.25, 1.13,  'B', fontsize = 'x-large', transform = ax.transAxes )
    start = 19.1 
    end = 21.6
    spike_idx = 0
    display_spike_trains(spikes_h5file, spike_idx, start, end, ax)

    
    ax = subplot(3, 2, 3, projection = 'frameaxes')
    text(-0.25, 1.13,  'C', fontsize = 'x-large', transform = ax.transAxes )
    corr = generate_spike_corr(spikes_h5file, XBeforeMin, XBeforeMax)
    plot_spike_corr(corr, sim_p)
    
    
    
    ax = subplot(3, 2, 4, projection = 'frameaxes')
    text(-0.25, 1.13,  'D', fontsize = 'x-large', transform = ax.transAxes )
    plot_weightvec_angle(sim_p, sim_r)
    
    
    
    ax = subplot(3, 2, 5)
    text(-0.25, 1.12,  'E', fontsize = 'x-large', transform = ax.transAxes )
    plot_weight_change_fig(sim_r, sim_p.biofeed)
    
    
    ax = subplot(3, 2, 6)
    text(-0.25, 1.12,  'F', fontsize = 'x-large', transform = ax.transAxes )    
    plot_weight_evolution(sim_p, sim_r, ax)    


    savefig("wstar_static_current.eps")

Loading data, please wait...