Learning spatial transformations through STDP (Davison, Frégnac 2006)

 Download zip file   Auto-launch 
Help downloading and running models
Accession:64261
A common problem in tasks involving the integration of spatial information from multiple senses, or in sensorimotor coordination, is that different modalities represent space in different frames of reference. Coordinate transformations between different reference frames are therefore required. One way to achieve this relies on the encoding of spatial information using population codes. The set of network responses to stimuli in different locations (tuning curves) constitute a basis set of functions which can be combined linearly through weighted synaptic connections in order to approximate non-linear transformations of the input variables. The question then arises how the appropriate synaptic connectivity is obtained. This model shows that a network of spiking neurons can learn the coordinate transformation from one frame of reference to another, with connectivity that develops continuously in an unsupervised manner, based only on the correlations available in the environment, and with a biologically-realistic plasticity mechanism (spike timing-dependent plasticity).
Reference:
1 . Davison AP, Fr├ęgnac Y (2006) Learning cross-modal spatial transformations through spike timing-dependent plasticity. J Neurosci 26:5604-15 [PubMed]
Model Information (Click on a link to find other models with that property)
Model Type: Realistic Network;
Brain Region(s)/Organism: Generic;
Cell Type(s):
Channel(s):
Gap Junctions:
Receptor(s): GabaA; AMPA;
Gene(s):
Transmitter(s):
Simulation Environment: NEURON;
Model Concept(s): Synaptic Plasticity; Long-term Synaptic Plasticity; Unsupervised Learning; STDP;
Implementer(s): Davison, Andrew [Andrew.Davison at iaf.cnrs-gif.fr];
Search NeuronDB for information about:  GabaA; AMPA;
// Learning basis functions to implement functions of one population-encoded
// variable using STDP.

// The model has an Input Layer (cellLayer[0]) and a Training Layer
// (cellLayer[1]), each consisting of spike sources, and projecting to an Output
// Layer (cellLayer[2]) consisting of integrate-and-fire neurons.

// The synaptic weights from Training-->Output are fixed.
// The synaptic weights from Input-->Output are plastic and obey a STDP rule.

// During training, the Input Layer receives input x, and the Training Layer
// input f(x). After training, the Training Layer is silent, and an input x to
// the Input Layer produces an output f(x) in the Output Layer.

// Uses the NetStimVR2 mechanism, rather than VecStimMs

// Andrew P. Davison, UNIC, CNRS, July 2004-May 2006

startsw()
objref cvode
cvode = new CVode()
xopen("netLayer.hoc")
xopen("layerConn.hoc")
xopen("ObjectArray.hoc")
xopen("intfire4nc.hoc")
xopen("plotweights.hoc")

// =-=-= Create objects and strings  =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=

objref random, fileobj[3], histfileobj
objref cellLayer[3], conn[3], spikecontrol
objref cellParams, spikerec[2]
objref deltat_vec[2][3], deltat_hist
strdef fileroot, infile, filename, save_fileroot
strdef command, funcstr, label, datadir
double m[2][3]

// =-=-= Global Parameters =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=

seed             = 0           // Seed for the random number generator
ncells           = 30          // Number of input spike trains per layer
pconnect         = 1.0         // Connection probability
wmax             = 0.02        // Maximum synaptic weight
f_winhib         = 0.0         // Inhibitory weight = f_winhib*wmax (fixed)
f_wtr            = 1.0         // Max training weight = f_wtr*wmax
syndelay         = 0.0         // Synaptic delay
tauLTP_StdwaSA   = 20          // (ms) Time constant for LTP
tauLTD_StdwaSA   = 20          // (ms) Time constant for LTD
B                = 1.06        // B = (aLTD*tauLTD)/(aLTP*tau_LTP)
aLTP             = 0.01        // Amplitude parameter for LTP
Rmax             = 60          // (Hz) Peak firing rate of input distribution
Rmin             = 0           // (Hz) Minumum input firing rate
Rsigma           = 0.2         // Width parameter for input distribution
alpha            = 1.0         // Gain of Training Layer rates compared to Input Layer
correlation_time = 20          // (ms) 
bgRate           = 1000        // (Hz) Firing rate for background activity
bgWeight         = 0.02        // Weight for background activity
funcstr          = "sin"       // Label for function to be approximated
nfuncparam       = 1           // Number of parameters of function
double k[nfuncparam]
k[0]             = 0.0         // Function parameter(s)
wtr_square       = 1           // Sets square or bell-shaped profile for T-->O weights
wtr_sigma        = 0.15        // Width parameter for Training-->Output weights
noise            = 1           // Noise parameter
histbins         = 100         // Number of bins for weight histograms
record_spikes    = 0           // Whether or not to record spikes
wfromfile        = 0           // if positive, read connections/weights from file
infile           = ""          // File to read connections/weights from
tstop            = 1e7         // (ms)
trw              = 1e5         // (ms) Time between reading input spikes/printing weights
numhist          = 10          // Number of histograms between each weight printout
label            = "bfstdp_demo_" // Extra label for labelling output files
datadir          = ""          // Sub-directory of Data for writing output files
tau_m            = 20          // Membrane time constant

