Mean field model for Hodgkin Huxley networks of neurons (Carlu et al 2020)

 Download zip file 
Help downloading and running models
Accession:263259
"We present a mean-field formalism able to predict the collective dynamics of large networks of conductance-based interacting spiking neurons. We apply this formalism to several neuronal models, from the simplest Adaptive Exponential Integrate-and-Fire model to the more complex Hodgkin-Huxley and Morris-Lecar models. We show that the resulting mean-field models are capable of predicting the correct spontaneous activity of both excitatory and inhibitory neurons in asynchronous irregular regimes, typical of cortical dynamics. Moreover, it is possible to quantitatively predict the population response to external stimuli in the form of external spike trains. This mean-field formalism therefore provides a paradigm to bridge the scale between population dynamics and the microscopic complexity of the individual cells physiology."
Reference:
1 . Carlu M, Chehab O, Dalla Porta L, Depannemaecker D, Héricé C, Jedynak M, Köksal Ersöz E, Muratore P, Souihel S, Capone C, Zerlaut Y, Destexhe A, di Volo M (2020) A mean-field approach to the dynamics of networks of complex neurons, from nonlinear Integrate-and-Fire to Hodgkin-Huxley models. J Neurophysiol 123:1042-1051 [PubMed]
Citations  Citation Browser
Model Information (Click on a link to find other models with that property)
Model Type:
Brain Region(s)/Organism:
Cell Type(s):
Channel(s):
Gap Junctions:
Receptor(s):
Gene(s):
Transmitter(s):
Simulation Environment: Python;
Model Concept(s): Methods;
Implementer(s): di Volo, Matteo [matteo.di-volo at cyu.fr];
import sys
sys.path.append('../code/')

import numpy as np
import matplotlib.pylab as plt
sys.path.append('../')
from mean_field.master_equation import find_fixed_point
from transfer_functions.theoretical_tools import mean_and_var_conductance, get_fluct_regime_vars, pseq_params
from scipy.signal import gaussian
from single_cell_models.cell_library import get_neuron_params
from synapses_and_connectivity.syn_and_connec_library import get_connectivity_and_synapses_matrix
from transfer_functions.tf_simulation import reformat_syn_parameters

def gaussian_func(x, mean, std):
    return np.exp(-(x-mean)**2/2/std**2)/std/np.sqrt(2.*np.pi)

def skew(x):
    return np.abs(np.mean((x.mean()-x)**3))**(1./3.)#/x.std()**3)

