Cortical model with reinforcement learning drives realistic virtual arm (Dura-Bernal et al 2015)

 Download zip file   Auto-launch 
Help downloading and running models
Accession:183014
We developed a 3-layer sensorimotor cortical network of consisting of 704 spiking model-neurons, including excitatory, fast-spiking and low-threshold spiking interneurons. Neurons were interconnected with AMPA/NMDA, and GABAA synapses. We trained our model using spike-timing-dependent reinforcement learning to control a virtual musculoskeletal human arm, with realistic anatomical and biomechanical properties, to reach a target. Virtual arm position was used to simultaneously control a robot arm via a network interface.
References:
1 . Dura-Bernal S, Zhou X, Neymotin SA, Przekwas A, Francis JT, Lytton WW (2015) Cortical Spiking Network Interfaced with Virtual Musculoskeletal Arm and Robotic Arm. Front Neurorobot 9:13 [PubMed]
2 . Dura-Bernal S, Li K, Neymotin SA, Francis JT, Principe JC, Lytton WW (2016) Restoring behavior via inverse neurocontroller in a lesioned cortical spiking model driving a virtual arm. Front. Neurosci. Neuroprosthetics 10:28
Model Information (Click on a link to find other models with that property)
Model Type: Realistic Network;
Brain Region(s)/Organism:
Cell Type(s): Neocortex M1 pyramidal pyramidal tract L5B cell; Neocortex M1 pyramidal intratelencephalic L2-5 cell; Neocortex M1 interneuron basket PV cell; Neocortex fast spiking (FS) interneuron; Neostriatum fast spiking interneuron; Neocortex spiking regular (RS) neuron; Neocortex spiking low threshold (LTS) neuron;
Channel(s):
Gap Junctions:
Receptor(s): GabaA; AMPA; NMDA;
Gene(s):
Transmitter(s): Gaba; Glutamate;
Simulation Environment: NEURON; Python (web link to model);
Model Concept(s): Synaptic Plasticity; Learning; Reinforcement Learning; STDP; Reward-modulated STDP; Sensory processing; Motor control;
Implementer(s): Neymotin, Sam [samn at neurosim.downstate.edu]; Dura, Salvador [ salvadordura at gmail.com];
Search NeuronDB for information about:  Neocortex M1 pyramidal intratelencephalic L2-5 cell; Neocortex M1 pyramidal pyramidal tract L5B cell; Neocortex M1 interneuron basket PV cell; GabaA; AMPA; NMDA; Gaba; Glutamate;
/
arm2dms_modeldb
mod
msarm
stimdata
README.html
analyse_funcs.py
analysis.py
armGraphs.py
arminterface_pipe.py
basestdp.hoc
bicolormap.py
boxes.hoc *
bpf.h *
col.hoc
colors.hoc *
declist.hoc *
decmat.hoc *
decnqs.hoc *
decvec.hoc *
default.hoc *
drline.hoc *
filtutils.hoc *
grvec.hoc
hinton.hoc *
hocinterface.py
infot.hoc *
init.hoc
intfsw.hoc *
labels.hoc
load.hoc
load.py
local.hoc *
main.hoc
main_demo.hoc
main_neurostim.hoc
misc.h *
misc.py *
msarm.hoc
network.hoc
neuroplot.py *
neurostim.hoc
nload.hoc
nqs.hoc *
nqsnet.hoc *
nrnoc.hoc
params.hoc
perturb.hoc
python.hoc
pywrap.hoc *
run.hoc
runbatch_neurostim.py
runsim_neurostim
samutils.hoc *
saveoutput.hoc
saveoutput2.hoc
setup.hoc *
sim.hoc
sim.py
sim_demo.py
simctrl.hoc *
stats.hoc *
stim.hoc
syncode.hoc *
units.hoc *
vector.py
xgetargs.hoc *
                            
// SAVEOUTPUT
// This program saves the output of an intfcol-derived simulation.
// Version: georgec 10/4/12

print "Loading saveoutput.hoc..."

//strdef filestem 
//filestem = "arch3d/3darch"
strdef runtype
if (type == 0) {
	runtype = "train"
} else {
	runtype = "test"
}

// What to save -- 1=save, 0=don't save; default is save all

savelfp=0
savespikes=1
savelocations=0
saveconnectivity=1
savenqss=1
savedynamicweights=1
savegrvec=0
savenqa=1
savenqaupd=0

