# code by sam neymotin & ernie forzano from neuron import h h.load_file("stdrun.hoc") from pylab import * import sys import pickle import numpy h.install_vecst() # for samp and other NQS/vecst functions from conf import * import os from scipy.stats.stats import pearsonr from utils import dtrans import shutil ion() rcParams['lines.markersize'] = 15 rcParams['lines.linewidth'] = 4 tl = tight_layout useRMP = False # True # use RMP for fitness calculation? useVoltDiff = False useISI = False # this is for evaluation of full isi voltage useISIFeat = False # this is for evaluation of isi voltage features useISIDepth = False # this is for evaluation of isi voltage depth (min voltage) useISIDur = False # this is for evaluation of isi voltage duration useSag = False # whether to use sag for fitness useSpikeAmp = False # spike amplitude (peak - treshold voltage) - do not need when using SpikeThresh and SpikePeak useSpikePeak = False # spike peak (absolute voltage) useSpikeW = False # spike widths at 25% and 50% useSpikeSlope = False # min,max dv/dt useSpikeThresh = False # spike threshold voltage useSpikeShape = False # overall spike shape - uses features (peak,width,slope,thresh) useSpikeTimes = useSpikeCoinc = False useSFA = False # spike-frequency adaptation measure useLVar = False useInstRate = False useTTFS = False # use time-to-first-spike for fitness # def getfitdims (): fitdims = [] if useRMP: fitdims.append('RMP') if useSag: fitdims.append('Sag') if useFI: fitdims.append('FI') if useISI: fitdims.append('ISIVolt') if useISIFeat: fitdims.append('ISIFeat') if useSFA: fitdims.append('SFA') if useLVar: fitdims.append('LVar') if useInstRate: fitdims.append('InstRate') if useTTFS: fitdims.append('TTFS') if useSpikeTimes: fitdims.append('SpikeTimes') if useSpikeCoinc: fitdims.append('SpikeCoinc') if useSpikeAmp: fitdims.append('SpikeAmp') if useSpikePeak: fitdims.append('SpikePeak') if useSpikeW: fitdims.append('SpikeW') if useSpikeSlope: fitdims.append('SpikeSlope') if useSpikeThresh: fitdims.append('SpikeThresh') if useSpikeShape: fitdims.append('SpikeShape') if useVoltDiff: fitdims.append('VoltDiff') if useISIDepth: fitdims.append('ISIDepth') if useISIDur: fitdims.append('ISIDur') return fitdims # determine config file name def setfcfg (): fcfg = "PTcell.BS0284.cfg" # default config file name for i in xrange(len(sys.argv)): if sys.argv[i].endswith(".cfg") and os.path.exists(sys.argv[i]): fcfg = sys.argv[i] #print "config file is " , fcfg return fcfg dmod = {} fcfg=setfcfg() # config file name dconf = readconf(fcfg) dprm = dconf['params'] dfixed = dconf['fixed'] sampr = dconf['sampr'] # sampling rate I = numpy.load(dconf['lstimamp']) evolts = numpy.load(dconf['evolts']) # experimental voltage traces tte = linspace(0, 1e3*evolts.shape[0]/sampr, evolts.shape[0]) evolts = numpy.load(dconf['evolts']) # experimental voltage traces useFI=useInstRate=useISI=useSpikeShape=useVoltDiff=True fitdims=getfitdims() # def geterramp (nqa,row,lc): err = 0.0 for c in lc: if nqa.fi(c) != -1: err += (nqa.getcol(c).x[row] / nqa.getcol(c).mean())**2 return sqrt(err) # def adderrampcol (nqa,lc): nqa.tog('DB') if nqa.fi('erramp')== -1.0: nqa.resize('erramp'); nqa.pad() for i in xrange(int(nqa.v[0].size())): nqa.getcol('erramp').x[i] = geterramp(nqa,i,lc) nqa.stat('erramp') # # convert population to NQS def pop2nq (fpop,fitdims=None): if fitdims == None: fitdims=getfitdims() nqa = None try: nqa = h.NQS() except: h.load_file("nqs.hoc"); #h.load_file("decnqs.hoc") nqa = h.NQS() # first setup the fitness dimensions for s in fitdims: nqa.resize(s) nqa.clear(len(fpop)) for m in fpop: fit = m.fitness for i,val in enumerate(fit): nqa.v[i].append(val) # then setup the parameter values for k in dprm.keys(): nqa.resize(k) nqa.pad() for i,m in enumerate(fpop): idx = len(fitdims) prm = m.candidate jdx = idx; kdx = 0 while jdx < nqa.m[0]: nqa.v[jdx].x[i] = prm[kdx] jdx += 1; kdx += 1 adderrampcol(nqa,fitdims) return nqa # print out param values (nqa is table, idx is row) def rowprmstr (nq,idx): s = '' for i in xrange(len(fitdims),int(nq.m[0]),1): s += str(nq.v[i].x[idx]) + ' ' return s # loads model archive and stores in global ark and nqa objects def loadark (fn): global ark,nqa ark = pickle.load(open(fn)) print len(ark), ' models in ', fn, ' archive.' nqa = pop2nq(ark,fitdims) if fcfg == 'SPI6.cfg': # simplified model useVoltDiff=useFI=useInstRate=useSpikeW=useSpikeSlope=useSpikeThresh=useSpikePeak=useISI=True useSpikeShape=False fitdims=getfitdims() # reset fitness dimensions(fitdims), which differ from detailed model loadark(os.path.join('data','simparch.pkl')) # load simple model archive else: # detailed model loadark(os.path.join('data','detarch.pkl')) # load detailed model archive # add text to a plot def addtext (row,col,lgn,ltxt,tx=-0.025,ty=1.03,c='k'): for gn,txt in zip(lgn,ltxt): ax = subplot(row,col,gn) text(tx,ty,txt,fontweight='bold',transform=ax.transAxes,color=c); def naxbin (ax,nb): ax.locator_params(nbins=nb); # print full row (fitness and param values) at the given row (idx) from table (nqa) def rowstr (nq,idx): s = '' for i in xrange(int(nq.m[0])): s += nq.s[i].s + ':' + str(nq.v[i].x[idx]) + "\n" return s # print param values at the given row (idx) from table (nqa) def rowprmvals (nq,idx): lval = [] for i in xrange(len(fitdims),int(nq.m[0]),1): lval.append((nq.v[i].x[idx])) return lval # find index of f in a (if not there return -1) def indexof (a,f): for i,val in enumerate(a): if abs(val-f) < 0.01: return i return -1 ISubth = I[0:6] # subthreshold current injections ISup = I[6:] # current injections for subthresh right before threshold & superthreshold traces IAll = list(ISubth); IAll.extend(list(ISup)) # draw traces from experiment (uses black color) def drawexptraces (): tx,ty=-.05,1.02; offy = amin(tte[0]) - 30 ax=gca(); ax.set_xticks([]); ax.set_yticks([]); plot([1420,1520],[590,590],'k',linewidth=4) plot([1520,1520],[580,590],'k',linewidth=4) ypos = offy for j,i in enumerate(IAll): idx = indexof(I,i) plot(tte,evolts[:,idx] + ypos,'k') if j > len(ISubth): ypos += 95 else: ypos += 15 cdx=0 # index into color list # draw traces from the model (cycles through colors) def drawtraces (model): global cdx lclr = ['r','g','b','c','m','y'] tt = numpy.array(dmod[model]['vt']) tx,ty=-.05,1.02; offy = amin(tt[0]) - 30 if len(get_fignums())==0: drawexptraces() mdx=0; m=model ax=gca() ypos = offy for j,i in enumerate(IAll): plot(tt, dmod[m][i] + ypos,lclr[cdx%len(lclr)]) if j > len(ISubth): ypos += 95 else: ypos += 15 ax.set_xticks([]); ax.set_yticks([]); xlim((400,1600)); ylim((-125,680)); cdx+=1 # run model idx using params in ark/nqa, then load/draw the data def runmodel (idx): global lastmodel # should move pkl file to arch index location so dont have to rerun fnew = os.path.join('data', fcfg.split('.cfg')[0] + '_' + str(idx) + '.pkl') if os.path.exists(fnew): print 'model ' + str(idx) + ' already ran, data in', fnew else: cmd = 'python sim.py ' + fcfg + ' ' + rowprmstr(nqa,idx) print cmd os.system(cmd) if fcfg.startswith('PTcell'): shutil.move(os.path.join('data','morph.pkl'),fnew) else: shutil.move(os.path.join('data','SPI6.pkl'),fnew) if not os.path.exists(fnew): print 'ERROR: could not run model!' return lastmodel = (fcfg,idx) dmod[lastmodel] = pickle.load(open(fnew)) # load the data print 'model fitness error/params:', rowstr(nqa,idx) drawtraces((fcfg,idx)) # def drtxt (ax,lett,tx=-0.075,ty=1.03,fsz=45): text(tx,ty,lett,fontweight='bold',transform=ax.transAxes,fontsize=fsz) # draw archive figure showing param values of bottom/top percentiles def drawarchfig (): if fcfg == 'SPI6.cfg': lprm = ['SPI6.gbar_kdmc','SPI6.cal_gcalbar','SPI6.can_gcanbar','SPI6.kBK_gpeak','SPI6.gbar_kap','SPI6.gbar_kdr','SPI6.gbar_nax','SPI6.kBK_caVhminShift','SPI6.cadad_taur','SPI6.cadad_depth','h.vhalfn_kdr','h.vhalfn_kap','h.vhalfl_kap','h.tq_kap'] else: lprm = ['morph.nax_gbar', 'morph.kdmc_gbar','morph.kdr_gbar','morph.kap_gbar','morph.kBK_gpeak','morph.kBK_caVhminShift','morph.cal_gcalbar','morph.can_gcanbar','morph.cadad_taur','morph.cadad_depth'] draw1dfig(nqa,'erramp',0.01,lprm,nrow=2,ncol=2,gdx=1,stxt='a') xlim((0.5,10.5)); ylim((-3,4.5)) mbotAMP,mtopAMP = getprct(nqa,'erramp',0.01,lprm) mcAMP = getprmcors(nqa,'erramp',0.01,lprm) ax = subplot(2,2,2) imshow(mcAMP,interpolation='None',origin='lower',aspect='auto',extent=(0,mcAMP.shape[0]-1,0,mcAMP.shape[0]-1)) colorbar(); ax.set_xticks([]); ax.set_yticks([]) mytxt = 'Worst Best'; xlabel(mytxt); ylabel(mytxt); text(-0.025,1.03,'b',fontweight='bold',transform=ax.transAxes,color='k'); title('Parameter correlations') draw1dfig(nqa,'FI',0.01,lprm,nrow=2,ncol=2,gdx=3,stxt='c') xlim((0.5,10.5)); ylim((-3,4.5)) mbotFI,mtopFI = getprct(nqa,'FI',0.01,lprm) mcFI = getprmcors(nqa,'FI',0.01,lprm) ax = subplot(2,2,4) imshow(mcFI,interpolation='None',origin='lower',aspect='auto',extent=(0,mcFI.shape[0]-1,0,mcFI.shape[0]-1)) colorbar(); ax.set_xticks([]); ax.set_yticks([]) mytxt = 'Worst Best'; xlabel(mytxt); ylabel(mytxt); text(-0.025,1.03,'d',fontweight='bold',transform=ax.transAxes,color='k'); title('Parameter correlations') subplot(2,2,1); title('Rank by Error Amplitude'); subplot(2,2,3); title('Rank by FI Error') # def draw1dfig (nq,scc,prct,lprm,nrow=1,ncol=1,gdx=1,stxt='a'): tx,ty=-0.025,1.03; nqt = h.NQS() nqt.cp(nq) nqt.sort(scc) botsidx,boteidx = 0,int(prct*nqt.v[0].size()) # good topsidx,topeidx = int(nqt.v[0].size()*(1.0-prct)),int(nqt.v[0].size()-1) # bad ax = subplot(nrow,ncol,gdx) for pdx,prm in enumerate(lprm): dat = numpy.array(nqt.getcol(prm).to_python()) dat = dat - mean(dat) dat = dat / std(dat) plot([pdx+1 for j in xrange(boteidx-botsidx)],dat[botsidx:boteidx],'^',markeredgecolor='m',markerfacecolor='none',markersize=60,linewidth=8) plot([pdx+1 for j in xrange(topeidx-topsidx)],dat[topsidx:topeidx],'v',markeredgecolor='c',markerfacecolor='none',markersize=60,linewidth=8) ax.set_xticklabels([dtrans[prm] for prm in lprm]) ax.set_xticks(linspace(1,len(lprm),len(lprm))) ylabel('Normalized parameter value'); #ylim((-4.2,4.2)) text(tx,ty,stxt,fontweight='bold',transform=ax.transAxes,color='k'); h.nqsdel(nqt) # get bottom/top percentile from nq using column scc def getprct (nq,scc,prct,lprm): nqt = h.NQS() nqt.cp(nq) nqt.sort(scc) botsidx,boteidx = 0,int(prct*nqt.v[0].size()) # good topsidx,topeidx = int(nqt.v[0].size()*(1.0-prct)),int(nqt.v[0].size()-1) # bad mtop = zeros((topeidx-topsidx,len(lprm))) mbot = zeros((boteidx-botsidx,len(lprm))) for pdx,prm in enumerate(lprm): dat = numpy.array(nqt.getcol(prm).to_python()) dat = dat - mean(dat) dat = dat / std(dat) mbot[:,pdx] = dat[botsidx:boteidx] mtop[:,pdx] = dat[topsidx:topeidx] h.nqsdel(nqt) return mbot,mtop # get parameter correlations def getprmcors (nq,scc,prct,lprm): mbot,mtop = getprct(nq,scc,prct,lprm) nrow,ncol = mbot.shape mprct = zeros((nrow*2,ncol)) mprct[0:nrow,:] = mbot mprct[nrow:,:] = mtop mc = ones((nrow*2,nrow*2)) for i in xrange(nrow*2): for j in xrange(i+1,nrow*2,1): mc[i,j]=mc[j,i]=pearsonr(mprct[i,:],mprct[j,:])[0] return mc