Motor system model with reinforcement learning drives virtual arm (Dura-Bernal et al 2017)

 Download zip file   Auto-launch 
Help downloading and running models
Accession:194897
"We implemented a model of the motor system with the following components: dorsal premotor cortex (PMd), primary motor cortex (M1), spinal cord and musculoskeletal arm (Figure 1). PMd modulated M1 to select the target to reach, M1 excited the descending spinal cord neurons that drove the arm muscles, and received arm proprioceptive feedback (information about the arm position) via the ascending spinal cord neurons. The large-scale model of M1 consisted of 6,208 spiking Izhikevich model neurons [37] of four types: regular-firing and bursting pyramidal neurons, and fast-spiking and low-threshold-spiking interneurons. These were distributed across cortical layers 2/3, 5A, 5B and 6, with cell properties, proportions, locations, connectivity, weights and delays drawn primarily from mammalian experimental data [38], [39], and described in detail in previous work [29]. The network included 486,491 connections, with synapses modeling properties of four different receptors ..."
Reference:
1 . Dura-Bernal S, Neymotin SA, Kerr CC, Sivagnanam S, Majumdar A, Francis JT, Lytton WW (2017) Evolutionary algorithm optimization of biological learning parameters in a biomimetic neuroprosthesis. IBM Journal of Research and Development (Computational Neuroscience special issue) 61(2/3):6:1-6:14
Model Information (Click on a link to find other models with that property)
Model Type: Realistic Network;
Brain Region(s)/Organism:
Cell Type(s): Abstract Izhikevich neuron;
Channel(s):
Gap Junctions:
Receptor(s): GabaA; GabaB; NMDA; AMPA;
Gene(s):
Transmitter(s): Glutamate; Gaba;
Simulation Environment: NEURON; Python;
Model Concept(s): Learning; Reinforcement Learning; Reward-modulated STDP; STDP; Motor control; Sensory processing;
Implementer(s): Dura-Bernal, Salvador [salvadordura at gmail.com]; Kerr, Cliff [cliffk at neurosim.downstate.edu];
Search NeuronDB for information about:  GabaA; GabaB; AMPA; NMDA; Gaba; Glutamate;
"""
analysis.py

Functions to plot and analyse results

Version: 2014July9
"""
from pylab import pcolor, nonzero, mean, histogram, arange, bar, vstack,scatter, figure, hold, isscalar, gca, unique, subplot, axes, shape, imshow, colorbar, plot, xlabel, ylabel, title, xlim, ylim, clim, show, zeros, legend, savefig, cm, specgram, get_cmap, psd
from scipy.io import loadmat
from scipy import loadtxt, size, array, linspace, ceil
from datetime import datetime
from time import time
import csv
import pickle
import shared as s

###############################################################################
### Simulation-related graph plotting functions
###############################################################################

## Create colormap
def bicolormap(gap=0.1,mingreen=0.2,redbluemix=0.5,epsilon=0.01):
   from matplotlib.colors import LinearSegmentedColormap as makecolormap
   
   mng=mingreen; # Minimum amount of green to add into the colors
   mix=redbluemix; # How much red to mix with the blue an vice versa
   eps=epsilon; # How much of the center of the colormap to make gray
   omg=1-gap # omg = one minus gap
   
   cdict = {'red': ((0.00000, 0.0, 0.0),
                    (0.5-eps, mix, omg),
                    (0.50000, omg, omg),
                    (0.5+eps, omg, 1.0),
                    (1.00000, 1.0, 1.0)),

         'green':  ((0.00000, mng, mng),
                    (0.5-eps, omg, omg),
                    (0.50000, omg, omg),
                    (0.5+eps, omg, omg),
                    (1.00000, mng, mng)),

         'blue':   ((0.00000, 1.0, 1.0),
                    (0.5-eps, 1.0, omg),
                    (0.50000, omg, omg),
                    (0.5+eps, omg, mix),
                    (1.00000, 0.0, 0.0))}
   cmap = makecolormap('bicolormap',cdict,256)

   return cmap

