Biochemically detailed model of LTP and LTD in a cortical spine (Maki-Marttunen et al 2020)

 Download zip file 
Help downloading and running models
Accession:260971
"Signalling pathways leading to post-synaptic plasticity have been examined in many types of experimental studies, but a unified picture on how multiple biochemical pathways collectively shape neocortical plasticity is missing. We built a biochemically detailed model of post-synaptic plasticity describing CaMKII, PKA, and PKC pathways and their contribution to synaptic potentiation or depression. We developed a statistical AMPA-receptor-tetramer model, which permits the estimation of the AMPA-receptor-mediated maximal synaptic conductance based on numbers of GluR1s and GluR2s predicted by the biochemical signalling model. We show that our model reproduces neuromodulator-gated spike-timing-dependent plasticity as observed in the visual cortex and can be fit to data from many cortical areas, uncovering the biochemical contributions of the pathways pinpointed by the underlying experimental studies. Our model explains the dependence of different forms of plasticity on the availability of different proteins and can be used for the study of mental disorder-associated impairments of cortical plasticity."
Reference:
1 . Mäki-Marttunen T, Iannella N, Edwards AG, Einevoll GT, Blackwell KT (2020) A unified computational model for cortical post-synaptic plasticity. Elife [PubMed]
Model Information (Click on a link to find other models with that property)
Model Type: Synapse;
Brain Region(s)/Organism: Neocortex;
Cell Type(s): Neocortex spiking regular (RS) neuron;
Channel(s): I Calcium;
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s): Glutamate; Norephinephrine; Acetylcholine;
Simulation Environment: NEURON; NeuroRD;
Model Concept(s): Long-term Synaptic Plasticity;
Implementer(s): Maki-Marttunen, Tuomo [tuomomm at uio.no];
Search NeuronDB for information about:  I Calcium; Acetylcholine; Norephinephrine; Glutamate;
/
synaptic
NEURON
fitfiles
README.html
#drawfig3.py#
calcconds.py
calcconds_dimerdimer.py
drawfig11.py
drawfig2.py
drawfig3.py
drawfig3_1.py
drawfig4.py
drawfig5.py
drawfig9abc.py
emoo.py
fit_cAMP_withdiss_1d.py
fits_goodparams.mat
fits_goodparams_manyb.mat
fitter_fewer.py
fitter_fewer_check.py
fitter_fewer_check_given.py *
fitter_fewer1withCK_check_given.py *
fitter_manyb_check_given.py
mesh_general.out *
model_nrn_altered_noU.py
model_nrn_altered_noU_extfilename_lowmem_recall.py
model_nrn_altered_noU_noninterp.py
model_nrn_altered_noU_simpleflux_extfilename_lowmem.py
model_nrn_oldCaM_altered_noU.py
model_nrn_oldCaM_altered_noU_extfilename_lowmem_recall.py
model_nrn_oldPKA_altered_noU.py
model_nrn_paired_contnm_var.py
model_nrn_paired_contnm_var_noninterp.py
model_nrn_testPKA_withdiss.py
model_nrn_testPKA_withdiss_williamson_varyrates.py
mytools.py
protocols_many.py
protocols_many_78withoutCK.py
protocols_many_78withoutCK_1withCK.py
reactionGraph.mat
runfig11.sh
runfig2.sh
runfig3_1.sh
runfig3-4.sh
runfig5.sh
runfig9.sh
simsteadystate_flexible.py
simsteadystate_li2020.py
simsteadystate_oldCaM_li2020.py
                            
import math
import copy
import pylab

def spike_times(time,vrec,V_min_peak=-20,V_max_valley=0):
  valley_reached = 1
  sptime = []  
  for j in range(1,len(time)-1):
    if valley_reached and vrec[j] >= V_min_peak and vrec[j] > vrec[j-1] and vrec[j] >= vrec[j+1]:
      valley_reached = 0
      sptime.append(time[j])
    elif valley_reached==False and vrec[j] <= V_max_valley:
      valley_reached = 1
  return sptime  

#membpotderivs(time,vrec): Given the membrane potentials (vrec) at time points time[0],time[1],...,time[N-1],
#return the derivatives at time points time[1],time[2],...,time[N-3]
def membpotderivs(time,vrec):
  N = len(time)
  tdiff = [x-y for x,y in zip(time[1:N-1],time[0:N-2])]
  vdiff = [x-y for x,y in zip(vrec[1:N-1],vrec[0:N-2])]
  mderiv = [x/y for x,y in zip(vdiff,tdiff)]
  return [0.5*(x+y) for x,y in zip(mderiv[1:N-2],mderiv[0:N-3])]