// =-=-= Create utility objects  =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=

random = new Random(seed)
histfileobj = new File()
for i = 0,2 { 
  fileobj[i] = new File()
}
spikerec[0] = new ObjectArray(1,ncells,"Vector","")
spikerec[1] = new ObjectArray(1,ncells,"Vector","")

// =-=-= Create the network  =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=

// Input spike trains are implemented using NetStimVR2s.
print "Creating network layers (time ", stopsw(), "s)"

cellParams = new Vector(4)
cellParams.x[0] = tau_m
cellParams.x[1] = 5
cellParams.x[2] = 10
cellParams.x[3] = 15

// Create network layers
for layer = 0,1 {
  cellLayer[layer] = new NetLayer(1,ncells,"NetStimVR2",0.5)
  cellLayer[layer].set("noise",1)
  for i = 0,ncells-1 {
    cellLayer[layer].cell[i].theta = i/ncells
  }
}
Rmax_NetStimVR2 = Rmax
Rmin_NetStimVR2 = Rmin
sigma_NetStimVR2 = Rsigma

cellLayer[0].set("transform",0)
cellLayer[0].set("prmtr",0)
if (strcmp(funcstr,"") == 0) cellLayer[1].set("transform",0)
if (strcmp(funcstr,"mul") == 0) cellLayer[1].set("transform",1)
if (strcmp(funcstr,"sin") == 0) cellLayer[1].set("transform",2)
if (strcmp(funcstr,"sq") == 0) cellLayer[1].set("transform",3)
if (strcmp(funcstr,"asin") == 0) cellLayer[1].set("transform",4)
if (strcmp(funcstr,"sinn") == 0) cellLayer[1].set("transform",5)
cellLayer[1].set("prmtr",k[0])
cellLayer[1].set("alpha",alpha)

spikecontrol = new ControlNSVR2(0.5)
spikecontrol.tau_corr = correlation_time
spikecontrol.seed(seed)
setpointer spikecontrol.thetastim, thetastim_NetStimVR2
setpointer spikecontrol.tchange, tchange_NetStimVR2

cellLayer[2] = new NetLayer(1,ncells,"IntFire4nc",cellParams)

// Create synaptic connections
print "Creating synaptic connections (time ", stopsw(), "s)"

random.uniform(0,1)
if (wfromfile) { // read connections from file
  for i = 0,1 {
    sprint(filename,"%s.conn%d.conn",infile,i+1)
    fileobj[0].ropen(filename)
    conn[i] = new LayerConn(cellLayer[i],"",cellLayer[2],"syn",4,fileobj[0])
    fileobj[0].close()
  }
  if (f_winhib != 0) {
    sprint(filename,"%s.conn2.conn",infile)
    fileobj[0].ropen(filename)
    conn[2] = new LayerConn(cellLayer[2],"syn",cellLayer[2],"syn",4,fileobj[0])
    fileobj[0].close()
  }
} else {         // or generate them according to the rules specified
  conn[0] = new LayerConn(cellLayer[0],"",cellLayer[2],"syn",1,pconnect,random) // 1 for all:all
  r = random.uniform(0,wmax)
  conn[0].randomize_weights(random)
  conn[1] = new LayerConn(cellLayer[1],"",cellLayer[2],"syn",1,pconnect,random)
  if (syndelay < 0) {
    conn[0].set_delays(-1*syndelay)
    conn[1].set_delays(0)
  } else if (syndelay > 0) {
    conn[0].set_delays(0)
    conn[1].set_delays(syndelay)
  }
  if (f_winhib != 0) {
    conn[2] = new LayerConn(cellLayer[2],"syn",cellLayer[2],"syn",1)
    conn[2].remove_self_connections()
    conn[2].set_weights(wmax*f_winhib)
  }
}

