Parallelizing large networks in NEURON (Lytton et al. 2016)

 Download zip file 
Help downloading and running models
Accession:188544
"Large multiscale neuronal network simulations and innovative neurotechnologies are required for development of these models requires development of new simulation technologies. We describe here the current use of the NEURON simulator with MPI (message passing interface) for simulation in the domain of moderately large networks on commonly available High Performance Computers (HPCs). We discuss the basic layout of such simulations, including the methods of simulation setup, the run-time spike passing paradigm and post-simulation data storage and data management approaches. We also compare three types of networks, ..."
Reference:
1 . Lytton WW, Seidenstein AH, Dura-Bernal S, McDougal RA, Schürmann F, Hines ML (2016) Simulation Neurotechnologies for Advancing Brain Research: Parallelizing Large Networks in NEURON. Neural Comput 28:2063-90 [PubMed]
Model Information (Click on a link to find other models with that property)
Model Type: Realistic Network;
Brain Region(s)/Organism:
Cell Type(s): Hodgkin-Huxley neuron; Abstract Izhikevich neuron;
Channel(s):
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: NEURON; NetPyNE;
Model Concept(s): Simplified Models; Methods; Multiscale;
Implementer(s): Dura-Bernal, Salvador [salvadordura at gmail.com]; Lytton, William [bill.lytton at downstate.edu]; Seidenstein, Alexandra [ahs342 at nyu.edu];
from matplotlib import pyplot
import random
from datetime import datetime
import pickle
from neuron import h, gui


class Cell(object):
    def __init__(self):
        self.synlist = []
        self.createSections()
        self.buildTopology()
        self.defineGeometry()
        self.defineBiophysics()
        self.createSynapses()
        self.nclist = []

    def createSections(self):
        pass

    def buildTopology(self):
        pass

    def defineGeometry(self):
        """Set the 3D geometry of the cell."""
        self.soma.L = 18.8
        self.soma.diam = 18.8
        self.soma.Ra = 123.0

    def defineBiophysics(self):
        pass

    def createSynapses(self):
        """Add an exponentially decaying synapse """
        syn = h.ExpSyn(self.soma(0.5))
        syn.tau = 2
        syn.e = 0
        self.synlist.append(syn) # synlist is defined in Cell

    def associateGid (self):
        pc.set_gid2node(self.gid, idhost)
        nc = h.NetCon(self.soma(0.5)._ref_v, None, sec=self.soma)
        nc.threshold = 10
        pc.cell(self.gid, nc)
        del nc # discard netcon


    def createNetcon(self, thresh=10):
        """ created netcon to record spikes """
        nc = h.NetCon(self.soma(0.5)._ref_v, None, sec = self.soma)
        nc.threshold = thresh
        return nc

    def createStim(self, number=1e9, start=1, noise=0.5, rate=50, weight=1, delay=5):
        self.stim = h.NetStim()
        self.stim.number = number
        self.stim.start = start
        self.stim.noise = noise
        self.stim.interval = 1000.0/rate
        self.ncstim = h.NetCon(self.stim, self.synlist[0])
        self.ncstim.delay = delay
        self.ncstim.weight[0] = noise # NetCon weight is a vector.

    def connect2Source(self, sourceCell, thresh=10):
        """Make a new NetCon with the source cell's membrane
        potential at the soma as the source (i.e. the spike detector)
        onto the target (i.e. a synapse on this cell)."""
        nc = h.NetCon(sourceCell.soma(1)._ref_v, self.synlist[0], sec = sourceCell.soma)
        nc.threshold = thresh
        return nc

    def setRecording(self):
        """Set soma, dendrite, and time recording vectors on the cell. """
        self.soma_v_vec = h.Vector()   # Membrane potential vector at soma
        self.tVec = h.Vector()        # Time stamp vector
        self.soma_v_vec.record(self.soma(0.5)._ref_v)
        self.tVec.record(h._ref_t)

    def plotTraces(self):
        """Plot the recorded traces"""
        pyplot.figure(figsize=(8,4)) # Default figsize is (8,6)
        somaPlot = pyplot.plot(self.tVec, self.soma_v_vec, color='black')
        pyplot.legend(somaPlot, ['soma'])
        pyplot.xlabel('time (ms)')
        pyplot.ylabel('mV')
        pyplot.title('Cell %d voltage trace'%(self.gid))
        pyplot.show()
        #pyplot.savefig('traces')


class HHCell(Cell): 
    """HH cell: A soma with active channels""" 
    def createSections(self):
        """Create the sections of the cell."""
        self.soma = h.Section(name='soma', cell=self)
    
    def defineBiophysics(self):
        """Assign the membrane properties across the cell."""
        # Insert active Hodgkin-Huxley current in the soma
        self.soma.insert('hh')
        self.soma.gnabar_hh = 0.12  # Sodium conductance in S/cm2
        self.soma.gkbar_hh = 0.036  # Potassium conductance in S/cm2
        self.soma.gl_hh = 0.003    # Leak conductance in S/cm2
        self.soma.el_hh = -70       # Reversal potential in mV