## Raster plot
def plotraster(filename=None): # allspiketimes, allspikecells, EorI, ncells, connspercell, backgroundweight, firingrate, duration): # Define a function for plotting a raster
    plotstart = time() # See how long it takes to plot
    EorIcolors = array([(1,0.4,0) , (0,0.2,0.8)]) # Define excitatory and inhibitory colors -- orange and turquoise
    cellcolors = EorIcolors[array(s.EorI)[array(s.allspikecells,dtype=int)]] # Set each cell to be either orange or turquoise
    figure() # Open a new figure
    scatter(s.allspiketimes,s.allspikecells,10,cellcolors,linewidths=0.5,marker='|') # Create raster  
    xlabel('Time (ms)')
    ylabel('Cell ID')
    title('cells=%i syns/cell=%0.1f noise=%0.1f rate=%0.1f Hz' % (s.ncells,s.connspercell,s.backgroundweight[0],s.firingrate),fontsize=12)
    xlim(0,s.duration)
    ylim(0,s.ncells)
    plottime = time()-plotstart # See how long it took
    print('  Done; time = %0.1f s' % plottime)
    if filename: savefig(filename)
    #show()

## Perievent time histogram
def plotPETH():
    binsize = 20 # bin size in ms
    binedges = arange(0, s.duration+binsize, binsize)
    peth = []
    for ipop in unique(s.cellpops):
        hist,binedges = histogram(s.allspiketimes[array([s.cellpops[int(i)] for i in s.allspikecells]) == ipop], binedges)
        peth.append(hist)
    figure()
    plot(array(peth).T)
    title('PETH (%d ms bins)'%binsize)
    xlabel('Time (ms)')
    ylabel('Spikes/bin')
    ylim(0,s.scale*binsize*2)
    h=axes()
    h.set_xticks(range(0,len(binedges),len(binedges)/10 ))
    h.set_xticklabels(binedges[0:-1:len(binedges)/10].astype(int))
    legend(s.popnames)

## Plot power spectra density
def plotpsd():
    colorspsd=array([[0.42,0.67,0.84],[0.42,0.83,0.59],[0.90,0.76,0.00],[0.90,0.32,0.00],[0.34,0.67,0.67],[0.42,0.82,0.83],[0.90,0.59,0.00],[0.33,0.67,0.47],[1.00,0.85,0.00],[0.71,0.82,0.41],[0.57,0.67,0.33],[1.00,0.38,0.60],[0.5,0.2,0.0],[0.0,0.2,0.5]]) 

    lfpv=[[] for c in range(len(s.lfppops))]    
    # Get last modified .mat file if no input and plot
    for c in range(len(s.lfppops)):
        lfpv[c] = s.lfps[:,c]    
    lfptot = sum(lfpv)
        
    # plot pops separately
    plotPops = 0
    if plotPops:    
        figure() # Open a new figure
        for p in range(len(s.lfppops)):
            psd(lfpv[p],Fs=200, linewidth= 2,color=colorspsd[p])
            xlabel('Frequency (Hz)')
            ylabel('Power')
            h=axes()
            h.set_yticklabels([])
        legend(['L2/3','L5A', 'L5B', 'L6'])

    # plot overall psd
    figure() # Open a new figure
    psd(lfptot,Fs=200, linewidth= 2)
    xlabel('Frequency (Hz)')
    ylabel('Power')
    h=axes()
    h.set_yticklabels([])

    show()