def plot_ntwk_sim_output(time_array, rate_array, rate_exc, rate_inh,\
                         Raster_exc, Raster_inh,\
                         Vm_exc, Vm_inh, Ge_exc, Ge_inh, Gi_exc, Gi_inh,\
                         BIN=5, min_time=2000):
    
    
    
    BIN=15.
    
    
    cond_t = (time_array>min_time) # transient behavior after 400 ms

    params = get_neuron_params(args.CONFIG.split('--')[0], SI_units=True)
    M = get_connectivity_and_synapses_matrix('CONFIG1', SI_units=True)
    EXC_AFF = M[0,0]['ext_drive']
    
    print('starting fixed point')
    fe0, fi0, sfe, sfie, sfi = find_fixed_point(args.CONFIG.split('--')[0], args.CONFIG.split('--')[1], 'CONFIG1',\
                                                exc_aff=EXC_AFF, Ne=8000, Ni=2000, verbose=True)
    print('end fixed point')
    
    reformat_syn_parameters(params, M) # merging those parameters
    
    xfe = fe0+np.linspace(-4,4)*sfe
    fe_pred = gaussian_func(xfe, fe0, sfe)
    xfi = fi0+np.linspace(-4,4)*sfi
    fi_pred = gaussian_func(xfi, fi0, sfi)

    mGe, mGi, sGe, sGi = mean_and_var_conductance(fe0+EXC_AFF, fi0, *pseq_params(params))
    muV, sV, muGn, TvN = get_fluct_regime_vars(fe0+EXC_AFF, fi0, *pseq_params(params))
    print("eee",muV)
    FE, FI = np.meshgrid(xfe, xfi)
    pFE, pFI = np.meshgrid(fe_pred, fi_pred)
    MUV, SV, _, _ = get_fluct_regime_vars(FE+EXC_AFF, FI, *pseq_params(params))*pFE*pFI/np.sum(pFE*pFI)
    
    ### MEMBRANE POTENTIAL
    MEAN_VM, STD_VM, KYRT_VM = [], [], []
    for i in range(len(Vm_exc)):
        MEAN_VM.append(Vm_exc[i][(time_array>min_time) & (Vm_exc[i]!=-65) & (Vm_exc[i]<-50)].mean())
        MEAN_VM.append(Vm_inh[i][(time_array>min_time) & (Vm_inh[i]!=-65) & (Vm_inh[i]<-50)].mean())
        for vv in [Vm_exc[i][(time_array>min_time)], Vm_inh[i][(time_array>min_time)]]:
            i0 = np.where((vv[:-1]>-52) & (vv[1:]<-60))[0]
            print(i0)
            sv = []
            if len(i0)==0:
                STD_VM.append(vv.std())
            elif len(i0)==1:
                STD_VM.append(vv[0].std())
            else:
                for i1, i2 in zip(i0[:-1], i0[1:]):
                    if i2-i1>60:
                        sv.append(vv[i1+30:i2-30].std())
                STD_VM.append(np.array(sv).mean())
        STD_VM.append(Vm_inh[i][(time_array>min_time) & (Vm_inh[i]<-50)].std())

    fig1, AX1 = plt.subplots(1, 3, figsize=(3,2)) # for means
    fig2, AX2 = plt.subplots(1, 3, figsize=(3,2)) # for std
    
    AX1[0].bar([0], np.array(MEAN_VM).mean()+65, yerr=np.array(MEAN_VM).std(), color='w', edgecolor='k', lw=3, error_kw=dict(elinewidth=3,ecolor='k'))
    AX2[0].bar([0], np.array(STD_VM).mean(), yerr=np.array(STD_VM).std(), color='w', edgecolor='k', lw=3, error_kw=dict(elinewidth=3,ecolor='k'), label='$V_m$')
    AX1[0].bar([1], [1e3*muV+65], color='gray', alpha=.5, label='$V_m$')
    AX2[0].bar([1], [1e3*sV], color='gray', alpha=.5)

    #set_plot(AX1[0], ['left'], xticks=[], ylim=[0,11], yticks=[0, 5, 10], yticks_labels=['-65', '-60', '-55'], ylabel='mean (mV)')
    #set_plot(AX2[0], ['left'], xticks=[], ylim=[0,5], yticks=[0, 2, 4], ylabel='std. dev. (mV)')

    
    ### EXCITATORY CONDUCTANCE
    MEAN_GE, STD_GE, KYRT_GE = [], [], []
    for i in range(len(Ge_exc)):
        MEAN_GE.append(Ge_exc[i][(time_array>min_time)].mean())
        MEAN_GE.append(Ge_inh[i][(time_array>min_time)].mean())
        STD_GE.append(Ge_exc[i][(time_array>min_time)].std())
        STD_GE.append(Ge_inh[i][(time_array>min_time)].std())

    AX1[1].bar([0], np.array(MEAN_GE).mean(), yerr=np.array(MEAN_GE).std(), color='w', edgecolor='g', lw=3, error_kw=dict(elinewidth=3,ecolor='g'), label='num. sim.')
    AX2[1].bar([0], np.array(STD_GE).mean(), yerr=np.array(STD_GE).std(), color='w', edgecolor='g', lw=3, error_kw=dict(elinewidth=3,ecolor='g'), label='exc.')
    AX1[1].bar([1], [1e9*mGe], color='g', label='mean field \n pred.')
    AX2[1].bar([1], [1e9*sGe], color='g', label='exc.')
    
    #set_plot(AX1[1], ['left'], xticks=[], yticks=[0,15,30], ylabel='mean (nS)')
    #set_plot(AX2[1], ['left'], xticks=[], yticks=[0,5,10], ylabel='std. dev. (nS)')
    
    ### INHIBITORY CONDUCTANCE
    MEAN_GI, STD_GI, KYRT_GI = [], [], []
    for i in range(len(Gi_exc)):
        MEAN_GI.append(Gi_exc[i][(time_array>min_time)].mean())
        MEAN_GI.append(Gi_inh[i][(time_array>min_time)].mean())
        STD_GI.append(Gi_exc[i][(time_array>min_time)].std())
        STD_GI.append(Gi_inh[i][(time_array>min_time)].std())

    AX1[2].bar([0], np.array(MEAN_GI).mean(), yerr=np.array(MEAN_GI).std(), color='w', edgecolor='r', lw=3, error_kw=dict(elinewidth=3,ecolor='r'), label='num. sim.')
    AX2[2].bar([0], np.array(STD_GI).mean(), yerr=np.array(STD_GI).std(), color='w', edgecolor='r', lw=3, error_kw=dict(elinewidth=3,ecolor='r'), label='inh.')
    AX1[2].bar([1], [1e9*mGi], color='r', label='mean field \n pred.')
    AX2[2].bar([1], [1e9*sGi], color='r', label='inh.')
    
    #set_plot(AX1[2], ['left'], xticks=[], yticks=[0,15,30], ylabel='mean (nS)')
    #set_plot(AX2[2], ['left'], xticks=[], yticks=[0,5,10], ylabel='std. dev. (nS)')

    ### POPULATION RATE ###
    
    fig, ax = plt.subplots(figsize=(4,3))
    # we bin the population rate
    N0 = int(BIN/(time_array[1]-time_array[0]))
    N1 = int((time_array[cond_t][-1]-time_array[cond_t][0])/BIN)
    time_array = time_array[cond_t][:N0*N1].reshape((N1,N0)).mean(axis=1)
    rate_exc = rate_exc[cond_t][:N0*N1].reshape((N1,N0)).mean(axis=1)
    rate_inh = rate_inh[cond_t][:N0*N1].reshape((N1,N0)).mean(axis=1)
    rate_array = rate_array[cond_t][:N0*N1].reshape((N1,N0)).mean(axis=1)

    hh, bb = np.histogram(rate_exc, bins=3, normed=True)
    ax.bar(.5*(bb[:-1]+bb[1:]), hh, color='w', width=bb[1]-bb[0], edgecolor='g', lw=3, label='exc.', alpha=.7)
    hh, bb = np.histogram(rate_inh, bins=5, normed=True)
    ax.bar(.5*(bb[:-1]+bb[1:]), hh, color='w', width=bb[1]-bb[0], edgecolor='r', lw=3, label='inh', alpha=.7)

    ax.fill_between(xfe, 0*fe_pred, fe_pred, color='g', alpha=.8)
    ax.fill_between(xfi, 0*fi_pred, fi_pred, color='r', alpha=.8)




    #set_plot(plt.gca(), ['bottom', 'left'], xlabel='pop. activity (Hz)', yticks=[], ylabel='density')

    for ax in AX1: ax.legend()
    for ax in AX2: ax.legend()
    
    return [fig, fig1, fig2]