#limitcyclescaledv(v1,dv1,v2,dv2): Give the coefficient for memb. pot. derivative that one has to use in order to make
#the difference on the derivative axis as significant as the difference on the memb. pot. axis
def limitcyclescaledv(v1,dv1,v2,dv2):
  maxv = max(max(v1),max(v2))
  minv = min(min(v1),min(v2))
  maxdv = max(max(dv1),max(dv2))
  mindv = min(min(dv1),min(dv2))
  return 1.0*(maxv-minv)/(maxdv-mindv)

def limitcyclediff(v1,dv1,v2,dv2,dvcoeff=0.1):
  N1 = len(v1)
  N2 = len(v2)
  dv1 = [dvcoeff*x for x in dv1]
  dv2 = [dvcoeff*x for x in dv2]
  Dmin = N1*[0]
  for i in range(0,N1):
    Dmin[i] = math.sqrt(min([(x-v1[i])**2+(y-dv1[i])**2 for x,y in zip(v2,dv2)]))
  vdiff = [x-y for x,y in zip(v1[1:N1],v1[0:N1-1])] 
  dvdiff = [x-y for x,y in zip(dv1[1:N1],dv1[0:N1-1])] 
  h = [math.sqrt(x**2+y**2) for x,y in zip(vdiff,dvdiff)]
  #use the trapezoid rule for integration:
  Dminmean = [(x+y)/2.0 for x,y in zip(Dmin[1:N1],Dmin[0:N1-1])]
  #print "hsum="+str(sum(h))
  return sum([x*y for x,y in zip(Dminmean,h)])
  

def interpolate(tref,vref,tint): #Assumes that the trefs come sorted!
  vint = len(tint)*[0.0]
  addedOne = False
  #print tref
  #print tint
  #if tref[len(tref)-1] == tint[len(tint)-1]:
  #  tref.append(tref[len(tref)-1]+0.0001)
  #  vref.append(vref[len(tref)-1])
  #  addedOne = True
  if tref[0] > tint[0] or tref[len(tref)-1] < tint[len(tint)-1]:
    print "Extrapolation needed!"
    return len(tint)*[-1]
  indvrecnow = 0  
  for j in range(0,len(tint)):
    while tref[indvrecnow+1] <= tint[j]:
      indvrecnow = indvrecnow + 1
      if indvrecnow == len(tref)-1: # It must be the last index if this happens
        vint[j:len(tint)] = [vref[indvrecnow]]*(len(tint)-j)
        return vint
    vint[j] = vref[indvrecnow] + 1.0*(tint[j]-tref[indvrecnow])/(tref[indvrecnow+1]-tref[indvrecnow])*(vref[indvrecnow+1]-vref[indvrecnow])
  return vint

def interpolate_extrapolate_constant(tref,vref,tint): #Assumes that the trefs come sorted!
  vint = len(tint)*[0.0]
  addedOne = False
  #print tref
  #print tint
  #if tref[len(tref)-1] == tint[len(tint)-1]:
  #  tref.append(tref[len(tref)-1]+0.0001)
  #  vref.append(vref[len(tref)-1])
  #  addedOne = True

  indvrecnow = 0  
  for j in range(0,len(tint)):
    if tint[j] < tref[0]:
      vint[j] = vref[0]
      continue
    if indvrecnow >= len(tref) - 1:
      vint[j] = vref[-1]
    while tref[indvrecnow+1] <= tint[j]:
      indvrecnow = indvrecnow + 1
      if indvrecnow == len(tref)-1: # It must be the last index if this happens
        vint[j:len(tint)] = [vref[indvrecnow]]*(len(tint)-j)
        return vint
    vint[j] = vref[indvrecnow] + 1.0*(tint[j]-tref[indvrecnow])/(tref[indvrecnow+1]-tref[indvrecnow])*(vref[indvrecnow+1]-vref[indvrecnow])
  return vint

#kronecker product of list A and list B
def kron(A,B):
  C = []
  if type(B[0]) is int or type(B[0]) is float:
    for i in range(0,len(A)):
      for j in range(0,len(B)):
        print "asdf"
        print B[j]
        C.append(A[i]*B[j])
  elif type(B[0][0]) is int or type(B[0][0]) is float:
    for i in range(0,len(A)):
      for j in range(0,len(B)):
        C.append([x*A[i] for x in B[j]])
  return C
  
def cumprod(A):
  B = len(A)*[0]; B[0]=A[0]
  for j in range(1,len(A)):
    B[j] = B[j-1]*A[j]
  return B