## Plot connectivityFor diagnostic purposes . Based on conndiagram.py.
def plotconn():
    # Create plot
    figh = figure(figsize=(8,6))
    figh.subplots_adjust(left=0.02) # Less space on left
    figh.subplots_adjust(right=0.98) # Less space on right
    figh.subplots_adjust(top=0.96) # Less space on bottom
    figh.subplots_adjust(bottom=0.02) # Less space on bottom
    figh.subplots_adjust(wspace=0) # More space between
    figh.subplots_adjust(hspace=0) # More space between
    h = axes()
    totalconns = zeros(shape(s.connprobs))
    for c1 in range(size(s.connprobs,0)):
        for c2 in range(size(s.connprobs,1)):
            for w in range(s.nreceptors):
                totalconns[c1,c2] += s.connprobs[c1,c2]*s.connweights[c1,c2,w]*(-1 if w>=2 else 1)*s.scaleconnweight[s.popEorI[c1],s.popEorI[c2]]
    imshow(totalconns,interpolation='nearest',cmap=bicolormap(gap=0))


    # Plot grid lines
    hold(True)
    for pop in range(s.npops):
        plot(array([0,s.npops])-0.5,array([pop,pop])-0.5,'-',c=(0.7,0.7,0.7))
        plot(array([pop,pop])-0.5,array([0,s.npops])-0.5,'-',c=(0.7,0.7,0.7))

    # Make pretty
    h.set_xticks(range(s.npops))
    h.set_yticks(range(s.npops))
    h.set_xticklabels(s.popnames)
    h.set_yticklabels(s.popnames)
    h.xaxis.set_ticks_position('top')
    xlim(-0.5,s.npops-0.5)
    ylim(s.npops-0.5,-0.5)
    clim(-abs(totalconns).max(),abs(totalconns).max())
    colorbar()
    #show()


## Plot weight changes
def plotweightchanges(filename=None):
    if s.usestdp:
    	# create plot
    	figh = figure(figsize=(1.2*8,1.2*6))
    	figh.subplots_adjust(left=0.02) # Less space on left
    	figh.subplots_adjust(right=0.98) # Less space on right
    	figh.subplots_adjust(top=0.96) # Less space on bottom
    	figh.subplots_adjust(bottom=0.02) # Less space on bottom
    	figh.subplots_adjust(wspace=0) # More space between
    	figh.subplots_adjust(hspace=0) # More space between
    	h = axes()

    	# create data matrix
        wcs = [x[-1][-1] for x in s.allweightchanges] # absolute final weight
    	wcs = [x[-1][-1]-x[0][-1] for x in s.allweightchanges] # absolute weight change
    	pre,post,recep = zip(*[(x[0],x[1],x[2]) for x in s.allstdpconndata])
    	ncells = int(max(max(pre),max(post))+1)
    	wcmat = zeros([ncells, ncells])

    	for iwc,ipre,ipost,irecep in zip(wcs,pre,post,recep):
            wcmat[int(ipre),int(ipost)] = iwc *(-1 if irecep>=2 else 1)

    	# plot
    	imshow(wcmat,interpolation='nearest',cmap=bicolormap(gap=0,mingreen=0.2,redbluemix=0.1,epsilon=0.01))
    	xlabel('post-synaptic cell id')
    	ylabel('pre-synaptic cell id')
    	h.set_xticks(s.popGidStart)
    	h.set_yticks(s.popGidStart)
    	h.set_xticklabels(s.popnames)
    	h.set_yticklabels(s.popnames)
    	h.xaxis.set_ticks_position('top')
    	xlim(-0.5,ncells-0.5)
    	ylim(ncells-0.5,-0.5)
    	clim(-abs(wcmat).max(),abs(wcmat).max())
    	colorbar()

        if filename: savefig(filename)
    	#show()

        changeOverTime = 0
        if changeOverTime:
        # change over time
            figure()
            relative = 1 # relative or absolute w changes

            wc = array([wi[-1] for w in s.allweightchanges for wi in w if len(w)>1])

            maxSteps = max([len(w) for w in s.allweightchanges])
            wc = zeros((len(s.allweightchanges), maxSteps))
            for iconn,conn in enumerate(s.allweightchanges):
                for it in range(maxSteps):
                    if relative:
                        wc[iconn, it] = conn[it][-1]-conn[0][-1] if len(conn)>it else wc[iconn, it-1]
                    else:
                        wc[iconn, it] = conn[it][-1] if len(conn)>it else wc[iconn, it-1]

            vmax = max([max(row) for row in wc])
            vmin = min([min(row) for row in wc])
            pcolor(wc, cmap='hot_r', vmin=vmin, vmax=vmax)
            xlim((0,maxSteps))
            ylim((0,len(wc)))
            xlabel('Time (weight updates)')
            ylabel('Synaptic connection id')
            colorbar()
            #show()
            