// Create string objects
strdef outfn1, outfn2, outfn3, outfn4, outfn1b, outfn2b, outfn3b, outfn4b, outfn5, outfn6, outfn6b,  outfn7,  outfn7b, outfn8
sprint(outfn1,"%s_%s-lfp.txt",filestem,runtype) // Initialize output LFP filename
sprint(outfn2,"%s_%s-spk.txt",filestem,runtype) // Initialize output spikes filename
sprint(outfn3,"%s_%s-loc.txt",filestem,runtype) // Store the cell locations
sprint(outfn4,"%s_%s-con.txt",filestem,runtype) // Store the connectivity information
sprint(outfn1b,"%s_%s-lfp.nqs",filestem,runtype) // Initialize output LFP filename (NQS version)
sprint(outfn2b,"%s_%s-spk.nqs",filestem,runtype) // Initialize output spikes filename (NQS version)
sprint(outfn3b,"%s_%s-loc.nqs",filestem,runtype) // Store the cell locations (NQS version)
sprint(outfn4b,"%s_%s-con.nqs",filestem,runtype) // Store the connectivity information (NQS version)
sprint(outfn5,"%s_%s-prlist.gvc",filestem,runtype) // Store the grvec information
sprint(outfn6,"%s_%s-nqa.txt",filestem,runtype) // Store the nqa info
sprint(outfn6b,"%s_%s-nqa.nqs",filestem,runtype) // Store the nqa info (NQS version)
sprint(outfn7,"%s_%s-nqaupd.txt",filestem,runtype) // Store the nqaupd info
sprint(outfn7b,"%s_%s-nqaupd.nqs",filestem,runtype) // Store the nqaupd info (NQS version)
sprint(outfn8,"%s_%s-syn.nqs",filestem,runtype) // Store the synaptic changes (NQS version)

// Declare objects -- can't be done inside if statements stupidly
{objref fobj, tempvec, tempstr, storelfp} // For LFPs
{objref fobj2, storespikes, tmpt, tmpid, tmptype, tmpcol} // For spikes
{objref fobj3, celllist, celllocations} // For locations
{objref fobj4, conpreid, conpostid, condelay, condistance, conweight1, conweight2, connectivity} // For connectivity

if (savelfp) {
	// For saving LFP results
	print "Saving LFP..."
	oldhz=nqLFP.cob.v.size/tstop*1000 // Original sampling rate; *1000 because tstop is in ms
	newhz=1000 // The new frequency to sample at, in Hz
	ratio=oldhz/newhz // Calculate the ratio betwen the old and new sampling rates
	npts=tstop/1000*newhz // Number of points in the resampled time seris
	nlayers=nqLFP.m // Number of layers (usually 5 -- 2/3, 4, 5, 6, all)
	storelfp = new Matrix(npts, nlayers*numcols+1) // Combine layers/columns into one dimension, and make first column time
	for k=0,npts-1 storelfp.x[k][0]=k/newhz*1000 // Save time data
	count=1 // Set column of storelfp to one (zero is time)
	for i=0,numcols-1 { // Loop over all columns
	  for j=0,nlayers-1 { // Loop over all layers
		tempstr=nqLFP[i].s[j] // Get this particular NQS column header
		tempvec=nqLFP[i].getcol(tempstr.s) // Save this particular NQS column to a vector
		for k=0,npts-1 { // Loop over points in the resampled time series
		  // Calculate this particular data point by downsampling and store
		  storelfp.x[k][count]=tempvec.mean(k*ratio,(k+1)*ratio-1) 
		}
		{fprint("  Column/layer %g of %g...\n",count,nlayers)} // Assume numcols is 1 since dies otherwise
		count+=1 // Increase column of storelfp
	  }
	}
	// For outputting LFPs to a file
	print "  Saving to file..."
	fobj = new File(outfn1)
	fobj.wopen()
	storelfp.fprint(0,fobj,"%10.1f") // It's usually in the thousands so one d.p. should do
	fobj.close()
	print "  ...done..."

   // Save the NQS file if desired
   if (savenqss) {
      // Add a time-stamps column to the NQS table.
      tempvec.resize(nqLFP.v[0].size)
      tempvec.indgen(vdt)
      nqLFP[0].resize("ts",tempvec)
      nqLFP[0].sv(outfn1b)  // assume 1 column for the time being (0)
   }
}

if (savenqa) {
	nqa.sv(outfn6b)
}

if (savenqaupd) {
	nqaupd.sv(outfn7b)
}