def oscillatorypoissontimeseries(N, minfreq, maxfreq, oscfreq, phase, T, isprinted=0):
  a = 0.5*(minfreq + maxfreq)
  b = 0.5*(maxfreq - minfreq)
  meanT = 1.0/a
  omega = 2*pylab.pi*oscfreq
  sptime = pylab.zeros([int(0.001*2*T*N/meanT)+100,2],dtype='d')
  #sptime = []
  Nplaced = 0
  for iN in range(0,N):
    t = 0
    while t < T:
      e = pylab.log(1-pylab.rand()) 
      Ttrial = meanT #Ttrial in sec, t and T in msec
      # Newton's method:
      for iter in range(0,50):
        Ttrialold = Ttrial
        Ttrial = Ttrial - (e + a*Ttrial - b/omega*(pylab.cos(omega*(0.001*t+Ttrial)+phase)-pylab.cos(omega*0.001*t+phase))) / (a + b*pylab.sin(omega*(0.001*t+Ttrial)+phase))
      myerr = abs(e + a*Ttrial - b/omega*(pylab.cos(omega*(0.001*t+Ttrial)+phase)-pylab.cos(omega*0.001*t+phase)))
      if Ttrial < 0 or myerr > 0.0001:
        Ttrial = meanT
        # If Newton's method didn't give good result, try fixed point method: Find T such that -aT+b/omega(pylab.cos(omega*t+phase)-pylab.cos(phase)) = pylab.log(1-x):
        for iter in range(0,30000):
          Ttrial = (b/omega*(pylab.cos(omega*(0.001*t+Ttrial)+phase)-pylab.cos(omega*0.001*t+phase)) - e)/a
        Ttrialold = Ttrial
        Ttrial = (b/omega*(pylab.cos(omega*(0.001*t+Ttrial)+phase)-pylab.cos(omega*0.001*t+phase)) - e)/a
      if isprinted > 1:
        print('iN='+str(iN)+', Err(Ttrial) = '+str(abs(Ttrial-Ttrialold))+', Ttrial = '+str(Ttrial)+', t = '+str(t)+', newPhase='+str(omega*(t/1000+Ttrial)+phase)+', myerr='+str(myerr))
      t = t+Ttrial*1000
      sptime[Nplaced,:] = [iN,t]
      #sptime.append([iN,t])
      Nplaced = Nplaced + 1
    if isprinted > 0:
      print('spikes for iN='+str(iN)+' complete')
  sptime = sptime[:Nplaced,:]
  if isprinted > 0:
    print(str(Nplaced)+' spikes in total, '+str(int(0.001*2*T*N/meanT)+100)+' reserved')
  return sptime

def printlistlen(A):
  #TODO: recursive method might work out but needs some thought...
  #toCheck = A
  #lens = [len(x) for x in A]
  #levels = [0 for x in A]
  #while type(toCheck) is list and len(toCheck) > 0:
  #  while type(toCheck[0]) is list and len(toCheck[0]) > 0:
  #    toCheck.append(toCheck[0][0])
  #    lens.append[len
  #    toCheck[0].pop(0)
  nan = -1
  if type(A) is list:
   B = copy.deepcopy(A)
   listFound0 = False
   for i0 in range(0,len(B)):
    if type(B[i0]) is list:
     listFound0 = True
     listFound1 = False
     for i1 in range(0,len(B[i0])):
      if type(B[i0][i1]) is list:
       listFound1 = True
       listFound2 = False
       for i2 in range(0,len(B[i0][i1])):
        if type(B[i0][i1][i2]) is list:
         listFound2 = True
         listFound3 = False
         for i3 in range(0,len(B[i0][i1][i2])):
          if type(B[i0][i1][i2][i3]) is list:
           listFound3 = True
           listFound4 = False
           for i4 in range(0,len(B[i0][i1][i2][i3])):
            if type(B[i0][i1][i2][i3][i4]) is list:
             listFound4 = True
             listFound5 = False
             for i5 in range(0,len(B[i0][i1][i2][i3][i4])):
              if type(B[i0][i1][i2][i3][i4][i5]) is list:
               listFound5 = True
               listFound6 = False               
               for i6 in range(0,len(B[i0][i1][i2][i3][i4][i5])):
                if type(B[i0][i1][i2][i3][i4][i5][i6]) is list:
                  listFound6 = True
                  B[i0][i1][i2][i3][i4][i5][i6] = len(B[i0][i1][i2][i3][i4][i5][i6])
                else:
                  B[i0][i1][i2][i3][i4][i5][i6] = nan
               if not listFound6:
                 B[i0][i1][i2][i3][i4][i5] = len(B[i0][i1][i2][i3][i4][i5])
              else:
                B[i0][i1][i2][i3][i4][i5] = nan
             if not listFound5:
               B[i0][i1][i2][i3][i4] = len(B[i0][i1][i2][i3][i4])
            else:
              B[i0][i1][i2][i3][i4] = nan
           if not listFound4:
             B[i0][i1][i2][i3] = len(B[i0][i1][i2][i3])
          else:
            B[i0][i1][i2][i3] = nan
         if not listFound3:
           B[i0][i1][i2] = len(B[i0][i1][i2])
        else:
          B[i0][i1][i2] = nan
       if not listFound2:
         B[i0][i1] = len(B[i0][i1])
      else:
        B[i0][i1] = nan
     if not listFound1:
       B[i0] = len(B[i0])
    else:
      B[i0] = nan
   if not listFound0:
     B = len(B)
  else:
   B = nan
  print B