## plot motor subpopulations connectivity changes
def plotmotorpopchanges():
    showInh = True
    if s.usestdp:
        Ewpre =  []
        Ewpost = []
        EwpreSum = []
        EwpostSum = []
        if showInh: 
            Iwpre =  []
            Iwpost = []
            IwpreSum = []
            IwpostSum = [] 
        for imus in range(len(s.motorCmdCellRange)):
            Ewpre.append([x[0][-1] for (icon,x) in enumerate(s.allweightchanges) if s.allstdpconndata[icon][1] in s.motorCmdCellRange[imus]])
            Ewpost.append([x[-1][-1] for (icon,x) in enumerate(s.allweightchanges) if s.allstdpconndata[icon][1] in s.motorCmdCellRange[imus]])
            EwpreSum.append(sum(Ewpre[imus]))
            EwpostSum.append(sum(Ewpost[imus]))
       

            if showInh:
                motorInhCellRange = s.motorCmdCellRange[imus] - s.popGidStart[s.EDSC] + s.popGidStart[s.IDSC]
                Iwpre.append([x[0][-1] for (icon,x) in enumerate(s.allweightchanges) if s.allstdpconndata[icon][1] in motorInhCellRange])
                Iwpost.append([x[-1][-1] for (icon,x) in enumerate(s.allweightchanges) if s.allstdpconndata[icon][1] in motorInhCellRange])
                IwpreSum.append(sum(Iwpre[imus]))
                IwpostSum.append(sum(Iwpost[imus]))

        print '\ninitial E weights: ',EwpreSum
        print 'final E weigths: ',EwpostSum
        print 'absolute E difference: ',array(EwpostSum) - array(EwpreSum)
        print 'relative E difference: ',(array(EwpostSum) - array(EwpreSum)) / array(EwpreSum)

        if showInh:
            print '\ninitial I weights: ',IwpreSum
            print 'final I weigths: ',IwpostSum
            print 'absolute I difference: ',array(IwpostSum) - array(IwpreSum)
            print 'relative I difference: ',(array(IwpostSum) - array(IwpreSum)) / array(IwpreSum)
            

        # plot
        figh = figure(figsize=(1.2*8,1.2*6))
        ax1 = figh.add_subplot(2,1,1)
        ind = arange(len(EwpreSum))  # the x locations for the groups
        width = 0.35       # the width of the bars
        ax1.bar(ind, EwpreSum, width, color='b')
        ax1.bar(ind+width, EwpostSum, width, color='r')
        ax1.set_xticks(ind+width)
        ax1.set_xticklabels( ('shext','shflex','elext','elflex') )
        #legend(['pre','post'])
        ax1.grid()

        ax2 = figh.add_subplot(2,1,2)
        width = 0.70       # the width of the bars
        bar(ind,(array(EwpostSum) - array(EwpreSum)) / array(EwpreSum), width, color='b')
        ax2.set_xticks(ind+width/2)
        ax2.set_xticklabels( ('shext','shflex','elext','elflex') )
        ax2.grid()

        if showInh:
            figh = figure(figsize=(1.2*8,1.2*6))
            ax1 = figh.add_subplot(2,1,1)
            ind = arange(len(IwpreSum))  # the x locations for the groups
            width = 0.35       # the width of the bars
            ax1.bar(ind, IwpreSum, width, color='b')
            ax1.bar(ind+width, IwpostSum, width, color='r')
            ax1.set_xticks(ind+width)
            ax1.set_xticklabels( ('shext','shflex','elext','elflex') )
            legend(['pre','post'])
            ax1.grid()

            ax2 = figh.add_subplot(2,1,2)
            width = 0.70       # the width of the bars
            bar(ind,(array(IwpostSum) - array(IwpreSum)) / array(IwpreSum), width, color='b')
            ax2.set_xticks(ind+width/2)
            ax2.set_xticklabels( ('shext','shflex','elext','elflex') )
            ax2.grid()
        

