from common import * h.load_file("spike2file.hoc") idvec = h.Vector() idvec.buffer_size(5000000) spikevec = h.Vector() spikevec.buffer_size(5000000) n_spkout_files = max(nhost/64, 1) # each file contains spikes from 64 ranks n_spkout_sort = min(n_spkout_files*8, nhost) #each file serializes from 8 ranks # so each sorting rank gathers spikes from nhost/n_spkout_sort ranks checkpoint_interval = 100000 clean_weights_active = False clean_weights_interval = 10500.0 from weightsave import weight_reset as clean_weights, weight_file def prun(tstop): isaved=0 cvode = h.CVode() cvode.cache_efficient(1) #pc.spike_compress(0,0,1) pc.setup_transfer() #pc.timeout(0) mindelay = pc.set_maxstep(10) if rank == 0: print 'mindelay = %g'%mindelay runtime = h.startsw() exchtime = pc.wait_time() inittime = h.startsw() cvode.active(0) # if rank == 0: print 'cvode active=', cvode.active() h.stdinit() inittime = h.startsw() - inittime if rank == 0: if clean_weights_active: print 'weights reset active at %g ms' % clean_weights_interval else: print 'weights reset not active' print 'init time = %g'%inittime tnext_clean = clean_weights_interval while h.t < tstop: told = h.t tnext = h.t + checkpoint_interval if tnext > tstop: tnext = tstop #if clean_weights_active: #while tnext_clean < tnext: #pc.psolve(tnext_clean) #clean_weights() #tnext_clean += clean_weights_interval pc.psolve(tnext) # if rank == 0: # print 'sim. checkpoint at %g' % h.t if h.t == told: if rank == 0: print "psolve did not advance time from t=%.20g to tnext=%.20g\n"%(h.t, tnext) break # weight_file(params.filename+('.%d'%isaved)) # save spikes and dictionary in a binary format to # make them more comprimibles import binsave binsave.save(params.filename, spikevec, idvec) # h.spike2file(params.filename, spikevec, idvec, n_spkout_sort, n_spkout_files) runtime = h.startsw() - runtime comptime = pc.step_time() splittime = pc.vtransfer_time(1) gaptime = pc.vtransfer_time() exchtime = pc.wait_time() - exchtime if rank == 0: print 'runtime = %g'% runtime printperf([comptime, exchtime, splittime, gaptime]) def printperf(p): avgp = [] maxp = [] header = ['comp','spk','split','gap'] for i in p: avgp.append(pc.allreduce(i, 1)/nhost) maxp.append(pc.allreduce(i, 2)) if rank > 0: return b = avgp[0]/maxp[0] print 'Load Balance = %g'% b print '\n ', for i in header: print '%12s'%i, print '\n avg ', for i in avgp: print '%12.2f'%i, print '\n max ', for i in maxp: print '%12.2f'%i, print '' if __name__ == '__main__': import common import util common.nmitral = 1 common.ncell = 2 import net_mitral_centric as nmc nmc.build_net_roundrobin(getmodel()) pc.spike_record(-1, spikevec, idvec) from odorstim import OdorStim from odors import odors ods = OdorStim(odors['Apple']) ods.setup(nmc.mitrals, 10., 20., 100.) prun(200.) util.finish()