def drawarrow(ax,x,y,acoeff=1,prc=0.9,lw=1,lc='#000000'):
  d = [x[1]-x[0], y[1]-y[0]];
  k = pylab.sqrt(d[0]**2 + d[1]**2)
  d = d/k

  dperp = [acoeff*z for z in [-d[1], d[0]]];
  lens = k-k*(1-prc);
  perplen = 0.5*k*(1-prc);

  px = [x[0],x[1],x[0]+lens*d[0]+perplen*dperp[0],x[0]+lens*d[0]-perplen*dperp[0]];
  py = [y[0],y[1],y[0]+lens*d[1]+perplen*dperp[1],y[0]+lens*d[1]-perplen*dperp[1]];

  px = [px[0],px[1],px[2],px[1],px[3]]
  py = [py[0],py[1],py[2],py[1],py[3]]
  #px = reshape([px(:,[1 2 3 2 4]) nan(size(x,1),1)]',size(x,1)*6,1);
  #py = reshape([py(:,[1 2 3 2 4]) nan(size(x,1),1)]',size(x,1)*6,1);
  ax.plot(px,py,'k-',linewidth=lw,color=lc)

def timeseriesmean(times,x):
  return 1.0*sum([(t2-t1)*(x1+x2)/2 for t1,t2,x1,x2 in zip(times[0:-1],times[1:],x[0:-1],x[1:])])/(times[-1]-times[0])

def timeseriessecondmoment(times,x):
  return 1.0*sum([(t2-t1)*(x1**2+x2**2)/2 for t1,t2,x1,x2 in zip(times[0:-1],times[1:],x[0:-1],x[1:])])

def timeseriesstd(times,x,xmean=pylab.nan):
  if pylab.isnan(xmean):
    xmean = timeseriesmean(times,x)
  return pylab.sqrt(1.0*sum([(t2-t1)*((x1-xmean)**2+(x2-xmean)**2)/2 for t1,t2,x1,x2 in zip(times[0:-1],times[1:],x[0:-1],x[1:])])/(times[-1]-times[0]))

def drawdiscontinuity(ax,y,yoffset,x=0,xoffset=0.1,lw=2.0,lw2=1.0):
  thisline = ax.plot([x-xoffset,x+xoffset],[y-yoffset,y],'k-',linewidth=lw2)
  thisline[0].set_clip_on(False)
  thisline = ax.plot([x-xoffset,x+xoffset],[y,y+yoffset],'k-',linewidth=lw2)
  thisline[0].set_clip_on(False)
  thisline = ax.plot([x-xoffset,x+xoffset],[y-0.5*yoffset,y+0.5*yoffset],'k-',color='#FFFFFF',zorder=100,linewidth=lw)
  thisline[0].set_clip_on(False)

def firingratecurve(spikes,T=[],dt=1.0,gauss_std=5.0):
  if type(spikes) is list:
    spikes = pylab.array(spikes)
  if type(T) is list and len(T) == 0:
    T = max(spikes[0,:])
  FRs = pylab.zeros([int(T/dt),1])
  FRts = [dt*(i+0.5) for i in range(0,len(FRs))]
  for iFRt in range(0,len(FRs)):
    FRs[iFRt] = sum(pylab.exp(-1/2*(FRts[iFRt]-spikes[0,:])**2/gauss_std**2))
  return [FRs,FRts]

def find(condition):
    "Return the indices where ravel(condition) is true"
    res, = np.nonzero(np.ravel(condition))
    return res