// Turn on STDP for Input-->Output connections
print "Setting up STDP for Input-->Output connections (time ", stopsw(), "s)"
conn[0].stdp("StdwaSA")
conn[0].set_max_weight(wmax)
conn[0].wa_set("aLTP",aLTP)
conn[0].wa_set("aLTD",B*aLTP*tauLTP_StdwaSA/tauLTD_StdwaSA)

// Set background input
print "Setting background activity (time ", stopsw(), "s)"
sprint(command,"%f, %f, 0, 1, 1e12",bgWeight,bgRate)
cellLayer[2].call("set_background",command)

// Turn on recording of spikes
if (record_spikes) {
  cellLayer[2].call("record","1")
  for i = 0,ncells-1 {
    conn[0].nc[i][i].record(spikerec[0].x[i])
    conn[1].nc[i][i].record(spikerec[1].x[i])
  }
}


// =-=-= Procedures =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

// Utility procedures ----------------------------------------------------------

proc set_fileroot() { local i
  system("date '+%Y%m%d_%H%M' > starttime")
  fileobj[0].ropen("starttime")
  fileobj[0].scanstr(save_fileroot)
  fileobj[0].close()
  sprint(fileroot,"Data/%s/%s%s",datadir,label,funcstr)
  for i = 0, nfuncparam-1 {
    sprint(fileroot,"%s-%3.1f",fileroot,k[i])
  }
  sprint(fileroot,"%s_%s",fileroot,save_fileroot)
  print "fileroot = ", fileroot
}

// Procedures to read input spike trains from file -----------------------------

// Procedures to set weights ---------------------------------------------------

proc set_training_weights() { local i, j, d
  // Set the Training-->Output weights
  
  for i = 0, ncells-1 {
    for j = 0, ncells-1 {
      if(object_id(conn[1].nc[i][j])) {
	d = i-j
	if (d > ncells/2)  { d = ncells - d }
	if (d < -ncells/2) { d = ncells + d }
	if (wtr_square) {
	  if (d <= wtr_sigma*ncells && d >= -wtr_sigma*ncells) {
	    conn[1].nc[i][j].weight = f_wtr*wmax
	  }
	} else {
	  conn[1].nc[i][j].weight = f_wtr*wmax*exp( (cos(2*PI*d/ncells) - 1) / (wtr_sigma*wtr_sigma) )
	}
      }
    }
  }
}

// Procedures for writing results to file --------------------------------------

proc save_parameters() { local i
  sprint(filename,"%s.param",fileroot)
  fileobj[0].wopen(filename)
  fileobj[0].printf("// Parameters for bfstdp_nsvr2.hoc\n")
  fileobj[0].printf("%-17s = %d\n","seed",seed)
  fileobj[0].printf("%-17s = %d\n","ncells",ncells)
  fileobj[0].printf("%-17s = %f\n","pconnect",pconnect)
  fileobj[0].printf("%-17s = %f\n","wmax",wmax)
  fileobj[0].printf("%-17s = %f\n","f_winhib",f_winhib)
  fileobj[0].printf("%-17s = %f\n","f_wtr",f_wtr)
  fileobj[0].printf("%-17s = %f\n","syndelay",syndelay)
  fileobj[0].printf("%-17s = %f\n","tauLTP_StdwaSA",tauLTP_StdwaSA)
  fileobj[0].printf("%-17s = %f\n","tauLTD_StdwaSA",tauLTD_StdwaSA)
  fileobj[0].printf("%-17s = %f\n","B",B)
  fileobj[0].printf("%-17s = %f\n","aLTP",aLTP)  
  fileobj[0].printf("%-17s = %f\n","Rmax",Rmax)
  fileobj[0].printf("%-17s = %f\n","Rmin",Rmin)
  fileobj[0].printf("%-17s = %f\n","Rsigma",Rsigma)
  fileobj[0].printf("%-17s = %f\n","alpha",alpha)
  fileobj[0].printf("%-17s = %f\n","correlation_time",correlation_time)
  fileobj[0].printf("%-17s = %f\n","bgWeight",bgWeight)
  fileobj[0].printf("%-17s = %f\n","bgRate",bgRate)
  fileobj[0].printf("%-17s = \"%s\"\n","funcstr",funcstr)
  fileobj[0].printf("%-17s = %f\n","nfuncparam",nfuncparam)
  for i = 0, nfuncparam-1 {
    fileobj[0].printf("%-14s[%d] = %f\n","k",i,k[i])
  }
  fileobj[0].printf("%-17s = %f\n","wtr_square",wtr_square)
  fileobj[0].printf("%-17s = %f\n","wtr_sigma",wtr_sigma)
  fileobj[0].printf("%-17s = %f\n","noise",noise)
  fileobj[0].printf("%-17s = %f\n","tau_m",tau_m)
  if (wfromfile) {
    fileobj[0].printf("%-17s = \"%s\"\n","infile",infile)
  }
  fileobj[0].close()
}