## plot 3d architecture:
def plot3darch():
    # create plot
    figh = figure(figsize=(1.2*8,1.2*6))
    # figh.subplots_adjust(left=0.02) # Less space on left
    # figh.subplots_adjust(right=0.98) # Less space on right
    # figh.subplots_adjust(top=0.98) # Less space on bottom
    # figh.subplots_adjust(bottom=0.02) # Less space on bottom
    ax = figh.add_subplot(1,1,1, projection='3d')
    h = axes()

    #print len(s.xlocs),len(s.ylocs),len(s.zlocs)
    xlocs =[1,2,3]
    ylocs=[3,2,1]
    zlocs=[0.1,0.5,1.2]
    ax.scatter(xlocs,ylocs, zlocs,  s=10, c=zlocs, edgecolors='none',cmap = 'jet_r' , linewidths=0.0, alpha=1, marker='o')
    azim = 40  
    elev = 60
    ax.view_init(elev, azim) 
    #xlim(min(s.xlocs),max(s.xlocs))
    #ylim(min(s.ylocs),max(s.ylocs))
    #ax.set_zlim(min(s.zlocs),max(s.zlocs))
    xlabel('lateral distance (mm)')
    ylabel('lateral distance (mm)')
    ylabel('cortical depth (mm)')


###############################################################################
### Evolutionary-algorithm analysis/plotting functions
###############################################################################

#%% plot filled error bars
def errorfill(x, y, yerr, lw=1, elinewidth=1, color=None, alpha_fill=0.2, ax=None):
    ax = ax if ax is not None else gca()
    if color is None:
        color = ax._get_lines.color_cycle.next()
    if isscalar(yerr) or len(yerr) == len(y):
        ymin = y - yerr
        ymax = y + yerr
    elif len(yerr) == 2:
        ymin, ymax = yerr
    ax.plot(x, y, color=color, lw=lw)
    ax.fill_between(x, ymax, ymin, color=color, lw= elinewidth, alpha=alpha_fill)

#%% function to obtain unique list of lists
def uniqueList(seq): 
    seen = {}
    result = []
    indices = []
    for index,item in enumerate(seq):
        marker = tuple(item)
        if marker in seen: continue
        seen[marker] = 1
        result.append(item)
        indices.append(index)
    return result,indices
                