if (savespikes) {
	// For saving spike results
	print "Saving spikes..."
	skipsnq=0 // flag to create NQS with spike times, one per column
	initAllMyNQs() 
	snq[0].sv(outfn2b)  // assume 1 column for the time being (0)
	//skipsnq=0 // flag to create NQS with spike times, one per column
	//initAllMyNQs() // setup of NQS objects with spike/other information
	/*
	totalnumberofspikes=0 // Calculate the total number of spikes generated across all columns
	for i=0,numcols-1 totalnumberofspikes+=snq[i].cob.v.size 
	storespikes = new Matrix(totalnumberofspikes, 3) // Four columns: spike time, cell ID, cell type, and spike time
	count=-1 // Initialize row count
	for i=0,numcols-1 { // Loop over columns
		tmpt=snq[i].getcol("t")
		tmpid=snq[i].getcol("id")
		tmptype=snq[i].getcol("type")
		for j=0,snq[i].cob.v.size-1 { // Loop over spikes
			if (1) { //(mod(tmpid.x[j],scale)==0) { // Only collect spikes from one out of every "scale" cells
				count+=1
				if (mod(count,10000)==0) {fprint("  %3.0f%% complete...\n",count*100/snq[i].cob.v.size)} // Print progress
				storespikes.x[count][0]=tmpt.x[j] // Store spike times
				storespikes.x[count][1]=tmpid.x[j] // Store cell number
				storespikes.x[count][2]=tmptype.x[j] // Store cell type
			}
		}
	}
	storespikes.resize(count,3) // Get rid of extra zeros
	// For outputting spikes to a file
	print "  Saving to file..."
	fobj2 = new File(outfn2)
	fobj2.wopen()
	storespikes.fprint(0,fobj2,"%6.0f") // All quantities are integers, so this should be fine
	fobj2.close()
	print "  ...done..."
	*/
   // Save the NQS file if desired
   //if (savenqss) {

   //}
}

if (savelocations) {
	// For saving cell locations
	print "Saving cell locations..."
	// Shorten name of important structure


	celllist=col.ce
	n=celllist.count() // Number of cells
	// Initialize array
	celllocations = new Matrix(n,5) // Number of cells in 3D space
	for i=0,n-1 { // Loop over each cell
		celllocations.x[i][0]=i // Cell ID
		celllocations.x[i][1]=celllist.o[i].type // Cell population
		celllocations.x[i][2]=celllist.o[i].xloc // X position
		celllocations.x[i][3]=celllist.o[i].yloc // Y position
		celllocations.x[i][4]=celllist.o[i].zloc // Z position
		}
	// Save results to disk in text format
	fobj3 = new File(outfn3)
	fobj3.wopen()
	celllocations.fprint(0,fobj3,"%10.1f") // It's usually in the thousands so one d.p. should do
	fobj3.close()

   // Save the NQS file if desired
   if (savenqss) {
      col[0].cellsnq.sv(outfn3b)  // assume 1 column for the time being (0)
   }
}

if (saveconnectivity) {
	// For saving cell connectivity
	print "Saving cell connectivity..."
	conpreid=col.connsnq.getcol("id1")
	conpostid=col.connsnq.getcol("id2")
	condelay=col.connsnq.getcol("del")
	condistance=col.connsnq.getcol("dist")
	conweight1=col.connsnq.getcol("wt1")
	conweight2=col.connsnq.getcol("wt2")
	// Initialize array
	n = conpreid.size()
	connectivity = new Matrix(n,5) // PreID, post ID, delay, distance, and the first weight
	for i=0,n-1 { // Loop over each synapse
		connectivity.x[i][0]=conpreid.x[i]
		connectivity.x[i][1]=conpostid.x[i]
		connectivity.x[i][2]=condelay.x[i]
		connectivity.x[i][3]=condistance.x[i]
		connectivity.x[i][4]=conweight1.x[i]
		}
	// Save results to disk in text format
	print "  Saving to file..."
	fobj4 = new File(outfn4)
	fobj4.wopen()
	connectivity.fprint(0,fobj4,"%7.1f") // It's usually in the thousands so one d.p. should do
	fobj4.close()
	print "  ...done..."
}

// Save the NQS file if desired
if (savenqss) {
	print "Saving nqs with cell connectivity..."
	col[0].connsnq.sv(outfn4b)  // assume 1 column for the time being (0)
}

if (savedynamicweights) {
	print "Saving nqs with synaptic changes ..."
	nqsy.sv(outfn8)
	//writeweightstodisk()
}

if (savegrvec) {
   pvall("grvec data",outfn5)
}

Loading data, please wait...