proc print_rasters() { local i,j,k
  // Write spike times to files.
  // Plot using 
  //   gnuplot> plot "<fileroot>.input1.ras" u 1:2 w d
  
  if (record_spikes) {
    for i = 0,1 {
      sprint(filename,"%s.cell%d.ras",fileroot,i+1)
      $o1.wopen(filename)
      for j = 0,ncells-1 {
	for k = 0,spikerec[i].x[j].size()-1 {
	  $o1.printf("%15.5g\t%d\n",spikerec[i].x[j].x[k],j)
	}
	$o1.printf("\n")
      }
      $o1.close()
    }
    sprint(filename,"%s.cell3.ras",fileroot)
    $o1.wopen(filename)
    cellLayer[2].print_spikes($o1)
    $o1.close()
  }
}

proc print_weights() { local i
  sprint(filename,"%s.conn%d.w",fileroot,$1+1)
  fileobj[0].wopen(filename)
  conn[$1].print_weights(fileobj[0])
  fileobj[0].close()
}

proc save_connections() { local i
  for i = 0,2-(f_winhib==0) {
    sprint(filename,"%s.conn%d.conn",fileroot,i+1)
    fileobj[0].wopen(filename)
    conn[i].save_connections(fileobj[0])
    fileobj[0].close()
  }
}

proc print_weight_distribution() { local i
  // Pointless to calculate distribution for inhibitory weights (i=1,2)
  conn[0].print_weight_hist(histfileobj,histbins,1)
}

proc print_delta_t() { local i,ii, histbins, range, total_size
  binwidth = $1 // ms
  range = $2
  histbins = 2*range+1
  deltat_hist = new Vector(histbins)
  for layer = 0,1 {
    total_size = deltat_vec[layer][0].size() + deltat_vec[layer][1].size() + deltat_vec[layer][2].size()
    for ii = 0,2 {
      deltat_hist.hist(deltat_vec[layer][ii],-range-0.5,histbins,binwidth)
      if ($3 == 1) deltat_hist.div(total_size)
      sprint(filename,"%s.conn%d.deltat%d",fileroot,layer+1,ii)
      fileobj.wopen(filename)
      for i = 0, histbins-1 { //print in a column
	fileobj.printf("%g\t%g\n",-range+binwidth*i,deltat_hist.x[i])
      }
      //deltat_vec.printf(fileobj)
      fileobj.close()
    }
  }
}

// Procedures that process recorded data ---------------------------------------