from plot_single_sim import bin_array
from mean_field.euler_method import run_mean_field_extended,run_mean_field
from waveform_input import double_gaussian
from transfer_functions.theoretical_tools import get_fluct_regime_vars, pseq_params

def plot_ntwk_sim_output_for_waveform(args,\
                                      BIN=5, min_time=200, bar_ms=50,
                                      zoom_conditions=None, vpeak=-35,\
                                      raster_number=400):
    
    
    BIN=8
    time_array, rate_array, rate_exc, rate_inh,\
        Raster_exc, Raster_inh,\
        Vm_exc, Vm_inh, Ge_exc, Ge_inh, Gi_exc, Gi_inh = np.load(args.file,allow_pickle = True)
    
    rate_exc=rate_exc+0.2
    print("ee")

    if zoom_conditions is not None:
        z = zoom_conditions
    else:
        z = [time_array[0],time_array[-1]]
    cond_t = (time_array>z[0]) & (time_array<z[1])

    ###### =========================================== ######
    ############# adding the theoretical eval ###############
    ###### =========================================== ######
    
    t0 = args.t0-4*args.T1
    def rate_func(t):
        return double_gaussian(t, 1e-3*args.t0, 1e-3*args.T1, 1e-3*args.T2, args.amp)

    t, fe, fi, sfe, sfei, sfi = run_mean_field_extended(args.CONFIG.split('--')[0],\
                               args.CONFIG.split('--')[1],args.CONFIG.split('--')[2],\
                               rate_func,dt=5e-4,
                               tstop=args.tstop*1e-3)
                               
                               
    tfirst, fefirst, fifirst= run_mean_field(args.CONFIG.split('--')[0], args.CONFIG.split('--')[1],args.CONFIG.split('--')[2],rate_func,dt=5e-4,tstop=args.tstop*1e-3)
                               
                               

    params = get_neuron_params(args.CONFIG.split('--')[0], SI_units=True)
    M = get_connectivity_and_synapses_matrix('CONFIG1', SI_units=True)
    reformat_syn_parameters(params, M) # merging those parameters
    ext_drive = M[0,0]['ext_drive']
    afferent_exc_fraction = M[0,0]['afferent_exc_fraction']

    fe_e, fe_i = fe+ext_drive+rate_func(t), fe+ext_drive
    muV_e, sV_e, muGn_e, TvN_e = get_fluct_regime_vars(fe_e, fi, *pseq_params(params))
    muV_i, sV_i, muGn_i, TvN_i = get_fluct_regime_vars(fe_i, fi, *pseq_params(params))
    
    Ne = Raster_inh[1].min()

    FIGS = []
    # plotting 
    FIGS.append(plt.figure(figsize=(5,4)))
    cond = (Raster_exc[0]>z[0]) & (Raster_exc[0]<z[1])& (Raster_exc[1]>Ne-raster_number)
    plt.plot(Raster_exc[0][cond], Raster_exc[1][cond], '.g')
    cond = (Raster_inh[0]>z[0]) & (Raster_inh[0]<z[1])& (Raster_inh[1]<Ne+.2*raster_number)
    plt.plot(Raster_inh[0][cond], Raster_inh[1][cond], '.r')
    '''
    plt.plot([z[0],z[0]+bar_ms], [Ne, Ne], 'k-', lw=5)
    plt.annotate(str(bar_ms)+'ms', (z[0]+bar_ms, Ne))
    '''
    #set_plot(plt.gca(), ['left'], ylabel='Neuron index', xticks=[])

    FIGS.append(plt.figure(figsize=(5,5)))
    ax1 = plt.subplot2grid((3,1), (0,0), rowspan=2)
    ax2 = plt.subplot2grid((3,1), (2,0))
    for i in range(len(Vm_exc)):
        ax1.plot(time_array[cond_t], Vm_exc[i][cond_t], 'g-', lw=.5)
        spk = np.where( (Vm_exc[i][cond_t][:-1]>-50) & (Vm_exc[i][cond_t][1:]<-55))[0]
        for ispk in spk: ax1.plot(time_array[cond_t][ispk]*np.ones(2), [Vm_exc[i][cond_t][ispk],vpeak], 'g--')
        ax2.plot(time_array[cond_t], Ge_exc[i][cond_t], 'g-', lw=.5)
        ax2.plot(time_array[cond_t], Gi_exc[i][cond_t], 'r-', lw=.5)
    # vm
    ax1.plot(1e3*t[1e3*t>t0], 1e3*muV_e[1e3*t>t0], 'g-', lw=2)
    ax1.fill_between(1e3*t[1e3*t>t0], 1e3*(muV_e[1e3*t>t0]-sV_e[1e3*t>t0]),\
                     1e3*(muV_e[1e3*t>t0]+sV_e[1e3*t>t0]), color='g', alpha=.4)
    # ge
    ge_th = 1e9*fe_e[1e3*t>t0]*params['Qe']*params['Te']*(1-params['gei'])*params['pconnec']*params['Ntot']
    sge_th = 1e9*params['Qe']*np.sqrt(fe_e[1e3*t>t0]*params['Te']*(1-params['gei'])*params['pconnec']*params['Ntot'])
    ax2.plot(1e3*t[1e3*t>t0], ge_th, 'g-', lw=2)
    ax2.fill_between(1e3*t[1e3*t>t0], ge_th-sge_th, ge_th+sge_th, color='g', alpha=.4)
    gi_th = 1e9*fi[1e3*t>t0]*params['Qi']*params['Ti']*params['gei']*params['pconnec']*params['Ntot']
    sgi_th = 1e9*params['Qi']*np.sqrt(fi[1e3*t>t0]*params['Ti']*params['gei']*params['pconnec']*params['Ntot'])
    ax2.plot(1e3*t[1e3*t>t0], gi_th, 'r-', lw=2)
    ax2.fill_between(1e3*t[1e3*t>t0], gi_th-sgi_th, gi_th+sgi_th, color='r', alpha=.4)

    
    ax1.plot(time_array[cond_t][0]*np.ones(2), [-65, -55], lw=4, color='k')
    ax1.annotate('10mV', (time_array[cond_t][0], -65))
    #set_plot(ax1, [], xticks=[], yticks=[])
    #set_plot(ax2, ['left'], ylabel='$G$ (nS)', xticks=[])

    FIGS.append(plt.figure(figsize=(5,5)))
    ax1 = plt.subplot2grid((3,1), (0,0), rowspan=2)
    ax2 = plt.subplot2grid((3,1), (2,0))
    for i in range(len(Vm_inh)):
        ax1.plot(time_array[cond_t], Vm_inh[i][cond_t], 'r-', lw=.5)
        spk = np.where( (Vm_inh[i][cond_t][:-1]>-51) & (Vm_inh[i][cond_t][1:]<-55))[0]
        for ispk in spk: ax1.plot(time_array[cond_t][ispk]*np.ones(2), [Vm_inh[i][cond_t][ispk],vpeak], 'r--')
        ax2.plot(time_array[cond_t], Ge_inh[i][cond_t], 'g-', lw=.5)
        ax2.plot(time_array[cond_t], Gi_inh[i][cond_t], 'r-', lw=.5)
        
    # vm
    ax1.plot(1e3*t[1e3*t>t0], 1e3*muV_i[1e3*t>t0], 'r-', lw=2)
    ax1.fill_between(1e3*t[1e3*t>t0], 1e3*(muV_i[1e3*t>t0]-sV_i[1e3*t>t0]),\
                     1e3*(muV_i[1e3*t>t0]+sV_i[1e3*t>t0]), color='r', alpha=.4)
    ge_th = 1e9*fe_i[1e3*t>t0]*params['Qe']*params['Te']*(1-params['gei'])*params['pconnec']*params['Ntot']
    sge_th = 1e9*params['Qe']*np.sqrt(fe_i[1e3*t>t0]*params['Te']*(1-params['gei'])*params['pconnec']*params['Ntot'])
    ax2.plot(1e3*t[1e3*t>t0], ge_th, 'g-', lw=2)
    ax2.fill_between(1e3*t[1e3*t>t0], ge_th-sge_th, ge_th+sge_th, color='g', alpha=.4)
    ax2.plot(1e3*t[1e3*t>t0], gi_th, 'r-', lw=2)
    ax2.fill_between(1e3*t[1e3*t>t0], gi_th-sgi_th, gi_th+sgi_th, color='r', alpha=.4)
    
    ax1.plot(time_array[cond_t][0]*np.ones(2), [-65, -55], lw=4, color='k')
    ax1.annotate('10mV', (time_array[cond_t][0], -65))
    #set_plot(ax1, [], xticks=[], yticks=[])
    #set_plot(ax2, ['left'], ylabel='$G$ (nS)', xticks=[])

    fig, AX = plt.subplots(figsize=(5,3))
    # we bin the population rate
    rate_exc = bin_array(rate_exc[cond_t], BIN, time_array[cond_t])
    rate_inh = bin_array(rate_inh[cond_t], BIN, time_array[cond_t])
    rate_array = bin_array(rate_array[cond_t], BIN, time_array[cond_t])
    time_array = bin_array(time_array[cond_t], BIN, time_array[cond_t])
    
    AX.plot(time_array, rate_exc, 'g-', lw=.6, label='$\\nu_e(t)$')
    
    
    
    AX.plot(time_array, rate_inh, 'r-', lw=.6, label='$\\nu_i(t)$')
    AX.plot(1000*tfirst, fifirst, 'r--', lw=.6, label='$\\nu_i(t)$')
    
    np.save('HH_response',[1000*tfirst,fefirst,fifirst,t,fe,fi,sfe,sfi,rate_func(tfirst)])
    

    
    #AX.plot(time_array, .8*rate_exc+.2*rate_inh, 'k-', lw=2, label='$\\nu(t)$')

    AX.plot(1e3*t[1e3*t>t0], fi[1e3*t>t0], 'r-', label='num. sim.')
    AX.plot(1e3*t[1e3*t>t0], rate_func(t[1e3*t>t0]), 'k:',\
            lw=2, label='$\\nu_e^{aff}(t)$ \n $\\nu_e^{drive}$')
    AX.plot(1e3*t[1e3*t>t0], fe[1e3*t>t0], 'g-', label='mean field \n pred.')
    AX.fill_between(1e3*t[1e3*t>t0], fe[1e3*t>t0]-sfe[1e3*t>t0], fe[1e3*t>t0]+sfe[1e3*t>t0],\
                        color='g', alpha=.3, label='mean field \n pred.')

    AX.fill_between(1e3*t[1e3*t>t0], fi[1e3*t>t0]-sfi[1e3*t>t0], fi[1e3*t>t0]+sfi[1e3*t>t0],\
                        color='r', alpha=.3)
    # AX.plot(1e3*t[1e3*t>t0], .8*fe[1e3*t>t0]+.2*fi[1e3*t>t0], 'k-', label='..')

    AX.legend(prop={'size':'xx-small'})
    
    #set_plot(AX, ['left'], ylabel='$\\nu$ (Hz)', xticks=[], num_yticks=3)
    FIGS.append(fig)

    print('excitatory rate: ', rate_exc[len(rate_exc)/2:].mean(), 'Hz')
    print('inhibitory rate: ', rate_inh[len(rate_exc)/2:].mean(), 'Hz')
    print(sfe[0], sfe[-1])
    plt.show()
    return FIGS
    