class Net:
    """Creates Network of N neurons (using parallelContext)
    Connectivity and stimulation params provided as arguments
    Also ncludes methods to gather and plot spikes
    """
    def __init__(self, N=5, cellType=HHCell, connParams={}, stimParams={}):

        """
        N: Number of cells.
        cellType: class of cell type
        connParams: dict of connectivity params
        stimParams: dict of stimulation params

        """
        self.cellType = cellType
        self.N = N                      # number of cells
        self.connParams = connParams    # connectivity params
        self.stimParams = stimParams    # backgroudn stim params
        self.cells = []                 # List of Cell objects in the net
        self.nclist = []                # List of NetCon in the net
        self.tVec = h.Vector()         # spike time of all cells
        self.idVec = h.Vector()        # cell ids of spike times
        self.createNet()  # Actually build the net
    

    def createNet(self):
        """Create, layout, and connect N cells."""
        self.setGids() #### set global ids (gids), used to connect cells
        self.createCells()
        self.connectCells() 
        self.createStims()


    def setGids(self):
        self.gidList = []
        #### Round-robin counting. Each host as an id from 0 to nhost - 1.
        for i in range(idhost, self.N, nhost):
            self.gidList.append(i)

    
    def createCells(self):
        """Create and layout cells (in this host) in the network."""
        self.cells = []

        for i in self.gidList: #### Loop over cells in this node/host
            cell = self.cellType() # dynamically create cell object 
            self.cells.append(cell)  # add cell object to net cell list
            cell.gid = i # assign gid (can be any unique integer)
            cell.associateGid() # associated gid to each cell
            pc.spike_record(cell.gid, self.tVec, self.idVec) # Record spikes of this cell
            
            print 'Created cell %d on host %d out of %d'%(i, idhost, nhost) 

    def connectCells(self):
        """Connect cells"""
        connType = self.connParams['type']
        if connType == 'rand':
            weight = self.connParams['weight']
            delayMean = self.connParams['delayMean']
            delayVar = self.connParams['delayVar']
            delayMin = self.connParams['delayMin']
            maxConnsPerCell = self.connParams['maxConnsPerCell']
            self.nclist = []

            ## create random delays
            random.seed(randSeed)  # Reset random number generator  
            randDelays = [max(delayMin, random.gauss(delayMean, delayVar)) for pre in range(maxConnsPerCell*self.N)] # select random delays based on mean and var params    

            ## loop over postsyn gids in this host
            for postCell in self.cells:  
                preGids = [gid for gid in self.gidList if gid != postCell.gid] # get list of presyn cell gids (omit post to prevent self connection)
                randPreGids = random.sample(preGids, random.randint(0, min(maxConnsPerCell, len(preGids)))) 
                for preGid in randPreGids: # for each presyn cell
                    nc = pc.gid_connect(preGid, postCell.synlist[0]) # create NetCon by associating pre gid to post synapse
                    nc.weight[0] = weight
                    nc.delay = randDelays.pop()
                    nc.threshold = 10
                    self.nclist.append((preGid,postCell.gid,nc))
                    postCell.nclist.append((preGid,nc))
                    
                    print 'Created conn between pregid %d and postgid %d on host %d'%(preGid, postCell.gid, idhost) 


    def createStims(self):
        """Connect a spiking generator to the first cell to get
        the network going."""
        #### Only continue if the first cell is not on this host
        for cell in self.cells:
            cell.createStim(
                noise=self.stimParams['noise'], 
                rate=self.stimParams['rate'], 
                weight=self.stimParams['weight'],
                delay=self.stimParams['delay'])
            print 'Created stim on cell %d on host %d'%(cell.gid, idhost) 

    def gatherSpikes(self):
        """Gather spikes from all nodes/hosts"""
        if idhost==0: print 'Gathering spikes ...'
        data = [None]*nhost
        data[0] = {'tVec': self.tVec, 'idVec': self.idVec}
        pc.barrier()
        gather=pc.py_alltoall(data)
        pc.barrier()
        self.tVecAll = [] 
        self.idVecAll = [] 
        if idhost==0:
            for d in gather:
                self.tVecAll.extend(list(d['tVec']))
                self.idVecAll.extend(list(d['idVec']))

    def plotRaster(self):

        print 'Plotting raster ...'
        pyplot.figure()
        pyplot.scatter(self.tVecAll,self.idVecAll,marker=".",s=1,color='blue')
        pyplot.xlabel('Time (ms)')
        pyplot.ylabel('Cell ID')
        pyplot.title('Raster Plot of Network with 500 HH Cells')
        pyplot.xlim(0,max(self.tVecAll))
        pyplot.ylim(0,self.N)
        pyplot.show()
        #pyplot.savefig('raster')

    def saveData(self):
        print 'Savind data ...'
        dataSave = {'N': self.N, 'connParams': self.connParams, 'stimParams': self.stimParams, 'tVec': self.tVecAll, 'idVec': self.idVec}
        with open('output.pkl', 'wb') as f:
            pickle.dump(dataSave, f)

#### New ParallelContext object 
pc = h.ParallelContext()
pc.set_maxstep(10)
idhost = int(pc.id())
nhost = int(pc.nhost())

# set randomizer seed
randSeed = 1

# create network
net = Net(N=500, cellType=HHCell,
    connParams={'type': 'rand', 'weight': 0.004, 'delayMean': 13.0, 'delayVar': 1.4, 'delayMin': 0.2, 'maxConnsPerCell': 20}, 
    stimParams={'rate': 50, 'noise': 0.5, 'weight': 50, 'delay':5}) 

# set voltage recording for cell 0 in net
net.cells[0].setRecording()

# run sim and gather spikes
h.stdinit()
h.dt = 0.025
duration = 1000.0
if idhost==0: 
    print 'Running sim...'
    startTime = datetime.now() # store sim start time
pc.psolve(duration)  # actually run the sim in all nodes
if idhost==0: 
    runTime = (datetime.now() - startTime).total_seconds()  # calculate run time
    print "Run time for %d sec sim = %.2f sec"%(int(duration/1000.0), runTime)
net.gatherSpikes()  # gather spikes from all nodes onto master node

# plot net raster, save net data and plot cell 0 traces 
if idhost==0:
    net.plotRaster()
    net.saveData()
    net.cells[0].plotTraces()

pc.done()

Loading data, please wait...