def mylegend(fig,pos,styles,texts,nx=1,dx=2,yplus=0.5,yplustext=0.35,colors=[],dashes=[],linewidths=[],myfontsize=8):
  ny = int(pylab.ceil(1.0*len(styles)/nx))
  axnew = fig.add_axes(pos)
  handles = []
  for i in range(0,len(styles)):
    handles.append(axnew.plot([dx*(i/ny)+0.15,dx*(i/ny)+0.35],[yplus+ny-1-(i%ny),yplus+ny-1-(i%ny)],styles[i]))
    axnew.text(dx*(i/ny)+0.5,yplustext+ny-1-(i%ny),texts[i],fontsize=myfontsize)
  axnew.set_xlim([0,dx*nx])
  axnew.set_ylim([0,len(styles)])
  for i in range(0,len(dashes)):
    if len(dashes[i]) > 0:
      handles[i][0].set_dashes(dashes[i])
  for i in range(0,len(colors)):
    if len(colors[i]) > 0:
      handles[i][0].set_color(colors[i])
  for i in range(0,len(linewidths)):
    if type(linewidths[i]) is not list and linewidths[i] > 0:
      handles[i][0].set_linewidth(linewidths[i])
  axnew.get_xaxis().set_visible(False)
  axnew.get_yaxis().set_visible(False)
  axnew.set_xlim([0,dx*nx])
  axnew.set_ylim([0,ny])
  return axnew

import math

def colorsredtolila(N,brightness=1.0,whiteness=0.0):
  if N==1:
    return ['#0000FF']
  C = colorsredtolilaint(N,brightness,whiteness)
  hexlist = []
  for j in range(0,N):
    myhex = '#'
    for k in range(0,3):
      if C[j][k] < 16:
        myhex = myhex + '0' + hex(C[j][k])[2]
      else:
        myhex = myhex + hex(C[j][k])[2:4]
    hexlist.append(myhex)
  return hexlist

def colorsredtolilaint(N,brightness=1.0,whiteness=0.0):

  linchange = [1.0/(N-1)*x for x in range(0,N)]
  linchange_x = [x/0.8 for x in [0.0, 0.36, 0.45, 0.65, 0.7, 0.8]]
  linchange_y = [0.0, 0.29, 0.46, 0.65, 0.75, 0.83]
  for i in range(0,len(linchange)):
    ind = N
    for j in range(0,len(linchange_x)):
      if linchange[i] >= linchange_x[j]:
        ind = j
    if ind < len(linchange_x)-1:
      linchange[i] = linchange_y[ind]+(linchange_y[ind+1]-linchange_y[ind])*(linchange[i]-linchange_x[ind])/(linchange_x[ind+1]-linchange_x[ind])

  C = hsv2rgblist([int(255*x) for x in linchange],[1.0-whiteness]*N,[brightness]*N);
  return C


def hsv2rgblist(h, s, v):
  hs = []
  ss = []
  vs = []
  for i in range(0,len(h)):
    h1, s1, v1 = hsv2rgb(h[i],s[i],v[i])
    hs.append(h1)
    ss.append(s1)
    vs.append(v1)
  return zip(hs,ss,vs)

def hsv2rgb(h, s, v):
  h = float(h)
  s = float(s)
  v = float(v)
  h60 = h / 60.0
  h60f = math.floor(h60)
  hi = int(h60f) % 6
  f = h60 - h60f
  p = v * (1 - s)
  q = v * (1 - f * s)
  t = v * (1 - (1 - f) * s)
  r, g, b = 0, 0, 0
  if hi == 0: r, g, b = v, t, p
  elif hi == 1: r, g, b = q, v, p
  elif hi == 2: r, g, b = p, v, t
  elif hi == 3: r, g, b = p, q, v
  elif hi == 4: r, g, b = t, p, v
  elif hi == 5: r, g, b = v, p, q
  r, g, b = int(r * 255), int(g * 255), int(b * 255)
  return r, g, b
    
def rgb2hsv(r, g, b):
  r, g, b = r/255.0, g/255.0, b/255.0
  mx = max(r, g, b)
  mn = min(r, g, b)
  df = mx-mn
  if mx == mn:
    h = 0
  elif mx == r:
    h = (60 * ((g-b)/df) + 360) % 360
  elif mx == g:
    h = (60 * ((b-r)/df) + 120) % 360
  elif mx == b:
    h = (60 * ((r-g)/df) + 240) % 360
  if mx == 0:
    s = 0
  else:
    s = df/mx
  v = mx
  return h, s, v

Loading data, please wait...