#%% function to read data               
def loadData(folder, islands, dataFrom):
    #%% Load data from files
    if islands > 1:
        ind_gens_isl=[] # individuals data for islands
        ind_cands_isl=[]
        ind_fits_isl=[]
        ind_cs_isl=[]
            
        stat_gens_isl=[] # statistics.csv for islands
        stat_worstfits_isl=[]
        stat_bestfits_isl=[]
        stat_avgfits_isl=[]
        stat_stdfits_isl=[]
        
        fits_sort_isl=[] #sorted data
        gens_sort_isl=[] 
        cands_sort_isl=[]
        params_sort_isl=[]
    
    for island in range(islands):
        ind_gens=[] # individuals data
        ind_cands=[]
        ind_fits=[]
        ind_cs=[]
        
        eval_gens=[] # error files for each evaluation
        eval_cands=[]
        eval_fits=[]
        eval_params=[]
        
        stat_gens=[] # statistics.csv 
        stat_worstfits=[]
        stat_bestfits=[]
        stat_avgfits=[]
        stat_stdfits=[]
        
        if islands > 0:
            folderFinal = folder+"_island_"+str(island)
        else: 
            folderFinal = folder
            
        with open('../data/%s/individuals.csv'% (folderFinal)) as f: # read individuals.csv
            reader=csv.reader(f)
            for row in reader:
                ind_gens.append(int(row[0]))
                ind_cands.append(int(row[1]))
                ind_fits.append(float(row[2]))
                cs = [float(row[i].replace("[","").replace("]","")) for i in range(3,len(row))]
                ind_cs.append(cs)
        
        with open('../data/%s/statistics.csv'% (folderFinal)) as f: # read statistics.csv
            reader=csv.reader(f)
            for row in reader:
                stat_gens.append(float(row[0]))
                stat_worstfits.append(float(row[2]))
                stat_bestfits.append(float(row[3]))
                stat_avgfits.append(float(row[4]))
                stat_stdfits.append(float(row[6]))
        
        # unique generation number (sometimes repeated due to rerunning in hpc)
        stat_gens, stat_gens_indices = unique(stat_gens,1) # unique individuals
        stat_worstfits, stat_bestfits, stat_avgfits, stat_stdfits = zip(*[[stat_worstfits[i], stat_bestfits[i], stat_avgfits[i], stat_stdfits[i]] for i in stat_gens_indices])
        
        if dataFrom == 'fitness':       
            for igen in range(max(ind_gens)): # read error files from evaluations
                for ican in range(max(ind_cands)):
                    try:
                        f=open('../data/%s/gen_%d_cand_%d_error'%(folderFinal, igen,ican)); 
                        eval_fits.append(pickle.load(f))
                        f=open('../data/%s/gen_%d_cand_%d_params'%(folderFinal, igen,ican)); 
                        eval_params.append(pickle.load(f))
                        eval_gens.append(igen)
                        eval_cands.append(ican)
                    except:
                             pass
                        #eval_fits.append(0.15)
                        #eval_params.append([])
        
        # find x corresponding to smallest error from function evaluations  
        if dataFrom == 'fitness':
            #fits_sort, fits_sort_indices, fits_sort_origind = unique(eval_fits, True, True)
            fits_sort_indices = sorted(range(len(eval_fits)), key=lambda k: eval_fits[k])
            fits_sort = [eval_fits[i] for i in fits_sort_indices]
            gens_sort = [eval_gens[i] for i in fits_sort_indices]
            cands_sort = [eval_cands[i] for i in fits_sort_indices]
            params_sort = [eval_params[i] for i in fits_sort_indices]
        # find x corresponding to smallest error from individuals file
        elif dataFrom == 'individuals':
            params_unique, unique_indices = uniqueList(ind_cs) # unique individuals
            fits_unique = [ind_fits[i] for i in unique_indices]
            gens_unique = [ind_gens[i] for i in unique_indices]
            cands_unique = [ind_cands[i] for i in unique_indices]
            
            sort_indices = sorted(range(len(fits_unique)), key=lambda k: fits_unique[k]) # sort fits
            fits_sort = [fits_unique[i] for i in sort_indices]
            gens_sort = [gens_unique[i] for i in sort_indices]
            cands_sort = [cands_unique[i] for i in sort_indices]
            params_sort = [params_unique[i] for i in sort_indices]
        
        # if multiple islands, save data for each
        if islands > 1:
            ind_gens_isl.append(ind_gens) # individuals data for islands
            ind_cands_isl.append(ind_cands)
            ind_fits_isl.append(ind_fits)
            ind_cs_isl.append(ind_cs)
                
            stat_gens_isl.append(stat_gens) # statistics.csv for islands
            stat_worstfits_isl.append(stat_worstfits)
            stat_bestfits_isl.append(stat_bestfits)
            stat_avgfits_isl.append(stat_avgfits)
            stat_stdfits_isl.append(stat_stdfits)
            
            fits_sort_isl.append(fits_sort) #sorted data
            gens_sort_isl.append(gens_sort) 
            cands_sort_isl.append(cands_sort)
            params_sort_isl.append(params_sort)
            
    if islands > 1:
        return ind_gens_isl, ind_cands_isl, ind_fits_isl, ind_cs_isl, stat_gens_isl, \
            stat_worstfits_isl, stat_bestfits_isl, stat_avgfits_isl, stat_stdfits_isl, \
            fits_sort_isl, gens_sort_isl, cands_sort_isl, params_sort_isl


Loading data, please wait...