proc calc_delta_t() { local i,j,k,l,ii, nspikes_post, nspikes_pre, deltat, d
  // Calculate the distribution of spike-time differences (post-pre)
  // in three classes: connections for which d < 0.1, d < 0.2, d >= 0.2
  if (record_spikes) {
    for ii = 0,2 {
      for layer = 0,1 {
	deltat_vec[layer][ii] = new Vector(1e6)
	m[layer][ii] = 0
      }
    }
    for i = 0,ncells-1 {
      nspikes_post = cellLayer[2].cell[i].spiketimes.size()
      if (nspikes_post > 0) {
	for j = 0, nspikes_post-1 {
	  tpost = cellLayer[2].cell[i].spiketimes.x[j]
	  for k = 0,ncells-1 {
	    for layer = 0,1 {
	      if (layer==0) {
		d  = i/ncells - (sin(2*PI*k/ncells)+1)/2
	      } else {
		d = i/ncells - k/ncells
	      }
	      if (d < -0.5) d += 1
	      if (d >= 0.5) d -= 1
	      d = abs(d)
	      if (d < 0.1) {
		ii = 0
	      } else {
		if (d < 0.2) {
		  ii = 1
		} else {
		  ii = 2
		}
	      }
	      nspikes_pre = spikerec[layer].x[k].size()
	      if (nspikes_pre > 0) {
		for l = 0, nspikes_pre-1 {
		  deltat = tpost - spikerec[layer].x[k].x[l]
		  if (deltat < $2 && deltat > -1*$2) {
		    deltat_vec[layer][ii].x[m[layer][ii]] = deltat
		    m[layer][ii] += 1
		    if (m[layer][ii] >= deltat_vec[layer][ii].size()-1) {
		      deltat_vec[layer][ii].resize(2*deltat_vec[layer][ii].size)
		      printf("deltat_vec[%d][%d] resized\n",layer,ii)
		    }
		  }
		}
	      }
	    }
	  }
	}
      }
    }
    printf("Spike pairs: %d,%d  %d,%d  %d,%d\n",m[0][0],m[1][0],m[0][1],m[1][1],m[0][2],m[1][2])
    for ii = 0,2 {
      deltat_vec[0][ii].resize(m[0][ii])
      deltat_vec[1][ii].resize(m[1][ii])
    }
    print_delta_t($1,$2,$3)
    
  }
}



// Procedures that run simulations ---------------------------------------------

proc run_training() { local i, j, fileopen, thist
  // Training the network. The weight histogram is written to
  // file every trw ms. The weights are written to file every
  // thist = trw/numhist ms. The spike-times of the network
  // cells are written to file at the end.
  

  on_StdwaSA = 1
  thist = int(trw/numhist)

  sprint(filename,"%s.conn1.whist",fileroot)
  histfileobj.wopen(filename)
  
  save_parameters()
  save_fileroot = fileroot
  sprint(fileroot,"%s_%d",save_fileroot,0)
  print_weights(0)
  print_weights(1)
  save_connections()
  
  i = 0
  j = 0

  running_ = 1
  stoprun = 0
  setup_weight_plot()
  finitialize(-65)
  plot_weights(conn[0])
  starttime = startsw()
  while (t < tstop && stoprun == 0) {
    sprint(fileroot,"%s_%d",save_fileroot,j*thist)
    print_weight_distribution()
    if (i == numhist) {
      print_weights(0)
      i = 0
      printf("--- Simulated %d seconds in %d seconds\r",int(t/1000),startsw()-starttime)
      flushf()
    }
    i += 1
    j += 1
    while (t < j*thist) {
      fadvance()
    }
    //continuerun(j*thist)
    plot_weights(conn[0])
  }
  printf("--- Simulated %d seconds in %d seconds\n",int(t/1000),stopsw())
  
  sprint(fileroot,"%s_%d",save_fileroot,j*thist)
  print_weights(0)
  print_weights(1) // for debugging. Should not have changed since t = 0
  print_weight_distribution()
  save_connections()
  
  fileroot = save_fileroot
  
  // This corrects the pre-synaptic spiketimes for syndelay.
  // This is necessary because nc.record records spike times at the source
  // whereas we want to know them at the target.
  
  if (syndelay < 0) {
    for i = 0,ncells-1 {
      spikerec[0].x[i].add(-1*syndelay)
    } 
  } else if (syndelay > 0) {
    for i = 0,ncells-1 {
      spikerec[1].x[i].add(syndelay)
    }
  }

  
  print_rasters(fileobj[0])
  
  histfileobj.close()
  print "Training complete. Time ", stopsw()
  calc_delta_t(1.0,1000,0)
}

// =-=-= Initialize the network =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

set_fileroot()
cvode.active(1)
cvode.use_local_dt(1)         // The variable time step method must be used.
cvode.condition_order(2)      // Improves threshold-detection.
set_training_weights()
//steps_per_ms = 10
//dt = 0.1

print "Finished set-up (time ", stopsw(), "s)"

print "Running training ..."

run_training()


Loading data, please wait...