if __name__=='__main__':


    import argparse
    # First a nice documentation 
    parser=argparse.ArgumentParser(description=
     """ 
     ----------------------------------------------------------------------
     Run the a network simulation using brian2

     Choose CELLULAR and NTWK PARAMETERS from the available libraries
     see  ../synapses_and_connectivity.syn_and_connec_library.py for the CELLS
     see ../synapses_and_connectivity.syn_and_connec_library.py for the NTWK

     Then construct the input as "NRN_exc--NRN_inh--NTWK"
     example: "LIF--LIF--Vogels-Abbott"
     ----------------------------------------------------------------------
     """
    ,formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument("--CONFIG",help="Cell and Network configuration !", default='RS-cell--FS-cell--CONFIG1')
    parser.add_argument("-f", "--file",help="filename for saving", default='data/example.npy')
    parser.add_argument("-z", "--zoom",help="zoom for activity", type=float, nargs=2)
    parser.add_argument("-b", "--bar_ms",help="bar for legend", type=int, default=100)
    parser.add_argument("-r", "--raster_number",help="max neuron number", type=int, default=10000)
    parser.add_argument("-t_after_transient", type=float, default=500)
    parser.add_argument("-s", "--save",action='store_true')

    args = parser.parse_args()
    
    time_array, rate_array, rate_exc, rate_inh,\
        Raster_exc, Raster_inh, Vm_exc, Vm_inh,\
        Ge_exc, Ge_inh, Gi_exc, Gi_inh = np.load(args.file,allow_pickle = True)

    FIG = plot_ntwk_sim_output(time_array, rate_array, rate_exc, rate_inh,\
                               Raster_exc, Raster_inh,\
                               Vm_exc, Vm_inh, Ge_exc, Ge_inh, Gi_exc, Gi_inh,\
                               min_time=args.t_after_transient)

    if args.save:
        put_list_of_figs_to_svg_fig(FIG, visualize=False)
    else:
        plt.show()